├── .Rbuildignore ├── .github ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ └── issue_template.md ├── SUPPORT.md └── workflows │ ├── R-CMD-check-hard.yaml │ ├── R-CMD-check.yaml │ ├── lock.yaml │ ├── pkgdown.yaml │ ├── pr-commands.yaml │ └── test-coverage.yaml ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── DESCRIPTION ├── LICENSE ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── 0_imports.R ├── aorsf_data.R ├── bonsai_package.R ├── import-standalone-obj-type.R ├── import-standalone-types-check.R ├── lightgbm.R ├── lightgbm_data.R ├── partykit_data.R └── zzz.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── air.toml ├── bonsai.Rproj ├── codecov.yml ├── cran-comments.md ├── inst └── figs │ └── hex.svg ├── man ├── bonsai-package.Rd ├── figures │ └── logo.png ├── lightgbm_helpers.Rd ├── reexports.Rd └── train_lightgbm.Rd ├── pkgdown └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── tests ├── testthat.R └── testthat │ ├── _snaps │ ├── lightgbm.md │ └── partykit.md │ ├── test-aorsf.R │ ├── test-lightgbm.R │ └── test-partykit.R └── vignettes ├── .gitignore └── bonsai.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^\.github$ 4 | ^codecov\.yml$ 5 | ^_pkgdown\.yml$ 6 | ^docs$ 7 | ^pkgdown$ 8 | README.Rmd 9 | cran-comments.md 10 | ^LICENSE\.md$ 11 | inst/figs/ 12 | ^README\.Rmd$ 13 | ^[\.]?air\.toml$ 14 | ^\.vscode$ 15 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at codeofconduct@posit.co. 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | . 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][https://github.com/mozilla/inclusion]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | . Translations are available at . 125 | 126 | [homepage]: https://www.contributor-covenant.org 127 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to bonsai 2 | 3 | This outlines how to propose a change to bonsai. 4 | For more detailed info about contributing to this, and other tidyverse packages, please see the 5 | [**development contributing guide**](https://rstd.io/tidy-contrib). 6 | 7 | ## Fixing typos 8 | 9 | You can fix typos, spelling mistakes, or grammatical errors in the documentation directly using the GitHub web interface, as long as the changes are made in the _source_ file. 10 | This generally means you'll need to edit [roxygen2 comments](https://roxygen2.r-lib.org/articles/roxygen2.html) in an `.R`, not a `.Rd` file. 11 | You can find the `.R` file that generates the `.Rd` by reading the comment in the first line. 12 | 13 | ## Bigger changes 14 | 15 | If you want to make a bigger change, it's a good idea to first file an issue and make sure someone from the team agrees that it’s needed. 16 | If you’ve found a bug, please file an issue that illustrates the bug with a minimal 17 | [reprex](https://www.tidyverse.org/help/#reprex) (this will also help you write a unit test, if needed). 18 | 19 | ### Pull request process 20 | 21 | * Fork the package and clone onto your computer. If you haven't done this before, we recommend using `usethis::create_from_github("tidymodels/bonsai", fork = TRUE)`. 22 | 23 | * Install all development dependencies with `devtools::install_dev_deps()`, and then make sure the package passes R CMD check by running `devtools::check()`. 24 | If R CMD check doesn't pass cleanly, it's a good idea to ask for help before continuing. 25 | * Create a Git branch for your pull request (PR). We recommend using `usethis::pr_init("brief-description-of-change")`. 26 | 27 | * Make your changes, commit to git, and then create a PR by running `usethis::pr_push()`, and following the prompts in your browser. 28 | The title of your PR should briefly describe the change. 29 | The body of your PR should contain `Fixes #issue-number`. 30 | 31 | * For user-facing changes, add a bullet to the top of `NEWS.md` (i.e. just below the first header). Follow the style described in . 32 | 33 | ### Code style 34 | 35 | * New code should follow the tidyverse [style guide](https://style.tidyverse.org). 36 | You can use the [styler](https://CRAN.R-project.org/package=styler) package to apply these styles, but please don't restyle code that has nothing to do with your PR. 37 | 38 | * We use [roxygen2](https://cran.r-project.org/package=roxygen2), with [Markdown syntax](https://cran.r-project.org/web/packages/roxygen2/vignettes/rd-formatting.html), for documentation. 39 | 40 | * We use [testthat](https://cran.r-project.org/package=testthat) for unit tests. 41 | Contributions with test cases included are easier to accept. 42 | 43 | ## Code of Conduct 44 | 45 | Please note that the bonsai project is released with a 46 | [Contributor Code of Conduct](CODE_OF_CONDUCT.md). By contributing to this 47 | project you agree to abide by its terms. 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/issue_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report or feature request 3 | about: Describe a bug you've seen or make a case for a new feature 4 | --- 5 | 6 | Please briefly describe your problem and what output you expect. If you have a question, please don't use this form. Instead, ask on or . 7 | 8 | Please include a minimal reproducible example (AKA a reprex). If you've never heard of a [reprex](http://reprex.tidyverse.org/) before, start by reading . 9 | 10 | Brief description of the problem 11 | 12 | ```r 13 | # insert reprex here 14 | ``` 15 | -------------------------------------------------------------------------------- /.github/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Getting help with bonsai 2 | 3 | Thanks for using bonsai! 4 | Before filing an issue, there are a few places to explore and pieces to put together to make the process as smooth as possible. 5 | 6 | ## Make a reprex 7 | 8 | Start by making a minimal **repr**oducible **ex**ample using the [reprex](https://reprex.tidyverse.org/) package. 9 | If you haven't heard of or used reprex before, you're in for a treat! 10 | Seriously, reprex will make all of your R-question-asking endeavors easier (which is a pretty insane ROI for the five to ten minutes it'll take you to learn what it's all about). 11 | For additional reprex pointers, check out the [Get help!](https://www.tidyverse.org/help/) section of the tidyverse site. 12 | 13 | ## Where to ask? 14 | 15 | Armed with your reprex, the next step is to figure out [where to ask](https://www.tidyverse.org/help/#where-to-ask). 16 | 17 | * If it's a question: start with [community.rstudio.com](https://community.rstudio.com/), and/or StackOverflow. There are more people there to answer questions. 18 | 19 | * If it's a bug: you're in the right place, [file an issue](https://github.com/tidymodels/bonsai/issues/new). 20 | 21 | * If you're not sure: let the community help you figure it out! 22 | If your problem _is_ a bug or a feature request, you can easily return here and report it. 23 | 24 | Before opening a new issue, be sure to [search issues and pull requests](https://github.com/tidymodels/bonsai/issues) to make sure the bug hasn't been reported and/or already fixed in the development version. 25 | By default, the search will be pre-populated with `is:issue is:open`. 26 | You can [edit the qualifiers](https://help.github.com/articles/searching-issues-and-pull-requests/) (e.g. `is:pr`, `is:closed`) as needed. 27 | For example, you'd simply remove `is:open` to search _all_ issues in the repo, open or closed. 28 | 29 | ## What happens next? 30 | 31 | To be as efficient as possible, development of tidyverse packages tends to be very bursty, so you shouldn't worry if you don't get an immediate response. 32 | Typically we don't look at a repo until a sufficient quantity of issues accumulates, then there’s a burst of intense activity as we focus our efforts. 33 | That makes development more efficient because it avoids expensive context switching between problems, at the cost of taking longer to get back to you. 34 | This process makes a good reprex particularly important because it might be multiple months between your initial report and when we start working on it. 35 | If we can’t reproduce the bug, we can’t fix it! 36 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check-hard.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: [main] 4 | pull_request: 5 | branches: [main] 6 | 7 | name: R-CMD-check-hard 8 | 9 | jobs: 10 | R-CMD-check: 11 | runs-on: ${{ matrix.config.os }} 12 | 13 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 14 | 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | config: 19 | - {os: ubuntu-latest, r: 'release'} 20 | 21 | env: 22 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 23 | R_KEEP_PKG_SOURCE: yes 24 | 25 | steps: 26 | - uses: actions/checkout@v2 27 | 28 | - uses: r-lib/actions/setup-pandoc@v2 29 | 30 | - uses: r-lib/actions/setup-r@v2 31 | with: 32 | r-version: ${{ matrix.config.r }} 33 | http-user-agent: ${{ matrix.config.http-user-agent }} 34 | use-public-rspm: true 35 | 36 | - uses: r-lib/actions/setup-r-dependencies@v2 37 | with: 38 | dependencies: '"hard"' 39 | cache: false 40 | extra-packages: | 41 | any::rcmdcheck 42 | any::testthat 43 | any::knitr 44 | any::rmarkdown 45 | needs: check 46 | 47 | - uses: r-lib/actions/check-r-package@v2 48 | with: 49 | upload-snapshots: true 50 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | # 4 | # NOTE: This workflow is overkill for most R packages and 5 | # check-standard.yaml is likely a better choice. 6 | # usethis::use_github_action("check-standard") will install it. 7 | on: 8 | push: 9 | branches: [main, master] 10 | pull_request: 11 | 12 | name: R-CMD-check.yaml 13 | 14 | permissions: read-all 15 | 16 | jobs: 17 | R-CMD-check: 18 | runs-on: ${{ matrix.config.os }} 19 | 20 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | config: 26 | - {os: macos-latest, r: 'release'} 27 | 28 | - {os: windows-latest, r: 'release'} 29 | # use 4.0 or 4.1 to check with rtools40's older compiler 30 | - {os: windows-latest, r: 'oldrel-4'} 31 | 32 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 33 | - {os: ubuntu-latest, r: 'release'} 34 | - {os: ubuntu-latest, r: 'oldrel-1'} 35 | - {os: ubuntu-latest, r: 'oldrel-2'} 36 | - {os: ubuntu-latest, r: 'oldrel-3'} 37 | - {os: ubuntu-latest, r: 'oldrel-4'} 38 | 39 | env: 40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 41 | R_KEEP_PKG_SOURCE: yes 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | 46 | - uses: r-lib/actions/setup-pandoc@v2 47 | 48 | - uses: r-lib/actions/setup-r@v2 49 | with: 50 | r-version: ${{ matrix.config.r }} 51 | http-user-agent: ${{ matrix.config.http-user-agent }} 52 | use-public-rspm: true 53 | 54 | - uses: r-lib/actions/setup-r-dependencies@v2 55 | with: 56 | extra-packages: any::rcmdcheck 57 | needs: check 58 | 59 | - uses: r-lib/actions/check-r-package@v2 60 | with: 61 | upload-snapshots: true 62 | build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")' 63 | -------------------------------------------------------------------------------- /.github/workflows/lock.yaml: -------------------------------------------------------------------------------- 1 | name: 'Lock Threads' 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | jobs: 8 | lock: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: dessant/lock-threads@v2 12 | with: 13 | github-token: ${{ github.token }} 14 | issue-lock-inactive-days: '14' 15 | # issue-exclude-labels: '' 16 | # issue-lock-labels: 'outdated' 17 | issue-lock-comment: > 18 | This issue has been automatically locked. If you believe you have 19 | found a related problem, please file a new issue (with a reprex: 20 | ) and link to this issue. 21 | issue-lock-reason: '' 22 | pr-lock-inactive-days: '14' 23 | # pr-exclude-labels: 'wip' 24 | pr-lock-labels: '' 25 | pr-lock-comment: > 26 | This pull request has been automatically locked. If you believe you 27 | have found a related problem, please file a new issue (with a reprex: 28 | ) and link to this issue. 29 | pr-lock-reason: '' 30 | # process-only: 'issues' 31 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | 11 | name: pkgdown.yaml 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | pkgdown: 17 | runs-on: ubuntu-latest 18 | # Only restrict concurrency for non-PR jobs 19 | concurrency: 20 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 21 | env: 22 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 23 | permissions: 24 | contents: write 25 | steps: 26 | - uses: actions/checkout@v4 27 | 28 | - uses: r-lib/actions/setup-pandoc@v2 29 | 30 | - uses: r-lib/actions/setup-r@v2 31 | with: 32 | use-public-rspm: true 33 | 34 | - uses: r-lib/actions/setup-r-dependencies@v2 35 | with: 36 | extra-packages: any::pkgdown, local::. 37 | needs: website 38 | 39 | - name: Build site 40 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 41 | shell: Rscript {0} 42 | 43 | - name: Deploy to GitHub pages 🚀 44 | if: github.event_name != 'pull_request' 45 | uses: JamesIves/github-pages-deploy-action@v4.5.0 46 | with: 47 | clean: false 48 | branch: gh-pages 49 | folder: docs 50 | -------------------------------------------------------------------------------- /.github/workflows/pr-commands.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | name: pr-commands.yaml 8 | 9 | permissions: read-all 10 | 11 | jobs: 12 | document: 13 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/document') }} 14 | name: document 15 | runs-on: ubuntu-latest 16 | env: 17 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 18 | permissions: 19 | contents: write 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: r-lib/actions/pr-fetch@v2 24 | with: 25 | repo-token: ${{ secrets.GITHUB_TOKEN }} 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::roxygen2 34 | needs: pr-document 35 | 36 | - name: Document 37 | run: roxygen2::roxygenise() 38 | shell: Rscript {0} 39 | 40 | - name: commit 41 | run: | 42 | git config --local user.name "$GITHUB_ACTOR" 43 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 44 | git add man/\* NAMESPACE 45 | git commit -m 'Document' 46 | 47 | - uses: r-lib/actions/pr-push@v2 48 | with: 49 | repo-token: ${{ secrets.GITHUB_TOKEN }} 50 | 51 | style: 52 | if: ${{ github.event.issue.pull_request && (github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') && startsWith(github.event.comment.body, '/style') }} 53 | name: style 54 | runs-on: ubuntu-latest 55 | env: 56 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 57 | permissions: 58 | contents: write 59 | steps: 60 | - uses: actions/checkout@v4 61 | 62 | - uses: r-lib/actions/pr-fetch@v2 63 | with: 64 | repo-token: ${{ secrets.GITHUB_TOKEN }} 65 | 66 | - uses: r-lib/actions/setup-r@v2 67 | 68 | - name: Install dependencies 69 | run: install.packages("styler") 70 | shell: Rscript {0} 71 | 72 | - name: Style 73 | run: styler::style_pkg() 74 | shell: Rscript {0} 75 | 76 | - name: commit 77 | run: | 78 | git config --local user.name "$GITHUB_ACTOR" 79 | git config --local user.email "$GITHUB_ACTOR@users.noreply.github.com" 80 | git add \*.R 81 | git commit -m 'Style' 82 | 83 | - uses: r-lib/actions/pr-push@v2 84 | with: 85 | repo-token: ${{ secrets.GITHUB_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package") 36 | ) 37 | print(cov) 38 | covr::to_cobertura(cov) 39 | shell: Rscript {0} 40 | 41 | - uses: codecov/codecov-action@v5 42 | with: 43 | # Fail if error if not on PR, or if on PR and token is given 44 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 45 | files: ./cobertura.xml 46 | plugins: noop 47 | disable_search: true 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | 50 | - name: Show testthat output 51 | if: always() 52 | run: | 53 | ## -------------------------------------------------------------------- 54 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 55 | shell: bash 56 | 57 | - name: Upload test results 58 | if: failure() 59 | uses: actions/upload-artifact@v4 60 | with: 61 | name: coverage-test-failures 62 | path: ${{ runner.temp }}/package 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | .Rproj.user 3 | .Rhistory 4 | .RData 5 | .Ruserdata 6 | docs 7 | inst/doc 8 | lib/ 9 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "Posit.air-vscode" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[r]": { 3 | "editor.formatOnSave": true, 4 | "editor.defaultFormatter": "Posit.air-vscode" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: bonsai 2 | Title: Model Wrappers for Tree-Based Models 3 | Version: 0.3.2.9000 4 | Authors@R: c( 5 | person("Daniel", "Falbel", , "dfalbel@curso-r.com", role = "aut"), 6 | person("Athos", "Damiani", , "adamiani@curso-r.com", role = "aut"), 7 | person("Roel M.", "Hogervorst", , "hogervorst.rm@gmail.com", role = "aut", 8 | comment = c(ORCID = "0000-0001-7509-0328")), 9 | person("Max", "Kuhn", , "max@posit.co", role = "aut", 10 | comment = c(ORCID = "0000-0003-2402-136X")), 11 | person("Simon", "Couch", , "simon.couch@posit.co", role = c("aut", "cre"), 12 | comment = c(ORCID = "0000-0001-5676-5107")), 13 | person("Posit Software, PBC", role = c("cph", "fnd"), 14 | comment = c(ROR = "03wc8by49")) 15 | ) 16 | Description: Bindings for additional tree-based model engines for use with 17 | the 'parsnip' package. Models include gradient boosted decision trees 18 | with 'LightGBM' (Ke et al, 2017.), conditional inference trees and 19 | conditional random forests with 'partykit' (Hothorn and Zeileis, 2015. 20 | and Hothorn et al, 2006. ), and 21 | accelerated oblique random forests with 'aorsf' (Jaeger et al, 2022 22 | ). 23 | License: MIT + file LICENSE 24 | URL: https://bonsai.tidymodels.org/, https://github.com/tidymodels/bonsai 25 | BugReports: https://github.com/tidymodels/bonsai/issues 26 | Depends: 27 | parsnip (>= 1.0.1), 28 | R (>= 4.1) 29 | Imports: 30 | cli, 31 | dials, 32 | dplyr, 33 | glue, 34 | purrr, 35 | rlang (>= 1.1.0), 36 | stats, 37 | tibble, 38 | utils, 39 | withr 40 | Suggests: 41 | aorsf (>= 0.1.5), 42 | covr, 43 | knitr, 44 | lightgbm, 45 | Matrix, 46 | modeldata, 47 | partykit, 48 | rmarkdown, 49 | rsample, 50 | testthat (>= 3.0.0), 51 | tune 52 | VignetteBuilder: 53 | knitr 54 | Config/Needs/website: tidyverse/tidytemplate 55 | Config/testthat/edition: 3 56 | Config/usethis/last-upkeep: 2025-04-25 57 | Encoding: UTF-8 58 | Roxygen: list(markdown = TRUE) 59 | RoxygenNote: 7.3.2 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2023 2 | COPYRIGHT HOLDER: bonsai authors 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2025 bonsai authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(multi_predict,"_lgb.Booster") 4 | export("%>%") 5 | export(predict_lightgbm_classification_class) 6 | export(predict_lightgbm_classification_prob) 7 | export(predict_lightgbm_classification_raw) 8 | export(predict_lightgbm_regression_numeric) 9 | export(train_lightgbm) 10 | import(rlang) 11 | importFrom(dials,min_n) 12 | importFrom(parsnip,"%>%") 13 | importFrom(parsnip,boost_tree) 14 | importFrom(parsnip,decision_tree) 15 | importFrom(parsnip,fit) 16 | importFrom(parsnip,multi_predict) 17 | importFrom(parsnip,rand_forest) 18 | importFrom(parsnip,set_engine) 19 | importFrom(parsnip,set_mode) 20 | importFrom(purrr,map_df) 21 | importFrom(purrr,map_dfr) 22 | importFrom(rlang,call2) 23 | importFrom(rlang,empty_env) 24 | importFrom(rlang,enquo) 25 | importFrom(rlang,enquos) 26 | importFrom(rlang,eval_tidy) 27 | importFrom(rlang,expr) 28 | importFrom(rlang,new_quosure) 29 | importFrom(stats,predict) 30 | importFrom(tibble,as_tibble) 31 | importFrom(tibble,tibble) 32 | importFrom(utils,packageVersion) 33 | importFrom(withr,defer) 34 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # bonsai (development version) 2 | 3 | * Increased the minimum R version to R 4.1. 4 | 5 | # bonsai 0.3.2 6 | 7 | * Resolves a test failure ahead of an upcoming parsnip release (#95). 8 | 9 | * lightgbm models can now accept sparse matrices for training and prediction (#91). 10 | 11 | # bonsai 0.3.1 12 | 13 | * Fixed bug where `"aorsf"` models would not successfully fit in socket cluster workers (i.e. with `plan(multisession)`) unless another engine requiring bonsai had been fitted in the worker (#85). 14 | 15 | # bonsai 0.3.0 16 | 17 | * Introduced support for accelerated oblique random forests for the `"classification"` and `"regression"` modes using the new [`"aorsf"` engine](https://github.com/ropensci/aorsf) (#78 by `@bcjaeger`). 18 | 19 | * Enabled passing [Dataset Parameters](https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters) to the `"lightgbm"` engine. To pass an argument that would be usually passed as an element to the `param` argument in `lightgbm::lgb.Dataset()`, pass the argument directly through the ellipses in `set_engine()`, e.g. `boost_tree() %>% set_engine("lightgbm", linear_tree = TRUE)` (#77). 20 | 21 | * Enabled case weights with the `"lightgbm"` engine (#72 by `@p-schaefer`). 22 | 23 | * Fixed issues in metadata for the `"partykit"` engine for `rand_forest()` where some engine arguments were mistakenly protected (#74). 24 | 25 | * Addressed type check error when fitting lightgbm model specifications with arguments mistakenly left as `tune()` (#79). 26 | 27 | # bonsai 0.2.1 28 | 29 | * The most recent dials and parsnip releases introduced tuning integration for the lightgbm `num_leaves` engine argument! The `num_leaves` parameter sets the maximum number of nodes per tree, and is an [important tuning parameter for lightgbm](https://lightgbm.readthedocs.io/en/latest/Parameters-Tuning.html) ([tidymodels/dials#256](https://github.com/tidymodels/dials/pull/256), [tidymodels/parsnip#838](https://github.com/tidymodels/parsnip/pull/838)). With the newest version of each of dials, parsnip, and bonsai installed, tune this argument by marking the `num_leaves` engine argument for tuning when defining your model specification: 30 | 31 | ``` r 32 | boost_tree() %>% set_engine("lightgbm", num_leaves = tune()) 33 | ``` 34 | 35 | * Fixed a bug where lightgbm's parallelism argument `num_threads` was overridden when passed via `param` rather than as a main argument. By default, then, lightgbm will fit sequentially rather than with `num_threads = foreach::getDoParWorkers()`. The user can still set `num_threads` via engine arguments with `engine = "lightgbm"`: 36 | 37 | ``` r 38 | boost_tree() %>% set_engine("lightgbm", num_threads = x) 39 | ``` 40 | 41 | Note that, when tuning hyperparameters with the tune package, detection of parallel backend will still work [as usual](https://tune.tidymodels.org/articles/extras/optimizations.html). 42 | 43 | * The `boost_tree` argument `stop_iter` now maps to the `lightgbm:::lgb.train()` argument `early_stopping_round` rather than its alias `early_stopping_rounds`. This does not affect parsnip's interface to lightgbm (i.e. via `boost_tree() %>% set_engine("lightgbm")`), though will introduce errors for code that uses the `train_lightgbm()` wrapper directly and sets the `lightgbm::lgb.train()` argument `early_stopping_round` by its alias `early_stopping_rounds` via `train_lightgbm()`'s `...`. 44 | 45 | * Disallowed passing main model arguments as engine arguments to `set_engine("lightgbm", ...)` via aliases. That is, if a main argument is marked for tuning and a lightgbm alias is supplied as an engine argument, bonsai will now error, rather than supplying both to lightgbm and allowing the package to handle aliases. Users can still interface with non-main `boost_tree()` arguments via their lightgbm aliases ([#53](https://github.com/tidymodels/bonsai/issues/53)). 46 | 47 | # bonsai 0.2.0 48 | 49 | * Enabled bagging with lightgbm via the `sample_size` argument to `boost_tree` 50 | (#32 and tidymodels/parsnip#768). The following docs now available in 51 | `?details_boost_tree_lightgbm` describe the interface in detail: 52 | 53 | > The `sample_size` argument is translated to the `bagging_fraction` parameter in the `param` argument of `lgb.train`. The argument is interpreted by lightgbm as a _proportion_ rather than a count, so bonsai internally reparameterizes the `sample_size` argument with [dials::sample_prop()] during tuning. 54 | > 55 | > To effectively enable bagging, the user would also need to set the `bagging_freq` argument to lightgbm. `bagging_freq` defaults to 0, which means bagging is disabled, and a `bagging_freq` argument of `k` means that the booster will perform bagging at every `k`th boosting iteration. Thus, by default, the `sample_size` argument would be ignored without setting this argument manually. Other boosting libraries, like xgboost, do not have an analogous argument to `bagging_freq` and use `k = 1` when the analogue to `bagging_fraction` is in $(0, 1)$. _bonsai will thus automatically set_ `bagging_freq = 1` _in_ `set_engine("lightgbm", ...)` if `sample_size` (i.e. `bagging_fraction`) is not equal to 1 and no `bagging_freq` value is supplied. This default can be overridden by setting the `bagging_freq` argument to `set_engine()` manually. 56 | 57 | * Corrected mapping of the `mtry` argument in `boost_tree` with the lightgbm 58 | engine. `mtry` previously mapped to the `feature_fraction` argument to 59 | `lgb.train` but was documented as mapping to an argument more closely 60 | resembling `feature_fraction_bynode`. `mtry` now maps 61 | to `feature_fraction_bynode`. 62 | 63 | This means that code that set `feature_fraction_bynode` as an argument to 64 | `set_engine()` will now error, and the user can now pass `feature_fraction` 65 | to `set_engine()` without raising an error. 66 | 67 | * Fixed error in lightgbm with engine argument `objective = "tweedie"` and 68 | response values less than 1. 69 | 70 | * A number of documentation improvements, increases in testing coverage, and 71 | changes to internals in anticipation of the 4.0.0 release of the lightgbm 72 | package. Thank you to `@jameslamb` for the effort and expertise! 73 | 74 | # bonsai 0.1.0 75 | 76 | Initial release! 77 | -------------------------------------------------------------------------------- /R/0_imports.R: -------------------------------------------------------------------------------- 1 | #' @importFrom rlang enquo call2 eval_tidy new_quosure empty_env enquos expr 2 | #' @importFrom purrr map_dfr map_df 3 | #' @importFrom tibble as_tibble tibble 4 | #' @importFrom parsnip multi_predict set_mode set_engine fit 5 | #' @importFrom parsnip decision_tree boost_tree rand_forest 6 | #' @importFrom stats predict 7 | #' @importFrom utils packageVersion 8 | 9 | # ------------------------------------------------------------------------------ 10 | 11 | #' @importFrom parsnip %>% 12 | #' @export 13 | parsnip::`%>%` 14 | 15 | # quiet R CMD CHECK warning re: declared Imports 16 | #' @importFrom dials min_n 17 | dials::min_n() 18 | 19 | # ------------------------------------------------------------------------------ 20 | 21 | utils::globalVariables( 22 | c( 23 | "categorical_columns", 24 | "categorical_features_to_int", 25 | "new_data", 26 | "object" 27 | ) 28 | ) 29 | 30 | # ------------------------------------------------------------------------------ 31 | 32 | # lightgbm had significant breaking changes following release v3.3.2. 33 | # this function is used by patches that make bonsai backward-compatible with 34 | # older lightgbm versions 35 | using_newer_lightgbm_version <- function() { 36 | utils::packageVersion("lightgbm") > package_version("3.3.2") 37 | } 38 | -------------------------------------------------------------------------------- /R/aorsf_data.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | make_rand_forest_aorsf <- function() { 4 | parsnip::set_model_engine("rand_forest", "classification", "aorsf") 5 | parsnip::set_model_engine("rand_forest", "regression", "aorsf") 6 | 7 | parsnip::set_dependency( 8 | "rand_forest", 9 | "aorsf", 10 | "aorsf", 11 | mode = "classification" 12 | ) 13 | parsnip::set_dependency( 14 | "rand_forest", 15 | "aorsf", 16 | "bonsai", 17 | mode = "classification" 18 | ) 19 | 20 | parsnip::set_dependency("rand_forest", "aorsf", "aorsf", mode = "regression") 21 | parsnip::set_dependency("rand_forest", "aorsf", "bonsai", mode = "regression") 22 | 23 | parsnip::set_model_arg( 24 | model = "rand_forest", 25 | eng = "aorsf", 26 | parsnip = "mtry", 27 | original = "mtry", 28 | func = list(pkg = "dials", fun = "mtry"), 29 | has_submodel = FALSE 30 | ) 31 | 32 | parsnip::set_model_arg( 33 | model = "rand_forest", 34 | eng = "aorsf", 35 | parsnip = "trees", 36 | original = "n_tree", 37 | func = list(pkg = "dials", fun = "trees"), 38 | has_submodel = FALSE 39 | ) 40 | 41 | parsnip::set_model_arg( 42 | model = "rand_forest", 43 | eng = "aorsf", 44 | parsnip = "min_n", 45 | original = "leaf_min_obs", 46 | func = list(pkg = "dials", fun = "min_n"), 47 | has_submodel = FALSE 48 | ) 49 | 50 | parsnip::set_model_arg( 51 | model = "rand_forest", 52 | eng = "aorsf", 53 | parsnip = "mtry", 54 | original = "mtry", 55 | func = list(pkg = "dials", fun = "mtry"), 56 | has_submodel = FALSE 57 | ) 58 | 59 | parsnip::set_fit( 60 | model = "rand_forest", 61 | eng = "aorsf", 62 | mode = "classification", 63 | value = list( 64 | interface = "formula", 65 | protect = c("formula", "data", "weights"), 66 | func = c(pkg = "aorsf", fun = "orsf"), 67 | defaults = list( 68 | n_thread = 1, 69 | verbose_progress = FALSE 70 | ) 71 | ) 72 | ) 73 | 74 | parsnip::set_encoding( 75 | model = "rand_forest", 76 | eng = "aorsf", 77 | mode = "classification", 78 | options = list( 79 | predictor_indicators = "none", 80 | compute_intercept = FALSE, 81 | remove_intercept = FALSE, 82 | allow_sparse_x = FALSE 83 | ) 84 | ) 85 | 86 | parsnip::set_fit( 87 | model = "rand_forest", 88 | eng = "aorsf", 89 | mode = "regression", 90 | value = list( 91 | interface = "formula", 92 | protect = c("formula", "data", "weights"), 93 | func = c(pkg = "aorsf", fun = "orsf"), 94 | defaults = list( 95 | n_thread = 1, 96 | verbose_progress = FALSE 97 | ) 98 | ) 99 | ) 100 | 101 | parsnip::set_encoding( 102 | model = "rand_forest", 103 | eng = "aorsf", 104 | mode = "regression", 105 | options = list( 106 | predictor_indicators = "none", 107 | compute_intercept = FALSE, 108 | remove_intercept = FALSE, 109 | allow_sparse_x = FALSE 110 | ) 111 | ) 112 | 113 | parsnip::set_pred( 114 | model = "rand_forest", 115 | eng = "aorsf", 116 | mode = "classification", 117 | type = "class", 118 | value = list( 119 | pre = NULL, 120 | # makes prob preds consistent with class ones. 121 | # note: the class predict method in aorsf uses the standard 'each tree 122 | # gets one vote' approach, which is usually consistent with probability 123 | # but not all the time. I opted to make predicted probability totally 124 | # consistent with predicted class in the parsnip bindings for aorsf b/c 125 | # I think it's really confusing when predicted probs do not align with 126 | # predicted classes. I'm fine with this in aorsf but in bonsai I want 127 | # to minimize confusion (#78). 128 | post = function(results, object) { 129 | missings <- apply(results, 1, function(x) any(is.na(x))) 130 | 131 | if (!any(missings)) { 132 | return(colnames(results)[apply(results, 1, which.max)]) 133 | } 134 | 135 | obs <- which(!missings) 136 | 137 | out <- rep(NA_character_, nrow(results)) 138 | out[obs] <- colnames(results)[apply(results[obs, ], 1, which.max)] 139 | out 140 | }, 141 | func = c(fun = "predict"), 142 | args = list( 143 | object = quote(object$fit), 144 | new_data = quote(new_data), 145 | pred_type = "prob", 146 | verbose_progress = FALSE, 147 | na_action = 'pass' 148 | ) 149 | ) 150 | ) 151 | 152 | parsnip::set_pred( 153 | model = "rand_forest", 154 | eng = "aorsf", 155 | mode = "classification", 156 | type = "prob", 157 | value = list( 158 | pre = NULL, 159 | post = function(x, object) { 160 | as_tibble(x) 161 | }, 162 | func = c(fun = "predict"), 163 | args = list( 164 | object = quote(object$fit), 165 | new_data = quote(new_data), 166 | pred_type = 'prob', 167 | verbose_progress = FALSE, 168 | na_action = 'pass' 169 | ) 170 | ) 171 | ) 172 | 173 | parsnip::set_pred( 174 | model = "rand_forest", 175 | eng = "aorsf", 176 | mode = "classification", 177 | type = "raw", 178 | value = list( 179 | pre = NULL, 180 | post = NULL, 181 | func = c(fun = "predict"), 182 | args = list( 183 | object = quote(object$fit), 184 | new_data = quote(new_data), 185 | verbose_progress = FALSE, 186 | na_action = 'pass' 187 | ) 188 | ) 189 | ) 190 | 191 | parsnip::set_pred( 192 | model = "rand_forest", 193 | eng = "aorsf", 194 | mode = "regression", 195 | type = "numeric", 196 | value = list( 197 | pre = NULL, 198 | post = as.numeric, 199 | func = c(fun = "predict"), 200 | args = list( 201 | object = quote(object$fit), 202 | new_data = quote(new_data), 203 | pred_type = "mean", 204 | verbose_progress = FALSE, 205 | na_action = 'pass' 206 | ) 207 | ) 208 | ) 209 | 210 | parsnip::set_pred( 211 | model = "rand_forest", 212 | eng = "aorsf", 213 | mode = "regression", 214 | type = "raw", 215 | value = list( 216 | pre = NULL, 217 | post = as.numeric, 218 | func = c(fun = "predict"), 219 | args = list( 220 | object = quote(object$fit), 221 | new_data = quote(new_data), 222 | pred_type = "mean", 223 | verbose_progress = FALSE, 224 | na_action = 'pass' 225 | ) 226 | ) 227 | ) 228 | } 229 | 230 | # nocov end 231 | -------------------------------------------------------------------------------- /R/bonsai_package.R: -------------------------------------------------------------------------------- 1 | #' bonsai: Model Wrappers for Tree-Based Models 2 | #' 3 | #' @docType package 4 | #' @aliases bonsai 5 | "_PACKAGE" 6 | 7 | #' @importFrom withr defer 8 | #' @import rlang 9 | NULL 10 | -------------------------------------------------------------------------------- /R/import-standalone-obj-type.R: -------------------------------------------------------------------------------- 1 | # Standalone file: do not edit by hand 2 | # Source: https://github.com/r-lib/rlang/blob/HEAD/R/standalone-obj-type.R 3 | # Generated by: usethis::use_standalone("r-lib/rlang", "obj-type") 4 | # ---------------------------------------------------------------------- 5 | # 6 | # --- 7 | # repo: r-lib/rlang 8 | # file: standalone-obj-type.R 9 | # last-updated: 2024-02-14 10 | # license: https://unlicense.org 11 | # imports: rlang (>= 1.1.0) 12 | # --- 13 | # 14 | # ## Changelog 15 | # 16 | # 2024-02-14: 17 | # - `obj_type_friendly()` now works for S7 objects. 18 | # 19 | # 2023-05-01: 20 | # - `obj_type_friendly()` now only displays the first class of S3 objects. 21 | # 22 | # 2023-03-30: 23 | # - `stop_input_type()` now handles `I()` input literally in `arg`. 24 | # 25 | # 2022-10-04: 26 | # - `obj_type_friendly(value = TRUE)` now shows numeric scalars 27 | # literally. 28 | # - `stop_friendly_type()` now takes `show_value`, passed to 29 | # `obj_type_friendly()` as the `value` argument. 30 | # 31 | # 2022-10-03: 32 | # - Added `allow_na` and `allow_null` arguments. 33 | # - `NULL` is now backticked. 34 | # - Better friendly type for infinities and `NaN`. 35 | # 36 | # 2022-09-16: 37 | # - Unprefixed usage of rlang functions with `rlang::` to 38 | # avoid onLoad issues when called from rlang (#1482). 39 | # 40 | # 2022-08-11: 41 | # - Prefixed usage of rlang functions with `rlang::`. 42 | # 43 | # 2022-06-22: 44 | # - `friendly_type_of()` is now `obj_type_friendly()`. 45 | # - Added `obj_type_oo()`. 46 | # 47 | # 2021-12-20: 48 | # - Added support for scalar values and empty vectors. 49 | # - Added `stop_input_type()` 50 | # 51 | # 2021-06-30: 52 | # - Added support for missing arguments. 53 | # 54 | # 2021-04-19: 55 | # - Added support for matrices and arrays (#141). 56 | # - Added documentation. 57 | # - Added changelog. 58 | # 59 | # nocov start 60 | 61 | #' Return English-friendly type 62 | #' @param x Any R object. 63 | #' @param value Whether to describe the value of `x`. Special values 64 | #' like `NA` or `""` are always described. 65 | #' @param length Whether to mention the length of vectors and lists. 66 | #' @return A string describing the type. Starts with an indefinite 67 | #' article, e.g. "an integer vector". 68 | #' @noRd 69 | obj_type_friendly <- function(x, value = TRUE) { 70 | if (is_missing(x)) { 71 | return("absent") 72 | } 73 | 74 | if (is.object(x)) { 75 | if (inherits(x, "quosure")) { 76 | type <- "quosure" 77 | } else { 78 | type <- class(x)[[1L]] 79 | } 80 | return(sprintf("a <%s> object", type)) 81 | } 82 | 83 | if (!is_vector(x)) { 84 | return(.rlang_as_friendly_type(typeof(x))) 85 | } 86 | 87 | n_dim <- length(dim(x)) 88 | 89 | if (!n_dim) { 90 | if (!is_list(x) && length(x) == 1) { 91 | if (is_na(x)) { 92 | return(switch( 93 | typeof(x), 94 | logical = "`NA`", 95 | integer = "an integer `NA`", 96 | double = 97 | if (is.nan(x)) { 98 | "`NaN`" 99 | } else { 100 | "a numeric `NA`" 101 | }, 102 | complex = "a complex `NA`", 103 | character = "a character `NA`", 104 | .rlang_stop_unexpected_typeof(x) 105 | )) 106 | } 107 | 108 | show_infinites <- function(x) { 109 | if (x > 0) { 110 | "`Inf`" 111 | } else { 112 | "`-Inf`" 113 | } 114 | } 115 | str_encode <- function(x, width = 30, ...) { 116 | if (nchar(x) > width) { 117 | x <- substr(x, 1, width - 3) 118 | x <- paste0(x, "...") 119 | } 120 | encodeString(x, ...) 121 | } 122 | 123 | if (value) { 124 | if (is.numeric(x) && is.infinite(x)) { 125 | return(show_infinites(x)) 126 | } 127 | 128 | if (is.numeric(x) || is.complex(x)) { 129 | number <- as.character(round(x, 2)) 130 | what <- if (is.complex(x)) "the complex number" else "the number" 131 | return(paste(what, number)) 132 | } 133 | 134 | return(switch( 135 | typeof(x), 136 | logical = if (x) "`TRUE`" else "`FALSE`", 137 | character = { 138 | what <- if (nzchar(x)) "the string" else "the empty string" 139 | paste(what, str_encode(x, quote = "\"")) 140 | }, 141 | raw = paste("the raw value", as.character(x)), 142 | .rlang_stop_unexpected_typeof(x) 143 | )) 144 | } 145 | 146 | return(switch( 147 | typeof(x), 148 | logical = "a logical value", 149 | integer = "an integer", 150 | double = if (is.infinite(x)) show_infinites(x) else "a number", 151 | complex = "a complex number", 152 | character = if (nzchar(x)) "a string" else "\"\"", 153 | raw = "a raw value", 154 | .rlang_stop_unexpected_typeof(x) 155 | )) 156 | } 157 | 158 | if (length(x) == 0) { 159 | return(switch( 160 | typeof(x), 161 | logical = "an empty logical vector", 162 | integer = "an empty integer vector", 163 | double = "an empty numeric vector", 164 | complex = "an empty complex vector", 165 | character = "an empty character vector", 166 | raw = "an empty raw vector", 167 | list = "an empty list", 168 | .rlang_stop_unexpected_typeof(x) 169 | )) 170 | } 171 | } 172 | 173 | vec_type_friendly(x) 174 | } 175 | 176 | vec_type_friendly <- function(x, length = FALSE) { 177 | if (!is_vector(x)) { 178 | abort("`x` must be a vector.") 179 | } 180 | type <- typeof(x) 181 | n_dim <- length(dim(x)) 182 | 183 | add_length <- function(type) { 184 | if (length && !n_dim) { 185 | paste0(type, sprintf(" of length %s", length(x))) 186 | } else { 187 | type 188 | } 189 | } 190 | 191 | if (type == "list") { 192 | if (n_dim < 2) { 193 | return(add_length("a list")) 194 | } else if (is.data.frame(x)) { 195 | return("a data frame") 196 | } else if (n_dim == 2) { 197 | return("a list matrix") 198 | } else { 199 | return("a list array") 200 | } 201 | } 202 | 203 | type <- switch( 204 | type, 205 | logical = "a logical %s", 206 | integer = "an integer %s", 207 | numeric = , 208 | double = "a double %s", 209 | complex = "a complex %s", 210 | character = "a character %s", 211 | raw = "a raw %s", 212 | type = paste0("a ", type, " %s") 213 | ) 214 | 215 | if (n_dim < 2) { 216 | kind <- "vector" 217 | } else if (n_dim == 2) { 218 | kind <- "matrix" 219 | } else { 220 | kind <- "array" 221 | } 222 | out <- sprintf(type, kind) 223 | 224 | if (n_dim >= 2) { 225 | out 226 | } else { 227 | add_length(out) 228 | } 229 | } 230 | 231 | .rlang_as_friendly_type <- function(type) { 232 | switch( 233 | type, 234 | 235 | list = "a list", 236 | 237 | NULL = "`NULL`", 238 | environment = "an environment", 239 | externalptr = "a pointer", 240 | weakref = "a weak reference", 241 | S4 = "an S4 object", 242 | 243 | name = , 244 | symbol = "a symbol", 245 | language = "a call", 246 | pairlist = "a pairlist node", 247 | expression = "an expression vector", 248 | 249 | char = "an internal string", 250 | promise = "an internal promise", 251 | ... = "an internal dots object", 252 | any = "an internal `any` object", 253 | bytecode = "an internal bytecode object", 254 | 255 | primitive = , 256 | builtin = , 257 | special = "a primitive function", 258 | closure = "a function", 259 | 260 | type 261 | ) 262 | } 263 | 264 | .rlang_stop_unexpected_typeof <- function(x, call = caller_env()) { 265 | abort( 266 | sprintf("Unexpected type <%s>.", typeof(x)), 267 | call = call 268 | ) 269 | } 270 | 271 | #' Return OO type 272 | #' @param x Any R object. 273 | #' @return One of `"bare"` (for non-OO objects), `"S3"`, `"S4"`, 274 | #' `"R6"`, or `"S7"`. 275 | #' @noRd 276 | obj_type_oo <- function(x) { 277 | if (!is.object(x)) { 278 | return("bare") 279 | } 280 | 281 | class <- inherits(x, c("R6", "S7_object"), which = TRUE) 282 | 283 | if (class[[1]]) { 284 | "R6" 285 | } else if (class[[2]]) { 286 | "S7" 287 | } else if (isS4(x)) { 288 | "S4" 289 | } else { 290 | "S3" 291 | } 292 | } 293 | 294 | #' @param x The object type which does not conform to `what`. Its 295 | #' `obj_type_friendly()` is taken and mentioned in the error message. 296 | #' @param what The friendly expected type as a string. Can be a 297 | #' character vector of expected types, in which case the error 298 | #' message mentions all of them in an "or" enumeration. 299 | #' @param show_value Passed to `value` argument of `obj_type_friendly()`. 300 | #' @param ... Arguments passed to [abort()]. 301 | #' @inheritParams args_error_context 302 | #' @noRd 303 | stop_input_type <- function(x, 304 | what, 305 | ..., 306 | allow_na = FALSE, 307 | allow_null = FALSE, 308 | show_value = TRUE, 309 | arg = caller_arg(x), 310 | call = caller_env()) { 311 | # From standalone-cli.R 312 | cli <- env_get_list( 313 | nms = c("format_arg", "format_code"), 314 | last = topenv(), 315 | default = function(x) sprintf("`%s`", x), 316 | inherit = TRUE 317 | ) 318 | 319 | if (allow_na) { 320 | what <- c(what, cli$format_code("NA")) 321 | } 322 | if (allow_null) { 323 | what <- c(what, cli$format_code("NULL")) 324 | } 325 | if (length(what)) { 326 | what <- oxford_comma(what) 327 | } 328 | if (inherits(arg, "AsIs")) { 329 | format_arg <- identity 330 | } else { 331 | format_arg <- cli$format_arg 332 | } 333 | 334 | message <- sprintf( 335 | "%s must be %s, not %s.", 336 | format_arg(arg), 337 | what, 338 | obj_type_friendly(x, value = show_value) 339 | ) 340 | 341 | abort(message, ..., call = call, arg = arg) 342 | } 343 | 344 | oxford_comma <- function(chr, sep = ", ", final = "or") { 345 | n <- length(chr) 346 | 347 | if (n < 2) { 348 | return(chr) 349 | } 350 | 351 | head <- chr[seq_len(n - 1)] 352 | last <- chr[n] 353 | 354 | head <- paste(head, collapse = sep) 355 | 356 | # Write a or b. But a, b, or c. 357 | if (n > 2) { 358 | paste0(head, sep, final, " ", last) 359 | } else { 360 | paste0(head, " ", final, " ", last) 361 | } 362 | } 363 | 364 | # nocov end 365 | -------------------------------------------------------------------------------- /R/import-standalone-types-check.R: -------------------------------------------------------------------------------- 1 | # Standalone file: do not edit by hand 2 | # Source: https://github.com/r-lib/rlang/blob/HEAD/R/standalone-types-check.R 3 | # Generated by: usethis::use_standalone("r-lib/rlang", "types-check") 4 | # ---------------------------------------------------------------------- 5 | # 6 | # --- 7 | # repo: r-lib/rlang 8 | # file: standalone-types-check.R 9 | # last-updated: 2023-03-13 10 | # license: https://unlicense.org 11 | # dependencies: standalone-obj-type.R 12 | # imports: rlang (>= 1.1.0) 13 | # --- 14 | # 15 | # ## Changelog 16 | # 17 | # 2024-08-15: 18 | # - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724) 19 | # 20 | # 2023-03-13: 21 | # - Improved error messages of number checkers (@teunbrand) 22 | # - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich). 23 | # - Added `check_data_frame()` (@mgirlich). 24 | # 25 | # 2023-03-07: 26 | # - Added dependency on rlang (>= 1.1.0). 27 | # 28 | # 2023-02-15: 29 | # - Added `check_logical()`. 30 | # 31 | # - `check_bool()`, `check_number_whole()`, and 32 | # `check_number_decimal()` are now implemented in C. 33 | # 34 | # - For efficiency, `check_number_whole()` and 35 | # `check_number_decimal()` now take a `NULL` default for `min` and 36 | # `max`. This makes it possible to bypass unnecessary type-checking 37 | # and comparisons in the default case of no bounds checks. 38 | # 39 | # 2022-10-07: 40 | # - `check_number_whole()` and `_decimal()` no longer treat 41 | # non-numeric types such as factors or dates as numbers. Numeric 42 | # types are detected with `is.numeric()`. 43 | # 44 | # 2022-10-04: 45 | # - Added `check_name()` that forbids the empty string. 46 | # `check_string()` allows the empty string by default. 47 | # 48 | # 2022-09-28: 49 | # - Removed `what` arguments. 50 | # - Added `allow_na` and `allow_null` arguments. 51 | # - Added `allow_decimal` and `allow_infinite` arguments. 52 | # - Improved errors with absent arguments. 53 | # 54 | # 55 | # 2022-09-16: 56 | # - Unprefixed usage of rlang functions with `rlang::` to 57 | # avoid onLoad issues when called from rlang (#1482). 58 | # 59 | # 2022-08-11: 60 | # - Added changelog. 61 | # 62 | # nocov start 63 | 64 | # Scalars ----------------------------------------------------------------- 65 | 66 | .standalone_types_check_dot_call <- .Call 67 | 68 | check_bool <- function(x, 69 | ..., 70 | allow_na = FALSE, 71 | allow_null = FALSE, 72 | arg = caller_arg(x), 73 | call = caller_env()) { 74 | if (!missing(x) && .standalone_types_check_dot_call(ffi_standalone_is_bool_1.0.7, x, allow_na, allow_null)) { 75 | return(invisible(NULL)) 76 | } 77 | 78 | stop_input_type( 79 | x, 80 | c("`TRUE`", "`FALSE`"), 81 | ..., 82 | allow_na = allow_na, 83 | allow_null = allow_null, 84 | arg = arg, 85 | call = call 86 | ) 87 | } 88 | 89 | check_string <- function(x, 90 | ..., 91 | allow_empty = TRUE, 92 | allow_na = FALSE, 93 | allow_null = FALSE, 94 | arg = caller_arg(x), 95 | call = caller_env()) { 96 | if (!missing(x)) { 97 | is_string <- .rlang_check_is_string( 98 | x, 99 | allow_empty = allow_empty, 100 | allow_na = allow_na, 101 | allow_null = allow_null 102 | ) 103 | if (is_string) { 104 | return(invisible(NULL)) 105 | } 106 | } 107 | 108 | stop_input_type( 109 | x, 110 | "a single string", 111 | ..., 112 | allow_na = allow_na, 113 | allow_null = allow_null, 114 | arg = arg, 115 | call = call 116 | ) 117 | } 118 | 119 | .rlang_check_is_string <- function(x, 120 | allow_empty, 121 | allow_na, 122 | allow_null) { 123 | if (is_string(x)) { 124 | if (allow_empty || !is_string(x, "")) { 125 | return(TRUE) 126 | } 127 | } 128 | 129 | if (allow_null && is_null(x)) { 130 | return(TRUE) 131 | } 132 | 133 | if (allow_na && (identical(x, NA) || identical(x, na_chr))) { 134 | return(TRUE) 135 | } 136 | 137 | FALSE 138 | } 139 | 140 | check_name <- function(x, 141 | ..., 142 | allow_null = FALSE, 143 | arg = caller_arg(x), 144 | call = caller_env()) { 145 | if (!missing(x)) { 146 | is_string <- .rlang_check_is_string( 147 | x, 148 | allow_empty = FALSE, 149 | allow_na = FALSE, 150 | allow_null = allow_null 151 | ) 152 | if (is_string) { 153 | return(invisible(NULL)) 154 | } 155 | } 156 | 157 | stop_input_type( 158 | x, 159 | "a valid name", 160 | ..., 161 | allow_na = FALSE, 162 | allow_null = allow_null, 163 | arg = arg, 164 | call = call 165 | ) 166 | } 167 | 168 | IS_NUMBER_true <- 0 169 | IS_NUMBER_false <- 1 170 | IS_NUMBER_oob <- 2 171 | 172 | check_number_decimal <- function(x, 173 | ..., 174 | min = NULL, 175 | max = NULL, 176 | allow_infinite = TRUE, 177 | allow_na = FALSE, 178 | allow_null = FALSE, 179 | arg = caller_arg(x), 180 | call = caller_env()) { 181 | if (missing(x)) { 182 | exit_code <- IS_NUMBER_false 183 | } else if (0 == (exit_code <- .standalone_types_check_dot_call( 184 | ffi_standalone_check_number_1.0.7, 185 | x, 186 | allow_decimal = TRUE, 187 | min, 188 | max, 189 | allow_infinite, 190 | allow_na, 191 | allow_null 192 | ))) { 193 | return(invisible(NULL)) 194 | } 195 | 196 | .stop_not_number( 197 | x, 198 | ..., 199 | exit_code = exit_code, 200 | allow_decimal = TRUE, 201 | min = min, 202 | max = max, 203 | allow_na = allow_na, 204 | allow_null = allow_null, 205 | arg = arg, 206 | call = call 207 | ) 208 | } 209 | 210 | check_number_whole <- function(x, 211 | ..., 212 | min = NULL, 213 | max = NULL, 214 | allow_infinite = FALSE, 215 | allow_na = FALSE, 216 | allow_null = FALSE, 217 | arg = caller_arg(x), 218 | call = caller_env()) { 219 | if (missing(x)) { 220 | exit_code <- IS_NUMBER_false 221 | } else if (0 == (exit_code <- .standalone_types_check_dot_call( 222 | ffi_standalone_check_number_1.0.7, 223 | x, 224 | allow_decimal = FALSE, 225 | min, 226 | max, 227 | allow_infinite, 228 | allow_na, 229 | allow_null 230 | ))) { 231 | return(invisible(NULL)) 232 | } 233 | 234 | .stop_not_number( 235 | x, 236 | ..., 237 | exit_code = exit_code, 238 | allow_decimal = FALSE, 239 | min = min, 240 | max = max, 241 | allow_na = allow_na, 242 | allow_null = allow_null, 243 | arg = arg, 244 | call = call 245 | ) 246 | } 247 | 248 | .stop_not_number <- function(x, 249 | ..., 250 | exit_code, 251 | allow_decimal, 252 | min, 253 | max, 254 | allow_na, 255 | allow_null, 256 | arg, 257 | call) { 258 | if (allow_decimal) { 259 | what <- "a number" 260 | } else { 261 | what <- "a whole number" 262 | } 263 | 264 | if (exit_code == IS_NUMBER_oob) { 265 | min <- min %||% -Inf 266 | max <- max %||% Inf 267 | 268 | if (min > -Inf && max < Inf) { 269 | what <- sprintf("%s between %s and %s", what, min, max) 270 | } else if (x < min) { 271 | what <- sprintf("%s larger than or equal to %s", what, min) 272 | } else if (x > max) { 273 | what <- sprintf("%s smaller than or equal to %s", what, max) 274 | } else { 275 | abort("Unexpected state in OOB check", .internal = TRUE) 276 | } 277 | } 278 | 279 | stop_input_type( 280 | x, 281 | what, 282 | ..., 283 | allow_na = allow_na, 284 | allow_null = allow_null, 285 | arg = arg, 286 | call = call 287 | ) 288 | } 289 | 290 | check_symbol <- function(x, 291 | ..., 292 | allow_null = FALSE, 293 | arg = caller_arg(x), 294 | call = caller_env()) { 295 | if (!missing(x)) { 296 | if (is_symbol(x)) { 297 | return(invisible(NULL)) 298 | } 299 | if (allow_null && is_null(x)) { 300 | return(invisible(NULL)) 301 | } 302 | } 303 | 304 | stop_input_type( 305 | x, 306 | "a symbol", 307 | ..., 308 | allow_na = FALSE, 309 | allow_null = allow_null, 310 | arg = arg, 311 | call = call 312 | ) 313 | } 314 | 315 | check_arg <- function(x, 316 | ..., 317 | allow_null = FALSE, 318 | arg = caller_arg(x), 319 | call = caller_env()) { 320 | if (!missing(x)) { 321 | if (is_symbol(x)) { 322 | return(invisible(NULL)) 323 | } 324 | if (allow_null && is_null(x)) { 325 | return(invisible(NULL)) 326 | } 327 | } 328 | 329 | stop_input_type( 330 | x, 331 | "an argument name", 332 | ..., 333 | allow_na = FALSE, 334 | allow_null = allow_null, 335 | arg = arg, 336 | call = call 337 | ) 338 | } 339 | 340 | check_call <- function(x, 341 | ..., 342 | allow_null = FALSE, 343 | arg = caller_arg(x), 344 | call = caller_env()) { 345 | if (!missing(x)) { 346 | if (is_call(x)) { 347 | return(invisible(NULL)) 348 | } 349 | if (allow_null && is_null(x)) { 350 | return(invisible(NULL)) 351 | } 352 | } 353 | 354 | stop_input_type( 355 | x, 356 | "a defused call", 357 | ..., 358 | allow_na = FALSE, 359 | allow_null = allow_null, 360 | arg = arg, 361 | call = call 362 | ) 363 | } 364 | 365 | check_environment <- function(x, 366 | ..., 367 | allow_null = FALSE, 368 | arg = caller_arg(x), 369 | call = caller_env()) { 370 | if (!missing(x)) { 371 | if (is_environment(x)) { 372 | return(invisible(NULL)) 373 | } 374 | if (allow_null && is_null(x)) { 375 | return(invisible(NULL)) 376 | } 377 | } 378 | 379 | stop_input_type( 380 | x, 381 | "an environment", 382 | ..., 383 | allow_na = FALSE, 384 | allow_null = allow_null, 385 | arg = arg, 386 | call = call 387 | ) 388 | } 389 | 390 | check_function <- function(x, 391 | ..., 392 | allow_null = FALSE, 393 | arg = caller_arg(x), 394 | call = caller_env()) { 395 | if (!missing(x)) { 396 | if (is_function(x)) { 397 | return(invisible(NULL)) 398 | } 399 | if (allow_null && is_null(x)) { 400 | return(invisible(NULL)) 401 | } 402 | } 403 | 404 | stop_input_type( 405 | x, 406 | "a function", 407 | ..., 408 | allow_na = FALSE, 409 | allow_null = allow_null, 410 | arg = arg, 411 | call = call 412 | ) 413 | } 414 | 415 | check_closure <- function(x, 416 | ..., 417 | allow_null = FALSE, 418 | arg = caller_arg(x), 419 | call = caller_env()) { 420 | if (!missing(x)) { 421 | if (is_closure(x)) { 422 | return(invisible(NULL)) 423 | } 424 | if (allow_null && is_null(x)) { 425 | return(invisible(NULL)) 426 | } 427 | } 428 | 429 | stop_input_type( 430 | x, 431 | "an R function", 432 | ..., 433 | allow_na = FALSE, 434 | allow_null = allow_null, 435 | arg = arg, 436 | call = call 437 | ) 438 | } 439 | 440 | check_formula <- function(x, 441 | ..., 442 | allow_null = FALSE, 443 | arg = caller_arg(x), 444 | call = caller_env()) { 445 | if (!missing(x)) { 446 | if (is_formula(x)) { 447 | return(invisible(NULL)) 448 | } 449 | if (allow_null && is_null(x)) { 450 | return(invisible(NULL)) 451 | } 452 | } 453 | 454 | stop_input_type( 455 | x, 456 | "a formula", 457 | ..., 458 | allow_na = FALSE, 459 | allow_null = allow_null, 460 | arg = arg, 461 | call = call 462 | ) 463 | } 464 | 465 | 466 | # Vectors ----------------------------------------------------------------- 467 | 468 | # TODO: Figure out what to do with logical `NA` and `allow_na = TRUE` 469 | 470 | check_character <- function(x, 471 | ..., 472 | allow_na = TRUE, 473 | allow_null = FALSE, 474 | arg = caller_arg(x), 475 | call = caller_env()) { 476 | 477 | if (!missing(x)) { 478 | if (is_character(x)) { 479 | if (!allow_na && any(is.na(x))) { 480 | abort( 481 | sprintf("`%s` can't contain NA values.", arg), 482 | arg = arg, 483 | call = call 484 | ) 485 | } 486 | 487 | return(invisible(NULL)) 488 | } 489 | 490 | if (allow_null && is_null(x)) { 491 | return(invisible(NULL)) 492 | } 493 | } 494 | 495 | stop_input_type( 496 | x, 497 | "a character vector", 498 | ..., 499 | allow_null = allow_null, 500 | arg = arg, 501 | call = call 502 | ) 503 | } 504 | 505 | check_logical <- function(x, 506 | ..., 507 | allow_null = FALSE, 508 | arg = caller_arg(x), 509 | call = caller_env()) { 510 | if (!missing(x)) { 511 | if (is_logical(x)) { 512 | return(invisible(NULL)) 513 | } 514 | if (allow_null && is_null(x)) { 515 | return(invisible(NULL)) 516 | } 517 | } 518 | 519 | stop_input_type( 520 | x, 521 | "a logical vector", 522 | ..., 523 | allow_na = FALSE, 524 | allow_null = allow_null, 525 | arg = arg, 526 | call = call 527 | ) 528 | } 529 | 530 | check_data_frame <- function(x, 531 | ..., 532 | allow_null = FALSE, 533 | arg = caller_arg(x), 534 | call = caller_env()) { 535 | if (!missing(x)) { 536 | if (is.data.frame(x)) { 537 | return(invisible(NULL)) 538 | } 539 | if (allow_null && is_null(x)) { 540 | return(invisible(NULL)) 541 | } 542 | } 543 | 544 | stop_input_type( 545 | x, 546 | "a data frame", 547 | ..., 548 | allow_null = allow_null, 549 | arg = arg, 550 | call = call 551 | ) 552 | } 553 | 554 | # nocov end 555 | -------------------------------------------------------------------------------- /R/lightgbm.R: -------------------------------------------------------------------------------- 1 | #' Boosted trees with lightgbm 2 | #' 3 | #' `train_lightgbm` is a wrapper for `lightgbm` tree-based models 4 | #' where all of the model arguments are in the main function. 5 | #' 6 | #' This is an internal function, not meant to be directly called by the user. 7 | #' 8 | #' @param x A data frame or matrix of predictors 9 | #' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. 10 | #' @param weights A numeric vector of sample weights. 11 | #' @param max_depth An integer for the maximum depth of the tree. 12 | #' @param num_iterations An integer for the number of boosting iterations. 13 | #' @param learning_rate A numeric value between zero and one to control the learning rate. 14 | #' @param feature_fraction_bynode Fraction of predictors that will be randomly sampled 15 | #' at each split. 16 | #' @param min_data_in_leaf A numeric value for the minimum sum of instances needed 17 | #' in a child to continue to split. 18 | #' @param min_gain_to_split A number for the minimum loss reduction required to make a 19 | #' further partition on a leaf node of the tree. 20 | #' @param bagging_fraction Subsampling proportion of rows. Setting this argument 21 | #' to a non-default value will also set `bagging_freq = 1`. See the Bagging 22 | #' section in `?details_boost_tree_lightgbm` for more details. 23 | #' @param early_stopping_round Number of iterations without an improvement in 24 | #' the objective function occur before training should be halted. 25 | #' @param validation The _proportion_ of the training data that are used for 26 | #' performance assessment and potential early stopping. 27 | #' @param counts A logical; should `feature_fraction_bynode` be interpreted as the 28 | #' _number_ of predictors that will be randomly sampled at each split? 29 | #' `TRUE` indicates that `mtry` will be interpreted in its sense as a _count_, 30 | #' `FALSE` indicates that the argument will be interpreted in its sense as a 31 | #' _proportion_. 32 | #' @param quiet A logical; should logging by [lightgbm::lgb.train()] be muted? 33 | #' @param ... Other options to pass to [lightgbm::lgb.train()]. Arguments 34 | #' will be correctly routed to the `param` argument, or as a main argument, 35 | #' depending on their name. 36 | #' @return A fitted `lightgbm.Model` object. 37 | #' @keywords internal 38 | #' @export 39 | train_lightgbm <- function( 40 | x, 41 | y, 42 | weights = NULL, 43 | max_depth = -1, 44 | num_iterations = 100, 45 | learning_rate = 0.1, 46 | feature_fraction_bynode = 1, 47 | min_data_in_leaf = 20, 48 | min_gain_to_split = 0, 49 | bagging_fraction = 1, 50 | early_stopping_round = NULL, 51 | validation = 0, 52 | counts = TRUE, 53 | quiet = FALSE, 54 | ... 55 | ) { 56 | force(x) 57 | force(y) 58 | 59 | call <- call2("fit") 60 | 61 | check_number_whole(max_depth, call = call) 62 | check_number_whole(num_iterations, call = call) 63 | check_number_decimal(learning_rate, call = call) 64 | check_number_decimal(feature_fraction_bynode, call = call) 65 | check_number_whole(min_data_in_leaf, call = call) 66 | check_number_decimal(min_gain_to_split, call = call) 67 | check_number_decimal(bagging_fraction, call = call) 68 | check_number_decimal(early_stopping_round, allow_null = TRUE, call = call) 69 | check_bool(counts, call = call) 70 | check_bool(quiet, call = call) 71 | 72 | feature_fraction_bynode <- 73 | process_mtry( 74 | feature_fraction_bynode = feature_fraction_bynode, 75 | counts = counts, 76 | x = x, 77 | is_missing = missing(feature_fraction_bynode) 78 | ) 79 | 80 | check_lightgbm_aliases(...) 81 | 82 | # bonsai should be able to differentiate between 83 | # 1) main arguments to `lgb.train()` (as in `names(formals(lgb.train))` other 84 | # than `params`), 85 | # 2) main arguments to `lgb.Dataset()` (as in `names(formals(lgb.Dataset))` 86 | # other than `params`), and 87 | # 3) arguments to pass to `lgb.train(params)` OR `lgb.Dataset(params)`. 88 | # arguments to the `params` argument of either function can be concatenated 89 | # together and passed to both (#77). 90 | args <- 91 | list( 92 | num_iterations = num_iterations, 93 | learning_rate = learning_rate, 94 | max_depth = max_depth, 95 | feature_fraction_bynode = feature_fraction_bynode, 96 | min_data_in_leaf = min_data_in_leaf, 97 | min_gain_to_split = min_gain_to_split, 98 | bagging_fraction = bagging_fraction, 99 | early_stopping_round = early_stopping_round, 100 | ... 101 | ) 102 | 103 | args <- process_bagging(args) 104 | args <- process_parallelism(args) 105 | args <- process_objective_function(args, x, y) 106 | 107 | args <- sort_args(args) 108 | 109 | if (!is.numeric(y)) { 110 | y <- as.numeric(y) - 1 111 | } 112 | 113 | args <- process_data(args, x, y, weights, validation, missing(validation)) 114 | 115 | compacted <- c(list(params = args$params), args$main_args_train) 116 | 117 | call <- rlang::call2("lgb.train", !!!compacted, .ns = "lightgbm") 118 | 119 | if (quiet) { 120 | junk <- utils::capture.output( 121 | res <- rlang::eval_tidy(call, env = rlang::current_env()) 122 | ) 123 | } else { 124 | res <- rlang::eval_tidy(call, env = rlang::current_env()) 125 | } 126 | 127 | res 128 | } 129 | 130 | process_mtry <- function( 131 | feature_fraction_bynode, 132 | counts, 133 | x, 134 | is_missing, 135 | call = call2("fit") 136 | ) { 137 | check_bool(counts, call = call) 138 | 139 | ineq <- if (counts) { 140 | "greater" 141 | } else { 142 | "less" 143 | } 144 | interp <- if (counts) { 145 | "count" 146 | } else { 147 | "proportion" 148 | } 149 | opp <- if (!counts) { 150 | "count" 151 | } else { 152 | "proportion" 153 | } 154 | 155 | if ( 156 | (feature_fraction_bynode < 1 & counts) | 157 | (feature_fraction_bynode > 1 & !counts) 158 | ) { 159 | cli::cli_abort( 160 | c( 161 | "{.arg mtry} must be {ineq} than or equal to 1, not {feature_fraction_bynode}.", 162 | "i" = "{.arg mtry} is currently being interpreted as a {interp} 163 | rather than a {opp}.", 164 | "i" = "Supply {.code counts = {!counts}} to {.fn set_engine} to supply 165 | this argument as a {opp} rather than a {interp}.", 166 | "i" = "See {.help train_lightgbm} for more details." 167 | ), 168 | call = call 169 | ) 170 | } 171 | 172 | if (counts && !is_missing) { 173 | feature_fraction_bynode <- feature_fraction_bynode / ncol(x) 174 | } 175 | 176 | feature_fraction_bynode 177 | } 178 | 179 | process_objective_function <- function(args, x, y) { 180 | # set the "objective" param argument, clear it out from main args 181 | if (!any(names(args) %in% c("objective"))) { 182 | if (is.numeric(y)) { 183 | args$objective <- "regression" 184 | } else { 185 | lvl <- levels(y) 186 | lvls <- length(lvl) 187 | if (lvls == 2) { 188 | args$num_class <- 1 189 | args$objective <- "binary" 190 | } else { 191 | args$num_class <- lvls 192 | args$objective <- "multiclass" 193 | } 194 | } 195 | } 196 | 197 | args 198 | } 199 | 200 | # supply the number of threads as num_threads in params, clear out 201 | # any other thread args that might be passed as main arguments 202 | process_parallelism <- function(args) { 203 | if (!is.null(args["num_threads"])) { 204 | args$num_threads <- args[names(args) == "num_threads"] 205 | args[names(args) == "num_threads"] <- NULL 206 | } 207 | 208 | args 209 | } 210 | 211 | process_bagging <- function(args) { 212 | if ( 213 | args$bagging_fraction != 1 && 214 | (!"bagging_freq" %in% names(args)) 215 | ) { 216 | args$bagging_freq <- 1 217 | } 218 | 219 | args 220 | } 221 | 222 | process_data <- function(args, x, y, weights, validation, missing_validation) { 223 | # trn_index | val_index 224 | # ---------------------------------- 225 | # needs_validation & missing_validation | 1:n 1:n 226 | # needs_validation & !missing_validation | sample(1:n, m) setdiff(trn_index, 1:n) 227 | # !needs_validation & missing_validation | 1:n NULL 228 | # !needs_validation & !missing_validation | sample(1:n, m) setdiff(trn_index, 1:n) 229 | 230 | n <- nrow(x) 231 | needs_validation <- !is.null(args$params$early_stopping_round) 232 | if (!needs_validation) { 233 | # If early_stopping_round isn't set, clear it from arguments actually 234 | # passed to LightGBM. 235 | args$params$early_stopping_round <- NULL 236 | } 237 | 238 | if (missing_validation) { 239 | trn_index <- 1:n 240 | if (needs_validation) { 241 | val_index <- trn_index 242 | } else { 243 | val_index <- NULL 244 | } 245 | } else { 246 | m <- min(floor(n * (1 - validation)) + 1, n - 1) 247 | trn_index <- sample(1:n, size = max(m, 2)) 248 | val_index <- setdiff(1:n, trn_index) 249 | } 250 | 251 | data_args <- 252 | c( 253 | list( 254 | data = prepare_df_lgbm(x[trn_index, , drop = FALSE]), 255 | label = y[trn_index], 256 | categorical_feature = categorical_columns(x[trn_index, , drop = FALSE]), 257 | params = c(list(feature_pre_filter = FALSE), args$params), 258 | weight = weights[trn_index] 259 | ), 260 | args$main_args_dataset 261 | ) 262 | 263 | args$main_args_train$data <- 264 | rlang::eval_bare( 265 | rlang::call2("lgb.Dataset", !!!data_args, .ns = "lightgbm") 266 | ) 267 | 268 | if (!is.null(val_index)) { 269 | valids_args <- 270 | c( 271 | list( 272 | data = prepare_df_lgbm(x[val_index, , drop = FALSE]), 273 | label = y[val_index], 274 | categorical_feature = categorical_columns(x[ 275 | val_index, 276 | , 277 | drop = FALSE 278 | ]), 279 | params = list(feature_pre_filter = FALSE, args$params), 280 | weight = weights[val_index] 281 | ), 282 | args$main_args_dataset 283 | ) 284 | 285 | args$main_args_train$valids <- 286 | list( 287 | validation = rlang::eval_bare( 288 | rlang::call2("lgb.Dataset", !!!valids_args, .ns = "lightgbm") 289 | ) 290 | ) 291 | } 292 | 293 | args 294 | } 295 | 296 | # identifies supplied arguments as destined for `lgb.Dataset()`, `lgb.train()`, 297 | # or the `params` argument to both of the above (#77). 298 | sort_args <- function(args) { 299 | # warn on arguments that won't be passed along 300 | protected <- c( 301 | "obj", 302 | "init_model", 303 | "colnames", 304 | "categorical_feature", 305 | "callbacks", 306 | "reset_data" 307 | ) 308 | 309 | if (any(names(args) %in% protected)) { 310 | protected_args <- names(args[names(args) %in% protected]) 311 | 312 | rlang::warn( 313 | glue::glue( 314 | "The following argument(s) are guarded by bonsai and will not ", 315 | "be passed to LightGBM: {glue::glue_collapse(protected_args, sep = ', ')}" 316 | ) 317 | ) 318 | 319 | args[protected_args] <- NULL 320 | } 321 | 322 | main_args_dataset <- main_args(lightgbm::lgb.Dataset) 323 | main_args_train <- main_args(lightgbm::lgb.train) 324 | 325 | args <- 326 | list( 327 | main_args_dataset = args[names(args) %in% main_args_dataset], 328 | main_args_train = args[names(args) %in% main_args_train], 329 | params = args[!names(args) %in% c(main_args_dataset, main_args_train)] 330 | ) 331 | 332 | args 333 | } 334 | 335 | main_args <- function(fn) { 336 | res <- names(formals(fn)) 337 | res[res != "params"] 338 | } 339 | 340 | # in lightgbm <= 3.3.2, predict() for multiclass classification produced a single 341 | # vector of length num_observations * num_classes, in row-major order 342 | # 343 | # in versions after that release, lightgbm produces a numeric matrix with shape 344 | # [num_observations, num_classes] 345 | # 346 | # this function ensures that multiclass classification predictions are always 347 | # returned as a [num_observations, num_classes] matrix, regardless of lightgbm version 348 | reshape_lightgbm_multiclass_preds <- function(preds, num_rows) { 349 | n_preds_per_case <- length(preds) / num_rows 350 | if (is.vector(preds) && n_preds_per_case > 1) { 351 | preds <- matrix(preds, ncol = n_preds_per_case, byrow = TRUE) 352 | } 353 | preds 354 | } 355 | 356 | #' Internal functions 357 | #' 358 | #' Not intended for direct use. 359 | #' 360 | #' @keywords internal 361 | #' @export 362 | #' @rdname lightgbm_helpers 363 | predict_lightgbm_classification_prob <- function(object, new_data, ...) { 364 | p <- stats::predict(object$fit, prepare_df_lgbm(new_data), ...) 365 | p <- reshape_lightgbm_multiclass_preds(preds = p, num_rows = nrow(new_data)) 366 | 367 | if (is.vector(p)) { 368 | p <- tibble::tibble(p1 = 1 - p, p2 = p) 369 | } 370 | 371 | colnames(p) <- object$lvl 372 | 373 | tibble::as_tibble(p) 374 | } 375 | 376 | #' @keywords internal 377 | #' @export 378 | #' @rdname lightgbm_helpers 379 | predict_lightgbm_classification_class <- function(object, new_data, ...) { 380 | p <- predict_lightgbm_classification_prob( 381 | object, 382 | prepare_df_lgbm(new_data), 383 | ... 384 | ) 385 | 386 | q <- apply(p, 1, function(x) which.max(x)) 387 | 388 | names(p)[q] 389 | } 390 | 391 | #' @keywords internal 392 | #' @export 393 | #' @rdname lightgbm_helpers 394 | predict_lightgbm_classification_raw <- function(object, new_data, ...) { 395 | if (using_newer_lightgbm_version()) { 396 | p <- stats::predict( 397 | object$fit, 398 | prepare_df_lgbm(new_data), 399 | type = "raw", 400 | ... 401 | ) 402 | } else { 403 | p <- stats::predict( 404 | object$fit, 405 | prepare_df_lgbm(new_data), 406 | rawscore = TRUE, 407 | ... 408 | ) 409 | } 410 | reshape_lightgbm_multiclass_preds(preds = p, num_rows = nrow(new_data)) 411 | } 412 | 413 | #' @keywords internal 414 | #' @export 415 | #' @rdname lightgbm_helpers 416 | predict_lightgbm_regression_numeric <- function(object, new_data, ...) { 417 | p <- 418 | stats::predict( 419 | object$fit, 420 | prepare_df_lgbm(new_data), 421 | params = list(predict_disable_shape_check = TRUE), 422 | ... 423 | ) 424 | p 425 | } 426 | 427 | 428 | #' @keywords internal 429 | #' @export 430 | #' @rdname lightgbm_helpers 431 | multi_predict._lgb.Booster <- function( 432 | object, 433 | new_data, 434 | type = NULL, 435 | trees = NULL, 436 | ... 437 | ) { 438 | if (any(names(rlang::enquos(...)) == "newdata")) { 439 | cli::cli_abort( 440 | "Did you mean to use {.code new_data} instead of {.code newdata}?" 441 | ) 442 | } 443 | 444 | trees <- sort(trees) 445 | 446 | res <- map_df( 447 | trees, 448 | lightgbm_by_tree, 449 | object = object, 450 | new_data = new_data, 451 | type = type 452 | ) 453 | res <- dplyr::arrange(res, .row, trees) 454 | res <- split(res[, -1], res$.row) 455 | names(res) <- NULL 456 | 457 | tibble::tibble(.pred = res) 458 | } 459 | 460 | lightgbm_by_tree <- function(tree, object, new_data, type = NULL) { 461 | # switch based on prediction type 462 | if (object$spec$mode == "regression") { 463 | pred <- predict_lightgbm_regression_numeric( 464 | object, 465 | new_data, 466 | num_iteration = tree 467 | ) 468 | 469 | pred <- tibble::tibble(.pred = pred) 470 | 471 | nms <- names(pred) 472 | } else { 473 | if (is.null(type) || type == "class") { 474 | pred <- predict_lightgbm_classification_class( 475 | object, 476 | new_data, 477 | num_iteration = tree 478 | ) 479 | 480 | pred <- tibble::tibble(.pred_class = factor(pred, levels = object$lvl)) 481 | } else { 482 | pred <- predict_lightgbm_classification_prob( 483 | object, 484 | new_data, 485 | num_iteration = tree 486 | ) 487 | 488 | names(pred) <- paste0(".pred_", names(pred)) 489 | } 490 | 491 | nms <- names(pred) 492 | } 493 | 494 | pred[["trees"]] <- tree 495 | pred[[".row"]] <- 1:nrow(new_data) 496 | pred[, c(".row", "trees", nms)] 497 | } 498 | 499 | prepare_df_lgbm <- function(x, y = NULL) { 500 | categorical_cols <- categorical_columns(x) 501 | 502 | x <- categorical_features_to_int(x, categorical_cols) 503 | 504 | x <- parsnip::maybe_matrix(x) 505 | 506 | return(x) 507 | } 508 | 509 | categorical_columns <- function(x) { 510 | categorical_cols <- NULL 511 | if (inherits(x, c("matrix", "Matrix"))) { 512 | return(categorical_cols) 513 | } 514 | for (i in seq_along(x)) { 515 | if (is.factor(x[[i]])) { 516 | categorical_cols <- c(categorical_cols, i) 517 | } 518 | } 519 | categorical_cols 520 | } 521 | 522 | categorical_features_to_int <- function(x, cat_indices) { 523 | if (inherits(x, c("matrix", "Matrix"))) { 524 | return(x) 525 | } 526 | for (i in cat_indices) { 527 | x[[i]] <- as.integer(x[[i]]) - 1 528 | } 529 | x 530 | } 531 | 532 | check_lightgbm_aliases <- function(...) { 533 | dots <- rlang::list2(...) 534 | 535 | for (param in names(dots)) { 536 | uses_alias <- lightgbm_aliases$alias %in% param 537 | if (any(uses_alias)) { 538 | main <- lightgbm_aliases$lightgbm[uses_alias] 539 | parsnip <- lightgbm_aliases$parsnip[uses_alias] 540 | cli::cli_abort( 541 | c( 542 | "!" = "The {.var {param}} argument passed to \\ 543 | {.help [`set_engine()`](parsnip::set_engine)} is an alias for \\ 544 | a main model argument.", 545 | "i" = "Please instead pass this argument via the {.var {parsnip}} \\ 546 | argument to {.help [`boost_tree()`](parsnip::boost_tree)}." 547 | ), 548 | call = rlang::call2("fit") 549 | ) 550 | } 551 | } 552 | 553 | invisible(TRUE) 554 | } 555 | 556 | lightgbm_aliases <- 557 | tibble::tribble( 558 | ~parsnip, 559 | ~lightgbm, 560 | ~alias, 561 | # note that "tree_depth" -> "max_depth" has no aliases 562 | "trees", 563 | "num_iterations", 564 | "num_iteration", 565 | "trees", 566 | "num_iterations", 567 | "n_iter", 568 | "trees", 569 | "num_iterations", 570 | "num_tree", 571 | "trees", 572 | "num_iterations", 573 | "num_trees", 574 | "trees", 575 | "num_iterations", 576 | "num_round", 577 | "trees", 578 | "num_iterations", 579 | "num_rounds", 580 | "trees", 581 | "num_iterations", 582 | "nrounds", 583 | "trees", 584 | "num_iterations", 585 | "num_boost_round", 586 | "trees", 587 | "num_iterations", 588 | "n_estimators", 589 | "trees", 590 | "num_iterations", 591 | "max_iter", 592 | "learn_rate", 593 | "learning_rate", 594 | "shrinkage_rate", 595 | "learn_rate", 596 | "learning_rate", 597 | "eta", 598 | "mtry", 599 | "feature_fraction_bynode", 600 | "sub_feature_bynode", 601 | "mtry", 602 | "feature_fraction_bynode", 603 | "colsample_bynode", 604 | "min_n", 605 | "min_data_in_leaf", 606 | "min_data_per_leaf", 607 | "min_n", 608 | "min_data_in_leaf", 609 | "min_data", 610 | "min_n", 611 | "min_data_in_leaf", 612 | "min_child_samples", 613 | "min_n", 614 | "min_data_in_leaf", 615 | "min_samples_leaf", 616 | "loss_reduction", 617 | "min_gain_to_split", 618 | "min_split_gain", 619 | "sample_size", 620 | "bagging_fraction", 621 | "sub_row", 622 | "sample_size", 623 | "bagging_fraction", 624 | "subsample", 625 | "sample_size", 626 | "bagging_fraction", 627 | "bagging", 628 | "stop_iter", 629 | "early_stopping_round", 630 | "early_stopping_rounds", 631 | "stop_iter", 632 | "early_stopping_round", 633 | "early_stopping", 634 | "stop_iter", 635 | "early_stopping_round", 636 | "n_iter_no_change" 637 | ) 638 | -------------------------------------------------------------------------------- /R/lightgbm_data.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | make_boost_tree_lightgbm <- function() { 4 | parsnip::set_model_engine( 5 | model = "boost_tree", 6 | mode = "regression", 7 | eng = "lightgbm" 8 | ) 9 | 10 | parsnip::set_model_engine( 11 | model = "boost_tree", 12 | mode = "classification", 13 | eng = "lightgbm" 14 | ) 15 | 16 | parsnip::set_dependency( 17 | model = "boost_tree", 18 | eng = "lightgbm", 19 | pkg = "lightgbm", 20 | mode = "regression" 21 | ) 22 | 23 | parsnip::set_dependency( 24 | model = "boost_tree", 25 | eng = "lightgbm", 26 | pkg = "bonsai", 27 | mode = "regression" 28 | ) 29 | 30 | parsnip::set_dependency( 31 | model = "boost_tree", 32 | eng = "lightgbm", 33 | pkg = "lightgbm", 34 | mode = "classification" 35 | ) 36 | 37 | parsnip::set_dependency( 38 | model = "boost_tree", 39 | eng = "lightgbm", 40 | pkg = "bonsai", 41 | mode = "classification" 42 | ) 43 | parsnip::set_fit( 44 | model = "boost_tree", 45 | eng = "lightgbm", 46 | mode = "regression", 47 | value = list( 48 | interface = "data.frame", 49 | protect = c("x", "y", "weights"), 50 | func = c(pkg = "bonsai", fun = "train_lightgbm"), 51 | defaults = list( 52 | verbose = -1, 53 | num_threads = 0, 54 | seed = quote(sample.int(10^5, 1)), 55 | deterministic = TRUE 56 | ) 57 | ) 58 | ) 59 | 60 | parsnip::set_encoding( 61 | model = "boost_tree", 62 | mode = "regression", 63 | eng = "lightgbm", 64 | options = list( 65 | predictor_indicators = "none", 66 | compute_intercept = FALSE, 67 | remove_intercept = FALSE, 68 | allow_sparse_x = TRUE 69 | ) 70 | ) 71 | 72 | parsnip::set_pred( 73 | model = "boost_tree", 74 | eng = "lightgbm", 75 | mode = "regression", 76 | type = "numeric", 77 | value = list( 78 | pre = NULL, 79 | post = NULL, 80 | func = c(pkg = "bonsai", fun = "predict_lightgbm_regression_numeric"), 81 | args = list( 82 | object = quote(object), 83 | new_data = quote(new_data) 84 | ) 85 | ) 86 | ) 87 | 88 | parsnip::set_fit( 89 | model = "boost_tree", 90 | eng = "lightgbm", 91 | mode = "classification", 92 | value = list( 93 | interface = "data.frame", 94 | protect = c("x", "y", "weights"), 95 | func = c(pkg = "bonsai", fun = "train_lightgbm"), 96 | defaults = list( 97 | verbose = -1, 98 | num_threads = 0, 99 | seed = quote(sample.int(10^5, 1)), 100 | deterministic = TRUE 101 | ) 102 | ) 103 | ) 104 | 105 | parsnip::set_encoding( 106 | model = "boost_tree", 107 | mode = "classification", 108 | eng = "lightgbm", 109 | options = list( 110 | predictor_indicators = "none", 111 | compute_intercept = FALSE, 112 | remove_intercept = FALSE, 113 | allow_sparse_x = TRUE 114 | ) 115 | ) 116 | 117 | parsnip::set_pred( 118 | model = "boost_tree", 119 | eng = "lightgbm", 120 | mode = "classification", 121 | type = "class", 122 | value = parsnip::pred_value_template( 123 | pre = NULL, 124 | post = NULL, 125 | func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_class"), 126 | object = quote(object), 127 | new_data = quote(new_data) 128 | ) 129 | ) 130 | 131 | parsnip::set_pred( 132 | model = "boost_tree", 133 | eng = "lightgbm", 134 | mode = "classification", 135 | type = "prob", 136 | value = parsnip::pred_value_template( 137 | pre = NULL, 138 | post = NULL, 139 | func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_prob"), 140 | object = quote(object), 141 | new_data = quote(new_data) 142 | ) 143 | ) 144 | 145 | parsnip::set_pred( 146 | model = "boost_tree", 147 | eng = "lightgbm", 148 | mode = "classification", 149 | type = "raw", 150 | value = parsnip::pred_value_template( 151 | pre = NULL, 152 | post = NULL, 153 | func = c(pkg = "bonsai", fun = "predict_lightgbm_classification_raw"), 154 | object = quote(object), 155 | new_data = quote(new_data) 156 | ) 157 | ) 158 | 159 | parsnip::set_model_arg( 160 | model = "boost_tree", 161 | eng = "lightgbm", 162 | parsnip = "tree_depth", 163 | original = "max_depth", 164 | func = list(pkg = "dials", fun = "tree_depth"), 165 | has_submodel = FALSE 166 | ) 167 | 168 | parsnip::set_model_arg( 169 | model = "boost_tree", 170 | eng = "lightgbm", 171 | parsnip = "trees", 172 | original = "num_iterations", 173 | func = list(pkg = "dials", fun = "trees"), 174 | has_submodel = TRUE 175 | ) 176 | 177 | parsnip::set_model_arg( 178 | model = "boost_tree", 179 | eng = "lightgbm", 180 | parsnip = "learn_rate", 181 | original = "learning_rate", 182 | func = list(pkg = "dials", fun = "learn_rate"), 183 | has_submodel = FALSE 184 | ) 185 | 186 | parsnip::set_model_arg( 187 | model = "boost_tree", 188 | eng = "lightgbm", 189 | parsnip = "mtry", 190 | original = "feature_fraction_bynode", 191 | func = list(pkg = "dials", fun = "mtry"), 192 | has_submodel = FALSE 193 | ) 194 | 195 | parsnip::set_model_arg( 196 | model = "boost_tree", 197 | eng = "lightgbm", 198 | parsnip = "min_n", 199 | original = "min_data_in_leaf", 200 | func = list(pkg = "dials", fun = "min_n"), 201 | has_submodel = FALSE 202 | ) 203 | 204 | parsnip::set_model_arg( 205 | model = "boost_tree", 206 | eng = "lightgbm", 207 | parsnip = "loss_reduction", 208 | original = "min_gain_to_split", 209 | func = list(pkg = "dials", fun = "loss_reduction"), 210 | has_submodel = FALSE 211 | ) 212 | 213 | parsnip::set_model_arg( 214 | model = "boost_tree", 215 | eng = "lightgbm", 216 | parsnip = "sample_size", 217 | original = "bagging_fraction", 218 | func = list(pkg = "dials", fun = "sample_size"), 219 | has_submodel = FALSE 220 | ) 221 | 222 | parsnip::set_model_arg( 223 | model = "boost_tree", 224 | eng = "lightgbm", 225 | parsnip = "stop_iter", 226 | original = "early_stopping_round", 227 | func = list(pkg = "dials", fun = "stop_iter"), 228 | has_submodel = FALSE 229 | ) 230 | } 231 | 232 | # nocov end 233 | -------------------------------------------------------------------------------- /R/partykit_data.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | make_decision_tree_partykit <- function() { 4 | parsnip::set_model_engine( 5 | "decision_tree", 6 | mode = "regression", 7 | eng = "partykit" 8 | ) 9 | 10 | parsnip::set_dependency( 11 | model = "decision_tree", 12 | eng = "partykit", 13 | pkg = "partykit", 14 | mode = "regression" 15 | ) 16 | 17 | parsnip::set_dependency( 18 | model = "decision_tree", 19 | eng = "partykit", 20 | pkg = "bonsai", 21 | mode = "regression" 22 | ) 23 | 24 | parsnip::set_encoding( 25 | model = "decision_tree", 26 | eng = "partykit", 27 | mode = "regression", 28 | options = list( 29 | predictor_indicators = "none", 30 | compute_intercept = FALSE, 31 | remove_intercept = FALSE, 32 | allow_sparse_x = FALSE 33 | ) 34 | ) 35 | 36 | parsnip::set_fit( 37 | model = "decision_tree", 38 | eng = "partykit", 39 | mode = "regression", 40 | value = list( 41 | interface = "formula", 42 | protect = c("formula", "data", "weights"), 43 | func = c(pkg = "parsnip", fun = "ctree_train"), 44 | defaults = list() 45 | ) 46 | ) 47 | 48 | parsnip::set_model_arg( 49 | model = "decision_tree", 50 | eng = "partykit", 51 | parsnip = "min_n", 52 | original = "minsplit", 53 | func = list(pkg = "dials", fun = "min_n"), 54 | has_submodel = FALSE 55 | ) 56 | 57 | parsnip::set_model_arg( 58 | model = "decision_tree", 59 | eng = "partykit", 60 | parsnip = "tree_depth", 61 | original = "maxdepth", 62 | func = list(pkg = "dials", fun = "tree_depth"), 63 | has_submodel = FALSE 64 | ) 65 | 66 | parsnip::set_pred( 67 | model = "decision_tree", 68 | eng = "partykit", 69 | mode = "regression", 70 | type = "numeric", 71 | value = list( 72 | pre = NULL, 73 | post = NULL, 74 | func = c(fun = "predict"), 75 | args = list( 76 | object = rlang::expr(object$fit), 77 | newdata = rlang::expr(new_data), 78 | type = "response" 79 | ) 80 | ) 81 | ) 82 | 83 | # ---------------------------------------------------------------------------- 84 | 85 | parsnip::set_model_engine( 86 | "decision_tree", 87 | mode = "classification", 88 | eng = "partykit" 89 | ) 90 | 91 | parsnip::set_dependency( 92 | model = "decision_tree", 93 | eng = "partykit", 94 | pkg = "partykit", 95 | mode = "classification" 96 | ) 97 | parsnip::set_dependency( 98 | model = "decision_tree", 99 | eng = "partykit", 100 | pkg = "bonsai", 101 | mode = "classification" 102 | ) 103 | 104 | parsnip::set_encoding( 105 | model = "decision_tree", 106 | eng = "partykit", 107 | mode = "classification", 108 | options = list( 109 | predictor_indicators = "none", 110 | compute_intercept = FALSE, 111 | remove_intercept = FALSE, 112 | allow_sparse_x = FALSE 113 | ) 114 | ) 115 | 116 | parsnip::set_fit( 117 | model = "decision_tree", 118 | eng = "partykit", 119 | mode = "classification", 120 | value = list( 121 | interface = "formula", 122 | protect = c("formula", "data", "weights"), 123 | func = c(pkg = "parsnip", fun = "ctree_train"), 124 | defaults = list() 125 | ) 126 | ) 127 | 128 | parsnip::set_model_arg( 129 | model = "decision_tree", 130 | eng = "partykit", 131 | parsnip = "min_n", 132 | original = "minsplit", 133 | func = list(pkg = "dials", fun = "min_n"), 134 | has_submodel = FALSE 135 | ) 136 | 137 | parsnip::set_model_arg( 138 | model = "decision_tree", 139 | eng = "partykit", 140 | parsnip = "tree_depth", 141 | original = "maxdepth", 142 | func = list(pkg = "dials", fun = "tree_depth"), 143 | has_submodel = FALSE 144 | ) 145 | 146 | parsnip::set_pred( 147 | model = "decision_tree", 148 | eng = "partykit", 149 | mode = "classification", 150 | type = "class", 151 | value = list( 152 | pre = NULL, 153 | post = NULL, 154 | func = c(fun = "predict"), 155 | args = list( 156 | object = rlang::expr(object$fit), 157 | newdata = rlang::expr(new_data), 158 | type = "response" 159 | ) 160 | ) 161 | ) 162 | 163 | parsnip::set_pred( 164 | model = "decision_tree", 165 | eng = "partykit", 166 | mode = "classification", 167 | type = "prob", 168 | value = list( 169 | pre = NULL, 170 | post = function(result, object) tibble::as_tibble(result), 171 | func = c(fun = "predict"), 172 | args = list( 173 | object = rlang::expr(object$fit), 174 | newdata = rlang::expr(new_data), 175 | type = "prob" 176 | ) 177 | ) 178 | ) 179 | } 180 | 181 | make_rand_forest_partykit <- function() { 182 | parsnip::set_model_engine( 183 | "rand_forest", 184 | mode = "regression", 185 | eng = "partykit" 186 | ) 187 | 188 | parsnip::set_dependency( 189 | "rand_forest", 190 | eng = "partykit", 191 | pkg = "partykit", 192 | mode = "regression" 193 | ) 194 | parsnip::set_dependency( 195 | "rand_forest", 196 | eng = "partykit", 197 | pkg = "bonsai", 198 | mode = "regression" 199 | ) 200 | 201 | parsnip::set_encoding( 202 | model = "rand_forest", 203 | eng = "partykit", 204 | mode = "regression", 205 | options = list( 206 | predictor_indicators = "none", 207 | compute_intercept = FALSE, 208 | remove_intercept = FALSE, 209 | allow_sparse_x = FALSE 210 | ) 211 | ) 212 | 213 | parsnip::set_fit( 214 | model = "rand_forest", 215 | eng = "partykit", 216 | mode = "regression", 217 | value = list( 218 | interface = "formula", 219 | protect = c("formula", "data", "weights"), 220 | func = c(pkg = "parsnip", fun = "cforest_train"), 221 | defaults = list() 222 | ) 223 | ) 224 | 225 | parsnip::set_model_arg( 226 | model = "rand_forest", 227 | eng = "partykit", 228 | parsnip = "min_n", 229 | original = "minsplit", 230 | func = list(pkg = "dials", fun = "min_n"), 231 | has_submodel = FALSE 232 | ) 233 | 234 | parsnip::set_model_arg( 235 | model = "rand_forest", 236 | eng = "partykit", 237 | parsnip = "mtry", 238 | original = "mtry", 239 | func = list(pkg = "dials", fun = "mtry"), 240 | has_submodel = FALSE 241 | ) 242 | 243 | parsnip::set_model_arg( 244 | model = "rand_forest", 245 | eng = "partykit", 246 | parsnip = "trees", 247 | original = "ntree", 248 | func = list(pkg = "dials", fun = "trees"), 249 | has_submodel = FALSE 250 | ) 251 | 252 | parsnip::set_pred( 253 | model = "rand_forest", 254 | eng = "partykit", 255 | mode = "regression", 256 | type = "numeric", 257 | value = list( 258 | pre = NULL, 259 | post = NULL, 260 | func = c(fun = "predict"), 261 | args = list( 262 | object = rlang::expr(object$fit), 263 | newdata = rlang::expr(new_data), 264 | type = "response" 265 | ) 266 | ) 267 | ) 268 | 269 | # ---------------------------------------------------------------------------- 270 | 271 | parsnip::set_model_engine( 272 | "rand_forest", 273 | mode = "classification", 274 | eng = "partykit" 275 | ) 276 | 277 | parsnip::set_dependency( 278 | "rand_forest", 279 | eng = "partykit", 280 | pkg = "partykit", 281 | mode = "classification" 282 | ) 283 | parsnip::set_dependency( 284 | "rand_forest", 285 | eng = "partykit", 286 | pkg = "bonsai", 287 | mode = "classification" 288 | ) 289 | 290 | parsnip::set_encoding( 291 | model = "rand_forest", 292 | eng = "partykit", 293 | mode = "classification", 294 | options = list( 295 | predictor_indicators = "none", 296 | compute_intercept = FALSE, 297 | remove_intercept = FALSE, 298 | allow_sparse_x = FALSE 299 | ) 300 | ) 301 | 302 | parsnip::set_fit( 303 | model = "rand_forest", 304 | eng = "partykit", 305 | mode = "classification", 306 | value = list( 307 | interface = "formula", 308 | protect = c("formula", "data", "weights"), 309 | func = c(pkg = "parsnip", fun = "cforest_train"), 310 | defaults = list() 311 | ) 312 | ) 313 | 314 | parsnip::set_model_arg( 315 | model = "rand_forest", 316 | eng = "partykit", 317 | parsnip = "min_n", 318 | original = "minsplit", 319 | func = list(pkg = "dials", fun = "min_n"), 320 | has_submodel = FALSE 321 | ) 322 | 323 | parsnip::set_model_arg( 324 | model = "rand_forest", 325 | eng = "partykit", 326 | parsnip = "tree_depth", 327 | original = "maxdepth", 328 | func = list(pkg = "dials", fun = "tree_depth"), 329 | has_submodel = FALSE 330 | ) 331 | 332 | parsnip::set_pred( 333 | model = "rand_forest", 334 | eng = "partykit", 335 | mode = "classification", 336 | type = "class", 337 | value = list( 338 | pre = NULL, 339 | post = NULL, 340 | func = c(fun = "predict"), 341 | args = list( 342 | object = rlang::expr(object$fit), 343 | newdata = rlang::expr(new_data), 344 | type = "response" 345 | ) 346 | ) 347 | ) 348 | 349 | parsnip::set_pred( 350 | model = "rand_forest", 351 | eng = "partykit", 352 | mode = "classification", 353 | type = "prob", 354 | value = list( 355 | pre = NULL, 356 | post = function(result, object) tibble::as_tibble(result), 357 | func = c(fun = "predict"), 358 | args = list( 359 | object = rlang::expr(object$fit), 360 | newdata = rlang::expr(new_data), 361 | type = "prob" 362 | ) 363 | ) 364 | ) 365 | } 366 | 367 | # nocov end 368 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | # nocov start 2 | 3 | # The functions below define the model information. These access the model 4 | # environment inside of parsnip so they have to be executed once parsnip has 5 | # been loaded. 6 | 7 | .onLoad <- function(libname, pkgname) { 8 | make_boost_tree_lightgbm() 9 | 10 | make_decision_tree_partykit() 11 | make_rand_forest_partykit() 12 | 13 | make_rand_forest_aorsf() 14 | } 15 | 16 | # nocov end 17 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # bonsai 17 | 18 | 19 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) 20 | [![CRAN status](https://www.r-pkg.org/badges/version/bonsai)](https://CRAN.R-project.org/package=bonsai) 21 | [![Codecov test coverage](https://codecov.io/gh/tidymodels/bonsai/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/bonsai?branch=main) 22 | [![R-CMD-check](https://github.com/tidymodels/bonsai/workflows/R-CMD-check/badge.svg)](https://github.com/tidymodels/bonsai/actions) 23 | 24 | 25 | bonsai provides bindings for additional tree-based model engines for use with the [parsnip](https://parsnip.tidymodels.org/) package. 26 | 27 | This package is based off of the work done in the [treesnip repository](https://github.com/curso-r/treesnip) by Athos Damiani, Daniel Falbel, and Roel Hogervorst. bonsai is the official CRAN version of the package; new development will reside here. 28 | 29 | ## Installation 30 | 31 | You can install the most recent official release of bonsai with: 32 | 33 | ``` r 34 | install.packages("bonsai") 35 | ``` 36 | 37 | You can install the development version of bonsai from [GitHub](https://github.com/) with: 38 | 39 | ``` r 40 | # install.packages("pak") 41 | pak::pak("tidymodels/bonsai") 42 | ``` 43 | 44 | 45 | ## Available Engines 46 | 47 | The bonsai package provides additional engines for the models in the following table: 48 | 49 | ```{r, echo = FALSE, message = FALSE} 50 | library(parsnip) 51 | 52 | parsnip_models <- 53 | setNames(nm = get_from_env("models")) |> 54 | purrr::map_dfr(get_from_env, .id = "model") 55 | 56 | library(bonsai) 57 | 58 | bonsai_models <- 59 | setNames(nm = get_from_env("models")) |> 60 | purrr::map_dfr(get_from_env, .id = "model") 61 | 62 | dplyr::anti_join( 63 | bonsai_models, parsnip_models, 64 | by = c("model", "engine", "mode") 65 | ) |> 66 | knitr::kable() 67 | ``` 68 | 69 | ## Code of Conduct 70 | 71 | Please note that the bonsai project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # bonsai 5 | 6 | 7 | 8 | [![Lifecycle: 9 | experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) 10 | [![CRAN 11 | status](https://www.r-pkg.org/badges/version/bonsai)](https://CRAN.R-project.org/package=bonsai) 12 | [![Codecov test 13 | coverage](https://codecov.io/gh/tidymodels/bonsai/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tidymodels/bonsai?branch=main) 14 | [![R-CMD-check](https://github.com/tidymodels/bonsai/workflows/R-CMD-check/badge.svg)](https://github.com/tidymodels/bonsai/actions) 15 | 16 | 17 | bonsai provides bindings for additional tree-based model engines for use 18 | with the [parsnip](https://parsnip.tidymodels.org/) package. 19 | 20 | This package is based off of the work done in the [treesnip 21 | repository](https://github.com/curso-r/treesnip) by Athos Damiani, 22 | Daniel Falbel, and Roel Hogervorst. bonsai is the official CRAN version 23 | of the package; new development will reside here. 24 | 25 | ## Installation 26 | 27 | You can install the most recent official release of bonsai with: 28 | 29 | ``` r 30 | install.packages("bonsai") 31 | ``` 32 | 33 | You can install the development version of bonsai from 34 | [GitHub](https://github.com/) with: 35 | 36 | ``` r 37 | # install.packages("pak") 38 | pak::pak("tidymodels/bonsai") 39 | ``` 40 | 41 | ## Available Engines 42 | 43 | The bonsai package provides additional engines for the models in the 44 | following table: 45 | 46 | | model | engine | mode | 47 | |:--------------|:---------|:---------------| 48 | | boost_tree | lightgbm | regression | 49 | | boost_tree | lightgbm | classification | 50 | | decision_tree | partykit | regression | 51 | | decision_tree | partykit | classification | 52 | | rand_forest | partykit | regression | 53 | | rand_forest | partykit | classification | 54 | | rand_forest | aorsf | classification | 55 | | rand_forest | aorsf | regression | 56 | 57 | 58 | ## Code of Conduct 59 | 60 | Please note that the bonsai project is released with a [Contributor Code 61 | of 62 | Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). 63 | By contributing to this project, you agree to abide by its terms. 64 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://bonsai.tidymodels.org/ 2 | 3 | template: 4 | package: tidytemplate 5 | bootstrap: 5 6 | bslib: 7 | danger: "#CA225E" 8 | primary: "#CA225E" 9 | includes: 10 | in_header: | 11 | 12 | development: 13 | mode: auto 14 | 15 | figures: 16 | fig.width: 8 17 | fig.height: 5.75 18 | -------------------------------------------------------------------------------- /air.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/air.toml -------------------------------------------------------------------------------- /bonsai.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | PackageRoxygenize: rd,collate,namespace 22 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check results 2 | 3 | I see no ERRORs, WARNINGs, or NOTEs. 4 | 5 | We checked 5 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package, and saw no new problems. 6 | -------------------------------------------------------------------------------- /inst/figs/hex.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 13 | 14 | 15 | 16 | 17 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 122 | 125 | 129 | 134 | 139 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /man/bonsai-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/bonsai_package.R 3 | \docType{package} 4 | \name{bonsai-package} 5 | \alias{bonsai-package} 6 | \alias{bonsai} 7 | \title{bonsai: Model Wrappers for Tree-Based Models} 8 | \description{ 9 | \if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} 10 | 11 | Bindings for additional tree-based model engines for use with the 'parsnip' package. Models include gradient boosted decision trees with 'LightGBM' (Ke et al, 2017.), conditional inference trees and conditional random forests with 'partykit' (Hothorn and Zeileis, 2015. and Hothorn et al, 2006. \doi{10.1198/106186006X133933}), and accelerated oblique random forests with 'aorsf' (Jaeger et al, 2022 \doi{10.5281/zenodo.7116854}). 12 | } 13 | \seealso{ 14 | Useful links: 15 | \itemize{ 16 | \item \url{https://bonsai.tidymodels.org/} 17 | \item \url{https://github.com/tidymodels/bonsai} 18 | \item Report bugs at \url{https://github.com/tidymodels/bonsai/issues} 19 | } 20 | 21 | } 22 | \author{ 23 | \strong{Maintainer}: Simon Couch \email{simon.couch@posit.co} (\href{https://orcid.org/0000-0001-5676-5107}{ORCID}) 24 | 25 | Authors: 26 | \itemize{ 27 | \item Daniel Falbel \email{dfalbel@curso-r.com} 28 | \item Athos Damiani \email{adamiani@curso-r.com} 29 | \item Roel M. Hogervorst \email{hogervorst.rm@gmail.com} (\href{https://orcid.org/0000-0001-7509-0328}{ORCID}) 30 | \item Max Kuhn \email{max@posit.co} (\href{https://orcid.org/0000-0003-2402-136X}{ORCID}) 31 | } 32 | 33 | Other contributors: 34 | \itemize{ 35 | \item Posit Software, PBC (03wc8by49) [copyright holder, funder] 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/man/figures/logo.png -------------------------------------------------------------------------------- /man/lightgbm_helpers.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lightgbm.R 3 | \name{predict_lightgbm_classification_prob} 4 | \alias{predict_lightgbm_classification_prob} 5 | \alias{predict_lightgbm_classification_class} 6 | \alias{predict_lightgbm_classification_raw} 7 | \alias{predict_lightgbm_regression_numeric} 8 | \alias{multi_predict._lgb.Booster} 9 | \title{Internal functions} 10 | \usage{ 11 | predict_lightgbm_classification_prob(object, new_data, ...) 12 | 13 | predict_lightgbm_classification_class(object, new_data, ...) 14 | 15 | predict_lightgbm_classification_raw(object, new_data, ...) 16 | 17 | predict_lightgbm_regression_numeric(object, new_data, ...) 18 | 19 | \method{multi_predict}{`_lgb.Booster`}(object, new_data, type = NULL, trees = NULL, ...) 20 | } 21 | \description{ 22 | Not intended for direct use. 23 | } 24 | \keyword{internal} 25 | -------------------------------------------------------------------------------- /man/reexports.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/0_imports.R 3 | \docType{import} 4 | \name{reexports} 5 | \alias{reexports} 6 | \alias{\%>\%} 7 | \title{Objects exported from other packages} 8 | \keyword{internal} 9 | \description{ 10 | These objects are imported from other packages. Follow the links 11 | below to see their documentation. 12 | 13 | \describe{ 14 | \item{parsnip}{\code{\link[parsnip:reexports]{\%>\%}}} 15 | }} 16 | 17 | -------------------------------------------------------------------------------- /man/train_lightgbm.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/lightgbm.R 3 | \name{train_lightgbm} 4 | \alias{train_lightgbm} 5 | \title{Boosted trees with lightgbm} 6 | \usage{ 7 | train_lightgbm( 8 | x, 9 | y, 10 | weights = NULL, 11 | max_depth = -1, 12 | num_iterations = 100, 13 | learning_rate = 0.1, 14 | feature_fraction_bynode = 1, 15 | min_data_in_leaf = 20, 16 | min_gain_to_split = 0, 17 | bagging_fraction = 1, 18 | early_stopping_round = NULL, 19 | validation = 0, 20 | counts = TRUE, 21 | quiet = FALSE, 22 | ... 23 | ) 24 | } 25 | \arguments{ 26 | \item{x}{A data frame or matrix of predictors} 27 | 28 | \item{y}{A vector (factor or numeric) or matrix (numeric) of outcome data.} 29 | 30 | \item{weights}{A numeric vector of sample weights.} 31 | 32 | \item{max_depth}{An integer for the maximum depth of the tree.} 33 | 34 | \item{num_iterations}{An integer for the number of boosting iterations.} 35 | 36 | \item{learning_rate}{A numeric value between zero and one to control the learning rate.} 37 | 38 | \item{feature_fraction_bynode}{Fraction of predictors that will be randomly sampled 39 | at each split.} 40 | 41 | \item{min_data_in_leaf}{A numeric value for the minimum sum of instances needed 42 | in a child to continue to split.} 43 | 44 | \item{min_gain_to_split}{A number for the minimum loss reduction required to make a 45 | further partition on a leaf node of the tree.} 46 | 47 | \item{bagging_fraction}{Subsampling proportion of rows. Setting this argument 48 | to a non-default value will also set \code{bagging_freq = 1}. See the Bagging 49 | section in \code{?details_boost_tree_lightgbm} for more details.} 50 | 51 | \item{early_stopping_round}{Number of iterations without an improvement in 52 | the objective function occur before training should be halted.} 53 | 54 | \item{validation}{The \emph{proportion} of the training data that are used for 55 | performance assessment and potential early stopping.} 56 | 57 | \item{counts}{A logical; should \code{feature_fraction_bynode} be interpreted as the 58 | \emph{number} of predictors that will be randomly sampled at each split? 59 | \code{TRUE} indicates that \code{mtry} will be interpreted in its sense as a \emph{count}, 60 | \code{FALSE} indicates that the argument will be interpreted in its sense as a 61 | \emph{proportion}.} 62 | 63 | \item{quiet}{A logical; should logging by \code{\link[lightgbm:lgb.train]{lightgbm::lgb.train()}} be muted?} 64 | 65 | \item{...}{Other options to pass to \code{\link[lightgbm:lgb.train]{lightgbm::lgb.train()}}. Arguments 66 | will be correctly routed to the \code{param} argument, or as a main argument, 67 | depending on their name.} 68 | } 69 | \value{ 70 | A fitted \code{lightgbm.Model} object. 71 | } 72 | \description{ 73 | \code{train_lightgbm} is a wrapper for \code{lightgbm} tree-based models 74 | where all of the model arguments are in the main function. 75 | } 76 | \details{ 77 | This is an internal function, not meant to be directly called by the user. 78 | } 79 | \keyword{internal} 80 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tidymodels/bonsai/4add5af4fa4cfcdf895851ec27fd6497b8c84e6f/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | library(bonsai) 3 | 4 | test_check("bonsai") 5 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/lightgbm.md: -------------------------------------------------------------------------------- 1 | # boost_tree with lightgbm 2 | 3 | Code 4 | set_mode(set_engine(boost_tree(), "lightgbm"), "regression") 5 | Output 6 | Boosted Tree Model Specification (regression) 7 | 8 | Computational engine: lightgbm 9 | 10 | 11 | --- 12 | 13 | Code 14 | set_mode(set_engine(boost_tree(), "lightgbm", nrounds = 100), "classification") 15 | Output 16 | Boosted Tree Model Specification (classification) 17 | 18 | Engine-Specific Arguments: 19 | nrounds = 100 20 | 21 | Computational engine: lightgbm 22 | 23 | 24 | # bonsai handles mtry vs mtry_prop gracefully 25 | 26 | `mtry` must be greater than or equal to 1, not 0.5. 27 | i `mtry` is currently being interpreted as a count rather than a proportion. 28 | i Supply `counts = FALSE` to `set_engine()` to supply this argument as a proportion rather than a count. 29 | i See `?train_lightgbm()` for more details. 30 | 31 | --- 32 | 33 | `mtry` must be less than or equal to 1, not 3. 34 | i `mtry` is currently being interpreted as a proportion rather than a count. 35 | i Supply `counts = TRUE` to `set_engine()` to supply this argument as a count rather than a proportion. 36 | i See `?train_lightgbm()` for more details. 37 | 38 | # tuning mtry vs mtry_prop 39 | 40 | Code 41 | fit(set_mode(set_engine(boost_tree(mtry = tune::tune()), "lightgbm"), 42 | "regression"), bill_length_mm ~ ., data = penguins) 43 | Condition 44 | Error in `fit()`: 45 | ! `feature_fraction_bynode` must be a number, not a call. 46 | 47 | # training wrapper warns on protected arguments 48 | 49 | Code 50 | .res <- fit(set_mode(set_engine(boost_tree(), "lightgbm", colnames = paste0("X", 51 | 1:ncol(penguins))), "regression"), bill_length_mm ~ ., data = penguins) 52 | Condition 53 | Warning: 54 | The following argument(s) are guarded by bonsai and will not be passed to LightGBM: colnames 55 | 56 | --- 57 | 58 | Code 59 | .res <- fit(set_mode(set_engine(boost_tree(), "lightgbm", colnames = paste0("X", 60 | 1:ncol(penguins)), callbacks = list(p = print)), "regression"), 61 | bill_length_mm ~ ., data = penguins) 62 | Condition 63 | Warning: 64 | The following argument(s) are guarded by bonsai and will not be passed to LightGBM: colnames, callbacks 65 | 66 | --- 67 | 68 | Code 69 | .res <- fit(set_mode(set_engine(boost_tree(), "lightgbm", colnames = paste0("X", 70 | 1:ncol(penguins))), "regression"), bill_length_mm ~ ., data = penguins) 71 | Condition 72 | Warning: 73 | The following argument(s) are guarded by bonsai and will not be passed to LightGBM: colnames 74 | 75 | --- 76 | 77 | Code 78 | fit(set_mode(set_engine(boost_tree(), "lightgbm", n_iter = 10), "regression"), 79 | bill_length_mm ~ ., data = penguins) 80 | Condition 81 | Error in `fit()`: 82 | ! The `n_iter` argument passed to `set_engine()` (`?parsnip::set_engine()`) is an alias for a main model argument. 83 | i Please instead pass this argument via the `trees` argument to `boost_tree()` (`?parsnip::boost_tree()`). 84 | 85 | --- 86 | 87 | Code 88 | fit(set_mode(set_engine(boost_tree(), "lightgbm", num_tree = 10), "regression"), 89 | bill_length_mm ~ ., data = penguins) 90 | Condition 91 | Error in `fit()`: 92 | ! The `num_tree` argument passed to `set_engine()` (`?parsnip::set_engine()`) is an alias for a main model argument. 93 | i Please instead pass this argument via the `trees` argument to `boost_tree()` (`?parsnip::boost_tree()`). 94 | 95 | --- 96 | 97 | Code 98 | fit(set_mode(set_engine(boost_tree(), "lightgbm", min_split_gain = 2), 99 | "regression"), bill_length_mm ~ ., data = penguins) 100 | Condition 101 | Error in `fit()`: 102 | ! The `min_split_gain` argument passed to `set_engine()` (`?parsnip::set_engine()`) is an alias for a main model argument. 103 | i Please instead pass this argument via the `loss_reduction` argument to `boost_tree()` (`?parsnip::boost_tree()`). 104 | 105 | --- 106 | 107 | Code 108 | fit(set_mode(set_engine(boost_tree(), "lightgbm", min_split_gain = 2, 109 | lambda_l2 = 0.5), "regression"), bill_length_mm ~ ., data = penguins) 110 | Condition 111 | Error in `fit()`: 112 | ! The `min_split_gain` argument passed to `set_engine()` (`?parsnip::set_engine()`) is an alias for a main model argument. 113 | i Please instead pass this argument via the `loss_reduction` argument to `boost_tree()` (`?parsnip::boost_tree()`). 114 | 115 | -------------------------------------------------------------------------------- /tests/testthat/_snaps/partykit.md: -------------------------------------------------------------------------------- 1 | # condition inference trees 2 | 3 | Code 4 | set_mode(set_engine(decision_tree(), "partykit"), "regression") 5 | Output 6 | Decision Tree Model Specification (regression) 7 | 8 | Computational engine: partykit 9 | 10 | 11 | --- 12 | 13 | Code 14 | set_mode(set_engine(decision_tree(), "partykit", teststat = "maximum"), 15 | "classification") 16 | Output 17 | Decision Tree Model Specification (classification) 18 | 19 | Engine-Specific Arguments: 20 | teststat = maximum 21 | 22 | Computational engine: partykit 23 | 24 | 25 | # condition inference forests 26 | 27 | Code 28 | set_mode(set_engine(rand_forest(), "partykit"), "regression") 29 | Output 30 | Random Forest Model Specification (regression) 31 | 32 | Computational engine: partykit 33 | 34 | 35 | --- 36 | 37 | Code 38 | set_mode(set_engine(rand_forest(), "partykit", teststat = "maximum"), 39 | "classification") 40 | Output 41 | Random Forest Model Specification (classification) 42 | 43 | Engine-Specific Arguments: 44 | teststat = maximum 45 | 46 | Computational engine: partykit 47 | 48 | 49 | -------------------------------------------------------------------------------- /tests/testthat/test-aorsf.R: -------------------------------------------------------------------------------- 1 | library(testthat) 2 | 3 | withr::local_envvar("OMP_THREAD_LIMIT" = 1) 4 | 5 | mtcars_orsf <- mtcars 6 | mtcars_orsf$vs <- factor(mtcars_orsf$vs) 7 | mtcars_na <- mtcars_orsf 8 | mtcars_na$cyl[1] <- NA 9 | 10 | test_that("regression model object", { 11 | skip_if_not_installed("aorsf") 12 | 13 | set.seed(321) 14 | wts <- sample(0:5, size = nrow(mtcars_orsf), replace = TRUE) 15 | 16 | set.seed(1234) 17 | aorsf_regr_fit <- 18 | aorsf::orsf( 19 | # everyone's favorite 20 | data = mtcars_orsf, 21 | formula = mpg ~ ., 22 | # faster 23 | n_tree = 10, 24 | # requested default from tidymodels 25 | n_thread = 1 26 | ) 27 | 28 | set.seed(1234) 29 | aorsf_regr_fit_wtd <- aorsf::orsf_update(aorsf_regr_fit, weights = wts) 30 | 31 | # formula method 32 | regr_spec <- 33 | rand_forest(trees = 10) |> 34 | set_engine("aorsf") |> 35 | set_mode("regression") 36 | 37 | set.seed(1234) 38 | expect_no_condition( 39 | bonsai_regr_fit <- 40 | fit( 41 | regr_spec, 42 | data = mtcars_orsf, 43 | formula = mpg ~ . 44 | ) 45 | ) 46 | 47 | set.seed(1234) 48 | expect_no_condition( 49 | bonsai_regr_fit_wtd <- 50 | fit( 51 | regr_spec, 52 | data = mtcars_orsf, 53 | formula = mpg ~ ., 54 | case_weights = importance_weights(wts) 55 | ) 56 | ) 57 | 58 | expect_equal( 59 | bonsai_regr_fit$fit, 60 | aorsf_regr_fit, 61 | ignore_formula_env = TRUE 62 | ) 63 | 64 | expect_equal( 65 | bonsai_regr_fit_wtd$fit, 66 | aorsf_regr_fit_wtd, 67 | ignore_formula_env = TRUE 68 | ) 69 | }) 70 | 71 | test_that("classification model object", { 72 | skip_if_not_installed("aorsf") 73 | 74 | set.seed(321) 75 | wts <- sample(0:5, size = nrow(mtcars_orsf), replace = TRUE) 76 | 77 | set.seed(1234) 78 | aorsf_clsf_fit <- 79 | aorsf::orsf( 80 | data = mtcars_orsf, 81 | formula = vs ~ ., 82 | n_tree = 10, 83 | n_thread = 1 84 | ) 85 | 86 | set.seed(1234) 87 | aorsf_clsf_fit_wtd <- aorsf::orsf_update(aorsf_clsf_fit, weights = wts) 88 | 89 | # formula method 90 | clsf_spec <- rand_forest(trees = 10) |> 91 | set_engine("aorsf") |> 92 | set_mode("classification") 93 | 94 | set.seed(1234) 95 | expect_no_condition( 96 | bonsai_clsf_fit <- 97 | fit( 98 | clsf_spec, 99 | data = mtcars_orsf, 100 | formula = vs ~ . 101 | ) 102 | ) 103 | 104 | set.seed(1234) 105 | expect_no_condition( 106 | bonsai_clsf_fit_wtd <- 107 | fit( 108 | clsf_spec, 109 | data = mtcars_orsf, 110 | formula = vs ~ ., 111 | case_weights = importance_weights(wts) 112 | ) 113 | ) 114 | 115 | expect_equal( 116 | bonsai_clsf_fit$fit, 117 | aorsf_clsf_fit, 118 | ignore_formula_env = TRUE 119 | ) 120 | 121 | expect_equal( 122 | bonsai_clsf_fit_wtd$fit, 123 | aorsf_clsf_fit_wtd, 124 | ignore_formula_env = TRUE 125 | ) 126 | }) 127 | 128 | test_that("regression predictions", { 129 | skip_if_not_installed("aorsf") 130 | 131 | set.seed(1234) 132 | aorsf_regr_fit <- 133 | aorsf::orsf( 134 | data = mtcars_orsf, 135 | formula = mpg ~ ., 136 | n_tree = 10 137 | ) 138 | 139 | aorsf_regr_pred <- 140 | predict( 141 | aorsf_regr_fit, 142 | new_data = mtcars_na, 143 | na_action = 'pass' 144 | ) 145 | 146 | # formula method 147 | regr_spec <- 148 | rand_forest(trees = 10) |> 149 | set_engine("aorsf") |> 150 | set_mode("regression") 151 | 152 | set.seed(1234) 153 | bonsai_regr_fit <- fit(regr_spec, mpg ~ ., data = mtcars_orsf) 154 | bonsai_regr_pred <- predict(bonsai_regr_fit, new_data = mtcars_na) 155 | 156 | expect_s3_class(bonsai_regr_pred, "tbl_df") 157 | expect_true(all(names(bonsai_regr_pred) == ".pred")) 158 | expect_equal(bonsai_regr_pred$.pred, as.vector(aorsf_regr_pred)) 159 | expect_equal(nrow(bonsai_regr_pred), nrow(mtcars_orsf)) 160 | 161 | # single observation 162 | pred_1row <- predict(bonsai_regr_fit, mtcars_orsf[2, ]) 163 | expect_identical(nrow(pred_1row), 1L) 164 | }) 165 | 166 | test_that("classification predictions", { 167 | skip_if_not_installed("aorsf") 168 | 169 | set.seed(1234) 170 | aorsf_clsf_fit <- 171 | aorsf::orsf( 172 | data = mtcars_orsf, 173 | formula = vs ~ ., 174 | n_tree = 10 175 | ) 176 | 177 | aorsf_clsf_pred <- 178 | predict( 179 | aorsf_clsf_fit, 180 | new_data = mtcars_na, 181 | pred_type = 'prob', 182 | na_action = 'pass' 183 | ) 184 | 185 | aorsf_probs <- aorsf_clsf_pred 186 | 187 | # see #78--do not expect predictions to align exactly 188 | aorsf_class <- colnames(aorsf_probs)[apply( 189 | aorsf_clsf_pred[-1, ], 190 | 1, 191 | which.max 192 | )] 193 | # inserting the NA from first row 194 | aorsf_class <- c(NA_character_, aorsf_class) 195 | 196 | # formula method 197 | clsf_spec <- rand_forest(trees = 10) |> 198 | set_engine("aorsf") |> 199 | set_mode("classification") 200 | 201 | set.seed(1234) 202 | bonsai_clsf_fit <- fit(clsf_spec, vs ~ ., data = mtcars_orsf) 203 | 204 | bonsai_clsf_pred_prob <- 205 | predict( 206 | bonsai_clsf_fit, 207 | new_data = mtcars_na, 208 | type = 'prob' 209 | ) 210 | 211 | bonsai_clsf_pred_class <- 212 | predict( 213 | bonsai_clsf_fit, 214 | new_data = mtcars_na, 215 | type = 'class' 216 | ) 217 | 218 | expect_s3_class(bonsai_clsf_pred_prob, "tbl_df") 219 | expect_s3_class(bonsai_clsf_pred_class, "tbl_df") 220 | 221 | expect_true(all(names(bonsai_clsf_pred_prob) == c(".pred_0", ".pred_1"))) 222 | expect_true(all(names(bonsai_clsf_pred_class) == ".pred_class")) 223 | 224 | expect_equal(bonsai_clsf_pred_prob$.pred_0, as.vector(aorsf_probs[, 1])) 225 | expect_equal(bonsai_clsf_pred_prob$.pred_1, as.vector(aorsf_probs[, 2])) 226 | 227 | expect_equal(bonsai_clsf_pred_class$.pred_class, factor(aorsf_class)) 228 | 229 | expect_equal(nrow(bonsai_clsf_pred_prob), nrow(mtcars_orsf)) 230 | expect_equal(nrow(bonsai_clsf_pred_class), nrow(mtcars_orsf)) 231 | 232 | # single observation 233 | pred_1row <- predict(bonsai_clsf_fit, mtcars_orsf[2, ]) 234 | expect_identical(nrow(pred_1row), 1L) 235 | }) 236 | -------------------------------------------------------------------------------- /tests/testthat/test-lightgbm.R: -------------------------------------------------------------------------------- 1 | withr::local_envvar("OMP_THREAD_LIMIT" = 1) 2 | 3 | test_that("boost_tree with lightgbm", { 4 | skip_if_not_installed("lightgbm") 5 | skip_if_not_installed("modeldata") 6 | 7 | suppressPackageStartupMessages({ 8 | library(lightgbm) 9 | library(dplyr) 10 | }) 11 | 12 | data("penguins", package = "modeldata") 13 | 14 | penguins <- penguins[complete.cases(penguins), ] 15 | 16 | expect_snapshot( 17 | boost_tree() |> set_engine("lightgbm") |> set_mode("regression") 18 | ) 19 | expect_snapshot( 20 | boost_tree() |> 21 | set_engine("lightgbm", nrounds = 100) |> 22 | set_mode("classification") 23 | ) 24 | 25 | # regression ----------------------------------------------------------------- 26 | expect_no_error({ 27 | pars_fit_1 <- 28 | boost_tree() |> 29 | set_engine("lightgbm") |> 30 | set_mode("regression") |> 31 | fit(bill_length_mm ~ ., data = penguins) 32 | }) 33 | 34 | expect_no_error({ 35 | pars_preds_1 <- 36 | predict(pars_fit_1, penguins) 37 | }) 38 | 39 | peng <- 40 | penguins |> 41 | mutate(across(where(is.character), \(x) as.factor(x))) |> 42 | mutate(across(where(is.factor), \(x) as.integer(x) - 1)) 43 | 44 | peng_y <- peng$bill_length_mm 45 | 46 | peng_m <- peng |> 47 | select(-bill_length_mm) |> 48 | as.matrix() 49 | 50 | peng_x <- 51 | lgb.Dataset( 52 | data = peng_m, 53 | label = peng_y, 54 | params = list(feature_pre_filter = FALSE), 55 | categorical_feature = c(1L, 2L, 6L) 56 | ) 57 | 58 | params_1 <- list( 59 | objective = "regression" 60 | ) 61 | 62 | lgbm_fit_1 <- 63 | lightgbm::lgb.train( 64 | data = peng_x, 65 | params = params_1, 66 | verbose = -1 67 | ) 68 | 69 | lgbm_preds_1 <- predict(lgbm_fit_1, peng_m) 70 | 71 | expect_equal(pars_preds_1$.pred, lgbm_preds_1) 72 | 73 | # regression, adjusting a primary argument 74 | expect_no_error({ 75 | pars_fit_2 <- 76 | boost_tree(trees = 20) |> 77 | set_engine("lightgbm") |> 78 | set_mode("regression") |> 79 | fit(bill_length_mm ~ ., data = penguins) 80 | }) 81 | 82 | expect_no_error({ 83 | pars_preds_2 <- 84 | predict(pars_fit_2, penguins) 85 | }) 86 | 87 | params_2 <- list( 88 | objective = "regression", 89 | num_iterations = 20 90 | ) 91 | 92 | lgbm_fit_2 <- 93 | lightgbm::lgb.train( 94 | data = peng_x, 95 | params = params_2, 96 | verbose = -1 97 | ) 98 | 99 | lgbm_preds_2 <- predict(lgbm_fit_2, peng_m) 100 | 101 | expect_equal(pars_preds_2$.pred, lgbm_preds_2) 102 | 103 | # regression, adjusting an engine argument 104 | expect_no_error({ 105 | pars_fit_3 <- 106 | boost_tree() |> 107 | set_engine("lightgbm", lambda_l2 = .5) |> 108 | set_mode("regression") |> 109 | fit(bill_length_mm ~ ., data = penguins) 110 | }) 111 | 112 | expect_no_error({ 113 | pars_preds_3 <- 114 | predict(pars_fit_3, penguins) 115 | }) 116 | 117 | params_3 <- list( 118 | objective = "regression", 119 | lambda_l2 = .5 120 | ) 121 | 122 | lgbm_fit_3 <- 123 | lightgbm::lgb.train( 124 | data = peng_x, 125 | params = params_3, 126 | verbose = -1 127 | ) 128 | 129 | lgbm_preds_3 <- predict(lgbm_fit_3, peng_m) 130 | 131 | expect_equal(pars_preds_3$.pred, lgbm_preds_3) 132 | 133 | # classification ------------------------------------------------------------- 134 | 135 | # multiclass 136 | expect_no_error({ 137 | pars_fit_4 <- 138 | boost_tree() |> 139 | set_engine("lightgbm") |> 140 | set_mode("classification") |> 141 | fit(species ~ ., data = penguins) 142 | }) 143 | 144 | expect_no_error({ 145 | pars_preds_4 <- 146 | predict(pars_fit_4, penguins, type = "prob") 147 | pars_preds_raw_4 <- 148 | predict(pars_fit_4, penguins, type = "raw") 149 | }) 150 | 151 | expect_equal(nrow(pars_preds_raw_4), nrow(penguins)) 152 | expect_equal(ncol(pars_preds_raw_4), 3) 153 | 154 | pars_preds_4_mtx <- as.matrix(pars_preds_4) 155 | dimnames(pars_preds_4_mtx) <- NULL 156 | 157 | peng_y_c <- peng$species 158 | 159 | peng_m_c <- peng |> 160 | select(-species) |> 161 | as.matrix() 162 | 163 | peng_x_c <- 164 | lgb.Dataset( 165 | data = peng_m_c, 166 | label = peng_y_c, 167 | params = list(feature_pre_filter = FALSE), 168 | categorical_feature = c(1L, 6L), 169 | ) 170 | 171 | params_4 <- list( 172 | objective = "multiclass", 173 | num_class = 3 174 | ) 175 | 176 | lgbm_fit_4 <- 177 | lightgbm::lgb.train( 178 | data = peng_x_c, 179 | params = params_4, 180 | verbose = -1 181 | ) 182 | 183 | lgbm_preds_4 <- 184 | predict(lgbm_fit_4, peng_m_c) |> 185 | reshape_lightgbm_multiclass_preds(num_rows = nrow(peng_m_c)) 186 | 187 | expect_equal(pars_preds_4_mtx, lgbm_preds_4) 188 | 189 | # check class predictions 190 | pars_preds_5 <- 191 | predict(pars_fit_4, penguins, type = "class") |> 192 | (\(x) x[[".pred_class"]])() |> 193 | as.character() 194 | 195 | lgbm_preds_5 <- apply(pars_preds_4_mtx, 1, function(x) which.max(x)) |> 196 | factor(labels = c("Adelie", "Chinstrap", "Gentoo")) |> 197 | as.character() 198 | 199 | expect_equal(pars_preds_5, lgbm_preds_5) 200 | 201 | # classification on a two-level outcome 202 | expect_no_error({ 203 | pars_fit_6 <- 204 | boost_tree() |> 205 | set_engine("lightgbm") |> 206 | set_mode("classification") |> 207 | fit(sex ~ ., data = penguins) 208 | }) 209 | 210 | expect_no_error({ 211 | pars_preds_6 <- 212 | predict(pars_fit_6, penguins, type = "prob") 213 | pars_preds_raw_6 <- 214 | predict(pars_fit_6, penguins, type = "raw") 215 | }) 216 | 217 | expect_equal(length(pars_preds_raw_6), nrow(penguins)) 218 | expect_false(identical(pars_preds_6, pars_preds_raw_6)) 219 | 220 | pars_preds_6_b <- pars_preds_6$.pred_male 221 | 222 | peng_y_b <- peng$sex 223 | 224 | peng_m_b <- peng |> 225 | select(-sex) |> 226 | as.matrix() 227 | 228 | peng_x_b <- 229 | lgb.Dataset( 230 | data = peng_m_b, 231 | label = peng_y_b, 232 | params = list(feature_pre_filter = FALSE), 233 | categorical_feature = c(1L, 2L), 234 | ) 235 | 236 | params_6 <- list( 237 | objective = "binary", 238 | num_class = 1 239 | ) 240 | 241 | lgbm_fit_6 <- 242 | lightgbm::lgb.train( 243 | data = peng_x_b, 244 | params = params_6, 245 | verbose = -1 246 | ) 247 | 248 | lgbm_preds_6 <- predict(lgbm_fit_6, peng_m_b) 249 | 250 | expect_equal(pars_preds_6_b, lgbm_preds_6) 251 | }) 252 | 253 | test_that("bonsai applies dataset parameters (#77)", { 254 | skip_if_not_installed("lightgbm") 255 | skip_if_not_installed("modeldata") 256 | 257 | suppressPackageStartupMessages({ 258 | library(lightgbm) 259 | library(dplyr) 260 | }) 261 | 262 | data("penguins", package = "modeldata") 263 | 264 | penguins <- penguins[complete.cases(penguins), ] 265 | 266 | # regression ----------------------------------------------------------------- 267 | expect_no_error({ 268 | pars_fit_1 <- 269 | boost_tree() |> 270 | set_engine("lightgbm", linear_tree = TRUE) |> 271 | set_mode("regression") |> 272 | fit(bill_length_mm ~ ., data = penguins) 273 | }) 274 | 275 | expect_no_error({ 276 | pars_preds_1 <- 277 | predict(pars_fit_1, penguins) 278 | }) 279 | 280 | peng <- 281 | penguins |> 282 | mutate(across(where(is.character), \(x) as.factor(x))) |> 283 | mutate(across(where(is.factor), \(x) as.integer(x) - 1)) 284 | 285 | peng_y <- peng$bill_length_mm 286 | 287 | peng_m <- peng |> 288 | select(-bill_length_mm) |> 289 | as.matrix() 290 | 291 | peng_x <- 292 | lgb.Dataset( 293 | data = peng_m, 294 | label = peng_y, 295 | params = list(feature_pre_filter = FALSE, linear_tree = TRUE), 296 | categorical_feature = c(1L, 2L, 6L) 297 | ) 298 | 299 | params_1 <- list( 300 | objective = "regression" 301 | ) 302 | 303 | lgbm_fit_1 <- 304 | lightgbm::lgb.train( 305 | data = peng_x, 306 | params = params_1, 307 | verbose = -1 308 | ) 309 | 310 | lgbm_preds_1 <- predict(lgbm_fit_1, peng_m) 311 | 312 | expect_equal(pars_preds_1$.pred, lgbm_preds_1) 313 | expect_true(pars_fit_1$fit$params$linear_tree) 314 | }) 315 | 316 | test_that("bonsai correctly determines objective when label is a factor", { 317 | skip_if_not_installed("lightgbm") 318 | skip_if_not_installed("modeldata") 319 | 320 | suppressPackageStartupMessages({ 321 | library(lightgbm) 322 | library(dplyr) 323 | }) 324 | 325 | data("penguins", package = "modeldata") 326 | penguins <- penguins[complete.cases(penguins), ] 327 | 328 | expect_no_error({ 329 | bst <- train_lightgbm( 330 | x = penguins[, c("bill_length_mm", "bill_depth_mm")], 331 | y = penguins[["sex"]], 332 | num_iterations = 5, 333 | verbose = -1L 334 | ) 335 | }) 336 | expect_equal(bst$params$objective, "binary") 337 | expect_equal(bst$params$num_class, 1) 338 | 339 | expect_no_error({ 340 | bst <- train_lightgbm( 341 | x = penguins[, c("bill_length_mm", "bill_depth_mm")], 342 | y = penguins[["species"]], 343 | num_iterations = 5, 344 | verbose = -1L 345 | ) 346 | }) 347 | expect_equal(bst$params$objective, "multiclass") 348 | expect_equal(bst$params$num_class, 3) 349 | }) 350 | 351 | 352 | test_that("bonsai handles mtry vs mtry_prop gracefully", { 353 | skip_if_not_installed("modeldata") 354 | 355 | data("penguins", package = "modeldata") 356 | 357 | penguins <- penguins[complete.cases(penguins), ] 358 | 359 | # supply no mtry 360 | expect_no_error({ 361 | pars_fit_1 <- 362 | boost_tree() |> 363 | set_engine("lightgbm") |> 364 | set_mode("regression") |> 365 | fit(bill_length_mm ~ ., data = penguins) 366 | }) 367 | 368 | expect_equal( 369 | extract_fit_engine(pars_fit_1)$params$feature_fraction_bynode, 370 | 1 371 | ) 372 | 373 | # supply mtry = 1 (edge cases) 374 | expect_no_error({ 375 | pars_fit_2 <- 376 | boost_tree(mtry = 1) |> 377 | set_engine("lightgbm", counts = TRUE) |> 378 | set_mode("regression") |> 379 | fit(bill_length_mm ~ ., data = penguins) 380 | }) 381 | 382 | expect_equal( 383 | extract_fit_engine(pars_fit_2)$params$feature_fraction_bynode, 384 | 1 / (ncol(penguins) - 1) 385 | ) 386 | 387 | expect_no_error({ 388 | pars_fit_3 <- 389 | boost_tree(mtry = 1) |> 390 | set_engine("lightgbm", counts = FALSE) |> 391 | set_mode("regression") |> 392 | fit(bill_length_mm ~ ., data = penguins) 393 | }) 394 | 395 | expect_equal( 396 | extract_fit_engine(pars_fit_3)$params$feature_fraction_bynode, 397 | 1 398 | ) 399 | 400 | # supply a count (with default counts = TRUE) 401 | expect_no_error({ 402 | pars_fit_4 <- 403 | boost_tree(mtry = 3) |> 404 | set_engine("lightgbm") |> 405 | set_mode("regression") |> 406 | fit(bill_length_mm ~ ., data = penguins) 407 | }) 408 | 409 | expect_equal( 410 | extract_fit_engine(pars_fit_4)$params$feature_fraction_bynode, 411 | 3 / (ncol(penguins) - 1) 412 | ) 413 | 414 | # supply a proportion when count expected 415 | expect_snapshot_error({ 416 | pars_fit_5 <- 417 | boost_tree(mtry = .5) |> 418 | set_engine("lightgbm") |> 419 | set_mode("regression") |> 420 | fit(bill_length_mm ~ ., data = penguins) 421 | }) 422 | 423 | # supply a count when proportion expected 424 | expect_snapshot_error({ 425 | pars_fit_6 <- 426 | boost_tree(mtry = 3) |> 427 | set_engine("lightgbm", counts = FALSE) |> 428 | set_mode("regression") |> 429 | fit(bill_length_mm ~ ., data = penguins) 430 | }) 431 | 432 | # supply a feature fraction argument rather than mtry 433 | # TODO: is there any way to extend parsnip's warning here to 434 | # point users to mtry? 435 | # will see "The argument `feature_fraction_bynode` cannot be..." (#95) 436 | suppressWarnings( 437 | pars_fit_7 <- 438 | boost_tree() |> 439 | set_engine("lightgbm", feature_fraction_bynode = .5) |> 440 | set_mode("regression") |> 441 | fit(bill_length_mm ~ ., data = penguins) 442 | ) 443 | 444 | expect_equal( 445 | extract_fit_engine(pars_fit_7)$params$feature_fraction_bynode, 446 | 1 447 | ) 448 | 449 | # supply both feature fraction and mtry (#95) 450 | suppressWarnings(expect_error({ 451 | pars_fit_8 <- 452 | boost_tree(mtry = .5) |> 453 | set_engine("lightgbm", feature_fraction_bynode = .5) |> 454 | set_mode("regression") |> 455 | fit(bill_length_mm ~ ., data = penguins) 456 | })) 457 | 458 | # will see "The argument `feature_fraction_bynode` cannot be..." (#95) 459 | suppressWarnings( 460 | pars_fit_9 <- 461 | boost_tree(mtry = 2) |> 462 | set_engine("lightgbm", feature_fraction_bynode = .5) |> 463 | set_mode("regression") |> 464 | fit(bill_length_mm ~ ., data = penguins) 465 | ) 466 | 467 | expect_equal( 468 | extract_fit_engine(pars_fit_9)$params$feature_fraction_bynode, 469 | 2 / (ncol(penguins) - 1) 470 | ) 471 | }) 472 | 473 | test_that("tuning mtry vs mtry_prop", { 474 | skip_if_not_installed("tune") 475 | skip_if_not_installed("rsample") 476 | skip_if_not_installed("modeldata") 477 | 478 | data("penguins", package = "modeldata") 479 | 480 | penguins <- penguins[complete.cases(penguins), ] 481 | 482 | set.seed(1) 483 | 484 | suppressMessages( 485 | expect_no_error({ 486 | gbm_tune <- tune::tune_grid( 487 | boost_tree(mtry = tune::tune()) |> 488 | set_engine("lightgbm") |> 489 | set_mode("regression"), 490 | grid = 4, 491 | preprocessor = bill_length_mm ~ ., 492 | resamples = rsample::bootstraps(penguins, times = 5) 493 | ) 494 | }) 495 | ) 496 | 497 | mtrys <- unique(tune::collect_metrics(gbm_tune)$mtry) 498 | 499 | expect_equal(length(mtrys), 4) 500 | expect_true(all(mtrys >= 1)) 501 | 502 | # supply tune() without tuning 503 | expect_snapshot( 504 | { 505 | boost_tree(mtry = tune::tune()) |> 506 | set_engine("lightgbm") |> 507 | set_mode("regression") |> 508 | fit(bill_length_mm ~ ., data = penguins) 509 | }, 510 | error = TRUE 511 | ) 512 | }) 513 | 514 | test_that("training wrapper warns on protected arguments", { 515 | skip_if_not_installed("lightgbm") 516 | skip_if_not_installed("modeldata") 517 | 518 | data("penguins", package = "modeldata") 519 | 520 | penguins <- penguins[complete.cases(penguins), ] 521 | 522 | expect_snapshot( 523 | .res <- boost_tree() |> 524 | set_engine("lightgbm", colnames = paste0("X", 1:ncol(penguins))) |> 525 | set_mode("regression") |> 526 | fit(bill_length_mm ~ ., data = penguins) 527 | ) 528 | 529 | expect_snapshot( 530 | .res <- boost_tree() |> 531 | set_engine( 532 | "lightgbm", 533 | colnames = paste0("X", 1:ncol(penguins)), 534 | callbacks = list(p = print) 535 | ) |> 536 | set_mode("regression") |> 537 | fit(bill_length_mm ~ ., data = penguins) 538 | ) 539 | 540 | expect_snapshot( 541 | .res <- 542 | boost_tree() |> 543 | set_engine( 544 | "lightgbm", 545 | colnames = paste0("X", 1:ncol(penguins)) 546 | ) |> 547 | set_mode("regression") |> 548 | fit(bill_length_mm ~ ., data = penguins) 549 | ) 550 | 551 | expect_snapshot( 552 | error = TRUE, 553 | boost_tree() |> 554 | set_engine("lightgbm", n_iter = 10) |> 555 | set_mode("regression") |> 556 | fit(bill_length_mm ~ ., data = penguins) 557 | ) 558 | 559 | expect_snapshot( 560 | error = TRUE, 561 | boost_tree() |> 562 | set_engine("lightgbm", num_tree = 10) |> 563 | set_mode("regression") |> 564 | fit(bill_length_mm ~ ., data = penguins) 565 | ) 566 | 567 | expect_snapshot( 568 | error = TRUE, 569 | boost_tree() |> 570 | set_engine("lightgbm", min_split_gain = 2) |> 571 | set_mode("regression") |> 572 | fit(bill_length_mm ~ ., data = penguins) 573 | ) 574 | 575 | expect_snapshot( 576 | error = TRUE, 577 | boost_tree() |> 578 | set_engine("lightgbm", min_split_gain = 2, lambda_l2 = .5) |> 579 | set_mode("regression") |> 580 | fit(bill_length_mm ~ ., data = penguins) 581 | ) 582 | }) 583 | 584 | test_that("training wrapper passes stop_iter correctly", { 585 | skip_if_not_installed("lightgbm") 586 | skip_if_not_installed("modeldata") 587 | 588 | data("penguins", package = "modeldata") 589 | 590 | penguins <- penguins[complete.cases(penguins), ] 591 | 592 | expect_no_error( 593 | pars_fit_1 <- 594 | boost_tree(stop_iter = 10) |> 595 | set_engine("lightgbm") |> 596 | set_mode("regression") |> 597 | fit(bill_length_mm ~ ., data = penguins) 598 | ) 599 | 600 | # will see "The argument `early_stopping_round` cannot be..." (#95) 601 | suppressWarnings( 602 | pars_fit_2 <- 603 | boost_tree() |> 604 | set_engine("lightgbm", early_stopping_round = 10) |> 605 | set_mode("regression") |> 606 | fit(bill_length_mm ~ ., data = penguins) 607 | ) 608 | 609 | expect_no_error( 610 | pars_fit_3 <- 611 | boost_tree() |> 612 | set_engine("lightgbm") |> 613 | set_mode("regression") |> 614 | fit(bill_length_mm ~ ., data = penguins) 615 | ) 616 | 617 | expect_no_error( 618 | pars_fit_4 <- 619 | boost_tree() |> 620 | set_engine("lightgbm", validation = .2) |> 621 | set_mode("regression") |> 622 | fit(bill_length_mm ~ ., data = penguins) 623 | ) 624 | 625 | expect_no_error( 626 | pars_fit_5 <- 627 | boost_tree(stop_iter = 10) |> 628 | set_engine("lightgbm", validation = .2) |> 629 | set_mode("regression") |> 630 | fit(bill_length_mm ~ ., data = penguins) 631 | ) 632 | 633 | # detect early_stopping round in the model fit 634 | expect_equal(pars_fit_1$fit$params$early_stopping_round, 10) 635 | expect_null(pars_fit_2$fit$params$early_stopping_round) 636 | expect_null(pars_fit_3$fit$params$early_stopping_round) 637 | expect_null(pars_fit_4$fit$params$early_stopping_round) 638 | expect_equal(pars_fit_5$fit$params$early_stopping_round, 10) 639 | 640 | # detect validation in the model fit 641 | expect_true(!is.na(pars_fit_1$fit$best_score)) 642 | expect_true(is.na(pars_fit_2$fit$best_score)) 643 | expect_true(is.na(pars_fit_3$fit$best_score)) 644 | expect_true(!is.na(pars_fit_4$fit$best_score)) 645 | expect_true(!is.na(pars_fit_5$fit$best_score)) 646 | }) 647 | 648 | test_that("training wrapper handles bagging correctly", { 649 | skip_if_not_installed("lightgbm") 650 | skip_if_not_installed("modeldata") 651 | 652 | data("penguins", package = "modeldata") 653 | 654 | penguins <- penguins[complete.cases(penguins), ] 655 | 656 | pars_fit_1 <- 657 | boost_tree() |> 658 | set_engine("lightgbm") |> 659 | set_mode("regression") |> 660 | fit(bill_length_mm ~ ., data = penguins) 661 | 662 | pars_fit_2 <- 663 | boost_tree(sample_size = .5) |> 664 | set_engine("lightgbm") |> 665 | set_mode("regression") |> 666 | fit(bill_length_mm ~ ., data = penguins) 667 | 668 | pars_fit_3 <- 669 | boost_tree(sample_size = .5) |> 670 | set_engine("lightgbm", bagging_freq = 2) |> 671 | set_mode("regression") |> 672 | fit(bill_length_mm ~ ., data = penguins) 673 | 674 | expect_equal(pars_fit_1$fit$params$bagging_fraction, 1) 675 | expect_null(pars_fit_1$fit$params$bagging_freq) 676 | 677 | expect_equal(pars_fit_2$fit$params$bagging_fraction, .5) 678 | expect_equal(pars_fit_2$fit$params$bagging_freq, 1) 679 | 680 | expect_equal(pars_fit_3$fit$params$bagging_fraction, .5) 681 | expect_equal(pars_fit_3$fit$params$bagging_freq, 2) 682 | }) 683 | 684 | test_that("multi_predict() predicts classes if 'type' not given ", { 685 | skip_if_not_installed("lightgbm") 686 | skip_if_not_installed("modeldata") 687 | 688 | suppressPackageStartupMessages({ 689 | library(lightgbm) 690 | library(dplyr) 691 | }) 692 | 693 | data("penguins", package = "modeldata") 694 | penguins <- penguins[complete.cases(penguins), ] 695 | penguins_subset <- penguins[1:10, ] 696 | penguins_subset_numeric <- 697 | penguins_subset |> 698 | mutate(across(where(is.character), \(x) as.factor(x))) |> 699 | mutate(across(where(is.factor), \(x) as.integer(x) - 1)) 700 | 701 | num_iterations <- 5 702 | 703 | # classification (multiclass) ------------------------------------------------ 704 | expect_no_error({ 705 | clf_multiclass_fit <- 706 | boost_tree(trees = num_iterations) |> 707 | set_engine("lightgbm") |> 708 | set_mode("classification") |> 709 | fit(species ~ ., data = penguins) 710 | }) 711 | expect_equal(clf_multiclass_fit$fit$current_iter(), num_iterations) 712 | 713 | new_data <- 714 | penguins_subset_numeric |> 715 | select(-species) |> 716 | as.matrix() 717 | 718 | multi_preds <- 719 | multi_predict( 720 | clf_multiclass_fit, 721 | new_data = new_data[1, , drop = FALSE], 722 | trees = seq_len(num_iterations) 723 | ) 724 | 725 | # should be a tibble 726 | pred_tbl <- multi_preds$.pred[[1]] 727 | expect_s3_class(pred_tbl, "tbl_df") 728 | 729 | # should look like class predictions 730 | expect_named(pred_tbl, c("trees", ".pred_class")) 731 | expect_s3_class(pred_tbl[[".pred_class"]], "factor") 732 | expect_true(all( 733 | as.character(pred_tbl[[".pred_class"]]) %in% levels(penguins[["species"]]) 734 | )) 735 | 736 | # classification (binary) ------------------------------------------------ 737 | expect_no_error({ 738 | clf_binary_fit <- 739 | boost_tree(trees = num_iterations) |> 740 | set_engine("lightgbm") |> 741 | set_mode("classification") |> 742 | fit(sex ~ ., data = penguins) 743 | }) 744 | expect_equal(clf_binary_fit$fit$current_iter(), num_iterations) 745 | 746 | new_data <- 747 | penguins_subset_numeric |> 748 | select(-sex) |> 749 | as.matrix() 750 | 751 | multi_preds <- 752 | multi_predict( 753 | clf_binary_fit, 754 | new_data = new_data[1, , drop = FALSE], 755 | trees = seq_len(num_iterations) 756 | ) 757 | 758 | # should be a tibble 759 | pred_tbl <- multi_preds$.pred[[1]] 760 | expect_s3_class(pred_tbl, "tbl_df") 761 | 762 | # should look like class predictions 763 | expect_named(pred_tbl, c("trees", ".pred_class")) 764 | expect_s3_class(pred_tbl[[".pred_class"]], "factor") 765 | expect_true(all( 766 | as.character(pred_tbl[[".pred_class"]]) %in% levels(penguins[["sex"]]) 767 | )) 768 | }) 769 | 770 | test_that("lightgbm with case weights", { 771 | skip_if_not_installed("lightgbm") 772 | skip_if_not_installed("modeldata") 773 | 774 | suppressPackageStartupMessages({ 775 | library(lightgbm) 776 | library(dplyr) 777 | }) 778 | 779 | data("penguins", package = "modeldata") 780 | 781 | penguins <- penguins[complete.cases(penguins), ] 782 | 783 | set.seed(1) 784 | penguins_wts <- runif(nrow(penguins)) 785 | 786 | # regression ----------------------------------------------------------------- 787 | expect_no_error({ 788 | pars_fit_1 <- 789 | boost_tree() |> 790 | set_engine("lightgbm") |> 791 | set_mode("regression") |> 792 | fit( 793 | bill_length_mm ~ ., 794 | data = penguins, 795 | case_weights = importance_weights(penguins_wts) 796 | ) 797 | }) 798 | 799 | pars_preds_1 <- predict(pars_fit_1, penguins) 800 | 801 | peng <- 802 | penguins |> 803 | mutate(across(where(is.character), \(x) as.factor(x))) |> 804 | mutate(across(where(is.factor), \(x) as.integer(x) - 1)) 805 | 806 | peng_y <- peng$bill_length_mm 807 | 808 | peng_m <- peng |> 809 | select(-bill_length_mm) |> 810 | as.matrix() 811 | 812 | peng_x <- 813 | lgb.Dataset( 814 | data = peng_m, 815 | label = peng_y, 816 | params = list(feature_pre_filter = FALSE), 817 | categorical_feature = c(1L, 2L, 6L), 818 | weight = penguins_wts 819 | ) 820 | 821 | params_1 <- list( 822 | objective = "regression" 823 | ) 824 | 825 | lgbm_fit_1 <- 826 | lightgbm::lgb.train( 827 | data = peng_x, 828 | params = params_1, 829 | verbose = -1 830 | ) 831 | 832 | lgbm_preds_1 <- predict(lgbm_fit_1, peng_m) 833 | 834 | expect_equal(pars_preds_1$.pred, lgbm_preds_1) 835 | }) 836 | 837 | test_that("sparse data with lightgbm", { 838 | skip_on_cran() 839 | skip_if_not_installed("lightgbm") 840 | skip_if_not_installed("modeldata") 841 | skip_if_not_installed("Matrix") 842 | 843 | library(Matrix) 844 | library(dplyr) 845 | 846 | hep <- modeldata::hepatic_injury_qsar 847 | 848 | lgb_spec <- boost_tree() |> 849 | set_mode("classification") |> 850 | set_engine("lightgbm") 851 | 852 | # ------------------------------------------------------------------------------ 853 | 854 | hepatic_x_sp <- as.matrix(hep[, -1]) 855 | hepatic_x_sp <- as(hepatic_x_sp, "sparseMatrix") 856 | 857 | sprs_fit <- fit_xy(lgb_spec, hepatic_x_sp, hep$class) 858 | sprs_prob <- predict(sprs_fit, new_data = hepatic_x_sp, type = "prob") 859 | sprs_cls <- predict(sprs_fit, new_data = hepatic_x_sp, type = "class") 860 | 861 | # ------------------------------------------------------------------------------ 862 | 863 | hepatic_x <- hep[, -1] 864 | 865 | dens_fit <- fit_xy(lgb_spec, hepatic_x, hep$class) 866 | dens_prob <- predict(dens_fit, new_data = hepatic_x, type = "prob") 867 | dens_cls <- predict(dens_fit, new_data = hepatic_x, type = "class") 868 | 869 | # ------------------------------------------------------------------------------ 870 | 871 | # very small differences in lightgbm probabilities 872 | expect_equal(sprs_prob, dens_prob, tolerance = .001) 873 | expect_equal(sprs_cls, dens_cls) 874 | }) 875 | -------------------------------------------------------------------------------- /tests/testthat/test-partykit.R: -------------------------------------------------------------------------------- 1 | withr::local_envvar("OMP_THREAD_LIMIT" = 1) 2 | 3 | test_that("condition inference trees", { 4 | skip_if_not_installed("partykit") 5 | skip_if_not_installed("modeldata") 6 | 7 | suppressPackageStartupMessages(library(partykit)) 8 | 9 | expect_snapshot( 10 | decision_tree() |> set_engine("partykit") |> set_mode("regression") 11 | ) 12 | expect_snapshot( 13 | decision_tree() |> 14 | set_engine("partykit", teststat = "maximum") |> 15 | set_mode("classification") 16 | ) 17 | 18 | # ---------------------------------------------------------------------------- 19 | # regression 20 | 21 | expect_no_error({ 22 | ct_fit_1 <- 23 | decision_tree() |> 24 | set_engine("partykit") |> 25 | set_mode("regression") |> 26 | fit(mpg ~ ., data = mtcars) 27 | }) 28 | pk_fit_1 <- ctree(mpg ~ ., data = mtcars) 29 | expect_equal(pk_fit_1$fitted, ct_fit_1$fit$fitted) 30 | 31 | expect_no_error(ct_pred_1 <- predict(ct_fit_1, mtcars)$.pred) 32 | pk_pred_1 <- unname(predict(pk_fit_1, mtcars)) 33 | expect_equal(pk_pred_1, ct_pred_1) 34 | 35 | expect_no_error({ 36 | ct_fit_2 <- 37 | decision_tree(tree_depth = 1) |> 38 | set_engine("partykit") |> 39 | set_mode("regression") |> 40 | fit(mpg ~ ., data = mtcars) 41 | }) 42 | pk_fit_2 <- ctree( 43 | mpg ~ ., 44 | data = mtcars, 45 | control = ctree_control(maxdepth = 1) 46 | ) 47 | expect_equal(pk_fit_2$fitted, ct_fit_2$fit$fitted) 48 | 49 | expect_no_error(ct_pred_2 <- predict(ct_fit_2, mtcars)$.pred) 50 | pk_pred_2 <- unname(predict(pk_fit_2, mtcars)) 51 | expect_equal(pk_pred_2, ct_pred_2) 52 | 53 | expect_no_error({ 54 | ct_fit_3 <- 55 | decision_tree() |> 56 | set_engine("partykit", mincriterion = .99) |> 57 | set_mode("regression") |> 58 | fit(mpg ~ ., data = mtcars) 59 | }) 60 | pk_fit_3 <- ctree( 61 | mpg ~ ., 62 | data = mtcars, 63 | control = ctree_control(mincriterion = .99) 64 | ) 65 | expect_equal(pk_fit_3$fitted, ct_fit_3$fit$fitted) 66 | 67 | expect_no_error(ct_pred_3 <- predict(ct_fit_3, mtcars)$.pred) 68 | pk_pred_3 <- unname(predict(pk_fit_3, mtcars)) 69 | expect_equal(pk_pred_3, ct_pred_3) 70 | 71 | # ---------------------------------------------------------------------------- 72 | # classification 73 | 74 | data(ad_data, package = "modeldata") 75 | 76 | expect_no_error({ 77 | ct_fit_4 <- 78 | decision_tree() |> 79 | set_engine("partykit") |> 80 | set_mode("classification") |> 81 | fit(Class ~ ., data = ad_data) 82 | }) 83 | pk_fit_4 <- ctree(Class ~ ., data = ad_data) 84 | expect_equal(pk_fit_4$fitted, ct_fit_4$fit$fitted) 85 | 86 | expect_no_error(ct_pred_4 <- predict(ct_fit_4, ad_data)$.pred_class) 87 | pk_pred_4 <- unname(predict(pk_fit_4, ad_data)) 88 | expect_equal(pk_pred_4, ct_pred_4) 89 | 90 | expect_no_error(ct_prob_4 <- predict(ct_fit_4, ad_data, type = "prob")[[2]]) 91 | pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2]) 92 | expect_equal(pk_prob_4, ct_prob_4) 93 | }) 94 | 95 | 96 | test_that("condition inference forests", { 97 | skip_if_not_installed("partykit") 98 | skip_if_not_installed("modeldata") 99 | 100 | suppressPackageStartupMessages(library(partykit)) 101 | 102 | expect_snapshot( 103 | rand_forest() |> set_engine("partykit") |> set_mode("regression") 104 | ) 105 | expect_snapshot( 106 | rand_forest() |> 107 | set_engine("partykit", teststat = "maximum") |> 108 | set_mode("classification") 109 | ) 110 | 111 | # ---------------------------------------------------------------------------- 112 | # regression 113 | 114 | expect_no_error({ 115 | set.seed(1) 116 | cf_fit_1 <- 117 | rand_forest(trees = 5) |> 118 | set_engine("partykit") |> 119 | set_mode("regression") |> 120 | fit(mpg ~ ., data = mtcars) 121 | }) 122 | set.seed(1) 123 | pk_fit_1 <- cforest(mpg ~ ., data = mtcars, ntree = 5) 124 | expect_equal(pk_fit_1$fitted, cf_fit_1$fit$fitted) 125 | 126 | expect_no_error(cf_pred_1 <- predict(cf_fit_1, mtcars)$.pred) 127 | pk_pred_1 <- unname(predict(pk_fit_1, mtcars)) 128 | expect_equal(pk_pred_1, cf_pred_1) 129 | 130 | expect_no_error({ 131 | set.seed(1) 132 | cf_fit_2 <- 133 | rand_forest(trees = 5, mtry = 2) |> 134 | set_engine("partykit") |> 135 | set_mode("regression") |> 136 | fit(mpg ~ ., data = mtcars) 137 | }) 138 | set.seed(1) 139 | pk_fit_2 <- cforest(mpg ~ ., data = mtcars, ntree = 5, mtry = 2) 140 | expect_equal(pk_fit_2$fitted, cf_fit_2$fit$fitted) 141 | 142 | expect_no_error(cf_pred_2 <- predict(cf_fit_2, mtcars)$.pred) 143 | pk_pred_2 <- unname(predict(pk_fit_2, mtcars)) 144 | expect_equal(pk_pred_2, cf_pred_2) 145 | 146 | expect_no_error({ 147 | set.seed(1) 148 | cf_fit_3 <- 149 | rand_forest(trees = 5) |> 150 | set_engine("partykit", mincriterion = .99) |> 151 | set_mode("regression") |> 152 | fit(mpg ~ ., data = mtcars) 153 | }) 154 | set.seed(1) 155 | pk_fit_3 <- cforest( 156 | mpg ~ ., 157 | data = mtcars, 158 | ntree = 5, 159 | control = ctree_control(mincriterion = .99) 160 | ) 161 | expect_equal(pk_fit_3$fitted, cf_fit_3$fit$fitted) 162 | 163 | expect_no_error(cf_pred_3 <- predict(cf_fit_3, mtcars)$.pred) 164 | pk_pred_3 <- unname(predict(pk_fit_3, mtcars)) 165 | expect_equal(pk_pred_3, cf_pred_3) 166 | 167 | # ---------------------------------------------------------------------------- 168 | # classification 169 | 170 | data(ad_data, package = "modeldata") 171 | 172 | expect_no_error({ 173 | set.seed(1) 174 | cf_fit_4 <- 175 | rand_forest(trees = 5) |> 176 | set_engine("partykit") |> 177 | set_mode("classification") |> 178 | fit(Class ~ ., data = ad_data) 179 | }) 180 | set.seed(1) 181 | pk_fit_4 <- cforest(Class ~ ., data = ad_data, ntree = 5) 182 | expect_equal(pk_fit_4$fitted, cf_fit_4$fit$fitted) 183 | 184 | expect_no_error(cf_pred_4 <- predict(cf_fit_4, ad_data)$.pred_class) 185 | pk_pred_4 <- unname(predict(pk_fit_4, ad_data)) 186 | expect_equal(pk_pred_4, cf_pred_4) 187 | 188 | expect_no_error(cf_prob_4 <- predict(cf_fit_4, ad_data, type = "prob")[[2]]) 189 | pk_prob_4 <- unname(predict(pk_fit_4, ad_data, type = "prob")[, 2]) 190 | expect_equal(pk_prob_4, cf_prob_4) 191 | }) 192 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/bonsai.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Introduction to bonsai" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Introduction to bonsai} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r} 11 | #| include: false 12 | knitr::opts_chunk$set( 13 | collapse = TRUE, 14 | comment = "#>" 15 | ) 16 | 17 | withr::local_envvar("OMP_THREAD_LIMIT" = 1) 18 | 19 | if (rlang::is_installed("partykit") && 20 | rlang::is_installed("lightgbm") && 21 | rlang::is_installed("modeldata")) { 22 | run <- TRUE 23 | } else { 24 | run <- FALSE 25 | } 26 | 27 | knitr::opts_chunk$set( 28 | eval = run 29 | ) 30 | ``` 31 | 32 | The goal of bonsai is to provide bindings for additional tree-based model engines for use with the {parsnip} package. 33 | 34 | If you're not familiar with parsnip, you can read more about the package on it's [website](https://parsnip.tidymodels.org). 35 | 36 | To get started, load bonsai with: 37 | 38 | ```{r} 39 | #| label: setup 40 | library(bonsai) 41 | ``` 42 | 43 | To illustrate how to use the package, we'll fit some models to a dataset containing measurements on 3 different species of penguins. Loading in that data and checking it out: 44 | 45 | ```{r} 46 | library(modeldata) 47 | 48 | data(penguins) 49 | 50 | str(penguins) 51 | ``` 52 | 53 | Specifically, making use of our knowledge of which island that they live on and measurements on their flipper length, we will predict their species using a decision tree. We'll first do so using the engine `"rpart"`, which is supported with parsnip alone: 54 | 55 | ```{r} 56 | # set seed for reproducibility 57 | set.seed(1) 58 | 59 | # specify and fit model 60 | dt_mod <- 61 | decision_tree() |> 62 | set_engine(engine = "rpart") |> 63 | set_mode(mode = "classification") |> 64 | fit( 65 | formula = species ~ flipper_length_mm + island, 66 | data = penguins 67 | ) 68 | 69 | dt_mod 70 | ``` 71 | 72 | From this output, we can see that the model generally first looks to `island` to determine species, and then makes use of a mix of flipper length and island to ultimately make a species prediction. 73 | 74 | A benefit of using parsnip and bonsai is that, to use a different implementation of decision trees, we simply change the engine argument to `set_engine`; all other elements of the interface stay the same. For instance, using `"partykit"`—which implements a type of decision tree called a _conditional inference tree_—as our backend instead: 75 | 76 | ```{r} 77 | decision_tree() |> 78 | set_engine(engine = "partykit") |> 79 | set_mode(mode = "classification") |> 80 | fit( 81 | formula = species ~ flipper_length_mm + island, 82 | data = penguins 83 | ) 84 | ``` 85 | 86 | This model, unlike the first, relies on recursive conditional inference to generate its splits. As such, we can see it generates slightly different results. Read more about this implementation of decision trees in `?details_decision_tree_partykit`. 87 | 88 | One generalization of a decision tree is a _random forest_, which fits a large number of decision trees, each independently of the others. The fitted random forest model combines predictions from the individual decision trees to generate its predictions. 89 | 90 | bonsai introduces support for random forests using the `partykit` engine, which implements an algorithm called a _conditional random forest_. Conditional random forests are a type of random forest that uses conditional inference trees (like the one we fit above!) for its constituent decision trees. 91 | 92 | To fit a conditional random forest with partykit, our code looks pretty similar to that which we we needed to fit a conditional inference tree. Just switch out `decision_tree()` with `rand_forest()` and remember to keep the engine set as `"partykit"`: 93 | 94 | ```{r} 95 | rf_mod <- 96 | rand_forest() |> 97 | set_engine(engine = "partykit") |> 98 | set_mode(mode = "classification") |> 99 | fit( 100 | formula = species ~ flipper_length_mm + island, 101 | data = penguins 102 | ) 103 | ``` 104 | 105 | Read more about this implementation of random forests in `?details_rand_forest_partykit`. 106 | 107 | Another generalization of a decision tree is a series of decision trees where _each tree depends on the results of previous trees_—this is called a _boosted tree_. bonsai implements an additional parsnip engine for this model type called `lightgbm`. To make use of it, start out with a `boost_tree` model spec and set `engine = "lightgbm"`: 108 | 109 | ```{r} 110 | bt_mod <- 111 | boost_tree() |> 112 | set_engine(engine = "lightgbm") |> 113 | set_mode(mode = "classification") |> 114 | fit( 115 | formula = species ~ flipper_length_mm + island, 116 | data = penguins 117 | ) 118 | 119 | bt_mod 120 | ``` 121 | 122 | Read more about this implementation of boosted trees in `?details_boost_tree_lightgbm`. 123 | 124 | Each of these model specs and engines have several arguments and tuning parameters that affect user experience and results greatly. We recommend reading about each of these parameters and tuning them when you find them relevant for your modeling use case. 125 | --------------------------------------------------------------------------------