├── .Rbuildignore ├── .github └── workflows │ ├── r-cmd-check-paradox.yml │ └── tic.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── AuditorFitters.R ├── MCBoost.R ├── PipeOpLearnerPred.R ├── PipeOpMCBoost.R ├── PipelineMCBoost.R ├── Predictor.R ├── ProbRange.R ├── helpers.R └── zzz.R ├── README.md ├── _pkgdown.yml ├── attic ├── MCBoostSurv.R ├── PipeOpMCBoostSurv.R ├── ProbRange2D.R ├── helpers_survival.R ├── mcboost_step_by_step.Rmd ├── mcboostsurv_basics.Rmd ├── test_mcboostsurv.R ├── test_pipeop_learner_pred.R ├── test_pipeop_mcboostsurv.R └── test_probrange2d.R ├── codecov.yml ├── codemeta.json ├── contributions.md ├── cran-comments.md ├── inst └── CITATION ├── man-roxygen ├── params_data_label.R ├── params_data_resid.R ├── params_mask.R ├── params_subpops.R ├── return_auditor.R ├── return_fit.R └── return_predictor.R ├── man ├── AuditorFitter.Rd ├── CVLearnerAuditorFitter.Rd ├── LearnerAuditorFitter.Rd ├── MCBoost.Rd ├── SubgroupAuditorFitter.Rd ├── SubpopAuditorFitter.Rd ├── figures │ ├── lifecycle-archived.svg │ ├── lifecycle-defunct.svg │ ├── lifecycle-deprecated.svg │ ├── lifecycle-experimental.svg │ ├── lifecycle-maturing.svg │ ├── lifecycle-questioning.svg │ ├── lifecycle-stable.svg │ └── lifecycle-superseded.svg ├── mcboost-package.Rd ├── mlr3_init_predictor.Rd ├── mlr_pipeops_mcboost.Rd ├── one_hot.Rd └── ppl_mcboost.Rd ├── mcboost.Rproj ├── paper ├── MCBoost.drawio ├── MCBoost.png ├── paper.bib └── paper.md ├── tests ├── testthat.R └── testthat │ ├── setup.R │ ├── teardown.R │ ├── test_auditor_fitters.R │ ├── test_cv_predictors.R │ ├── test_mcboost.R │ ├── test_mcboost_low_degree.R │ ├── test_pipeop_mcboost.R │ ├── test_predictor.R │ ├── test_probrange.R │ └── test_sonar_usecase.R ├── tic.R └── vignettes ├── .gitignore ├── mcboost_basics_extensions.Rmd └── mcboost_example.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | ^LICENSE\.md$ 4 | man-roxygen 5 | attic 6 | test/testthat/_snaps 7 | ^\.ccache$ 8 | ^\.github$ 9 | ^tic\.R$ 10 | ^_pkgdown\.yml$ 11 | ^docs$ 12 | ^pkgdown$ 13 | ^doc$ 14 | ^Meta$ 15 | ^codecov\.yml$ 16 | ^CODE_OF_CONDUCT\.md$ 17 | ^paper$ 18 | ^contributions.md$ 19 | ^cran-comments\.md$ 20 | ^mcboost-manual.tex$ 21 | ^CRAN-RELEASE$ 22 | ^codemeta\.json$ 23 | ^lastMiKTeXException$ 24 | ^CRAN-SUBMISSION$ 25 | -------------------------------------------------------------------------------- /.github/workflows/r-cmd-check-paradox.yml: -------------------------------------------------------------------------------- 1 | # r cmd check workflow of the mlr3 ecosystem v0.1.0 2 | # https://github.com/mlr-org/actions 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | name: r-cmd-check-paradox 13 | 14 | jobs: 15 | r-cmd-check: 16 | runs-on: ${{ matrix.config.os }} 17 | 18 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 19 | 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | 23 | strategy: 24 | fail-fast: false 25 | matrix: 26 | config: 27 | - {os: ubuntu-latest, r: 'devel'} 28 | - {os: ubuntu-latest, r: 'release'} 29 | 30 | steps: 31 | - uses: actions/checkout@v3 32 | 33 | - name: paradox 34 | run: 'echo -e "Remotes:\n mlr-org/paradox,\n mlr-org/mlr3learners,\n mlr-org/mlr3pipelines,\n mlr-org/mlr3oml" >> DESCRIPTION' 35 | 36 | - uses: r-lib/actions/setup-r@v2 37 | with: 38 | r-version: ${{ matrix.config.r }} 39 | 40 | - uses: r-lib/actions/setup-r-dependencies@v2 41 | with: 42 | extra-packages: any::rcmdcheck 43 | needs: check 44 | - uses: r-lib/actions/check-r-package@v2 45 | -------------------------------------------------------------------------------- /.github/workflows/tic.yml: -------------------------------------------------------------------------------- 1 | ## tic GitHub Actions template: linux-deploy 2 | ## revision date: 2020-12-11 3 | on: 4 | workflow_dispatch: 5 | push: 6 | pull_request: 7 | # for now, CRON jobs only run on the default branch of the repo (i.e. usually on master) 8 | schedule: 9 | # * is a special character in YAML so you have to quote this string 10 | - cron: "0 4 * * *" 11 | 12 | name: tic 13 | 14 | jobs: 15 | all: 16 | runs-on: ${{ matrix.config.os }} 17 | 18 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | config: 24 | # use a different tic template type if you do not want to build on all listed platforms 25 | - { os: ubuntu-latest, r: "release", pkgdown: "true" } 26 | - { os: ubuntu-latest, r: "devel" } 27 | env: 28 | # otherwise remotes::fun() errors cause the build to fail. Example: Unavailability of binaries 29 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 30 | CRAN: ${{ matrix.config.cran }} 31 | # make sure to run `tic::use_ghactions_deploy()` to set up deployment 32 | TIC_DEPLOY_KEY: ${{ secrets.TIC_DEPLOY_KEY }} 33 | # prevent rgl issues because no X11 display is available 34 | RGL_USE_NULL: true 35 | # if you use bookdown or blogdown, replace "PKGDOWN" by the respective 36 | # capitalized term. This also might need to be done in tic.R 37 | BUILD_PKGDOWN: ${{ matrix.config.pkgdown }} 38 | # macOS >= 10.15.4 linking 39 | SDKROOT: /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk 40 | # use GITHUB_TOKEN from GitHub to workaround rate limits in {remotes} 41 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 42 | 43 | steps: 44 | - uses: actions/checkout@v2.3.4 45 | 46 | - uses: r-lib/actions/setup-r@master 47 | with: 48 | r-version: ${{ matrix.config.r }} 49 | Ncpus: 4 50 | 51 | # LaTeX. Installation time: 52 | # Linux: ~ 1 min 53 | # macOS: ~ 1 min 30s 54 | # Windows: never finishes 55 | - uses: r-lib/actions/setup-tinytex@master 56 | if: matrix.config.latex == 'true' 57 | 58 | - uses: r-lib/actions/setup-pandoc@master 59 | 60 | # set date/week for use in cache creation 61 | # https://github.community/t5/GitHub-Actions/How-to-set-and-access-a-Workflow-variable/m-p/42970 62 | # - cache R packages daily 63 | - name: "[Cache] Prepare daily timestamp for cache" 64 | if: runner.os != 'Windows' 65 | id: date 66 | run: echo "::set-output name=date::$(date '+%d-%m')" 67 | 68 | - name: "[Cache] Cache R packages" 69 | if: runner.os != 'Windows' 70 | uses: pat-s/always-upload-cache@v2.1.3 71 | with: 72 | path: ${{ env.R_LIBS_USER }} 73 | key: ${{ runner.os }}-r-${{ matrix.config.r }}-${{steps.date.outputs.date}} 74 | restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-${{steps.date.outputs.date}} 75 | 76 | # for some strange Windows reason this step and the next one need to be decoupled 77 | - name: "[Stage] Prepare" 78 | run: | 79 | Rscript -e "if (!requireNamespace('remotes')) install.packages('remotes', type = 'source')" 80 | Rscript -e "if (getRversion() < '3.2' && !requireNamespace('curl')) install.packages('curl', type = 'source')" 81 | 82 | - name: "[Stage] [Linux] Install curl and libgit2" 83 | if: runner.os == 'Linux' 84 | run: sudo apt install libcurl4-openssl-dev libgit2-dev 85 | 86 | - name: "[Stage] [macOS] Install libgit2" 87 | if: runner.os == 'macOS' 88 | run: brew install libgit2 89 | 90 | - name: "[Stage] [macOS] Install system libs for pkgdown" 91 | if: runner.os == 'macOS' && matrix.config.pkgdown != '' 92 | run: brew install harfbuzz fribidi 93 | 94 | - name: "[Stage] [Linux] Install system libs for pkgdown" 95 | if: runner.os == 'Linux' && matrix.config.pkgdown != '' 96 | run: sudo apt install libharfbuzz-dev libfribidi-dev 97 | 98 | - name: "[Stage] Install" 99 | if: matrix.config.os != 'macOS-latest' || matrix.config.r != 'devel' 100 | run: Rscript -e "remotes::install_github('ropensci/tic')" -e "print(tic::dsl_load())" -e "tic::prepare_all_stages()" -e "tic::before_install()" -e "tic::install()" 101 | 102 | # macOS devel needs its own stage because we need to work with an option to suppress the usage of binaries 103 | - name: "[Stage] Prepare & Install (macOS-devel)" 104 | if: matrix.config.os == 'macOS-latest' && matrix.config.r == 'devel' 105 | run: | 106 | echo -e 'options(Ncpus = 4, pkgType = "source", repos = structure(c(CRAN = "https://cloud.r-project.org/")))' > $HOME/.Rprofile 107 | Rscript -e "remotes::install_github('ropensci/tic')" -e "print(tic::dsl_load())" -e "tic::prepare_all_stages()" -e "tic::before_install()" -e "tic::install()" 108 | 109 | # - name: "[Stage] Install Vignette Dependencies" 110 | # run: | 111 | # Rscript -e "install.packages('neuralnet')" 112 | # Rscript -e "install.packages('PracTools')" 113 | # Rscript -e "install.packages('tidyverse')" 114 | # Rscript -e "install.packages('formattable')" 115 | # Rscript -e "install.packages('curl')" 116 | 117 | - name: "[Stage] Script" 118 | run: Rscript -e 'tic::script()' 119 | 120 | - name: "[Stage] After Success" 121 | if: matrix.config.os == 'macOS-latest' && matrix.config.r == 'release' 122 | run: Rscript -e "tic::after_success()" 123 | 124 | - name: "[Stage] Upload R CMD check artifacts" 125 | if: failure() 126 | uses: actions/upload-artifact@v2.2.1 127 | with: 128 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results 129 | path: check 130 | - name: "[Stage] Before Deploy" 131 | run: | 132 | Rscript -e "tic::before_deploy()" 133 | 134 | - name: "[Stage] Deploy" 135 | run: Rscript -e "tic::deploy()" 136 | 137 | - name: "[Stage] After Deploy" 138 | run: Rscript -e "tic::after_deploy()" 139 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | inst/doc 6 | test/testthat/snaps 7 | docs/ 8 | docs 9 | /doc/ 10 | /Meta/ 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity and 10 | orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards 42 | of acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies 54 | when an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail 56 | address, posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at [INSERT CONTACT 63 | METHOD]. All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.0, 118 | available at https://www.contributor-covenant.org/version/2/0/ 119 | code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at https:// 128 | www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: mcboost 2 | Type: Package 3 | Title: Multi-Calibration Boosting 4 | Version: 0.4.3-9000 5 | Authors@R: 6 | c(person(given = "Florian", 7 | family = "Pfisterer", 8 | role = "aut", 9 | email = "pfistererf@googlemail.com", 10 | comment = c(ORCID = "0000-0001-8867-762X")), 11 | person(given = "Susanne", 12 | family = "Dandl", 13 | role = "ctb", 14 | email = "susanne.dandl@stat.uni-muenchen.de", 15 | comment = c(ORCID = "0000-0003-4324-4163")), 16 | person(given = "Christoph", 17 | family = "Kern", 18 | role = "ctb", 19 | email = "c.kern@uni-mannheim.de", 20 | comment = c(ORCID = "0000-0001-7363-4299")), 21 | person(given = "Carolin", 22 | family = "Becker", 23 | role = "ctb"), 24 | person(given = "Bernd", 25 | family = "Bischl", 26 | role = "ctb", 27 | email = "bernd_bischl@gmx.net", 28 | comment = c(ORCID = "0000-0001-6002-6980")), 29 | person(given = "Sebastian", 30 | family = "Fischer", 31 | role = c("ctb", "cre"), 32 | email = "sebf.fischer@gmail.com") 33 | ) 34 | Description: Implements 'Multi-Calibration Boosting' (2018) and 35 | 'Multi-Accuracy Boosting' (2019) for the multi-calibration of a machine learning model's prediction. 36 | 'MCBoost' updates predictions for sub-groups in an iterative fashion in order to mitigate biases like poor calibration or large accuracy differences across subgroups. 37 | Multi-Calibration works best in scenarios where the underlying data & labels are unbiased, but resulting models are. 38 | This is often the case, e.g. when an algorithm fits a majority population while ignoring or under-fitting minority populations. 39 | License: LGPL (>= 3) 40 | URL: https://github.com/mlr-org/mcboost 41 | BugReports: https://github.com/mlr-org/mcboost/issues 42 | Encoding: UTF-8 43 | Depends: 44 | R (>= 3.1.0) 45 | Imports: 46 | backports, 47 | checkmate (>= 2.0.0), 48 | data.table (>= 1.13.6), 49 | mlr3 (>= 0.10), 50 | mlr3misc (>= 0.8.0), 51 | mlr3pipelines (>= 0.3.0), 52 | R6 (>= 2.4.1), 53 | rmarkdown, 54 | rpart, 55 | glmnet 56 | Suggests: 57 | curl, 58 | lgr, 59 | formattable, 60 | tidyverse, 61 | PracTools, 62 | mlr3learners, 63 | mlr3oml, 64 | neuralnet, 65 | paradox, 66 | knitr, 67 | ranger, 68 | xgboost, 69 | covr, 70 | testthat (>= 3.1.0) 71 | Roxygen: list(markdown = TRUE, r6 = TRUE) 72 | RoxygenNote: 7.3.1 73 | VignetteBuilder: knitr 74 | Collate: 75 | 'AuditorFitters.R' 76 | 'MCBoost.R' 77 | 'PipelineMCBoost.R' 78 | 'PipeOpLearnerPred.R' 79 | 'PipeOpMCBoost.R' 80 | 'Predictor.R' 81 | 'ProbRange.R' 82 | 'helpers.R' 83 | 'zzz.R' 84 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU Lesser General Public License 2 | ================================= 3 | 4 | _Version 3, 29 June 2007_ 5 | _Copyright © 2007 Free Software Foundation, Inc. <>_ 6 | 7 | Everyone is permitted to copy and distribute verbatim copies 8 | of this license document, but changing it is not allowed. 9 | 10 | 11 | This version of the GNU Lesser General Public License incorporates 12 | the terms and conditions of version 3 of the GNU General Public 13 | License, supplemented by the additional permissions listed below. 14 | 15 | ### 0. Additional Definitions 16 | 17 | As used herein, “this License” refers to version 3 of the GNU Lesser 18 | General Public License, and the “GNU GPL” refers to version 3 of the GNU 19 | General Public License. 20 | 21 | “The Library” refers to a covered work governed by this License, 22 | other than an Application or a Combined Work as defined below. 23 | 24 | An “Application” is any work that makes use of an interface provided 25 | by the Library, but which is not otherwise based on the Library. 26 | Defining a subclass of a class defined by the Library is deemed a mode 27 | of using an interface provided by the Library. 28 | 29 | A “Combined Work” is a work produced by combining or linking an 30 | Application with the Library. The particular version of the Library 31 | with which the Combined Work was made is also called the “Linked 32 | Version”. 33 | 34 | The “Minimal Corresponding Source” for a Combined Work means the 35 | Corresponding Source for the Combined Work, excluding any source code 36 | for portions of the Combined Work that, considered in isolation, are 37 | based on the Application, and not on the Linked Version. 38 | 39 | The “Corresponding Application Code” for a Combined Work means the 40 | object code and/or source code for the Application, including any data 41 | and utility programs needed for reproducing the Combined Work from the 42 | Application, but excluding the System Libraries of the Combined Work. 43 | 44 | ### 1. Exception to Section 3 of the GNU GPL 45 | 46 | You may convey a covered work under sections 3 and 4 of this License 47 | without being bound by section 3 of the GNU GPL. 48 | 49 | ### 2. Conveying Modified Versions 50 | 51 | If you modify a copy of the Library, and, in your modifications, a 52 | facility refers to a function or data to be supplied by an Application 53 | that uses the facility (other than as an argument passed when the 54 | facility is invoked), then you may convey a copy of the modified 55 | version: 56 | 57 | * **a)** under this License, provided that you make a good faith effort to 58 | ensure that, in the event an Application does not supply the 59 | function or data, the facility still operates, and performs 60 | whatever part of its purpose remains meaningful, or 61 | 62 | * **b)** under the GNU GPL, with none of the additional permissions of 63 | this License applicable to that copy. 64 | 65 | ### 3. Object Code Incorporating Material from Library Header Files 66 | 67 | The object code form of an Application may incorporate material from 68 | a header file that is part of the Library. You may convey such object 69 | code under terms of your choice, provided that, if the incorporated 70 | material is not limited to numerical parameters, data structure 71 | layouts and accessors, or small macros, inline functions and templates 72 | (ten or fewer lines in length), you do both of the following: 73 | 74 | * **a)** Give prominent notice with each copy of the object code that the 75 | Library is used in it and that the Library and its use are 76 | covered by this License. 77 | * **b)** Accompany the object code with a copy of the GNU GPL and this license 78 | document. 79 | 80 | ### 4. Combined Works 81 | 82 | You may convey a Combined Work under terms of your choice that, 83 | taken together, effectively do not restrict modification of the 84 | portions of the Library contained in the Combined Work and reverse 85 | engineering for debugging such modifications, if you also do each of 86 | the following: 87 | 88 | * **a)** Give prominent notice with each copy of the Combined Work that 89 | the Library is used in it and that the Library and its use are 90 | covered by this License. 91 | 92 | * **b)** Accompany the Combined Work with a copy of the GNU GPL and this license 93 | document. 94 | 95 | * **c)** For a Combined Work that displays copyright notices during 96 | execution, include the copyright notice for the Library among 97 | these notices, as well as a reference directing the user to the 98 | copies of the GNU GPL and this license document. 99 | 100 | * **d)** Do one of the following: 101 | - **0)** Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | - **1)** Use a suitable shared library mechanism for linking with the 109 | Library. A suitable mechanism is one that **(a)** uses at run time 110 | a copy of the Library already present on the user's computer 111 | system, and **(b)** will operate properly with a modified version 112 | of the Library that is interface-compatible with the Linked 113 | Version. 114 | 115 | * **e)** Provide Installation Information, but only if you would otherwise 116 | be required to provide such information under section 6 of the 117 | GNU GPL, and only to the extent that such information is 118 | necessary to install and execute a modified version of the 119 | Combined Work produced by recombining or relinking the 120 | Application with a modified version of the Linked Version. (If 121 | you use option **4d0**, the Installation Information must accompany 122 | the Minimal Corresponding Source and Corresponding Application 123 | Code. If you use option **4d1**, you must provide the Installation 124 | Information in the manner specified by section 6 of the GNU GPL 125 | for conveying Corresponding Source.) 126 | 127 | ### 5. Combined Libraries 128 | 129 | You may place library facilities that are a work based on the 130 | Library side by side in a single library together with other library 131 | facilities that are not Applications and are not covered by this 132 | License, and convey such a combined library under terms of your 133 | choice, if you do both of the following: 134 | 135 | * **a)** Accompany the combined library with a copy of the same work based 136 | on the Library, uncombined with any other library facilities, 137 | conveyed under the terms of this License. 138 | * **b)** Give prominent notice with the combined library that part of it 139 | is a work based on the Library, and explaining where to find the 140 | accompanying uncombined form of the same work. 141 | 142 | ### 6. Revised Versions of the GNU Lesser General Public License 143 | 144 | The Free Software Foundation may publish revised and/or new versions 145 | of the GNU Lesser General Public License from time to time. Such new 146 | versions will be similar in spirit to the present version, but may 147 | differ in detail to address new problems or concerns. 148 | 149 | Each version is given a distinguishing version number. If the 150 | Library as you received it specifies that a certain numbered version 151 | of the GNU Lesser General Public License “or any later version” 152 | applies to it, you have the option of following the terms and 153 | conditions either of that published version or of any later version 154 | published by the Free Software Foundation. If the Library as you 155 | received it does not specify a version number of the GNU Lesser 156 | General Public License, you may choose any version of the GNU Lesser 157 | General Public License ever published by the Free Software Foundation. 158 | 159 | If the Library as you received it specifies that a proxy can decide 160 | whether future versions of the GNU Lesser General Public License shall 161 | apply, that proxy's public statement of acceptance of any version is 162 | permanent authorization for you to choose that version for the 163 | Library. -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(AuditorFitter) 4 | export(CVLearnerAuditorFitter) 5 | export(CVRidgeAuditorFitter) 6 | export(CVTreeAuditorFitter) 7 | export(LearnerAuditorFitter) 8 | export(MCBoost) 9 | export(PipeOpLearnerPred) 10 | export(PipeOpMCBoost) 11 | export(RidgeAuditorFitter) 12 | export(SubgroupAuditorFitter) 13 | export(SubpopAuditorFitter) 14 | export(TreeAuditorFitter) 15 | export(mlr3_init_predictor) 16 | export(one_hot) 17 | export(ppl_mcboost) 18 | import(checkmate) 19 | import(data.table) 20 | import(glmnet) 21 | import(mlr3) 22 | import(mlr3misc) 23 | import(mlr3pipelines) 24 | import(rmarkdown) 25 | import(rpart) 26 | importFrom(R6,R6Class) 27 | importFrom(R6,is.R6) 28 | importFrom(stats,contrasts) 29 | importFrom(stats,quantile) 30 | importFrom(stats,rnorm) 31 | importFrom(stats,runif) 32 | importFrom(stats,setNames) 33 | importFrom(utils,head) 34 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # dev 2 | 3 | # mcboost 0.4.3 4 | 5 | * Compatibility with upcoming 'paradox' release. 6 | * Change the vignette to not use the holdout task. 7 | 8 | # mcboost 0.4.2 9 | * Removed new functionality for survival tasks added in `0.4.0`. 10 | A dependency, `mlr3proba` was removed from CRAN for now. 11 | The functionality will be added back when `mlr3proba` is re-introduced to CRAN. 12 | Users who wish to use `mcboost` for `survival` are adviced to use version `0.4.1` usetogether with the GitHub version of `mlr3proba`. 13 | * Improved stability of unit tests and example checks on CRAN. 14 | 15 | # mcboost 0.4.1 16 | * Fixed unit error in unit tests that led to non-passing unit tests with new mlr3proba version. 17 | 18 | # mcboost 0.4.0 19 | * [Experimental] mcboost now has experimental support for *survival* tasks. 20 | See `MCBoostSurv` and the corresponding vignette "MCBoostSurv - Basics" for more information. 21 | * We have published an article about mcboost in the Journal of Open Source Software: "https://joss.theoj.org/papers/10.21105/joss.03453". See citation("mcboost") for the citation info. 22 | 23 | 24 | # mcboost 0.3.3 25 | * Auditors can now also update weights if correlations are negative by switching the sign of the update direction as intended in the paper. 26 | * Minor adaptions to improve stability of unit tests 27 | 28 | # mcboost 0.3.2 29 | * Minor adpations to improve stability of unit tests 30 | 31 | # mcboost 0.3.1 32 | 33 | * Fixed a bug for additive weight updates, were updates went 34 | in the wrong direction. 35 | * Added new parameter `eval_fulldata` that allows to compute 36 | auditor effect across the full sample (as opposed to the bucket). 37 | 38 | # mcboost 0.3.0 39 | 40 | * First CRAN-ready version of the package. 41 | * Added a `NEWS.md` file to track changes to the package. 42 | -------------------------------------------------------------------------------- /R/AuditorFitters.R: -------------------------------------------------------------------------------- 1 | #' AuditorFitter Abstract Base Class 2 | #' @description 3 | #' Defines an `AuditorFitter` abstract base class. 4 | #' @export 5 | AuditorFitter = R6::R6Class("AuditorFitter", 6 | public = list( 7 | #' @description 8 | #' Initialize a [`AuditorFitter`]. 9 | #' This is an abstract base class. 10 | initialize = function() { 11 | }, 12 | #' @description 13 | #' Fit to residuals. 14 | #' @template params_data_resid 15 | #' @template params_mask 16 | #' @template return_fit 17 | fit_to_resid = function(data, resid, mask) { # 18 | 19 | #Learners fail on constant residuals. 20 | if (all(unlist(unique(resid)) == unlist(resid[1]))) { 21 | return(list(0, ConstantPredictor$new(0))) 22 | } 23 | self$fit(data, resid, mask) 24 | }, 25 | #' @description 26 | #' Fit (mostly used internally, use `fit_to_resid`). 27 | fit = function(data, resid, mask) { 28 | stop("Not implemented") 29 | } 30 | ) 31 | ) 32 | 33 | #' Create an AuditorFitter from a Learner 34 | #' @description 35 | #' Instantiates an AuditorFitter that trains a [`mlr3::Learner`] 36 | #' on the data. 37 | #' @family AuditorFitter 38 | #' @export 39 | LearnerAuditorFitter = R6::R6Class("LearnerAuditorFitter", 40 | inherit = AuditorFitter, 41 | public = list( 42 | #' @field learner `LearnerPredictor`\cr 43 | #' Learner used for fitting residuals. 44 | learner = NULL, 45 | #' @description 46 | #' Define an `AuditorFitter` from a Learner. 47 | #' Available instantiations:\cr [`TreeAuditorFitter`] (rpart) and 48 | #' [`RidgeAuditorFitter`] (glmnet). 49 | #' 50 | #' @param learner [`mlr3::Learner`]\cr 51 | #' Regression learner to use. 52 | #' @template return_auditor 53 | initialize = function(learner) { 54 | self$learner = LearnerPredictor$new(learner) 55 | }, 56 | #' @description 57 | #' Fit the learner and compute correlation 58 | #' 59 | #' @template params_data_resid 60 | #' @template params_mask 61 | #' @template return_fit 62 | fit = function(data, resid, mask) { 63 | l = self$learner$clone() 64 | l$fit(data, resid) 65 | h = l$predict(data) 66 | corr = mean(h * resid) 67 | return(list(corr, l)) 68 | } 69 | ) 70 | ) 71 | 72 | #' @describeIn LearnerAuditorFitter Learner auditor based on rpart 73 | #' @family AuditorFitter 74 | #' @export 75 | TreeAuditorFitter = R6::R6Class("TreeAuditorFitter", 76 | inherit = LearnerAuditorFitter, 77 | public = list( 78 | #' @description 79 | #' Define a AuditorFitter from a rpart learner. 80 | initialize = function() { 81 | mlr3misc::require_namespaces("rpart") 82 | super$initialize(learner = lrn("regr.rpart")) 83 | } 84 | ) 85 | ) 86 | 87 | #' @describeIn LearnerAuditorFitter Learner auditor based on glmnet 88 | #' @family AuditorFitter 89 | #' @export 90 | RidgeAuditorFitter = R6::R6Class("RidgeAuditorFitter", 91 | inherit = LearnerAuditorFitter, 92 | public = list( 93 | #' @description 94 | #' Define a AuditorFitter from a glmnet learner. 95 | initialize = function() { 96 | mlr3misc::require_namespaces(c("mlr3learners", "glmnet")) 97 | super$initialize(learner = lrn("regr.glmnet", alpha = 0, s = 0.01)) 98 | } 99 | ) 100 | ) 101 | 102 | #' Static AuditorFitter based on Subpopulations 103 | #' @description 104 | #' Used to assess multi-calibration based on a list of 105 | #' binary valued columns: `subpops` passed during initialization. 106 | #' @family AuditorFitter 107 | #' @examples 108 | #' library("data.table") 109 | #' data = data.table( 110 | #' "AGE_NA" = c(0, 0, 0, 0, 0), 111 | #' "AGE_0_10" = c(1, 1, 0, 0, 0), 112 | #' "AGE_11_20" = c(0, 0, 1, 0, 0), 113 | #' "AGE_21_31" = c(0, 0, 0, 1, 1), 114 | #' "X1" = runif(5), 115 | #' "X2" = runif(5) 116 | #' ) 117 | #' label = c(1,0,0,1,1) 118 | #' pops = list("AGE_NA", "AGE_0_10", "AGE_11_20", "AGE_21_31", function(x) {x[["X1" > 0.5]]}) 119 | #' sf = SubpopAuditorFitter$new(subpops = pops) 120 | #' sf$fit(data, label - 0.5) 121 | #' @export 122 | SubpopAuditorFitter = R6::R6Class("SubpopAuditorFitter", 123 | inherit = AuditorFitter, 124 | public = list( 125 | #' @field subpops [`list`] \cr 126 | #' List of subpopulation indicators. 127 | subpops = NULL, 128 | #' Initialize a SubpopAuditorFitter 129 | #' @description 130 | #' Initializes a [`SubpopAuditorFitter`] that 131 | #' assesses multi-calibration within each group defined 132 | #' by the `subpops'. Names in `subpops` must correspond to 133 | #' columns in the data. 134 | #' 135 | #' @template params_subpops 136 | #' @template return_auditor 137 | initialize = function(subpops) { 138 | assert_list(subpops) 139 | self$subpops = map(subpops, function(pop) { 140 | # Can be character (referring to a column) 141 | if (is.character(pop)) { 142 | function(rw) { 143 | rw[[pop]] 144 | } 145 | } else { 146 | assert_function(pop) 147 | } 148 | }) 149 | }, 150 | #' @description 151 | #' Fit the learner and compute correlation 152 | #' 153 | #' @template params_data_resid 154 | #' @template params_mask 155 | #' @template return_fit 156 | fit = function(data, resid, mask) { 157 | worstCorr = 0 158 | worst_subpop = function(pt) { 159 | return(rep(0L, nrow(pt))) 160 | } # nocov 161 | for (sfn in self$subpops) { 162 | sub = data[, sfn(.SD)] 163 | corr = mean(sub * resid) 164 | if (abs(corr) > abs(worstCorr)) { 165 | worstCorr = corr 166 | worst_subpop = sfn 167 | } 168 | } 169 | return(list(worstCorr, SubpopPredictor$new(worst_subpop, worstCorr))) 170 | } 171 | ) 172 | ) 173 | 174 | #' @title Static AuditorFitter based on Subgroups 175 | #' @description 176 | #' Used to assess multi-calibration based on a list of 177 | #' binary `subgroup_masks` passed during initialization. 178 | #' @family AuditorFitter 179 | #' @examples 180 | #' library("data.table") 181 | #' data = data.table( 182 | #' "AGE_0_10" = c(1, 1, 0, 0, 0), 183 | #' "AGE_11_20" = c(0, 0, 1, 0, 0), 184 | #' "AGE_21_31" = c(0, 0, 0, 1, 1), 185 | #' "X1" = runif(5), 186 | #' "X2" = runif(5) 187 | #' ) 188 | #' label = c(1,0,0,1,1) 189 | #' masks = list( 190 | #' "M1" = c(1L, 0L, 1L, 1L, 0L), 191 | #' "M2" = c(1L, 0L, 0L, 0L, 1L) 192 | #' ) 193 | #' sg = SubgroupAuditorFitter$new(masks) 194 | #' @export 195 | SubgroupAuditorFitter = R6::R6Class("SubgroupAuditorFitter", 196 | inherit = AuditorFitter, 197 | public = list( 198 | #' @field subgroup_masks [`list`] \cr 199 | #' List of subgroup masks. 200 | subgroup_masks = NULL, 201 | #' Initialize a SubgroupAuditorFitter 202 | #' @description 203 | #' Initializes a [`SubgroupAuditorFitter`] that 204 | #' assesses multi-calibration within each group defined 205 | #' by the `subpops'. 206 | #' 207 | #' @param subgroup_masks [`list`] \cr 208 | #' List of subgroup masks. Subgroup masks are list(s) of integer masks, 209 | #' each with the same length as data to be fitted on. 210 | #' They allow defining subgroups of the data. 211 | #' @template return_auditor 212 | initialize = function(subgroup_masks) { 213 | subgroup_masks = tryCatch({ 214 | map(subgroup_masks, as.integer) 215 | }, 216 | warning = function(w) { 217 | stop("subgroup_masks must be a list of integers.") 218 | }) 219 | self$subgroup_masks = assert_list(subgroup_masks, types = "integer") 220 | if (!all(map_lgl(self$subgroup_masks, function(x) { 221 | test_numeric(x, lower = 0, upper = 1) 222 | }))) { 223 | stop("subgroup_masks must be binary vectors") 224 | } 225 | }, 226 | #' @description 227 | #' Fit the learner and compute correlation 228 | #' 229 | #' @template params_data_resid 230 | #' @template params_mask 231 | #' @template return_fit 232 | fit = function(data, resid, mask) { 233 | sg = map(self$subgroup_masks, function(x) x[mask]) 234 | if (!all(map_lgl(sg, function(x) { 235 | nrow(data) == length(x) 236 | }))) { 237 | stop("Length of subgroup masks must match length of data!") 238 | } 239 | m = SubgroupModel$new(sg) 240 | m$fit(data, resid) 241 | preds = m$predict(data) 242 | corr = mean(preds * resid) 243 | return(list(corr, m)) 244 | } 245 | ) 246 | ) 247 | 248 | #' Cross-validated AuditorFitter from a Learner 249 | #' @description CVLearnerAuditorFitter returns the cross-validated predictions 250 | #' instead of the in-sample predictions. 251 | #' 252 | #' Available data is cut into complementary subsets (folds). 253 | #' For each subset out-of-sample predictions are received by training a model 254 | #' on all other subsets and predicting afterwards on the left-out subset. 255 | #' @family AuditorFitter 256 | #' @export 257 | CVLearnerAuditorFitter = R6::R6Class("CVLearnerAuditorFitter", 258 | inherit = AuditorFitter, 259 | public = list( 260 | #' @field learner `CVLearnerPredictor`\cr 261 | #' Learner used for fitting residuals. 262 | learner = NULL, 263 | #' @description 264 | #' Define a `CVAuditorFitter` from a learner. 265 | #' Available instantiations:\cr [`CVTreeAuditorFitter`] (rpart) and 266 | #' [`CVRidgeAuditorFitter`] (glmnet). 267 | #' See [`mlr3pipelines::PipeOpLearnerCV`] for more information on 268 | #' cross-validated learners. 269 | #' 270 | #' @param learner [`mlr3::Learner`]\cr 271 | #' Regression Learner to use. 272 | #' @param folds [`integer`]\cr 273 | #' Number of folds to use for PipeOpLearnerCV. Defaults to 3. 274 | #' @template return_auditor 275 | initialize = function(learner, folds = 3L) { 276 | self$learner = CVLearnerPredictor$new(learner, folds) 277 | }, 278 | #' @description 279 | #' Fit the cross-validated learner and compute correlation 280 | #' 281 | #' @template params_data_resid 282 | #' @template params_mask 283 | #' @template return_fit 284 | fit = function(data, resid, mask) { 285 | l = self$learner$clone() 286 | h = l$fit_transform(data, resid) 287 | corr = mean(h * resid) 288 | return(list(corr, l)) 289 | } 290 | ) 291 | ) 292 | 293 | #' @describeIn CVLearnerAuditorFitter Cross-Validated auditor based on rpart 294 | #' @family AuditorFitter 295 | #' @export 296 | CVTreeAuditorFitter = R6::R6Class("CVTreeAuditorFitter", 297 | inherit = CVLearnerAuditorFitter, 298 | public = list( 299 | #' @description 300 | #' Define a cross-validated AuditorFitter from a rpart learner 301 | #' See [`mlr3pipelines::PipeOpLearnerCV`] for more information on 302 | #' cross-validated learners. 303 | initialize = function() { 304 | mlr3misc::require_namespaces(c("mlr3learners", "rpart")) 305 | super$initialize(learner = lrn("regr.rpart")) 306 | } 307 | ) 308 | ) 309 | 310 | #' @describeIn CVLearnerAuditorFitter Cross-Validated auditor based on glmnet 311 | #' @family AuditorFitter 312 | #' @export 313 | CVRidgeAuditorFitter = R6::R6Class("CVRidgeAuditorFitter", 314 | inherit = CVLearnerAuditorFitter, 315 | public = list( 316 | #' @description 317 | #' Define a cross-validated AuditorFitter from a glmnet learner. 318 | #' See [`mlr3pipelines::PipeOpLearnerCV`] for more information on 319 | #' cross-validated learners. 320 | initialize = function() { 321 | mlr3misc::require_namespaces(c("mlr3learners", "glmnet")) 322 | super$initialize(learner = lrn("regr.glmnet", alpha = 0, lambda = .01)) 323 | } 324 | ) 325 | ) 326 | -------------------------------------------------------------------------------- /R/PipeOpLearnerPred.R: -------------------------------------------------------------------------------- 1 | #' @title Multi-Calibrate a Learner's Prediction 2 | #' 3 | #' @usage NULL 4 | #' @name mlr_pipeops_mcboost 5 | #' @format [`R6Class`] inheriting from [`mlr3pipelines::PipeOp`]. 6 | #' 7 | #' @description 8 | #' [`mlr3pipelines::PipeOp`] that trains a [`Learner`][mlr3::Learner] and passes its predictions forward during training and prediction. 9 | #' 10 | #' @section Construction: 11 | #' ``` 12 | #'PipeOpLearnerPred$new(learner, id = NULL, param_vals = list()) 13 | #' 14 | #' * `learner` :: [`Learner`][mlr3::Learner] \cr 15 | #' [`Learner`][mlr3::Learner] to prediction, or a string identifying a 16 | #' [`Learner`][mlr3::Learner] in the [`mlr3::mlr_learners`] [`Dictionary`][mlr3misc::Dictionary]. 17 | #' * `id` :: `character(1)` 18 | #' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped. 19 | #' * `param_vals` :: named `list`\cr 20 | #' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`. 21 | #' 22 | #' @section Input and Output Channels: 23 | #' [`PipeOpLearnerPred`] has one input channel named `"input"`, taking a [`Task`][mlr3::Task] specific to the [`Learner`][mlr3::Learner] 24 | #' type given to `learner` during construction; both during training and prediction. 25 | #' 26 | #' [`PipeOpLearnerPred`] has one output channel named `"output"`, producing a [`Task`][mlr3::Task] specific to the [`Learner`][mlr3::Learner] 27 | #' type given to `learner` during construction; both during training and prediction. 28 | #' 29 | #' @section State: 30 | # 31 | #' @section Parameters: 32 | #' The `$state` is set to the `$state` slot of the [`Learner`][mlr3::Learner] object, together with the `$state` elements inherited from 33 | #' [`mlr3pipelines::PipeOpTaskPreproc`]. It is a named `list` with the inherited members, as well as: 34 | #' * `model` :: `any`\cr 35 | #' Model created by the [`Learner`][mlr3::Learner]'s `$.train()` function. 36 | #' * `train_log` :: [`data.table`] with columns `class` (`character`), `msg` (`character`)\cr 37 | #' Errors logged during training. 38 | #' * `train_time` :: `numeric(1)`\cr 39 | #' Training time, in seconds. 40 | #' * `predict_log` :: `NULL` | [`data.table`] with columns `class` (`character`), `msg` (`character`)\cr 41 | #' Errors logged during prediction. 42 | #' * `predict_time` :: `NULL` | `numeric(1)` 43 | #' Prediction time, in seconds. 44 | #' 45 | #' @section Fields: 46 | #' Fields inherited from [`PipeOp`], as well as: 47 | #' * `learner` :: [`Learner`][mlr3::Learner]\cr 48 | #' [`Learner`][mlr3::Learner] that is being wrapped. Read-only. 49 | #' * `learner_model` :: [`Learner`][mlr3::Learner]\cr 50 | #' [`Learner`][mlr3::Learner] that is being wrapped. This learner contains the model if the `PipeOp` is trained. Read-only. 51 | #' 52 | #' @section Methods: 53 | #' Methods inherited from [`mlr3pipelines::PipeOpTaskPreproc`]/[`mlr3pipelines::PipeOp`]. 54 | #' 55 | #' @family PipeOps 56 | #' @seealso https://mlr3book.mlr-org.com/list-pipeops.html 57 | #' @export 58 | PipeOpLearnerPred = R6Class("PipeOpLearnerPred", 59 | inherit = mlr3pipelines::PipeOpTaskPreproc, 60 | public = list( 61 | #' @description 62 | #' Initialize a Learner Predictor PipeOp. Can be used to wrap trained or untrainted 63 | #' mlr3 learners. 64 | #' @param learner [`Learner`]\cr 65 | #' The learner that should be wrapped. 66 | #' @param id [`character`] \cr 67 | #' The `PipeOp`'s id. Defaults to "mcboost". 68 | #' @param param_vals [`list`] \cr 69 | #' List of hyperparameters for the `PipeOp`. 70 | initialize = function(learner, id = NULL, param_vals = list()) { 71 | private$.learner = as_learner(learner, clone = TRUE) 72 | private$.learner$param_set$set_id = "" 73 | id = id %??% private$.learner$id 74 | task_type = mlr_reflections$task_types[get("type") == private$.learner$task_type][order(get("package"))][1L]$task 75 | super$initialize(id, alist(private$.learner$param_set), 76 | param_vals = param_vals, 77 | can_subset_cols = TRUE, 78 | task_type = task_type, 79 | tags = c("learner") 80 | ) 81 | } 82 | 83 | ), 84 | active = list( 85 | #' @field learner The wrapped learner. 86 | learner = function(val) { 87 | if (!missing(val)) { 88 | if (!identical(val, private$.learner)) { 89 | stop("$learner is read-only.") 90 | } 91 | } 92 | private$.learner 93 | }, 94 | #' @field learner_model The wrapped learner's model(s). 95 | learner_model = function(val) { 96 | if (!missing(val)) { 97 | if (!identical(val, private$.learner)) { 98 | stop("$learner is read-only.") 99 | } 100 | } 101 | if (is.null(self$state) || mlr3pipelines::is_noop(self$state)) { 102 | private$.learner 103 | } else { 104 | multiplicity_recurse(self$state, clone_with_state, learner = private$.learner) 105 | } 106 | } 107 | ), 108 | private = list( 109 | .train_task = function(task) { 110 | on.exit({private$.learner$state = NULL}) 111 | 112 | # Train a learner for predicting 113 | state = private$.learner$state 114 | if (is.null(state)) { 115 | self$state = private$.learner$train(task)$state 116 | } else { 117 | self$state = state 118 | } 119 | 120 | prds = as.data.table(private$.learner$predict(task)) 121 | private$pred_to_task(prds, task) 122 | }, 123 | 124 | .predict_task = function(task) { 125 | on.exit({private$.learner$state = NULL}) 126 | private$.learner$state = self$state 127 | prediction = as.data.table(private$.learner$predict(task)) 128 | private$pred_to_task(prediction, task) 129 | }, 130 | 131 | pred_to_task = function(prds, task) { 132 | renaming = setdiff(colnames(prds), c( "row_ids")) 133 | setnames(prds, renaming, sprintf("%s.%s", self$id, renaming)) 134 | setnames(prds, old = "row_ids", new = task$backend$primary_key) 135 | task$select(character(0))$cbind(prds) 136 | }, 137 | .learner = NULL 138 | ) 139 | ) 140 | 141 | clone_with_state = function(learner, state) { 142 | lrn = learner$clone(deep = TRUE) 143 | lrn$state = state 144 | lrn 145 | } 146 | 147 | multiplicity_recurse = function(.multip, .fun, ...) { 148 | if (mlr3pipelines::is.Multiplicity(.multip)) { 149 | mlr3pipelines::as.Multiplicity(lapply(.multip, function(m) multiplicity_recurse(.multip = m, .fun = .fun, ...))) 150 | } else { 151 | .fun(.multip, ...) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /R/PipeOpMCBoost.R: -------------------------------------------------------------------------------- 1 | #' @title Multi-Calibrate a Learner's Prediction 2 | #' 3 | #' @usage NULL 4 | #' @name mlr_pipeops_mcboost 5 | #' @format [`R6Class`] inheriting from [`mlr3pipelines::PipeOp`]. 6 | #' 7 | #' @description 8 | #' Post-process a learner prediction using multi-calibration. 9 | #' For more details, please refer to \url{https://arxiv.org/pdf/1805.12317.pdf} (Kim et al. 2018) 10 | #' or the help for [`MCBoost`]. 11 | #' If no `init_predictor` is provided, the preceding learner's predictions 12 | #' corresponding to the `prediction` slot are used as an initial predictor for `MCBoost`. 13 | #' 14 | #' @section Construction: 15 | #' ``` 16 | #' PipeOpMCBoost$new(id = "mcboost", param_vals = list()) 17 | #' ``` 18 | #' * `id` :: `character(1)` 19 | #' Identifier of the resulting object, default `"threshold"`. 20 | #' * `param_vals` :: named `list`\cr 21 | #' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. 22 | #' See `MCBoost` for a comprehensive description of all hyperparameters. 23 | #' 24 | #' @section Input and Output Channels: 25 | #' During training, the input and output are `"data"` and `"prediction"`, two [`TaskClassif`][mlr3::TaskClassif]. 26 | #' A [`PredictionClassif`][mlr3::PredictionClassif] is required as input and returned as output during prediction. 27 | #' 28 | #' @section State: 29 | #' The `$state` is a `MCBoost` Object as obtained from `MCBoost$new()`. 30 | #' 31 | #' @section Parameters: 32 | #' * `max_iter` :: `integer`\cr 33 | #' A integer specifying the number of multi-calibration rounds. Defaults to 5. 34 | #' 35 | #' @section Fields: 36 | #' Only fields inherited from [`mlr3pipelines::PipeOp`]. 37 | #' 38 | #' @section Methods: 39 | #' Only methods inherited from [`mlr3pipelines::PipeOp`]. 40 | #' 41 | #' @examples 42 | #' \dontrun{ 43 | #' gr = gunion(list( 44 | #' "data" = po("nop"), 45 | #' "prediction" = po("learner_cv", lrn("classif.rpart")) 46 | #' )) %>>% 47 | #' PipeOpMCBoost$new() 48 | #' tsk = tsk("sonar") 49 | #' tid = sample(1:208, 108) 50 | #' gr$train(tsk$clone()$filter(tid)) 51 | #' gr$predict(tsk$clone()$filter(setdiff(1:208, tid))) 52 | #' } 53 | #' @family PipeOps 54 | #' @seealso https://mlr3book.mlr-org.com/list-pipeops.html 55 | #' @export 56 | PipeOpMCBoost = R6Class("PipeOpMCBoost", 57 | inherit = mlr3pipelines::PipeOp, 58 | public = list( 59 | 60 | #' @description 61 | #' Initialize a Multi-Calibration PipeOp. 62 | #' 63 | #' @param id [`character`] \cr 64 | #' The `PipeOp`'s id. Defaults to "mcboost". 65 | #' @param param_vals [`list`] \cr 66 | #' List of hyperparameters for the `PipeOp`. 67 | initialize = function(id = "mcboost", param_vals = list()) { 68 | param_set = paradox::ps( 69 | max_iter = paradox::p_int(lower = 0L, upper = Inf, default = 5L, tags = "train"), 70 | alpha = paradox::p_dbl(lower = 0, upper = 1, default = 1e-4, tags = "train"), 71 | eta = paradox::p_dbl(lower = 0, upper = 1, default = 1, tags = "train"), 72 | partition = paradox::p_lgl(tags = "train", default = TRUE), 73 | num_buckets = paradox::p_int(lower = 1, upper = Inf, default = 2L, tags = "train"), 74 | rebucket = paradox::p_lgl(default = FALSE, tags = "train"), 75 | multiplicative = paradox::p_lgl(default = TRUE, tags = "train"), 76 | auditor_fitter = paradox::p_uty(default = NULL, tags = "train"), 77 | subpops = paradox::p_uty(default = NULL, tags = "train"), 78 | default_model_class = paradox::p_uty(default = ConstantPredictor, tags = "train"), 79 | init_predictor = paradox::p_uty(default = NULL, tags = "train") 80 | ) 81 | super$initialize(id, 82 | param_set = param_set, param_vals = param_vals, packages = character(0), 83 | input = data.table(name = c("data", "prediction"), train = c("TaskClassif", "TaskClassif"), predict = c("TaskClassif", "TaskClassif")), 84 | output = data.table(name = "output", train = "NULL", predict = "PredictionClassif"), 85 | tags = "target transform") 86 | } 87 | ), 88 | private = list( 89 | .train = function(inputs) { 90 | 91 | d = inputs$data$data(cols = inputs$data$feature_names) 92 | l = inputs$data$data(cols = inputs$data$target_names) 93 | 94 | args = self$param_set$get_values(tags = "train") 95 | 96 | if (is.null(args$init_predictor)) { 97 | # Construct an initial predictor from the input model if non is provided. 98 | init_predictor = function(data, prediction) { 99 | # Prob or response prediction 100 | if (length(prediction$feature_names) > 1L) { 101 | prds = prediction$data(cols = prediction$feature_names) 102 | as.matrix(prds) 103 | } else { 104 | prds = prediction$data(cols = prediction$feature_names)[[1]] 105 | one_hot(prds) 106 | } 107 | } 108 | args$init_predictor = init_predictor 109 | } 110 | mc = invoke(MCBoost$new, .args = args) 111 | mc$multicalibrate(d, l, predictor_args = inputs$prediction) 112 | self$state = list("mc" = mc) 113 | list(NULL) 114 | }, 115 | 116 | .predict = function(inputs) { 117 | d = inputs$data$data(cols = inputs$data$feature_names) 118 | prob = self$state$mc$predict_probs(d, predictor_args = inputs$prediction) 119 | prob = cbind(1 - prob, prob) 120 | lvls = c(inputs$prediction$negative, inputs$prediction$positive) 121 | colnames(prob) = lvls 122 | list(PredictionClassif$new( 123 | inputs$prediction, 124 | row_ids = inputs$prediction$row_ids, 125 | truth = inputs$prediction$truth(), 126 | prob = prob 127 | )) 128 | } 129 | ), 130 | active = list( 131 | #' @field predict_type Predict type of the PipeOp. 132 | predict_type = function(val) { 133 | if (!missing(val)) { 134 | if (!identical(val, private$.learner)) { 135 | stop("$predict_type for PipeOpMCBoost is read-only.") 136 | } 137 | } 138 | return("prob") 139 | } 140 | ) 141 | ) 142 | 143 | 144 | -------------------------------------------------------------------------------- /R/PipelineMCBoost.R: -------------------------------------------------------------------------------- 1 | #' Multi-calibration pipeline 2 | #' 3 | #' Wraps MCBoost in a Pipeline to be used with `mlr3pipelines`. 4 | #' For now this assumes training on the same dataset that is later used 5 | #' for multi-calibration. 6 | #' @param learner (mlr3)[`mlr3::Learner`]\cr 7 | #' Initial learner. Internally wrapped into a `PipeOpLearnerCV` 8 | #' with `resampling.method = "insample"` as a default. 9 | #' All parameters can be adjusted through the resulting Graph's `param_set`. 10 | #' Defaults to `lrn("classif.featureless")`. 11 | #' Note: An initial predictor can also be supplied via the `init_predictor` parameter. 12 | #' @param param_vals `list` \cr 13 | #' List of parameter values passed on to `MCBoost$new`. 14 | #' @return (mlr3pipelines) [`Graph`] 15 | #' @examples 16 | #' \dontrun{ 17 | #' library("mlr3pipelines") 18 | #' gr = ppl_mcboost() 19 | #' } 20 | #' @export 21 | ppl_mcboost = function(learner = lrn("classif.featureless"), param_vals = list()) { 22 | mlr3misc::require_namespaces("mlr3pipelines") 23 | po_lrn = mlr3pipelines::po("learner_cv", learner = learner, resampling.method = "insample") 24 | gr = mlr3pipelines::`%>>%`( 25 | mlr3pipelines::gunion(list( 26 | "data" = mlr3pipelines::po("nop"), 27 | "prediction" = po_lrn 28 | )), 29 | PipeOpMCBoost$new(param_vals = param_vals) 30 | ) 31 | } 32 | -------------------------------------------------------------------------------- /R/Predictor.R: -------------------------------------------------------------------------------- 1 | #' Predictor 2 | #' @family Predictor 3 | #' @noRd 4 | Predictor = R6::R6Class("Predictor", 5 | public = list( 6 | #' @description 7 | #' Instantiate a Predictor 8 | initialize = function() { 9 | invisible(self) 10 | }, 11 | #' @description 12 | #' Fit the predictor. 13 | #' @template params_data_label 14 | fit = function(data, labels) { 15 | stop("Abstract base class") 16 | }, 17 | #' @description 18 | #' Predict a dataset with constant predictions. 19 | #' @param data [`data.table`] \cr 20 | #' Prediction data. 21 | #' @param ... [`any`] \cr 22 | #' Not used, only for compatibility with other methods. 23 | predict = function(data, ...) { 24 | stop("Abstract base Class") 25 | } 26 | ) 27 | ) 28 | 29 | #' ConstantPredictor 30 | #' @family Predictor 31 | #' @noRd 32 | ConstantPredictor = R6::R6Class("ConstantPredictor", 33 | inherit = Predictor, 34 | public = list( 35 | #' @field constant [`numeric`]\cr 36 | #' mlr3 Constant to predict with. 37 | constant = 0., 38 | #' @field is_fitted [`logical`]\cr 39 | #' Whether the model is fitted. 40 | is_fitted = TRUE, 41 | #' @description 42 | #' Instantiate a ConstantPredictor 43 | #' 44 | #' @param constant [`numeric`]\cr 45 | #' Constant to predict with. 46 | #' @template return_predictor 47 | initialize = function(constant = 0.5) { 48 | self$constant = assert_number(constant) 49 | invisible(self) 50 | }, 51 | #' @description 52 | #' Fit the constant predictor. 53 | #' Does nothing. 54 | #' @template params_data_label 55 | fit = function(data, labels) { 56 | }, 57 | #' @description 58 | #' Predict a dataset with constant predictions. 59 | #' @param data [`data.table`] \cr 60 | #' Prediction data. 61 | #' @param ... [`any`] \cr 62 | #' Not used, only for compatibility with other methods. 63 | predict = function(data, ...) { 64 | rep(self$constant, nrow(data)) 65 | } 66 | ) 67 | ) 68 | 69 | #' LearnerPredictor 70 | #' @family Predictor 71 | #' Wraps a mlr3 Learner into a `LearnerPredictor` object that can be used 72 | #' with mcboost. 73 | #' @noRd 74 | LearnerPredictor = R6::R6Class("LearnerPredictor", 75 | inherit = Predictor, 76 | public = list( 77 | #' @field learner [`mlr3::Learner`]\cr 78 | #' mlr3 Learner used for fitting residuals. 79 | learner = NULL, 80 | #' @description 81 | #' Instantiate a LearnerPredictor 82 | #' 83 | #' @param learner [`mlr3::Learner`]\cr 84 | #' Learner used for train/predict. 85 | #' @template return_predictor 86 | initialize = function(learner) { 87 | self$learner = assert_class(learner, "Learner") 88 | }, 89 | #' @description 90 | #' Fit the learner. 91 | #' @template params_data_label 92 | fit = function(data, labels) { 93 | task = xy_to_task(data, labels) 94 | self$learner$train(task) 95 | }, 96 | #' @description 97 | #' Predict a dataset with leaner predictions. 98 | #' @param data [`data.table`] \cr 99 | #' Prediction data. 100 | #' @param ... [`any`] \cr 101 | #' Not used, only for compatibility with other methods. 102 | predict = function(data, ...) { 103 | prd = self$learner$predict_newdata(data) 104 | if (inherits(prd, "PredictionRegr")) { 105 | return(prd$response) 106 | } else if (inherits(prd, "PredictionClassif")) { 107 | if ("prob" %in% self$learner$predict_type) { 108 | p = prd$prob 109 | if (ncol(p) == 2L) p = p[, 1L] 110 | } else { 111 | p = one_hot(prd$response) 112 | } 113 | return(p) 114 | } else if (inherits(prd ,"PredictionSurv")) { 115 | return(as.data.table(prd)$distr[[1]][[1]]) 116 | } 117 | } 118 | ), 119 | active = list( 120 | #' @field is_fitted [`logical`]\cr 121 | #' Whether the Learner is trained 122 | is_fitted = function() { 123 | !is.null(self$learner$state) 124 | } 125 | ) 126 | ) 127 | 128 | 129 | #' SubpopPredictor 130 | #' @family Predictor 131 | #' @noRd 132 | SubpopPredictor = R6::R6Class("SubpopPredictor", 133 | inherit = Predictor, 134 | public = list( 135 | 136 | #' @field subpop [`function`] \cr 137 | #' A [`function`] that evaluates to binary for each row in a dataset. 138 | #' Defines a sub-population. 139 | subpop = NULL, 140 | #' @field value [`numeric`] \cr 141 | #' A correlation value. 142 | value = numeric(1), 143 | 144 | #' @description 145 | #' Instantiate a SubpopPredictor 146 | #' @param subpop [`character`]|[`function`] \cr 147 | #' Either a [`function`], that yields a binary value for each 148 | #' row in a dataset, or a [`character`] string referring to a 149 | #' feature column, that defines a sub-population. 150 | #' @param value [`numeric`] \cr 151 | #' Correlation value for the given subpop. 152 | #' @template return_predictor 153 | initialize = function(subpop, value) { 154 | # Can be character (referring to a column) or a function. 155 | if (is.character(subpop)) { 156 | self$subpop = function(rw) { 157 | rw[[subpop]] 158 | } # nocov 159 | } else { 160 | self$subpop = assert_function(subpop) 161 | } 162 | self$value = assert_number(value) 163 | invisible(self) 164 | }, 165 | #' @description 166 | #' Fit the predictor. 167 | #' @template params_data_label 168 | fit = function(data, labels) { 169 | }, 170 | #' @description 171 | #' Predict a dataset with sub-population predictions. 172 | #' @param data [`data.table`] \cr 173 | #' Prediction data. 174 | #' @param ... [`any`] \cr 175 | #' Not used, only for compatibility with other methods. 176 | #' @template return_predictor 177 | predict = function(data, ...) { 178 | data[, self$subpop(.SD)] * self$value 179 | } 180 | ) 181 | ) 182 | 183 | 184 | #' SubgroupModel 185 | #' @family Predictor 186 | #' @noRd 187 | SubgroupModel = R6::R6Class("SubgroupModel", 188 | public = list( 189 | #' @field subgroup_masks [`list`] \cr 190 | #' List of subgroup masks. 191 | subgroup_masks = NULL, 192 | #' @field subgroup_preds [`list`] \cr 193 | #' List of subgroup predictions after fitting. 194 | subgroup_preds = NULL, 195 | #' @description 196 | #' Instantiate a SubpopPredictor 197 | #' @param subgroup_masks [`list`] \cr 198 | #' List of subgroup masks. 199 | #' @template return_predictor 200 | initialize = function(subgroup_masks) { 201 | self$subgroup_masks = assert_list(subgroup_masks) 202 | invisible(self) 203 | }, 204 | #' @description 205 | #' Fit the predictor. 206 | #' @template params_data_label 207 | fit = function(data, labels) { 208 | self$subgroup_preds = map(self$subgroup_masks, function(mask) { 209 | mean(labels[as.logical(mask)]) 210 | }) 211 | }, 212 | #' @description 213 | #' Predict a dataset with sub-population predictions. 214 | #' @param data [`data.table`] \cr 215 | #' Prediction data. 216 | #' @param subgroup_masks [`list`] \cr 217 | #' List of subgroup masks for the data. 218 | #' @param partition_mask [`integer`] \cr 219 | #' Mask defined by partitions. 220 | predict = function(data, subgroup_masks = NULL, partition_mask = NULL) { 221 | # Check that masks fit 222 | if (is.null(subgroup_masks)) { 223 | subgroup_masks = self$subgroup_masks 224 | } 225 | if (!all(map_lgl(subgroup_masks, function(x) { 226 | nrow(data) == length(x) 227 | }))) { 228 | stop("Length of subgroup masks must match length of data!\n 229 | Subgroups are currently not implemented for 'partition=TRUE'.") 230 | } 231 | # If no paritition mask, use all datapoints 232 | if (is.null(partition_mask)) partition_mask = rep(1L, nrow(data)) 233 | # Predict 234 | preds = numeric(nrow(data)) 235 | for (i in seq_along(self$subgroup_preds)) { 236 | preds[subgroup_masks[[i]] & partition_mask] = self$subgroup_preds[[i]] 237 | } 238 | return(preds) 239 | } 240 | ) 241 | ) 242 | 243 | #' CVLearnerPredictor 244 | #' @family Predictor 245 | #' @description Wraps a mlr3 Learner into a `CVLearnerPredictor` object that can be used 246 | #' with mcboost. Internally cross-validates predictions. 247 | #' @noRd 248 | CVLearnerPredictor = R6::R6Class("CVLearnerPredictor", 249 | inherit = Predictor, 250 | public = list( 251 | #' @field pipeop [`mlr3::Learner`]\cr 252 | #' mlr3pipelines PipeOp used for fitting residuals. 253 | pipeop = NULL, 254 | 255 | #' @description 256 | #' Instantiate a LearnerPredictor with internal cross-validation. 257 | #' See [`mlr3pipelines::PipeOpLearnerCV`] for more information. 258 | #' 259 | #' @param learner [`mlr3::Learner`]\cr 260 | #' Learner used for train/predict. 261 | #' @param folds [`integer`]\cr 262 | #' Number of folds to use for PipeOpLearnerCV. 263 | #' @template return_predictor 264 | initialize = function(learner, folds) { 265 | self$pipeop = mlr3pipelines::po("learner_cv", learner, resampling.folds = folds) 266 | }, 267 | #' @description 268 | #' Fit the learner. 269 | #' @template params_data_label 270 | fit_transform = function(data, labels) { 271 | task = xy_to_task(data, labels) 272 | t = self$pipeop$train(list(task))$output 273 | return(as.matrix(t$data(cols = t$feature_names))) 274 | }, 275 | #' @description 276 | #' Predict a dataset with leaner predictions. 277 | #' @param data [`data.table`] \cr 278 | #' Prediction data. 279 | #' @param ... [`any`] \cr 280 | #' Not used, only for compatibility with other methods. 281 | predict = function(data, ...) { 282 | task = xy_to_task(data, runif(NROW(data))) 283 | t = self$pipeop$predict(list(task))$output 284 | return(as.matrix(t$data(cols = t$feature_names))) 285 | } 286 | ), 287 | active = list( 288 | #' @field is_fitted [`logical`]\cr 289 | #' Whether the Learner is trained 290 | is_fitted = function() { 291 | !is.null(self$pipeop$state) 292 | } 293 | ) 294 | ) 295 | -------------------------------------------------------------------------------- /R/ProbRange.R: -------------------------------------------------------------------------------- 1 | #' Range of Probabilities 2 | #' @description 3 | #' Range of format [lower; upper). 4 | #' @noRd 5 | ProbRange = R6::R6Class("ProbRange", 6 | public = list( 7 | #' @field lower [`numeric`] \cr 8 | #' Lower bound of the ProbRange. 9 | lower = -Inf, 10 | 11 | #' @field upper [`numeric`] \cr 12 | #' upper bound of the ProbRange. 13 | upper = Inf, 14 | #' @description 15 | #' Instantiate a Probability Range 16 | #' 17 | #' @param lower [`numeric`]\cr 18 | #' Lower bound of the ProbRange. 19 | #' @param upper [`numeric`]\cr 20 | #' Upper bound of the ProbRange. 21 | #' @return [`ProbRange`] 22 | initialize = function(lower = -Inf, upper = Inf) { 23 | self$lower = assert_number(lower) 24 | self$upper = assert_number(upper) 25 | invisible(self) 26 | }, 27 | #' @description 28 | #' Compare with 'other' Probability Range regarding equality 29 | #' 30 | #' @param other [`ProbRange`]\cr 31 | #' ProbRange to compare to. 32 | #' @return 33 | #' Logical, whether ProbRanges are equal. 34 | is_equal = function(other) { 35 | if (test_class(other, "ProbRange")) 36 | return((self$lower == other$lower) && (self$upper == other$upper)) 37 | return(FALSE) 38 | }, 39 | #' @description 40 | #' Compare with 'other' Probability Range regarding in-equality 41 | #' 42 | #' @param other [`ProbRange`]\cr 43 | #' ProbRange to compare to. 44 | #' @return 45 | #' Logical, whether ProbRanges are in-equal. 46 | is_not_equal = function(other) { 47 | if (test_class(other, "ProbRange")) 48 | return(!((self$lower == other$lower) && (self$upper == other$upper))) 49 | return(TRUE) 50 | }, 51 | #' @description 52 | #' Check whether elements of an array are in the ProbRange. 53 | #' 54 | #' @param x [`numeric`]\cr 55 | #' Array of probabilities 56 | #' @return 57 | #' Logical array, whether elements are in ProbRange or not. 58 | in_range_mask = function(x) { 59 | if(self$upper==Inf) 60 | return((x >= self$lower) & (x <= self$upper)) #FIXME 61 | 62 | (x >= self$lower) & (x < self$upper) 63 | }, 64 | #' @description 65 | #' Printer for ProbRange 66 | print = function() { 67 | cat(paste0("ProbRange: [", self$lower, ";", self$upper, ")\n")) 68 | } 69 | ) 70 | ) 71 | -------------------------------------------------------------------------------- /R/helpers.R: -------------------------------------------------------------------------------- 1 | #' One-hot encode a factor variable 2 | #' @param labels [`factor`]\cr 3 | #' Factor to encode. 4 | #' @examples 5 | #' \dontrun{ 6 | #' one_hot(factor(c("a", "b", "a"))) 7 | #' } 8 | #' @return [`integer`]\cr 9 | #' Integer vector of encoded labels. 10 | #' @export 11 | one_hot = function(labels) { 12 | con = contrasts(labels, contrasts = FALSE) 13 | mat = con[as.integer(labels), ] 14 | rownames(mat) = NULL 15 | if (ncol(mat) == 2L) mat = mat[, 1L] 16 | return(mat) 17 | } 18 | 19 | 20 | # clip numbers (probabilities) to [0;1] 21 | clip_prob = function(prob) { 22 | prob[prob > 1] = 1 23 | prob[prob < 0] = 0 24 | return(prob) 25 | } 26 | 27 | 28 | # Convert a X,y pair to a task 29 | # Required for interacting with 'mlr3' 30 | xy_to_task = function(x, y) { 31 | 32 | x = data.table::data.table(x) 33 | yname = "ytmp" 34 | 35 | # Safe yname 36 | i = 0 37 | while (yname %in% names(x)) { 38 | i = i + 1 39 | yname = paste0("ytmp", i) 40 | } 41 | 42 | x[, (yname) := y] 43 | 44 | if (is.numeric(y)) { 45 | ti = mlr3::TaskRegr 46 | } else { 47 | ti = mlr3::TaskClassif 48 | } 49 | ti$new(id = "tmptsk", backend = x, target = yname) 50 | } 51 | 52 | #' Create an initial predictor function from a trained mlr3 learner 53 | #' 54 | #' @param learner [`mlr3::Learner`] 55 | #' A trained learner used for initialization. 56 | #' @examples 57 | #' \dontrun{ 58 | #' library("mlr3") 59 | #' l = lrn("classif.featureless")$train(tsk("sonar")) 60 | #' mlr3_init_predictor(l) 61 | #' } 62 | #' @return [`function`] 63 | #' @export 64 | mlr3_init_predictor = function(learner) { 65 | if (is.null(learner$state)) stop("Learner needs to be trained first!") 66 | if (learner$predict_type == "response") { 67 | function(data, ...) { 68 | one_hot(learner$predict_newdata(data)$response) 69 | } 70 | } else if ("distr" %in% learner$predict_types) { 71 | function(data, ...) { 72 | as.data.table(learner$predict_newdata(data))$distr[[1]][[1]] 73 | } 74 | } else if(learner$predict_type == "prob") { 75 | function(data, ...) { 76 | learner$predict_newdata(data)$prob[, 1L] 77 | } 78 | } else{ 79 | stop("Predict type of your learner is not implemented. (response, distr, prob)") 80 | } 81 | } 82 | 83 | 84 | 85 | #' Create even intervals 86 | #' @param frac [`numeric`] 87 | #' number of buckets 88 | #' @param min [`numeric`] 89 | #' maximum value 90 | #' @param max [`numeric`] 91 | #' minimum value 92 | #' @return [`numeric`] 93 | #' @noRd 94 | even_bucket = function(frac, min, max) { 95 | pos = c(0, seq_len(frac)) 96 | min + pos / frac * (max - min) 97 | } 98 | -------------------------------------------------------------------------------- /R/zzz.R: -------------------------------------------------------------------------------- 1 | #' @import data.table 2 | #' @import checkmate 3 | #' @import mlr3 4 | #' @import mlr3misc 5 | #' @import glmnet 6 | #' @import rpart 7 | #' @import mlr3pipelines 8 | #' @import rmarkdown 9 | #' @importFrom R6 R6Class is.R6 10 | #' @importFrom utils head 11 | #' @importFrom stats contrasts runif rnorm setNames quantile 12 | #' @references 13 | #' Kim et al., 2019: Multiaccuracy: Black-Box Post-Processing for Fairness in Classification. 14 | #' Hebert-Johnson et al., 2018: Multicalibration: Calibration for the ({C}omputationally-Identifiable) Masses. 15 | #' `r tools::toRd(citation("mcboost"))` 16 | "_PACKAGE" 17 | 18 | 19 | register_pipeops = function() { # nocov start 20 | mlr3pipelines::mlr_pipeops$add("mcboost", PipeOpMCBoost) 21 | mlr3pipelines::mlr_pipeops$add("learner_pred", PipeOpLearnerPred) 22 | mlr3pipelines::mlr_graphs$add("ppl_mcboost", ppl_mcboost) 23 | } # nocov end 24 | 25 | .onLoad = function(libname, pkgname) { # nocov start 26 | if (requireNamespace("mlr3pipelines")) { 27 | register_pipeops() 28 | setHook(packageEvent("mlr3pipelines", "onLoad"), function(...) register_pipeops(), action = "append") 29 | } 30 | backports::import(pkgname) 31 | } # nocov end 32 | 33 | .onUnload = function(libpath) { # nocov start 34 | if (requireNamespace("mlr3pipelines")) { 35 | event = packageEvent("mlr3pipelines", "onLoad") 36 | hooks = getHook(event) 37 | pkgname = vapply(hooks[-1], function(x) environment(x)$pkgname, NA_character_) 38 | setHook(event, hooks[pkgname != "mcboost"], action = "replace") 39 | } 40 | } # nocov end 41 | 42 | leanify_package() 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mcboost 2 | 3 | 4 | [![tic](https://github.com/mlr-org/mcboost/workflows/tic/badge.svg?branch=main)](https://github.com/mlr-org/mcboost/actions) 5 | [![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental) 6 | [![CRAN Status](https://www.r-pkg.org/badges/version-ago/mcboost)](https://cran.r-project.org/package=mcboost) 7 | [![DOI](https://joss.theoj.org/papers/10.21105/joss.03453/status.svg)](https://doi.org/10.21105/joss.03453) 8 | [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) 9 | [![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/) 10 | 11 | 12 | ## What does it do? 13 | 14 | **mcboost** implements Multi-Calibration Boosting ([Hebert-Johnson et al., 2018](https://proceedings.mlr.press/v80/hebert-johnson18a.html); [Kim et al., 2019](https://arxiv.org/pdf/1805.12317.pdf)) for the multi-calibration of a machine learning model's prediction. Multi-Calibration works best in scenarios where the underlying data & labels are unbiased but a bias is introduced within the algorithm's fitting procedure. This is often the case, e.g. when an algorithm fits a majority population while ignoring or under-fitting minority populations. 15 | 16 | For more information and example, see the package's [website](https://mlr-org.github.io/mcboost/). 17 | 18 | More details with respect to usage and the procedures can be found in the package vignettes. 19 | 20 | ## Installation 21 | 22 | The current version can be downloaded from CRAN using: 23 | 24 | ```r 25 | install.packages("mcboost") 26 | ``` 27 | 28 | You can install the development version of mcboost from **Github** with: 29 | 30 | ```r 31 | remotes::install_github("mlr-org/mcboost") 32 | ``` 33 | 34 | ## Usage 35 | 36 | Post-processing with `mcboost` needs three components. We start with an initial prediction model (1) and an auditing algorithm (2) that may be customized by the user. The auditing algorithm then runs Multi-Calibration-Boosting on a labeled auditing dataset (3). The resulting model can be used for obtaining multi-calibrated predictions. 37 | 38 |

39 | 40 |

41 | 42 | ## Example 43 | 44 | In this simple example, our goal is to improve calibration 45 | for an `initial predictor`, e.g. a ML algorithm trained on 46 | an initial task. 47 | Internally, `mcboost` often makes use of `mlr3` and learners that come with `mlr3learners`. 48 | 49 | 50 | ``` r 51 | library(mcboost) 52 | library(mlr3) 53 | ``` 54 | 55 | First we set up an example dataset. 56 | 57 | ```r 58 | # Example Data: Sonar Task 59 | tsk = tsk("sonar") 60 | tid = sample(tsk$row_ids, 100) # 100 rows for training 61 | train_data = tsk$data(cols = tsk$feature_names, rows = tid) 62 | train_labels = tsk$data(cols = tsk$target_names, rows = tid)[[1]] 63 | ``` 64 | 65 | To provide an example, we assume that we have already a learner `l` which we train below. 66 | We can now wrap this initial learner's predict function for use with `mcboost`, since `mcboost` expects the initial model to be specified as a `function` with `data` as input. 67 | 68 | ```r 69 | l = lrn("classif.rpart") 70 | l$train(tsk$clone()$filter(tid)) 71 | 72 | init_predictor = function(data) { 73 | # Get response prediction from Learner 74 | p = l$predict_newdata(data)$response 75 | # One-hot encode and take first column 76 | one_hot(p) 77 | } 78 | ``` 79 | 80 | We can now run Multi-Calibration Boosting by instantiating the object and calling the `multicalibrate` method. 81 | Note, that typically, we would use Multi-Calibration on a separate validation set! 82 | We furthermore select the auditor model, a `SubpopAuditorFitter`, 83 | in our case a `Decision Tree`: 84 | 85 | ```r 86 | mc = MCBoost$new( 87 | init_predictor = init_predictor, 88 | auditor_fitter = "TreeAuditorFitter") 89 | mc$multicalibrate(train_data, train_labels) 90 | ``` 91 | 92 | Lastly, we predict on new data. 93 | 94 | ```r 95 | tstid = setdiff(tsk$row_ids, tid) # held-out data 96 | test_data = tsk$data(cols = tsk$feature_names, rows = tstid) 97 | mc$predict_probs(test_data) 98 | ``` 99 | 100 | ### Multi-Calibration 101 | 102 | While `mcboost` in its defaults implements Multi-Accuracy ([Kim et al., 2019](https://arxiv.org/pdf/1805.12317.pdf)), 103 | it can also multi-calibrate predictors ([Hebert-Johnson et al., 2018](http://proceedings.mlr.press/v80/hebert-johnson18a.html)). 104 | In order to achieve this, we have to set the following hyperparameters: 105 | 106 | ```r 107 | mc = MCBoost$new( 108 | init_predictor = init_predictor, 109 | auditor_fitter = "TreeAuditorFitter", 110 | num_buckets = 10, 111 | multiplicative = FALSE 112 | ) 113 | ``` 114 | 115 | ## MCBoost as a PipeOp 116 | 117 | `mcboost` can also be used within a `mlr3pipeline` in order to use at the full end-to-end pipeline (in the form of a `GraphLearner`). 118 | 119 | ```r 120 | library(mlr3) 121 | library(mlr3pipelines) 122 | gr = ppl_mcboost(lrn("classif.rpart")) 123 | tsk = tsk("sonar") 124 | tid = sample(1:208, 108) 125 | gr$train(tsk$clone()$filter(tid)) 126 | gr$predict(tsk$clone()$filter(setdiff(1:208, tid))) 127 | ``` 128 | 129 | 130 | 131 | ## Further Examples 132 | 133 | The `mcboost` vignettes [**Basics and Extensions**](https://mlr-org.github.io/mcboost/articles/mcboost_basics_extensions.html) and [**Health Survey Example**](https://mlr-org.github.io/mcboost/articles/mcboost_example.html) demonstrate a lot of interesting showcases for applying `mcboost`. 134 | 135 | 136 | ## Contributing 137 | 138 | This R package is licensed under the LGPL-3. 139 | If you encounter problems using this software (lack of documentation, misleading or wrong documentation, unexpected behaviour, bugs, …) or just want to suggest features, please open an issue in the issue tracker. 140 | Pull requests are welcome and will be included at the discretion of the maintainers. 141 | 142 | As this project is developed with [mlr3's](https://github.com/mlr-org/mlr3/) style guide in mind, the following resources can be helpful 143 | to individuals wishing to contribute: Please consult the [wiki](https://github.com/mlr-org/mlr3/wiki/) for a [style guide](https://github.com/mlr-org/mlr3/wiki/Style-Guide), a [roxygen guide](https://github.com/mlr-org/mlr3/wiki/Roxygen-Guide) and a [pull request guide](https://github.com/mlr-org/mlr3/wiki/PR-Guidelines). 144 | 145 | ### Code of Conduct 146 | 147 | Please note that the mcboost project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/0/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. 148 | 149 | ## Citing mcboost 150 | 151 | If you use `mcboost`, please cite our package as well as the two papers it is based on: 152 | 153 | ``` 154 | @article{pfisterer2021, 155 | author = {Pfisterer, Florian and Kern, Christoph and Dandl, Susanne and Sun, Matthew and 156 | Kim, Michael P. and Bischl, Bernd}, 157 | title = {mcboost: Multi-Calibration Boosting for R}, 158 | journal = {Journal of Open Source Software}, 159 | doi = {10.21105/joss.03453}, 160 | url = {https://doi.org/10.21105/joss.03453}, 161 | year = {2021}, 162 | publisher = {The Open Journal}, 163 | volume = {6}, 164 | number = {64}, 165 | pages = {3453} 166 | } 167 | # Multi-Calibration 168 | @inproceedings{hebert-johnson2018, 169 | title = {Multicalibration: Calibration for the ({C}omputationally-Identifiable) Masses}, 170 | author = {Hebert-Johnson, Ursula and Kim, Michael P. and Reingold, Omer and Rothblum, Guy}, 171 | booktitle = {Proceedings of the 35th International Conference on Machine Learning}, 172 | pages = {1939--1948}, 173 | year = {2018}, 174 | editor = {Jennifer Dy and Andreas Krause}, 175 | volume = {80}, 176 | series = {Proceedings of Machine Learning Research}, 177 | address = {Stockholmsmässan, Stockholm Sweden}, 178 | publisher = {PMLR} 179 | } 180 | # Multi-Accuracy 181 | @inproceedings{kim2019, 182 | author = {Kim, Michael P. and Ghorbani, Amirata and Zou, James}, 183 | title = {Multiaccuracy: Black-Box Post-Processing for Fairness in Classification}, 184 | year = {2019}, 185 | isbn = {9781450363242}, 186 | publisher = {Association for Computing Machinery}, 187 | address = {New York, NY, USA}, 188 | url = {https://doi.org/10.1145/3306618.3314287}, 189 | doi = {10.1145/3306618.3314287}, 190 | booktitle = {Proceedings of the 2019 AAAI/ACM Conference on AI, Ethics, and Society}, 191 | pages = {247--254}, 192 | location = {Honolulu, HI, USA}, 193 | series = {AIES '19} 194 | } 195 | ``` 196 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mcboost/5c80fc18d839791a4ffe183b515d5f41c3de25cd/_pkgdown.yml -------------------------------------------------------------------------------- /attic/PipeOpMCBoostSurv.R: -------------------------------------------------------------------------------- 1 | #' @title Multi-Calibrate a Learner's Prediction (Survival Model) 2 | #' 3 | #' @usage NULL 4 | #' @name mlr_pipeops_mcboostsurv 5 | #' @format [`R6Class`] inheriting from [`mlr3pipelines::PipeOp`]. 6 | #' 7 | #' @description 8 | #' Post-process a survival learner prediction using multi-calibration. 9 | #' For more details, please refer to \url{https://arxiv.org/pdf/1805.12317.pdf} (Kim et al. 2018) 10 | #' or the help for [`MCBoostSurv`]. 11 | #' If no `init_predictor` is provided, the preceding learner's predictions 12 | #' corresponding to the `prediction` slot are used as an initial predictor for `MCBoostSurv`. 13 | #' 14 | #' @section Construction: 15 | #' ``` 16 | #' PipeOpMCBoostSurv$new(id = "mcboostsurv", param_vals = list()) 17 | #' ``` 18 | #' * `id` :: `character(1)` 19 | #' Identifier of the resulting object, default `"threshold"`. 20 | #' * `param_vals` :: named `list`\cr 21 | #' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. 22 | #' See `MCBoostSurv` for a comprehensive description of all hyperparameters. 23 | #' 24 | #' @section Input and Output Channels: 25 | #' During training, the input and output are `"data"` and `"prediction"`, two [`TaskSurv`][mlr3proba::TaskSurv]. 26 | #' A [`PredictionSurv`][mlr3proba::PredictionSurv] is required as input and returned as output during prediction. 27 | #' 28 | #' @section State: 29 | #' The `$state` is a `MCBoostSurv` Object as obtained from `MCBoostSurv$new()`. 30 | #' 31 | #' @section Parameters: 32 | #' * `max_iter` :: `integer`\cr 33 | #' A integer specifying the number of multi-calibration rounds. Defaults to 5. 34 | #' 35 | #' @section Fields: 36 | #' Only fields inherited from [`mlr3pipelines::PipeOp`]. 37 | #' 38 | #' @section Methods: 39 | #' Only methods inherited from [`mlr3pipelines::PipeOp`]. 40 | #' 41 | #' @examples 42 | #' library(mlr3) 43 | #' library(mlr3pipelines) 44 | #' # Attention: gunion inputs have to be in the correct order for now. 45 | #' \dontrun{ 46 | #' gr = gunion(list( 47 | #' "data" = po("nop"), 48 | #' "prediction" = po("learner_pred", lrn("surv.ranger")) 49 | #' )) %>>% 50 | #' PipeOpMCBoostSurv$new() 51 | #' tsk = tsk("rats") 52 | #' tid = sample(1:300, 100) 53 | #' gr$train(tsk$clone()$filter(tid)) 54 | #' gr$predict(tsk$clone()$filter(setdiff(1:300, tid))) 55 | #' } 56 | #' @family PipeOps 57 | #' @seealso https://mlr3book.mlr-org.com/list-pipeops.html 58 | #' @export 59 | PipeOpMCBoostSurv = R6Class("PipeOpMCBoostSurv", 60 | inherit = mlr3pipelines::PipeOp, 61 | public = list( 62 | #' @description 63 | #' Initialize a Multi-Calibration PipeOp (for Survival). 64 | #' 65 | #' @param id [`character`] \cr 66 | #' The `PipeOp`'s id. Defaults to "mcboostsurv". 67 | #' @param param_vals [`list`] \cr 68 | #' List of hyperparameters for the `PipeOp`. 69 | initialize = function(id = "mcboostsurv", param_vals = list()) { 70 | param_set = paradox::ParamSet$new(list( 71 | paradox::ParamInt$new("max_iter", lower = 0L, upper = Inf, default = 5L, tags = "train"), 72 | paradox::ParamDbl$new("alpha", lower = 0, upper = 1, default = 1e-4, tags = "train"), 73 | paradox::ParamDbl$new("eta", lower = 0, upper = 1, default = 1, tags = "train"), 74 | paradox::ParamInt$new("num_buckets", lower = 1, upper = Inf, default = 2L, tags = "train"), 75 | paradox::ParamInt$new("time_buckets", lower = 1, upper = Inf, default = 1L, tags = "train"), 76 | paradox::ParamDbl$new("time_eval", lower = 0, upper = 1, default = 1, tags = "train"), 77 | paradox::ParamUty$new("bucket_strategy", default = "quantiles", tags = "train"), 78 | paradox::ParamUty$new("bucket_aggregation", default = NULL, tags = "train"), 79 | paradox::ParamLgl$new("eval_fulldata", default = FALSE, tags = "train"), 80 | paradox::ParamLgl$new("rebucket", default = FALSE, tags = "train"), 81 | paradox::ParamLgl$new("multiplicative", default = TRUE, tags = "train"), 82 | paradox::ParamUty$new("auditor_fitter", default = NULL, tags = "train"), 83 | paradox::ParamUty$new("subpops", default = NULL, tags = "train"), 84 | paradox::ParamUty$new("default_model_class", default = NULL, tags = "train"), 85 | paradox::ParamUty$new("init_predictor", default = NULL, tags = "train") 86 | )) 87 | super$initialize(id, 88 | param_set = param_set, param_vals = param_vals, packages = c("mlr3proba", "survival"), 89 | input = data.table( 90 | name = c("data", "prediction"), 91 | train = c("TaskSurv", "TaskSurv"), 92 | predict = c("TaskSurv", "TaskSurv") 93 | ), 94 | output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"), 95 | tags = "target transform") 96 | } 97 | ), 98 | private = list( 99 | .train = function(inputs) { 100 | 101 | d = inputs$data$data(cols = inputs$data$feature_names) 102 | l = inputs$data$data(cols = inputs$data$target_names) 103 | 104 | args = self$param_set$get_values(tags = "train") 105 | 106 | if (is.null(args$init_predictor)) { 107 | # Construct an initial predictor from the input model if none is provided. 108 | init_predictor = function(data, prediction) { 109 | distr_col = prediction$feature_names[grepl("distr$",prediction$feature_names)] 110 | if (is.null(distr_col)) stop("No distr output in the predictions.") 111 | if (length(distr_col) > 1) stop("More than one distr columns in the prediction?") 112 | as.data.table(prediction)[[distr_col]][[1]][[1]] 113 | } 114 | args$init_predictor = init_predictor 115 | } 116 | mc = mlr3misc::invoke(MCBoostSurv$new, .args = args) 117 | mc$multicalibrate(d, l, predictor_args = inputs$prediction) 118 | self$state = list("mc" = mc) 119 | list(NULL) 120 | }, 121 | 122 | .predict = function(inputs) { 123 | d = inputs$data$data(cols = inputs$data$feature_names) 124 | probs = as.matrix(self$state$mc$predict_probs(d, predictor_args = inputs$prediction)) 125 | 126 | time = as.numeric(colnames(probs)) 127 | 128 | list(mlr3proba::PredictionSurv$new( 129 | truth = inputs$prediction$truth(), 130 | distr = probs, 131 | row_ids = inputs$prediction$row_ids, 132 | crank = -apply(1 - probs, 1, function(.x) sum(c(.x[1], diff(.x)) * time)) 133 | )) 134 | 135 | } 136 | ), 137 | active = list( 138 | #' @field predict_type Predict type of the PipeOp. 139 | predict_type = function(val) { 140 | if (!missing(val)) { 141 | if (!identical(val, private$.learner)) { 142 | stop("$predict_type for PipeOpMCBoostSurv is read-only.") 143 | } 144 | } 145 | return("distr") 146 | } 147 | ) 148 | ) 149 | 150 | 151 | #' Multi-calibration pipeline (for survival models) 152 | #' 153 | #' Wraps MCBoostSurv in a Pipeline to be used with `mlr3pipelines`. 154 | #' For now this assumes training on the same dataset that is later used 155 | #' for multi-calibration. 156 | #' @param learner (mlr3)[`mlr3::Learner`]\cr 157 | #' Initial learner. 158 | #' Defaults to `lrn("surv.kaplan")`. 159 | #' Note: An initial predictor can also be supplied via the `init_predictor` parameter. 160 | #' The learner is internally wrapped into a `PipeOpLearnerCV` 161 | #' with `resampling.method = "insample"` as a default. 162 | #' All parameters can be adjusted through the resulting Graph's `param_set`. 163 | #' @param param_vals `list` \cr 164 | #' List of parameter values passed on to `MCBoostSurv$new` 165 | #' @return (mlr3pipelines) [`Graph`] 166 | #' @examples 167 | #' library("mlr3pipelines") 168 | #' gr = ppl_mcboostsurv() 169 | #' @export 170 | ppl_mcboostsurv = function(learner = lrn("surv.kaplan"), param_vals = list()) { 171 | mlr3misc::require_namespaces("mlr3pipelines") 172 | mlr3misc::require_namespaces("mlr3proba") 173 | gr = mlr3pipelines::`%>>%`( 174 | mlr3pipelines::gunion(list( 175 | "data" = mlr3pipelines::po("nop"), 176 | "prediction" = mlr3pipelines::po("learner_pred", learner = learner) 177 | )), 178 | PipeOpMCBoostSurv$new(param_vals = param_vals) 179 | ) 180 | } 181 | -------------------------------------------------------------------------------- /attic/ProbRange2D.R: -------------------------------------------------------------------------------- 1 | #' Range of Probabilities and Time 2 | #' @description 3 | #' Range of format list(prob = [lower_prob; upper_prob), time = [lower_time; upper_time). 4 | #' @noRd 5 | ProbRange2D = R6::R6Class("ProbRange2D", 6 | public = list( 7 | #' @field prob [`ProbRange`] \cr 8 | #' Range of probabilities 9 | prob = NULL, 10 | #' @field time [`ProbRange`] \cr 11 | #' Range of time 12 | time = NULL, 13 | #' @field aggregation [`function`] \cr 14 | #' Type of aggregation, if there are several values per row 15 | #' (e.g., mean or median) 16 | aggregation = NULL, 17 | 18 | 19 | #' @description 20 | #' Instantiate a Probability Range 2D 21 | #' 22 | #' @param prob [`ProbRange`]\cr 23 | #' Range of probabilities 24 | #' @param time [`ProbRange`]\cr 25 | #' Range of time 26 | #' @param aggregation [`character`]\cr 27 | #' Aggegation 28 | # FIXME 29 | #' @return [`ProbRange2D`] 30 | initialize = function(prob = ProbRange$new(), time = ProbRange$new(), aggregation = NULL) { 31 | self$prob = assert_r6(prob, "ProbRange") 32 | self$time = assert_r6(time, "ProbRange") 33 | 34 | self$aggregation = assert_function(aggregation, null.ok = TRUE) 35 | invisible(self) 36 | }, 37 | #' @description 38 | #' Compare with 'other' Probability Range (2D) regarding equality 39 | #' 40 | #' @param other [`ProbRange2D`]\cr 41 | #' ProbRange2D to compare to. 42 | #' @return 43 | #' Logical, whether ProbRanges2D are equal. 44 | is_equal = function(other) { 45 | if (test_class(other, "ProbRange2D")) { 46 | return(self$prob$is_equal(other$prob) && self$time$is_equal(other$time)) 47 | } 48 | return(FALSE) 49 | }, 50 | #' @description 51 | #' Compare with 'other' ProbabRange2D regarding in-equality 52 | #' 53 | #' @param other [`ProbRange2D`]\cr 54 | #' ProbRange2D to compare to. 55 | #' @return 56 | #' Logical, whether ProbRanges2D are in-equal. 57 | is_not_equal = function(other) { 58 | if (test_class(other, "ProbRange2D")) { 59 | return(self$prob$is_not_equal(other$prob) || self$time$is_not_equal(other$time)) 60 | } 61 | return(TRUE) 62 | }, 63 | #' @description 64 | #' Check whether elements of an array with dimensions individuals x time_points 65 | #' are in the ProbRange2D. 66 | #' 67 | #' @param x [`numeric`]\cr 68 | #' Matrix of probabilities with times as columnnames 69 | #' @return 70 | #' Logical array, whether elements are in ProbRange2D or not. 71 | in_range_mask = function(x) { 72 | 73 | time = as.numeric(colnames(x)) 74 | 75 | time_mask = self$time$in_range_mask(time) 76 | 77 | if (!(any(time_mask))) { 78 | return(NULL) 79 | } 80 | 81 | 82 | # do we take the mean to decide in which bucket? (= one bucket per subject) 83 | if (!is.null(self$aggregation)) { 84 | x = x[, time_mask] 85 | 86 | if (length(dim(x)) < 1) { 87 | return(NULL) 88 | } 89 | 90 | x = apply(x, 1, self$aggregation) 91 | n = self$prob$in_range_mask(x) 92 | 93 | if (!(any(n))) { 94 | return(NULL) 95 | } 96 | 97 | matrix = matrix(TRUE, nrow = sum(n), ncol = sum(time_mask)) 98 | } else { 99 | matrix = self$prob$in_range_mask(x) 100 | matrix [, !time_mask] = FALSE 101 | n = apply(matrix, 1, any) 102 | 103 | if (!(any(n))) { 104 | return(NULL) 105 | } 106 | 107 | matrix = matrix[n, time_mask] 108 | } 109 | return(list( 110 | matrix = matrix, 111 | n = n, 112 | time = time_mask)) 113 | }, 114 | #' @description 115 | #' Printer for ProbRange2D 116 | print = function() { 117 | cat(paste0("ProbRange2D: Probability Range: [", self$prob$lower, ";", self$prob$upper, "), 118 | Time Range[", self$time$lower, ";", self$time$upper, ")")) 119 | } 120 | ) 121 | ) 122 | -------------------------------------------------------------------------------- /attic/helpers_survival.R: -------------------------------------------------------------------------------- 1 | #' Make every row monotonically decreasing in order to obtain the survival property. 2 | #' Additionally, many predicitions need 1 as a first value and 0 as a last value. 3 | #' (e.g. `PredictionSurv` needs this attribute.) 4 | #' @param prediction [`data.table`] 5 | #' Data.table with predictions. Every row is survival probability for the corresponding time. 6 | #' Every column corresponds to a specific time point. 7 | #' @return [`data.table`] 8 | #' @export 9 | make_survival_curve = function(prediction) { 10 | survival_curves = apply(prediction, 1, function(x) { 11 | x[is.na(x)] = 0 12 | cm = cummin(x) 13 | if (any(x != cm)) { 14 | cm 15 | }else{ 16 | x 17 | } 18 | }) 19 | survival_curves = t(survival_curves) 20 | 21 | #needed for PredictionSurv 22 | survival_curves[,1]=1 23 | survival_curves[,ncol(survival_curves)]=0 24 | 25 | as.data.table(survival_curves) 26 | } 27 | 28 | # if (inherits(y, "Surv")) { 29 | # ti = mlr3proba::TaskSurv 30 | # } else -------------------------------------------------------------------------------- /attic/mcboost_step_by_step.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "MCBoost step by step" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{MCBoost step by step} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>" 14 | ) 15 | ``` 16 | 17 | ```{r setup} 18 | library(mcboost) 19 | ``` 20 | 21 | ```{r} 22 | library(data.table) 23 | n = 10000L 24 | x = rnorm(n, 5, 2.5) 25 | s = sample(1:2, n, replace = TRUE) 26 | itcpt = c(5.2, 1.1) 27 | betas = c(-.12, .7) 28 | y = x * betas[s] + itcpt[s] + rnorm(n, -3, .4) + 1 29 | dt = data.table(x = x, s = s, y = y) 30 | dt[, yprob := 1 / (1 + exp(-(y - mean(y)))) + rnorm(n, 0, abs(0.1*(s-2)))] 31 | dt[, train:=FALSE][1:(ceiling(n/2)), train := TRUE] 32 | dt[, y := as.integer(runif(n) > yprob)] 33 | ``` 34 | 35 | ```{r} 36 | library(ggplot2) 37 | ggplot(dt) + geom_point(aes(x=x,y=yprob,color=factor(s))) 38 | dt[, mean(y), by = s] 39 | ``` 40 | 41 | ```{r} 42 | mod = glm(y ~ x, data = dt[train == TRUE,], family = binomial()) 43 | dt[, yh := predict(mod, dt, type = "response")] 44 | dt[, .(mean((yh > 0.5) == y), .N), by = .(s, train)] 45 | ``` 46 | 47 | ```{r} 48 | ggplot(dt) + geom_point(aes(x=yprob,y=yh,color=factor(s))) 49 | dt[, mean(y), by = s] 50 | ``` 51 | 52 | 53 | ```{r} 54 | init_predictor = function(data) { 55 | predict(mod, data, type = "response") 56 | } 57 | mc = MCBoost$new( 58 | auditor_fitter = "TreeAuditorFitter", 59 | init_predictor = init_predictor, 60 | max_iter = 10L, 61 | eta = .2, 62 | multiplicative = TRUE 63 | ) 64 | mc$multicalibrate(dt[, c("s", "x"), with = FALSE], dt$y) 65 | dt[, yh_mc := mc$predict_probs(dt)] 66 | ggplot(dt) + geom_point(aes(x=yprob,y=yh_mc,color=factor(s))) 67 | dt[, mean(y), by = s] 68 | ``` -------------------------------------------------------------------------------- /attic/mcboostsurv_basics.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "MCBoostSurv - Basics" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{MCBoostSurv - Basics} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, echo = FALSE} 11 | NOT_CRAN <- identical(tolower(Sys.getenv("NOT_CRAN")), "true") 12 | knitr::opts_chunk$set( 13 | collapse = TRUE, 14 | comment = "#>", 15 | purl = NOT_CRAN, 16 | eval = NOT_CRAN 17 | ) 18 | ``` 19 | 20 | 21 | ```{r setup} 22 | library("mcboost") 23 | library("mlr3") 24 | library("mlr3proba") 25 | library("mlr3pipelines") 26 | library("mlr3learners") 27 | library("tidyverse") 28 | set.seed(27099) 29 | ``` 30 | 31 | 32 | ## Minimal Example: McBoostSurv 33 | 34 | To show the basic functionality of `MCBoostSurv`, we provide a minimal example on 35 | the standard survival data set rats. After loading and pre-processing the data, we train 36 | a `mlr3learner` on the training data. We instantiate a `MCBoostSurv` instance 37 | with the default parameters. Then, we run the `$multicalibrate()` method on our data to start multi-calibration in survival analysis. With `$predict_probs()`, we can get 38 | multicalibrated predictions. 39 | 40 | ```{r} 41 | 42 | #prepare task 43 | task = tsk("rats") 44 | prep_pipe = po("encode", param_vals = list(method="one-hot")) 45 | prep = prep_pipe$train(list(task))[[1]] 46 | 47 | #split data 48 | train = prep$clone()$filter(1:199) 49 | val = prep$clone()$filter(200:250) 50 | test = prep$clone()$filter(256:300) 51 | 52 | # get trained survival model 53 | baseline = lrn("surv.ranger")$train(train) 54 | 55 | # initialize mcboost 56 | mc_surv = MCBoostSurv$new(init_predictor = baseline) 57 | 58 | # multicalibrate model 59 | mc_surv$multicalibrate(data = val$data(cols = val$feature_names), 60 | labels = val$data(cols = val$target_names)) 61 | 62 | # get new predictions 63 | mc_surv$predict_probs(test$data(cols = test$feature_names)) 64 | 65 | ``` 66 | ## What does mcboost do? 67 | 68 | Internally mcboostsurv runs the following procedure `max_iter` times (similar ro `mcboost`, just for distributions over time): 69 | 70 | 1. Predict on X using the model from the previous iteration, `init_predictor` in the first iteration. 71 | 1. Compute the residuals `res = y - y_hat` for all time points 72 | 1. Split predictions into `num_buckets` according to `y_hat` and time. 73 | 1. Fit the auditor (`auditor_fitter`) (here called`c(x)`) on the data in each bucket with target variable `r`. 74 | 1. Compute `misscal = mean(c(x) * res(x))` 75 | 1. if `misscal > alpha`: 76 | For the bucket with highest `misscal`, update the model using the prediction `c(x)`. 77 | else: 78 | Stop the procedure 79 | 80 | 81 | 82 | ## Multicalibrate model trained on PBC data 83 | 84 | Based on this, we can now show multicalibration on a data set with two sensitive attributes (age and gender). Again, we load and pre-process the data. 85 | 86 | ### Load Dataset 87 | ```{r} 88 | library(survival) 89 | data_pbc = pbc %>% 90 | mutate(status = if_else(status == 2, 1, 0) 91 | ) %>% 92 | select(-id) %>% 93 | drop_na() 94 | 95 | task_pbc = TaskSurv$new("pbc", backend = as_data_backend(data_pbc), 96 | time = "time", event = "status") 97 | 98 | 99 | #Create data split 100 | 101 | train_test = rsmp("holdout", ratio = 0.8)$instantiate(task_pbc) 102 | train_g = train_test$train_set(1) 103 | test_ids = train_test$test_set(1) 104 | train_val = rsmp("holdout", ratio = 0.75)$instantiate(task_pbc$clone()$filter(train_g)) 105 | train_ids = train_val$train_set(1) 106 | val_ids = train_val$test_set(1) 107 | 108 | # Train distributional survival model 109 | 110 | xgb_distr = as_learner(ppl("distrcompositor", 111 | learner = as_learner(prep_pipe %>>% lrn("surv.xgboost")))) 112 | 113 | xgb_distr$train(task_pbc$clone()$filter(train_ids)) 114 | 115 | 116 | 117 | ``` 118 | 119 | ### Mutlicalibrate survival model with validation data 120 | 121 | ```{r} 122 | # initialize mcboost 123 | mcboost_learner = as_learner( 124 | prep_pipe %>>% ppl_mcboostsurv( 125 | learner = as_learner(prep_pipe %>>% xgb_distr), 126 | param_vals = list( 127 | alpha = 1e-6, 128 | eta = 0.2, 129 | time_buckets = 2, 130 | num_buckets = 1 ) 131 | ) 132 | ) 133 | 134 | # multicalibrate model 135 | mcboost_learner$train(task_pbc$clone()$filter(val_ids)) 136 | 137 | # get new predictions 138 | test_task = task_pbc$clone()$filter(test_ids) 139 | pred_pbc_mc = mcboost_learner$predict(task_pbc$clone()$filter(test_ids)) 140 | 141 | pred_pbc_xgb = xgb_distr$predict(task_pbc$clone()$filter(test_ids)) 142 | 143 | ``` 144 | 145 | ### Development of IBS in the defined subgroups 146 | ```{r} 147 | 148 | pred_pbc_xgb$score(msr("surv.graf")) 149 | pred_pbc_mc$score(msr("surv.graf")) 150 | ``` 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /attic/test_mcboostsurv.R: -------------------------------------------------------------------------------- 1 | context("MCBoostSurv Class") 2 | 3 | test_that("MCBoostSurv class instantiation", { 4 | skip_on_cran() 5 | skip_on_os("solaris") 6 | skip_if_not_installed("mlr3proba") 7 | mc = MCBoostSurv$new(auditor_fitter = "TreeAuditorFitter") 8 | expect_class(mc, "MCBoostSurv") 9 | expect_class(mc$auditor_fitter, "AuditorFitter") 10 | expect_function(mc$predictor, args = "data") 11 | }) 12 | 13 | test_that("MCBoostSurv multicalibrate and predict_probs", { 14 | skip_on_cran() 15 | skip_on_os("solaris") 16 | skip_if_not_installed("mlr3learners") 17 | skip_if_not_installed("mlr3proba") 18 | skip_if_not_installed("survival") 19 | library("mlr3learners") 20 | library("mlr3proba") 21 | library("survival") 22 | rats$sex = (as.character(rats$sex))=="f" 23 | b = as_data_backend(rats) 24 | tsk = TaskSurv$new("rats", 25 | backend = b, time = "time", 26 | event = "status") 27 | data = tsk$data(cols = tsk$feature_names) 28 | labels = tsk$data(cols = tsk$target_names) 29 | l = lrn("surv.ranger")$train(tsk) 30 | 31 | mc = MCBoostSurv$new(init_predictor = l, 32 | auditor_fitter = "TreeAuditorFitter", 33 | max_iter = 3, 34 | alpha = 0) 35 | mc$multicalibrate(data, labels) 36 | 37 | expect_list(mc$iter_models, types = "LearnerPredictor", len = mc$max_iter) 38 | expect_list(mc$iter_partitions, types = "ProbRange2D", len = mc$max_iter) 39 | 40 | prds = mc$predict_probs(data) 41 | expect_data_frame(prds, nrows = nrow(data), ncol = length(tsk$unique_times())) 42 | 43 | 44 | mc = MCBoostSurv$new( 45 | init_predictor = l, 46 | auditor_fitter = "TreeAuditorFitter", 47 | max_iter = 3, 48 | alpha = 0.001, 49 | loss = "brier", 50 | multiplicative = FALSE, 51 | max_time_quantile = 0.9 52 | ) 53 | mc$multicalibrate(data, labels) 54 | ae = mc$auditor_effect(data) 55 | expect_matrix(ae, nrows = 300, ncols = length(tsk$unique_times())) 56 | expect_true(mc$loss == "brier") 57 | }) 58 | -------------------------------------------------------------------------------- /attic/test_pipeop_learner_pred.R: -------------------------------------------------------------------------------- 1 | test_that("PipeOp Learner Pred", { 2 | skip_on_cran() 3 | skip_on_os("solaris") 4 | 5 | pop = mlr3pipelines::po("learner_pred", learner = lrn("surv.kaplan")) 6 | expect_is(pop, "PipeOp") 7 | 8 | out = pop$train(list(tsk("rats")))[[1]] 9 | expect_is(out, "Task") 10 | expect_true(all( 11 | c(colnames(out$data(cols = out$feature_names)) %in% 12 | c("surv.kaplan.time", "surv.kaplan.status", "surv.kaplan.crank", 13 | "surv.kaplan.distr")) 14 | )) 15 | dist = out$data()[["surv.kaplan.distr"]] 16 | expect_list(dist, types = "list", len = 300L) 17 | out = map(dist[[1]], function(x) expect_matrix(as.matrix(x), nrows = 300, ncols = length(tsk("rats")$unique_times()))) 18 | 19 | out = pop$predict(list(tsk("rats")))[[1]] 20 | expect_is(out, "Task") 21 | expect_true(all( 22 | c(colnames(out$data(cols = out$feature_names)) %in% 23 | c("surv.kaplan.time", "surv.kaplan.status", "surv.kaplan.crank", "surv.kaplan.distr")) 24 | )) 25 | dist = out$data()[["surv.kaplan.distr"]] 26 | expect_list(dist, types = "list", len = 300L) 27 | out = map(dist[[1]], function(x) expect_matrix(as.matrix(x), nrows = 300, ncols = length(tsk("rats")$unique_times()))) 28 | }) 29 | -------------------------------------------------------------------------------- /attic/test_pipeop_mcboostsurv.R: -------------------------------------------------------------------------------- 1 | context("MCBoostSurv PipeOp") 2 | 3 | test_that("MCBoostSurv class instantiation", { 4 | skip_on_cran() 5 | skip_on_os("solaris") 6 | library("mlr3pipelines") 7 | gr = gunion(list( 8 | "data" = po("encode") %>>% po("nop"), 9 | "prediction" = po("learner_pred", lrn("surv.ranger")) 10 | )) %>>% 11 | PipeOpMCBoostSurv$new(param_vals = list(multiplicative = FALSE, alpha = 0, max_iter = 3)) 12 | expect_is(gr, "Graph") 13 | tsk = tsk("rats") 14 | tid = sample(1:300, 200) 15 | train_out = gr$train(tsk$clone()$filter(tid)) 16 | expect_is(gr$state$mcboostsurv$mc, "MCBoostSurv") 17 | expect_list(gr$state$mcboostsurv$mc$iter_models, types = "LearnerPredictor") 18 | expect_true(!gr$state$mcboostsurv$mc$multiplicative) 19 | pr = gr$predict(tsk$clone()$filter(setdiff(1:300, tid))) 20 | expect_is(pr[[1]], "Prediction") 21 | }) 22 | 23 | test_that("pipeop instantiation", { 24 | skip_on_cran() 25 | skip_on_os("solaris") 26 | library("mlr3pipelines") 27 | pop = po("mcboostsurv") 28 | expect_is(pop, "PipeOpMCBoostSurv") 29 | expect_is(pop, "PipeOp") 30 | expect_list(pop$param_set$values, len = 0L) 31 | expect_true(pop$predict_type == "distr") 32 | }) 33 | 34 | test_that("MCBoostSurv ppl", { 35 | skip_on_cran() 36 | skip_on_os("solaris") 37 | library("survival") 38 | rats$sex = (as.character(rats$sex)) == "f" 39 | task = TaskSurv$new("rats", 40 | backend = as_data_backend(rats), 41 | time = "time", 42 | event = "status" 43 | ) 44 | l = lrn("surv.kaplan")# $train(task) 45 | pp = ppl_mcboostsurv(learner = l, param_vals = list(max_iter = 3, alpha = 0)) 46 | expect_is(pp, "Graph") 47 | pp$train(task) 48 | expect_true(!is.null(pp$state)) 49 | prd = pp$predict(task) 50 | expect_is(prd[[1]], "PredictionSurv") 51 | state = pp$pipeops$mcboostsurv$state$mc 52 | expect_true(length(state$iter_models) == 3) 53 | expect_true(state$alpha == 0) 54 | }) 55 | 56 | test_that("MCBoostSurv ppl", { 57 | skip_on_cran() 58 | gr = ppl_mcboostsurv(lrn("surv.kaplan")) 59 | expect_is(gr, "Graph") 60 | }) -------------------------------------------------------------------------------- /attic/test_probrange2d.R: -------------------------------------------------------------------------------- 1 | context("ProbRange2D") 2 | 3 | test_that("ProbRange2D works", { 4 | pr = ProbRange$new(0.1, 0.55) 5 | t = ProbRange$new(40,101) 6 | pr2d = ProbRange2D$new(pr, t) 7 | 8 | prs2d = list( 9 | pr2 = ProbRange2D$new(ProbRange$new(0.1, 0.55), 10 | ProbRange$new(40,101)), 11 | pr3 = ProbRange2D$new(ProbRange$new(0.3, 0.55), 12 | ProbRange$new(0,100)), 13 | pr4 = ProbRange2D$new(ProbRange$new(0.1, 0.55), 14 | ProbRange$new(0,100)) 15 | ) 16 | 17 | expect_class(pr2d, "ProbRange2D") 18 | expect_equal(pr2d$prob$lower, 0.1) 19 | expect_equal(pr2d$prob$upper, 0.55) 20 | expect_equal(pr2d$time$lower, 40) 21 | expect_equal(pr2d$time$upper, 101) 22 | 23 | values = c(TRUE, FALSE, FALSE) 24 | 25 | out = mlr3misc::map_lgl(prs2d, function(x) { 26 | pr2d$is_equal(x) 27 | }) 28 | expect_equal(out, values, check.attributes = FALSE) 29 | 30 | out = mlr3misc::map_lgl(prs2d, function(x) { 31 | pr2d$is_not_equal(x) 32 | }) 33 | expect_equal(!out, values, check.attributes = FALSE) 34 | 35 | prs = matrix(c(0.09, 0.1, 0.4, 0.55, 0.7), ncol = 5) 36 | time = seq(0,200,length.out = 5) 37 | colnames(prs) = time 38 | 39 | in_prob = pr2d$in_range_mask(prs) 40 | expect_equal(in_prob$n, TRUE) 41 | expect_equal(in_prob$time, c(FALSE, TRUE, TRUE, FALSE, FALSE)) 42 | expect_equal(as.logical(in_prob$matrix), c(TRUE, TRUE)) 43 | expect_equal(names(in_prob$matrix), c("50","100")) 44 | 45 | expect_false(pr2d$is_equal(5)) 46 | expect_true(pr2d$is_not_equal(5)) 47 | 48 | expect_output(print(pr2d), "ProbRange2D") 49 | }) 50 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | coverage: 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 1% 9 | informational: true 10 | patch: 11 | default: 12 | target: auto 13 | threshold: 1% 14 | informational: true 15 | -------------------------------------------------------------------------------- /contributions.md: -------------------------------------------------------------------------------- 1 | ### Florian Pfisterer 2 | 3 | Florian Pfisterer implemented and extended the main part of the package, heavily influcenced by 4 | an unpublished python code-base written in large parts by Matthew. Florian furthermore worked on the interaction with 5 | **mlr3** by integrating **mlr3** learners as Auditing Mechanism as well as exporting functionality to integrate 6 | **mcboost** as a `PipeOp` into **mlr3pipelines**. 7 | 8 | ### Christoph Kern 9 | 10 | Christoph Kern prepared and contributed to the vignettes and co-authored the summary paper. Christoph contributed (very) moderately to the python code underlying this package and helped conceptionally in transitioning from the python code to the R implementation. 11 | 12 | ### Susanne Dandl 13 | 14 | Susanne Dandl reviewed the R package, she provided advice on extensions, extended and improved vignettes 15 | and worked towards thorough unit testing of the different methods. 16 | 17 | ### Matthew Sun 18 | 19 | Matthew Sun wrote the initial Python implementation of MCBoost, with feedback and oversight provided by Michael. 20 | His version guided large parts of mcboost's current design and architecture. 21 | 22 | ### Michael P. Kim 23 | 24 | Michael P. Kim is a coauthor of the research papers that introduced Multi-Calibration. 25 | Michael oversaw the development of the initial python implementation of MCBoost 26 | and provided additional advice and directions in the development of this R package. 27 | 28 | ### Bernd Bischl 29 | 30 | Oversaw the package development and provided feedback with respect to API design, implementation details and methodology. 31 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | ## R CMD check 2 | 3 | 0 errors | 0 warnings | 1 note 4 | 5 | New maintainer: 6 | Sebastian Fischer 7 | -------------------------------------------------------------------------------- /inst/CITATION: -------------------------------------------------------------------------------- 1 | bibentry( 2 | bibtype = "article", 3 | key = "mcboost", 4 | title = "mcboost: Multi-Calibration Boosting for R", 5 | author = c( 6 | person(given = "Florian", family = "Pfisterer"), 7 | person(given = "Christoph", family = "Kern"), 8 | person(given = "Susanne", family = "Dandl"), 9 | person(given = "Matthew", family = "Sun"), 10 | person(given = "Michael P.", family = "Kim"), 11 | person(given = "Bernd", family = "Bischl") 12 | ), 13 | journal = "Journal of Open Source Software", 14 | year = 2021, 15 | month = "aug", 16 | doi = "10.21105/joss.03453", 17 | url = "https://joss.theoj.org/papers/10.21105/joss.03453", 18 | publisher = "The Open Journal", 19 | volume = "6", 20 | number = "64", 21 | pages = "3453" 22 | ) 23 | -------------------------------------------------------------------------------- /man-roxygen/params_data_label.R: -------------------------------------------------------------------------------- 1 | #' @param data [`data.table`]\cr 2 | #' Features. 3 | #' @param labels [`numeric`]\cr 4 | #' One-hot encoded labels (of same length as data). 5 | -------------------------------------------------------------------------------- /man-roxygen/params_data_resid.R: -------------------------------------------------------------------------------- 1 | #' @param data [`data.table`]\cr 2 | #' Features. 3 | #' @param resid [`numeric`]\cr 4 | #' Residuals (of same length as data). 5 | -------------------------------------------------------------------------------- /man-roxygen/params_mask.R: -------------------------------------------------------------------------------- 1 | #' @param mask [`integer`]\cr 2 | #' Mask applied to the data. Only used for `SubgroupAuditorFitter`. 3 | -------------------------------------------------------------------------------- /man-roxygen/params_subpops.R: -------------------------------------------------------------------------------- 1 | #' @param subpops [`list`] \cr 2 | #' Specifies a collection of characteristic attributes 3 | #' and the values they take to define subpopulations 4 | #' e.g. list(age = c('20-29','30-39','40+'), nJobs = c(0,1,2,'3+'), ,..). 5 | -------------------------------------------------------------------------------- /man-roxygen/return_auditor.R: -------------------------------------------------------------------------------- 1 | #' @return [`AuditorFitter`]\cr 2 | -------------------------------------------------------------------------------- /man-roxygen/return_fit.R: -------------------------------------------------------------------------------- 1 | #' @return `list` with items\cr 2 | #' - `corr`: pseudo-correlation between residuals and learner prediction. 3 | #' - `l`: the trained learner. 4 | -------------------------------------------------------------------------------- /man-roxygen/return_predictor.R: -------------------------------------------------------------------------------- 1 | #' @return [`Predictor`]\cr 2 | -------------------------------------------------------------------------------- /man/AuditorFitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/AuditorFitters.R 3 | \name{AuditorFitter} 4 | \alias{AuditorFitter} 5 | \title{AuditorFitter Abstract Base Class} 6 | \value{ 7 | \code{list} with items\cr 8 | \itemize{ 9 | \item \code{corr}: pseudo-correlation between residuals and learner prediction. 10 | \item \code{l}: the trained learner. 11 | } 12 | } 13 | \description{ 14 | Defines an \code{AuditorFitter} abstract base class. 15 | } 16 | \section{Methods}{ 17 | \subsection{Public methods}{ 18 | \itemize{ 19 | \item \href{#method-AuditorFitter-new}{\code{AuditorFitter$new()}} 20 | \item \href{#method-AuditorFitter-fit_to_resid}{\code{AuditorFitter$fit_to_resid()}} 21 | \item \href{#method-AuditorFitter-fit}{\code{AuditorFitter$fit()}} 22 | \item \href{#method-AuditorFitter-clone}{\code{AuditorFitter$clone()}} 23 | } 24 | } 25 | \if{html}{\out{
}} 26 | \if{html}{\out{}} 27 | \if{latex}{\out{\hypertarget{method-AuditorFitter-new}{}}} 28 | \subsection{Method \code{new()}}{ 29 | Initialize a \code{\link{AuditorFitter}}. 30 | This is an abstract base class. 31 | \subsection{Usage}{ 32 | \if{html}{\out{
}}\preformatted{AuditorFitter$new()}\if{html}{\out{
}} 33 | } 34 | 35 | } 36 | \if{html}{\out{
}} 37 | \if{html}{\out{}} 38 | \if{latex}{\out{\hypertarget{method-AuditorFitter-fit_to_resid}{}}} 39 | \subsection{Method \code{fit_to_resid()}}{ 40 | Fit to residuals. 41 | \subsection{Usage}{ 42 | \if{html}{\out{
}}\preformatted{AuditorFitter$fit_to_resid(data, resid, mask)}\if{html}{\out{
}} 43 | } 44 | 45 | \subsection{Arguments}{ 46 | \if{html}{\out{
}} 47 | \describe{ 48 | \item{\code{data}}{\code{\link{data.table}}\cr 49 | Features.} 50 | 51 | \item{\code{resid}}{\code{\link{numeric}}\cr 52 | Residuals (of same length as data).} 53 | 54 | \item{\code{mask}}{\code{\link{integer}}\cr 55 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 56 | } 57 | \if{html}{\out{
}} 58 | } 59 | } 60 | \if{html}{\out{
}} 61 | \if{html}{\out{}} 62 | \if{latex}{\out{\hypertarget{method-AuditorFitter-fit}{}}} 63 | \subsection{Method \code{fit()}}{ 64 | Fit (mostly used internally, use \code{fit_to_resid}). 65 | \subsection{Usage}{ 66 | \if{html}{\out{
}}\preformatted{AuditorFitter$fit(data, resid, mask)}\if{html}{\out{
}} 67 | } 68 | 69 | \subsection{Arguments}{ 70 | \if{html}{\out{
}} 71 | \describe{ 72 | \item{\code{data}}{\code{\link{data.table}}\cr 73 | Features.} 74 | 75 | \item{\code{resid}}{\code{\link{numeric}}\cr 76 | Residuals (of same length as data).} 77 | 78 | \item{\code{mask}}{\code{\link{integer}}\cr 79 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 80 | } 81 | \if{html}{\out{
}} 82 | } 83 | } 84 | \if{html}{\out{
}} 85 | \if{html}{\out{}} 86 | \if{latex}{\out{\hypertarget{method-AuditorFitter-clone}{}}} 87 | \subsection{Method \code{clone()}}{ 88 | The objects of this class are cloneable with this method. 89 | \subsection{Usage}{ 90 | \if{html}{\out{
}}\preformatted{AuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 91 | } 92 | 93 | \subsection{Arguments}{ 94 | \if{html}{\out{
}} 95 | \describe{ 96 | \item{\code{deep}}{Whether to make a deep clone.} 97 | } 98 | \if{html}{\out{
}} 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /man/CVLearnerAuditorFitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/AuditorFitters.R 3 | \name{CVLearnerAuditorFitter} 4 | \alias{CVLearnerAuditorFitter} 5 | \alias{CVTreeAuditorFitter} 6 | \alias{CVRidgeAuditorFitter} 7 | \title{Cross-validated AuditorFitter from a Learner} 8 | \value{ 9 | \code{\link{AuditorFitter}}\cr 10 | 11 | \code{list} with items\cr 12 | \itemize{ 13 | \item \code{corr}: pseudo-correlation between residuals and learner prediction. 14 | \item \code{l}: the trained learner. 15 | } 16 | } 17 | \description{ 18 | CVLearnerAuditorFitter returns the cross-validated predictions 19 | instead of the in-sample predictions. 20 | 21 | Available data is cut into complementary subsets (folds). 22 | For each subset out-of-sample predictions are received by training a model 23 | on all other subsets and predicting afterwards on the left-out subset. 24 | } 25 | \section{Functions}{ 26 | \itemize{ 27 | \item \code{CVTreeAuditorFitter}: Cross-Validated auditor based on rpart 28 | 29 | \item \code{CVRidgeAuditorFitter}: Cross-Validated auditor based on glmnet 30 | 31 | }} 32 | \seealso{ 33 | Other AuditorFitter: 34 | \code{\link{LearnerAuditorFitter}}, 35 | \code{\link{SubgroupAuditorFitter}}, 36 | \code{\link{SubpopAuditorFitter}} 37 | 38 | Other AuditorFitter: 39 | \code{\link{LearnerAuditorFitter}}, 40 | \code{\link{SubgroupAuditorFitter}}, 41 | \code{\link{SubpopAuditorFitter}} 42 | 43 | Other AuditorFitter: 44 | \code{\link{LearnerAuditorFitter}}, 45 | \code{\link{SubgroupAuditorFitter}}, 46 | \code{\link{SubpopAuditorFitter}} 47 | } 48 | \concept{AuditorFitter} 49 | \section{Super class}{ 50 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{CVLearnerAuditorFitter} 51 | } 52 | \section{Public fields}{ 53 | \if{html}{\out{
}} 54 | \describe{ 55 | \item{\code{learner}}{\code{CVLearnerPredictor}\cr 56 | Learner used for fitting residuals.} 57 | } 58 | \if{html}{\out{
}} 59 | } 60 | \section{Methods}{ 61 | \subsection{Public methods}{ 62 | \itemize{ 63 | \item \href{#method-CVLearnerAuditorFitter-new}{\code{CVLearnerAuditorFitter$new()}} 64 | \item \href{#method-CVLearnerAuditorFitter-fit}{\code{CVLearnerAuditorFitter$fit()}} 65 | \item \href{#method-CVLearnerAuditorFitter-clone}{\code{CVLearnerAuditorFitter$clone()}} 66 | } 67 | } 68 | \if{html}{\out{ 69 |
Inherited methods 70 | 73 |
74 | }} 75 | \if{html}{\out{
}} 76 | \if{html}{\out{}} 77 | \if{latex}{\out{\hypertarget{method-CVLearnerAuditorFitter-new}{}}} 78 | \subsection{Method \code{new()}}{ 79 | Define a \code{CVAuditorFitter} from a learner. 80 | Available instantiations:\cr \code{\link{CVTreeAuditorFitter}} (rpart) and 81 | \code{\link{CVRidgeAuditorFitter}} (glmnet). 82 | See \code{\link[mlr3pipelines:mlr_pipeops_learner_cv]{mlr3pipelines::PipeOpLearnerCV}} for more information on 83 | cross-validated learners. 84 | \subsection{Usage}{ 85 | \if{html}{\out{
}}\preformatted{CVLearnerAuditorFitter$new(learner, folds = 3L)}\if{html}{\out{
}} 86 | } 87 | 88 | \subsection{Arguments}{ 89 | \if{html}{\out{
}} 90 | \describe{ 91 | \item{\code{learner}}{\code{\link[mlr3:Learner]{mlr3::Learner}}\cr 92 | Regression Learner to use.} 93 | 94 | \item{\code{folds}}{\code{\link{integer}}\cr 95 | Number of folds to use for PipeOpLearnerCV. Defaults to 3.} 96 | } 97 | \if{html}{\out{
}} 98 | } 99 | } 100 | \if{html}{\out{
}} 101 | \if{html}{\out{}} 102 | \if{latex}{\out{\hypertarget{method-CVLearnerAuditorFitter-fit}{}}} 103 | \subsection{Method \code{fit()}}{ 104 | Fit the cross-validated learner and compute correlation 105 | \subsection{Usage}{ 106 | \if{html}{\out{
}}\preformatted{CVLearnerAuditorFitter$fit(data, resid, mask)}\if{html}{\out{
}} 107 | } 108 | 109 | \subsection{Arguments}{ 110 | \if{html}{\out{
}} 111 | \describe{ 112 | \item{\code{data}}{\code{\link{data.table}}\cr 113 | Features.} 114 | 115 | \item{\code{resid}}{\code{\link{numeric}}\cr 116 | Residuals (of same length as data).} 117 | 118 | \item{\code{mask}}{\code{\link{integer}}\cr 119 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 120 | } 121 | \if{html}{\out{
}} 122 | } 123 | } 124 | \if{html}{\out{
}} 125 | \if{html}{\out{}} 126 | \if{latex}{\out{\hypertarget{method-CVLearnerAuditorFitter-clone}{}}} 127 | \subsection{Method \code{clone()}}{ 128 | The objects of this class are cloneable with this method. 129 | \subsection{Usage}{ 130 | \if{html}{\out{
}}\preformatted{CVLearnerAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 131 | } 132 | 133 | \subsection{Arguments}{ 134 | \if{html}{\out{
}} 135 | \describe{ 136 | \item{\code{deep}}{Whether to make a deep clone.} 137 | } 138 | \if{html}{\out{
}} 139 | } 140 | } 141 | } 142 | \section{Super classes}{ 143 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{\link[mcboost:CVLearnerAuditorFitter]{mcboost::CVLearnerAuditorFitter}} -> \code{CVTreeAuditorFitter} 144 | } 145 | \section{Methods}{ 146 | \subsection{Public methods}{ 147 | \itemize{ 148 | \item \href{#method-CVTreeAuditorFitter-new}{\code{CVTreeAuditorFitter$new()}} 149 | \item \href{#method-CVTreeAuditorFitter-clone}{\code{CVTreeAuditorFitter$clone()}} 150 | } 151 | } 152 | \if{html}{\out{ 153 |
Inherited methods 154 | 158 |
159 | }} 160 | \if{html}{\out{
}} 161 | \if{html}{\out{}} 162 | \if{latex}{\out{\hypertarget{method-CVTreeAuditorFitter-new}{}}} 163 | \subsection{Method \code{new()}}{ 164 | Define a cross-validated AuditorFitter from a rpart learner 165 | See \code{\link[mlr3pipelines:mlr_pipeops_learner_cv]{mlr3pipelines::PipeOpLearnerCV}} for more information on 166 | cross-validated learners. 167 | \subsection{Usage}{ 168 | \if{html}{\out{
}}\preformatted{CVTreeAuditorFitter$new()}\if{html}{\out{
}} 169 | } 170 | 171 | } 172 | \if{html}{\out{
}} 173 | \if{html}{\out{}} 174 | \if{latex}{\out{\hypertarget{method-CVTreeAuditorFitter-clone}{}}} 175 | \subsection{Method \code{clone()}}{ 176 | The objects of this class are cloneable with this method. 177 | \subsection{Usage}{ 178 | \if{html}{\out{
}}\preformatted{CVTreeAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 179 | } 180 | 181 | \subsection{Arguments}{ 182 | \if{html}{\out{
}} 183 | \describe{ 184 | \item{\code{deep}}{Whether to make a deep clone.} 185 | } 186 | \if{html}{\out{
}} 187 | } 188 | } 189 | } 190 | \section{Super classes}{ 191 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{\link[mcboost:CVLearnerAuditorFitter]{mcboost::CVLearnerAuditorFitter}} -> \code{CVRidgeAuditorFitter} 192 | } 193 | \section{Methods}{ 194 | \subsection{Public methods}{ 195 | \itemize{ 196 | \item \href{#method-CVRidgeAuditorFitter-new}{\code{CVRidgeAuditorFitter$new()}} 197 | \item \href{#method-CVRidgeAuditorFitter-clone}{\code{CVRidgeAuditorFitter$clone()}} 198 | } 199 | } 200 | \if{html}{\out{ 201 |
Inherited methods 202 | 206 |
207 | }} 208 | \if{html}{\out{
}} 209 | \if{html}{\out{}} 210 | \if{latex}{\out{\hypertarget{method-CVRidgeAuditorFitter-new}{}}} 211 | \subsection{Method \code{new()}}{ 212 | Define a cross-validated AuditorFitter from a glmnet learner. 213 | See \code{\link[mlr3pipelines:mlr_pipeops_learner_cv]{mlr3pipelines::PipeOpLearnerCV}} for more information on 214 | cross-validated learners. 215 | \subsection{Usage}{ 216 | \if{html}{\out{
}}\preformatted{CVRidgeAuditorFitter$new()}\if{html}{\out{
}} 217 | } 218 | 219 | } 220 | \if{html}{\out{
}} 221 | \if{html}{\out{}} 222 | \if{latex}{\out{\hypertarget{method-CVRidgeAuditorFitter-clone}{}}} 223 | \subsection{Method \code{clone()}}{ 224 | The objects of this class are cloneable with this method. 225 | \subsection{Usage}{ 226 | \if{html}{\out{
}}\preformatted{CVRidgeAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 227 | } 228 | 229 | \subsection{Arguments}{ 230 | \if{html}{\out{
}} 231 | \describe{ 232 | \item{\code{deep}}{Whether to make a deep clone.} 233 | } 234 | \if{html}{\out{
}} 235 | } 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /man/LearnerAuditorFitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/AuditorFitters.R 3 | \name{LearnerAuditorFitter} 4 | \alias{LearnerAuditorFitter} 5 | \alias{TreeAuditorFitter} 6 | \alias{RidgeAuditorFitter} 7 | \title{Create an AuditorFitter from a Learner} 8 | \value{ 9 | \code{\link{AuditorFitter}}\cr 10 | 11 | \code{list} with items\cr 12 | \itemize{ 13 | \item \code{corr}: pseudo-correlation between residuals and learner prediction. 14 | \item \code{l}: the trained learner. 15 | } 16 | } 17 | \description{ 18 | Instantiates an AuditorFitter that trains a \code{\link[mlr3:Learner]{mlr3::Learner}} 19 | on the data. 20 | } 21 | \section{Functions}{ 22 | \itemize{ 23 | \item \code{TreeAuditorFitter}: Learner auditor based on rpart 24 | 25 | \item \code{RidgeAuditorFitter}: Learner auditor based on glmnet 26 | 27 | }} 28 | \seealso{ 29 | Other AuditorFitter: 30 | \code{\link{CVLearnerAuditorFitter}}, 31 | \code{\link{SubgroupAuditorFitter}}, 32 | \code{\link{SubpopAuditorFitter}} 33 | 34 | Other AuditorFitter: 35 | \code{\link{CVLearnerAuditorFitter}}, 36 | \code{\link{SubgroupAuditorFitter}}, 37 | \code{\link{SubpopAuditorFitter}} 38 | 39 | Other AuditorFitter: 40 | \code{\link{CVLearnerAuditorFitter}}, 41 | \code{\link{SubgroupAuditorFitter}}, 42 | \code{\link{SubpopAuditorFitter}} 43 | } 44 | \concept{AuditorFitter} 45 | \section{Super class}{ 46 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{LearnerAuditorFitter} 47 | } 48 | \section{Public fields}{ 49 | \if{html}{\out{
}} 50 | \describe{ 51 | \item{\code{learner}}{\code{LearnerPredictor}\cr 52 | Learner used for fitting residuals.} 53 | } 54 | \if{html}{\out{
}} 55 | } 56 | \section{Methods}{ 57 | \subsection{Public methods}{ 58 | \itemize{ 59 | \item \href{#method-LearnerAuditorFitter-new}{\code{LearnerAuditorFitter$new()}} 60 | \item \href{#method-LearnerAuditorFitter-fit}{\code{LearnerAuditorFitter$fit()}} 61 | \item \href{#method-LearnerAuditorFitter-clone}{\code{LearnerAuditorFitter$clone()}} 62 | } 63 | } 64 | \if{html}{\out{ 65 |
Inherited methods 66 | 69 |
70 | }} 71 | \if{html}{\out{
}} 72 | \if{html}{\out{}} 73 | \if{latex}{\out{\hypertarget{method-LearnerAuditorFitter-new}{}}} 74 | \subsection{Method \code{new()}}{ 75 | Define an \code{AuditorFitter} from a Learner. 76 | Available instantiations:\cr \code{\link{TreeAuditorFitter}} (rpart) and 77 | \code{\link{RidgeAuditorFitter}} (glmnet). 78 | \subsection{Usage}{ 79 | \if{html}{\out{
}}\preformatted{LearnerAuditorFitter$new(learner)}\if{html}{\out{
}} 80 | } 81 | 82 | \subsection{Arguments}{ 83 | \if{html}{\out{
}} 84 | \describe{ 85 | \item{\code{learner}}{\code{\link[mlr3:Learner]{mlr3::Learner}}\cr 86 | Regression learner to use.} 87 | } 88 | \if{html}{\out{
}} 89 | } 90 | } 91 | \if{html}{\out{
}} 92 | \if{html}{\out{}} 93 | \if{latex}{\out{\hypertarget{method-LearnerAuditorFitter-fit}{}}} 94 | \subsection{Method \code{fit()}}{ 95 | Fit the learner and compute correlation 96 | \subsection{Usage}{ 97 | \if{html}{\out{
}}\preformatted{LearnerAuditorFitter$fit(data, resid, mask)}\if{html}{\out{
}} 98 | } 99 | 100 | \subsection{Arguments}{ 101 | \if{html}{\out{
}} 102 | \describe{ 103 | \item{\code{data}}{\code{\link{data.table}}\cr 104 | Features.} 105 | 106 | \item{\code{resid}}{\code{\link{numeric}}\cr 107 | Residuals (of same length as data).} 108 | 109 | \item{\code{mask}}{\code{\link{integer}}\cr 110 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 111 | } 112 | \if{html}{\out{
}} 113 | } 114 | } 115 | \if{html}{\out{
}} 116 | \if{html}{\out{}} 117 | \if{latex}{\out{\hypertarget{method-LearnerAuditorFitter-clone}{}}} 118 | \subsection{Method \code{clone()}}{ 119 | The objects of this class are cloneable with this method. 120 | \subsection{Usage}{ 121 | \if{html}{\out{
}}\preformatted{LearnerAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 122 | } 123 | 124 | \subsection{Arguments}{ 125 | \if{html}{\out{
}} 126 | \describe{ 127 | \item{\code{deep}}{Whether to make a deep clone.} 128 | } 129 | \if{html}{\out{
}} 130 | } 131 | } 132 | } 133 | \section{Super classes}{ 134 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{\link[mcboost:LearnerAuditorFitter]{mcboost::LearnerAuditorFitter}} -> \code{TreeAuditorFitter} 135 | } 136 | \section{Methods}{ 137 | \subsection{Public methods}{ 138 | \itemize{ 139 | \item \href{#method-TreeAuditorFitter-new}{\code{TreeAuditorFitter$new()}} 140 | \item \href{#method-TreeAuditorFitter-clone}{\code{TreeAuditorFitter$clone()}} 141 | } 142 | } 143 | \if{html}{\out{ 144 |
Inherited methods 145 | 149 |
150 | }} 151 | \if{html}{\out{
}} 152 | \if{html}{\out{}} 153 | \if{latex}{\out{\hypertarget{method-TreeAuditorFitter-new}{}}} 154 | \subsection{Method \code{new()}}{ 155 | Define a AuditorFitter from a rpart learner. 156 | \subsection{Usage}{ 157 | \if{html}{\out{
}}\preformatted{TreeAuditorFitter$new()}\if{html}{\out{
}} 158 | } 159 | 160 | } 161 | \if{html}{\out{
}} 162 | \if{html}{\out{}} 163 | \if{latex}{\out{\hypertarget{method-TreeAuditorFitter-clone}{}}} 164 | \subsection{Method \code{clone()}}{ 165 | The objects of this class are cloneable with this method. 166 | \subsection{Usage}{ 167 | \if{html}{\out{
}}\preformatted{TreeAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 168 | } 169 | 170 | \subsection{Arguments}{ 171 | \if{html}{\out{
}} 172 | \describe{ 173 | \item{\code{deep}}{Whether to make a deep clone.} 174 | } 175 | \if{html}{\out{
}} 176 | } 177 | } 178 | } 179 | \section{Super classes}{ 180 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{\link[mcboost:LearnerAuditorFitter]{mcboost::LearnerAuditorFitter}} -> \code{RidgeAuditorFitter} 181 | } 182 | \section{Methods}{ 183 | \subsection{Public methods}{ 184 | \itemize{ 185 | \item \href{#method-RidgeAuditorFitter-new}{\code{RidgeAuditorFitter$new()}} 186 | \item \href{#method-RidgeAuditorFitter-clone}{\code{RidgeAuditorFitter$clone()}} 187 | } 188 | } 189 | \if{html}{\out{ 190 |
Inherited methods 191 | 195 |
196 | }} 197 | \if{html}{\out{
}} 198 | \if{html}{\out{}} 199 | \if{latex}{\out{\hypertarget{method-RidgeAuditorFitter-new}{}}} 200 | \subsection{Method \code{new()}}{ 201 | Define a AuditorFitter from a glmnet learner. 202 | \subsection{Usage}{ 203 | \if{html}{\out{
}}\preformatted{RidgeAuditorFitter$new()}\if{html}{\out{
}} 204 | } 205 | 206 | } 207 | \if{html}{\out{
}} 208 | \if{html}{\out{}} 209 | \if{latex}{\out{\hypertarget{method-RidgeAuditorFitter-clone}{}}} 210 | \subsection{Method \code{clone()}}{ 211 | The objects of this class are cloneable with this method. 212 | \subsection{Usage}{ 213 | \if{html}{\out{
}}\preformatted{RidgeAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 214 | } 215 | 216 | \subsection{Arguments}{ 217 | \if{html}{\out{
}} 218 | \describe{ 219 | \item{\code{deep}}{Whether to make a deep clone.} 220 | } 221 | \if{html}{\out{
}} 222 | } 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /man/MCBoost.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/MCBoost.R 3 | \name{MCBoost} 4 | \alias{MCBoost} 5 | \title{Multi-Calibration Boosting} 6 | \description{ 7 | Implements Multi-Calibration Boosting by Hebert-Johnson et al. (2018) and 8 | Multi-Accuracy Boosting by Kim et al. (2019) for the multi-calibration of a 9 | machine learning model's prediction. 10 | Multi-Calibration works best in scenarios where the underlying data & labels are unbiased 11 | but a bias is introduced within the algorithm's fitting procedure. This is often the case, 12 | e.g. when an algorithm fits a majority population while ignoring or under-fitting minority 13 | populations.\cr 14 | Expects initial models that fit binary outcomes or continuous outcomes with 15 | predictions that are in (or scaled to) the 0-1 range. 16 | The method defaults to \verb{Multi-Accuracy Boosting} as described in Kim et al. (2019). 17 | In order to obtain behaviour as described in Hebert-Johnson et al. (2018) set 18 | \code{multiplicative=FALSE} and \code{num_buckets} to 10. 19 | \itemize{ 20 | For additional details, please refer to the relevant publications: 21 | \item{Hebert-Johnson et al., 2018. Multicalibration: Calibration for the (Computationally-Identifiable) Masses. 22 | Proceedings of the 35th International Conference on Machine Learning, PMLR 80:1939-1948. 23 | https://proceedings.mlr.press/v80/hebert-johnson18a.html.}{} 24 | \item{Kim et al., 2019. Multiaccuracy: Black-Box Post-Processing for Fairness in Classification. 25 | Proceedings of the 2019 AAAI/ACM Conference on AI, Ethics, and Society (AIES '19). 26 | Association for Computing Machinery, New York, NY, USA, 247–254. 27 | https://dl.acm.org/doi/10.1145/3306618.3314287}{} 28 | } 29 | } 30 | \examples{ 31 | # See vignette for more examples. 32 | # Instantiate the object 33 | \dontrun{ 34 | mc = MCBoost$new() 35 | # Run multi-calibration on training dataset. 36 | mc$multicalibrate(iris[1:100, 1:4], factor(sample(c("A", "B"), 100, TRUE))) 37 | # Predict on test set 38 | mc$predict_probs(iris[101:150, 1:4]) 39 | # Get auditor effect 40 | mc$auditor_effect(iris[101:150, 1:4]) 41 | } 42 | } 43 | \section{Public fields}{ 44 | \if{html}{\out{
}} 45 | \describe{ 46 | \item{\code{max_iter}}{\code{\link{integer}} \cr 47 | The maximum number of iterations of the multi-calibration/multi-accuracy method.} 48 | 49 | \item{\code{alpha}}{\code{\link{numeric}} \cr 50 | Accuracy parameter that determines the stopping condition.} 51 | 52 | \item{\code{eta}}{\code{\link{numeric}} \cr 53 | Parameter for multiplicative weight update (step size).} 54 | 55 | \item{\code{num_buckets}}{\code{\link{integer}} \cr 56 | The number of buckets to split into in addition to using the whole sample.} 57 | 58 | \item{\code{bucket_strategy}}{\code{\link{character}} \cr 59 | Currently only supports "simple", even split along probabilities. 60 | Only relevant for \code{num_buckets} > 1.} 61 | 62 | \item{\code{rebucket}}{\code{\link{logical}} \cr 63 | Should buckets be re-calculated at each iteration?} 64 | 65 | \item{\code{eval_fulldata}}{\code{\link{logical}} \cr 66 | Should auditor be evaluated on the full data?} 67 | 68 | \item{\code{partition}}{\code{\link{logical}} \cr 69 | True/False flag for whether to split up predictions by their "partition" 70 | (e.g., predictions less than 0.5 and predictions greater than 0.5).} 71 | 72 | \item{\code{multiplicative}}{\code{\link{logical}} \cr 73 | Specifies the strategy for updating the weights (multiplicative weight vs additive).} 74 | 75 | \item{\code{iter_sampling}}{\code{\link{character}} \cr 76 | Specifies the strategy to sample the validation data for each iteration.} 77 | 78 | \item{\code{auditor_fitter}}{\code{\link{AuditorFitter}} \cr 79 | Specifies the type of model used to fit the residuals.} 80 | 81 | \item{\code{predictor}}{\code{\link{function}} \cr 82 | Initial predictor function.} 83 | 84 | \item{\code{iter_models}}{\code{\link{list}} \cr 85 | Cumulative list of fitted models.} 86 | 87 | \item{\code{iter_partitions}}{\code{\link{list}} \cr 88 | Cumulative list of data partitions for models.} 89 | 90 | \item{\code{iter_corr}}{\code{\link{list}} \cr 91 | Auditor correlation in each iteration.} 92 | 93 | \item{\code{auditor_effects}}{\code{\link{list}} \cr 94 | Auditor effect in each iteration.} 95 | 96 | \item{\code{bucket_strategies}}{\code{\link{character}} \cr 97 | Possible bucket_strategies.} 98 | 99 | \item{\code{weight_degree}}{\code{\link{integer}} \cr 100 | Weighting degree for low-degree multi-calibration.} 101 | } 102 | \if{html}{\out{
}} 103 | } 104 | \section{Methods}{ 105 | \subsection{Public methods}{ 106 | \itemize{ 107 | \item \href{#method-MCBoost-new}{\code{MCBoost$new()}} 108 | \item \href{#method-MCBoost-multicalibrate}{\code{MCBoost$multicalibrate()}} 109 | \item \href{#method-MCBoost-predict_probs}{\code{MCBoost$predict_probs()}} 110 | \item \href{#method-MCBoost-auditor_effect}{\code{MCBoost$auditor_effect()}} 111 | \item \href{#method-MCBoost-print}{\code{MCBoost$print()}} 112 | \item \href{#method-MCBoost-clone}{\code{MCBoost$clone()}} 113 | } 114 | } 115 | \if{html}{\out{
}} 116 | \if{html}{\out{}} 117 | \if{latex}{\out{\hypertarget{method-MCBoost-new}{}}} 118 | \subsection{Method \code{new()}}{ 119 | Initialize a multi-calibration instance. 120 | \subsection{Usage}{ 121 | \if{html}{\out{
}}\preformatted{MCBoost$new( 122 | max_iter = 5, 123 | alpha = 1e-04, 124 | eta = 1, 125 | num_buckets = 2, 126 | partition = ifelse(num_buckets > 1, TRUE, FALSE), 127 | bucket_strategy = "simple", 128 | rebucket = FALSE, 129 | eval_fulldata = FALSE, 130 | multiplicative = TRUE, 131 | auditor_fitter = NULL, 132 | subpops = NULL, 133 | default_model_class = ConstantPredictor, 134 | init_predictor = NULL, 135 | iter_sampling = "none", 136 | weight_degree = 1L 137 | )}\if{html}{\out{
}} 138 | } 139 | 140 | \subsection{Arguments}{ 141 | \if{html}{\out{
}} 142 | \describe{ 143 | \item{\code{max_iter}}{\code{\link{integer}} \cr 144 | The maximum number of iterations of the multi-calibration/multi-accuracy method. 145 | Default \code{5L}.} 146 | 147 | \item{\code{alpha}}{\code{\link{numeric}} \cr 148 | Accuracy parameter that determines the stopping condition. Default \code{1e-4}.} 149 | 150 | \item{\code{eta}}{\code{\link{numeric}} \cr 151 | Parameter for multiplicative weight update (step size). Default \code{1.0}.} 152 | 153 | \item{\code{num_buckets}}{\code{\link{integer}} \cr 154 | The number of buckets to split into in addition to using the whole sample. Default \code{2L}.} 155 | 156 | \item{\code{partition}}{\code{\link{logical}} \cr 157 | True/False flag for whether to split up predictions by their "partition" 158 | (e.g., predictions less than 0.5 and predictions greater than 0.5). 159 | Defaults to \code{TRUE} (multi-accuracy boosting).} 160 | 161 | \item{\code{bucket_strategy}}{\code{\link{character}} \cr 162 | Currently only supports "simple", even split along probabilities. 163 | Only taken into account for \code{num_buckets} > 1.} 164 | 165 | \item{\code{rebucket}}{\code{\link{logical}} \cr 166 | Should buckets be re-done at each iteration? Default \code{FALSE}.} 167 | 168 | \item{\code{eval_fulldata}}{\code{\link{logical}} \cr 169 | Should the auditor be evaluated on the full data or on the respective bucket for determining 170 | the stopping criterion? Default \code{FALSE}, auditor is only evaluated on the bucket. 171 | This setting keeps the implementation closer to the Algorithm proposed in the corresponding 172 | multi-accuracy paper (Kim et al., 2019) where auditor effects are computed across the full 173 | sample (i.e. eval_fulldata = TRUE).} 174 | 175 | \item{\code{multiplicative}}{\code{\link{logical}} \cr 176 | Specifies the strategy for updating the weights (multiplicative weight vs additive). 177 | Defaults to \code{TRUE} (multi-accuracy boosting). Set to \code{FALSE} for multi-calibration.} 178 | 179 | \item{\code{auditor_fitter}}{\code{\link{AuditorFitter}}|\code{\link{character}}|\code{\link[mlr3:Learner]{mlr3::Learner}} \cr 180 | Specifies the type of model used to fit the 181 | residuals. The default is \code{\link{RidgeAuditorFitter}}. 182 | Can be a \code{character}, the name of a \code{\link{AuditorFitter}}, a \code{\link[mlr3:Learner]{mlr3::Learner}} that is then 183 | auto-converted into a \code{\link{LearnerAuditorFitter}} or a custom \code{\link{AuditorFitter}}.} 184 | 185 | \item{\code{subpops}}{\code{\link{list}} \cr 186 | Specifies a collection of characteristic attributes 187 | and the values they take to define subpopulations 188 | e.g. list(age = c('20-29','30-39','40+'), nJobs = c(0,1,2,'3+'), ,..).} 189 | 190 | \item{\code{default_model_class}}{\code{Predictor} \cr 191 | The class of the model that should be used as the init predictor model if 192 | \code{init_predictor} is not specified. Defaults to \code{ConstantPredictor} which 193 | predicts a constant value.} 194 | 195 | \item{\code{init_predictor}}{\code{\link{function}}|\code{\link[mlr3:Learner]{mlr3::Learner}} \cr 196 | The initial predictor function to use (i.e., if the user has a pretrained model). 197 | If a \code{mlr3} \code{Learner} is passed, it will be autoconverted using \code{mlr3_init_predictor}. 198 | This requires the \code{\link[mlr3:Learner]{mlr3::Learner}} to be trained.} 199 | 200 | \item{\code{iter_sampling}}{\code{\link{character}} \cr 201 | How to sample the validation data for each iteration? 202 | Can be \code{bootstrap}, \code{split} or \code{none}.\cr 203 | "split" splits the data into \code{max_iter} parts and validates on each sample in each iteration.\cr 204 | "bootstrap" uses a new bootstrap sample in each iteration.\cr 205 | "none" uses the same dataset in each iteration.} 206 | 207 | \item{\code{weight_degree}}{\code{\link{character}} \cr 208 | Weighting degree for low-degree multi-calibration. Initialized to 1, which applies constant weighting with 1.} 209 | } 210 | \if{html}{\out{
}} 211 | } 212 | } 213 | \if{html}{\out{
}} 214 | \if{html}{\out{}} 215 | \if{latex}{\out{\hypertarget{method-MCBoost-multicalibrate}{}}} 216 | \subsection{Method \code{multicalibrate()}}{ 217 | Run multi-calibration. 218 | \subsection{Usage}{ 219 | \if{html}{\out{
}}\preformatted{MCBoost$multicalibrate(data, labels, predictor_args = NULL, audit = FALSE, ...)}\if{html}{\out{
}} 220 | } 221 | 222 | \subsection{Arguments}{ 223 | \if{html}{\out{
}} 224 | \describe{ 225 | \item{\code{data}}{\code{\link{data.table}}\cr 226 | Features.} 227 | 228 | \item{\code{labels}}{\code{\link{numeric}}\cr 229 | One-hot encoded labels (of same length as data).} 230 | 231 | \item{\code{predictor_args}}{\code{\link{any}} \cr 232 | Arguments passed on to \code{init_predictor}. Defaults to \code{NULL}.} 233 | 234 | \item{\code{audit}}{\code{\link{logical}} \cr 235 | Perform auditing? Initialized to \code{TRUE}.} 236 | 237 | \item{\code{...}}{\code{\link{any}} \cr 238 | Params passed on to other methods.} 239 | } 240 | \if{html}{\out{
}} 241 | } 242 | \subsection{Returns}{ 243 | \code{NULL} 244 | } 245 | } 246 | \if{html}{\out{
}} 247 | \if{html}{\out{}} 248 | \if{latex}{\out{\hypertarget{method-MCBoost-predict_probs}{}}} 249 | \subsection{Method \code{predict_probs()}}{ 250 | Predict a dataset with multi-calibrated predictions 251 | \subsection{Usage}{ 252 | \if{html}{\out{
}}\preformatted{MCBoost$predict_probs(x, t = Inf, predictor_args = NULL, audit = FALSE, ...)}\if{html}{\out{
}} 253 | } 254 | 255 | \subsection{Arguments}{ 256 | \if{html}{\out{
}} 257 | \describe{ 258 | \item{\code{x}}{\code{\link{data.table}} \cr 259 | Prediction data.} 260 | 261 | \item{\code{t}}{\code{\link{integer}} \cr 262 | Number of multi-calibration steps to predict. Default: \code{Inf} (all).} 263 | 264 | \item{\code{predictor_args}}{\code{\link{any}} \cr 265 | Arguments passed on to \code{init_predictor}. Defaults to \code{NULL}.} 266 | 267 | \item{\code{audit}}{\code{\link{logical}} \cr 268 | Should audit weights be stored? Default \code{FALSE}.} 269 | 270 | \item{\code{...}}{\code{\link{any}} \cr 271 | Params passed on to the residual prediction model's predict method.} 272 | } 273 | \if{html}{\out{
}} 274 | } 275 | \subsection{Returns}{ 276 | \code{\link{numeric}}\cr 277 | Numeric vector of multi-calibrated predictions. 278 | } 279 | } 280 | \if{html}{\out{
}} 281 | \if{html}{\out{}} 282 | \if{latex}{\out{\hypertarget{method-MCBoost-auditor_effect}{}}} 283 | \subsection{Method \code{auditor_effect()}}{ 284 | Compute the auditor effect for each instance which are the cumulative 285 | absolute predictions of the auditor. It indicates "how much" 286 | each observation was affected by multi-calibration on average across iterations. 287 | \subsection{Usage}{ 288 | \if{html}{\out{
}}\preformatted{MCBoost$auditor_effect( 289 | x, 290 | aggregate = TRUE, 291 | t = Inf, 292 | predictor_args = NULL, 293 | ... 294 | )}\if{html}{\out{
}} 295 | } 296 | 297 | \subsection{Arguments}{ 298 | \if{html}{\out{
}} 299 | \describe{ 300 | \item{\code{x}}{\code{\link{data.table}} \cr 301 | Prediction data.} 302 | 303 | \item{\code{aggregate}}{\code{\link{logical}} \cr 304 | Should the auditor effect be aggregated across iterations? Defaults to \code{TRUE}.} 305 | 306 | \item{\code{t}}{\code{\link{integer}} \cr 307 | Number of multi-calibration steps to predict. Defaults to \code{Inf} (all).} 308 | 309 | \item{\code{predictor_args}}{\code{\link{any}} \cr 310 | Arguments passed on to \code{init_predictor}. Defaults to \code{NULL}.} 311 | 312 | \item{\code{...}}{\code{\link{any}} \cr 313 | Params passed on to the residual prediction model's predict method.} 314 | } 315 | \if{html}{\out{
}} 316 | } 317 | \subsection{Returns}{ 318 | \code{\link{numeric}} \cr 319 | Numeric vector of auditor effects for each row in \code{x}. 320 | } 321 | } 322 | \if{html}{\out{
}} 323 | \if{html}{\out{}} 324 | \if{latex}{\out{\hypertarget{method-MCBoost-print}{}}} 325 | \subsection{Method \code{print()}}{ 326 | Prints information about multi-calibration. 327 | \subsection{Usage}{ 328 | \if{html}{\out{
}}\preformatted{MCBoost$print(...)}\if{html}{\out{
}} 329 | } 330 | 331 | \subsection{Arguments}{ 332 | \if{html}{\out{
}} 333 | \describe{ 334 | \item{\code{...}}{\code{any}\cr 335 | Not used.} 336 | } 337 | \if{html}{\out{
}} 338 | } 339 | } 340 | \if{html}{\out{
}} 341 | \if{html}{\out{}} 342 | \if{latex}{\out{\hypertarget{method-MCBoost-clone}{}}} 343 | \subsection{Method \code{clone()}}{ 344 | The objects of this class are cloneable with this method. 345 | \subsection{Usage}{ 346 | \if{html}{\out{
}}\preformatted{MCBoost$clone(deep = FALSE)}\if{html}{\out{
}} 347 | } 348 | 349 | \subsection{Arguments}{ 350 | \if{html}{\out{
}} 351 | \describe{ 352 | \item{\code{deep}}{Whether to make a deep clone.} 353 | } 354 | \if{html}{\out{
}} 355 | } 356 | } 357 | } 358 | -------------------------------------------------------------------------------- /man/SubgroupAuditorFitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/AuditorFitters.R 3 | \name{SubgroupAuditorFitter} 4 | \alias{SubgroupAuditorFitter} 5 | \title{Static AuditorFitter based on Subgroups} 6 | \value{ 7 | \code{\link{AuditorFitter}}\cr 8 | 9 | \code{list} with items\cr 10 | \itemize{ 11 | \item \code{corr}: pseudo-correlation between residuals and learner prediction. 12 | \item \code{l}: the trained learner. 13 | } 14 | } 15 | \description{ 16 | Used to assess multi-calibration based on a list of 17 | binary \code{subgroup_masks} passed during initialization. 18 | } 19 | \examples{ 20 | library("data.table") 21 | data = data.table( 22 | "AGE_0_10" = c(1, 1, 0, 0, 0), 23 | "AGE_11_20" = c(0, 0, 1, 0, 0), 24 | "AGE_21_31" = c(0, 0, 0, 1, 1), 25 | "X1" = runif(5), 26 | "X2" = runif(5) 27 | ) 28 | label = c(1,0,0,1,1) 29 | masks = list( 30 | "M1" = c(1L, 0L, 1L, 1L, 0L), 31 | "M2" = c(1L, 0L, 0L, 0L, 1L) 32 | ) 33 | sg = SubgroupAuditorFitter$new(masks) 34 | } 35 | \seealso{ 36 | Other AuditorFitter: 37 | \code{\link{CVLearnerAuditorFitter}}, 38 | \code{\link{LearnerAuditorFitter}}, 39 | \code{\link{SubpopAuditorFitter}} 40 | } 41 | \concept{AuditorFitter} 42 | \section{Super class}{ 43 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{SubgroupAuditorFitter} 44 | } 45 | \section{Public fields}{ 46 | \if{html}{\out{
}} 47 | \describe{ 48 | \item{\code{subgroup_masks}}{\code{\link{list}} \cr 49 | List of subgroup masks. 50 | Initialize a SubgroupAuditorFitter} 51 | } 52 | \if{html}{\out{
}} 53 | } 54 | \section{Methods}{ 55 | \subsection{Public methods}{ 56 | \itemize{ 57 | \item \href{#method-SubgroupAuditorFitter-new}{\code{SubgroupAuditorFitter$new()}} 58 | \item \href{#method-SubgroupAuditorFitter-fit}{\code{SubgroupAuditorFitter$fit()}} 59 | \item \href{#method-SubgroupAuditorFitter-clone}{\code{SubgroupAuditorFitter$clone()}} 60 | } 61 | } 62 | \if{html}{\out{ 63 |
Inherited methods 64 | 67 |
68 | }} 69 | \if{html}{\out{
}} 70 | \if{html}{\out{}} 71 | \if{latex}{\out{\hypertarget{method-SubgroupAuditorFitter-new}{}}} 72 | \subsection{Method \code{new()}}{ 73 | Initializes a \code{\link{SubgroupAuditorFitter}} that 74 | assesses multi-calibration within each group defined 75 | by the `subpops'. 76 | \subsection{Usage}{ 77 | \if{html}{\out{
}}\preformatted{SubgroupAuditorFitter$new(subgroup_masks)}\if{html}{\out{
}} 78 | } 79 | 80 | \subsection{Arguments}{ 81 | \if{html}{\out{
}} 82 | \describe{ 83 | \item{\code{subgroup_masks}}{\code{\link{list}} \cr 84 | List of subgroup masks. Subgroup masks are list(s) of integer masks, 85 | each with the same length as data to be fitted on. 86 | They allow defining subgroups of the data.} 87 | } 88 | \if{html}{\out{
}} 89 | } 90 | } 91 | \if{html}{\out{
}} 92 | \if{html}{\out{}} 93 | \if{latex}{\out{\hypertarget{method-SubgroupAuditorFitter-fit}{}}} 94 | \subsection{Method \code{fit()}}{ 95 | Fit the learner and compute correlation 96 | \subsection{Usage}{ 97 | \if{html}{\out{
}}\preformatted{SubgroupAuditorFitter$fit(data, resid, mask)}\if{html}{\out{
}} 98 | } 99 | 100 | \subsection{Arguments}{ 101 | \if{html}{\out{
}} 102 | \describe{ 103 | \item{\code{data}}{\code{\link{data.table}}\cr 104 | Features.} 105 | 106 | \item{\code{resid}}{\code{\link{numeric}}\cr 107 | Residuals (of same length as data).} 108 | 109 | \item{\code{mask}}{\code{\link{integer}}\cr 110 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 111 | } 112 | \if{html}{\out{
}} 113 | } 114 | } 115 | \if{html}{\out{
}} 116 | \if{html}{\out{}} 117 | \if{latex}{\out{\hypertarget{method-SubgroupAuditorFitter-clone}{}}} 118 | \subsection{Method \code{clone()}}{ 119 | The objects of this class are cloneable with this method. 120 | \subsection{Usage}{ 121 | \if{html}{\out{
}}\preformatted{SubgroupAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 122 | } 123 | 124 | \subsection{Arguments}{ 125 | \if{html}{\out{
}} 126 | \describe{ 127 | \item{\code{deep}}{Whether to make a deep clone.} 128 | } 129 | \if{html}{\out{
}} 130 | } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /man/SubpopAuditorFitter.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/AuditorFitters.R 3 | \name{SubpopAuditorFitter} 4 | \alias{SubpopAuditorFitter} 5 | \title{Static AuditorFitter based on Subpopulations} 6 | \value{ 7 | \code{\link{AuditorFitter}}\cr 8 | 9 | \code{list} with items\cr 10 | \itemize{ 11 | \item \code{corr}: pseudo-correlation between residuals and learner prediction. 12 | \item \code{l}: the trained learner. 13 | } 14 | } 15 | \description{ 16 | Used to assess multi-calibration based on a list of 17 | binary valued columns: \code{subpops} passed during initialization. 18 | } 19 | \examples{ 20 | library("data.table") 21 | data = data.table( 22 | "AGE_NA" = c(0, 0, 0, 0, 0), 23 | "AGE_0_10" = c(1, 1, 0, 0, 0), 24 | "AGE_11_20" = c(0, 0, 1, 0, 0), 25 | "AGE_21_31" = c(0, 0, 0, 1, 1), 26 | "X1" = runif(5), 27 | "X2" = runif(5) 28 | ) 29 | label = c(1,0,0,1,1) 30 | pops = list("AGE_NA", "AGE_0_10", "AGE_11_20", "AGE_21_31", function(x) {x[["X1" > 0.5]]}) 31 | sf = SubpopAuditorFitter$new(subpops = pops) 32 | sf$fit(data, label - 0.5) 33 | } 34 | \seealso{ 35 | Other AuditorFitter: 36 | \code{\link{CVLearnerAuditorFitter}}, 37 | \code{\link{LearnerAuditorFitter}}, 38 | \code{\link{SubgroupAuditorFitter}} 39 | } 40 | \concept{AuditorFitter} 41 | \section{Super class}{ 42 | \code{\link[mcboost:AuditorFitter]{mcboost::AuditorFitter}} -> \code{SubpopAuditorFitter} 43 | } 44 | \section{Public fields}{ 45 | \if{html}{\out{
}} 46 | \describe{ 47 | \item{\code{subpops}}{\code{\link{list}} \cr 48 | List of subpopulation indicators. 49 | Initialize a SubpopAuditorFitter} 50 | } 51 | \if{html}{\out{
}} 52 | } 53 | \section{Methods}{ 54 | \subsection{Public methods}{ 55 | \itemize{ 56 | \item \href{#method-SubpopAuditorFitter-new}{\code{SubpopAuditorFitter$new()}} 57 | \item \href{#method-SubpopAuditorFitter-fit}{\code{SubpopAuditorFitter$fit()}} 58 | \item \href{#method-SubpopAuditorFitter-clone}{\code{SubpopAuditorFitter$clone()}} 59 | } 60 | } 61 | \if{html}{\out{ 62 |
Inherited methods 63 | 66 |
67 | }} 68 | \if{html}{\out{
}} 69 | \if{html}{\out{}} 70 | \if{latex}{\out{\hypertarget{method-SubpopAuditorFitter-new}{}}} 71 | \subsection{Method \code{new()}}{ 72 | Initializes a \code{\link{SubpopAuditorFitter}} that 73 | assesses multi-calibration within each group defined 74 | by the \verb{subpops'. Names in }subpops` must correspond to 75 | columns in the data. 76 | \subsection{Usage}{ 77 | \if{html}{\out{
}}\preformatted{SubpopAuditorFitter$new(subpops)}\if{html}{\out{
}} 78 | } 79 | 80 | \subsection{Arguments}{ 81 | \if{html}{\out{
}} 82 | \describe{ 83 | \item{\code{subpops}}{\code{\link{list}} \cr 84 | Specifies a collection of characteristic attributes 85 | and the values they take to define subpopulations 86 | e.g. list(age = c('20-29','30-39','40+'), nJobs = c(0,1,2,'3+'), ,..).} 87 | } 88 | \if{html}{\out{
}} 89 | } 90 | } 91 | \if{html}{\out{
}} 92 | \if{html}{\out{}} 93 | \if{latex}{\out{\hypertarget{method-SubpopAuditorFitter-fit}{}}} 94 | \subsection{Method \code{fit()}}{ 95 | Fit the learner and compute correlation 96 | \subsection{Usage}{ 97 | \if{html}{\out{
}}\preformatted{SubpopAuditorFitter$fit(data, resid, mask)}\if{html}{\out{
}} 98 | } 99 | 100 | \subsection{Arguments}{ 101 | \if{html}{\out{
}} 102 | \describe{ 103 | \item{\code{data}}{\code{\link{data.table}}\cr 104 | Features.} 105 | 106 | \item{\code{resid}}{\code{\link{numeric}}\cr 107 | Residuals (of same length as data).} 108 | 109 | \item{\code{mask}}{\code{\link{integer}}\cr 110 | Mask applied to the data. Only used for \code{SubgroupAuditorFitter}.} 111 | } 112 | \if{html}{\out{
}} 113 | } 114 | } 115 | \if{html}{\out{
}} 116 | \if{html}{\out{}} 117 | \if{latex}{\out{\hypertarget{method-SubpopAuditorFitter-clone}{}}} 118 | \subsection{Method \code{clone()}}{ 119 | The objects of this class are cloneable with this method. 120 | \subsection{Usage}{ 121 | \if{html}{\out{
}}\preformatted{SubpopAuditorFitter$clone(deep = FALSE)}\if{html}{\out{
}} 122 | } 123 | 124 | \subsection{Arguments}{ 125 | \if{html}{\out{
}} 126 | \describe{ 127 | \item{\code{deep}}{Whether to make a deep clone.} 128 | } 129 | \if{html}{\out{
}} 130 | } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /man/figures/lifecycle-archived.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclearchivedarchived -------------------------------------------------------------------------------- /man/figures/lifecycle-defunct.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledefunctdefunct -------------------------------------------------------------------------------- /man/figures/lifecycle-deprecated.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycledeprecateddeprecated -------------------------------------------------------------------------------- /man/figures/lifecycle-experimental.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecycleexperimentalexperimental -------------------------------------------------------------------------------- /man/figures/lifecycle-maturing.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclematuringmaturing -------------------------------------------------------------------------------- /man/figures/lifecycle-questioning.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclequestioningquestioning -------------------------------------------------------------------------------- /man/figures/lifecycle-stable.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclestablestable -------------------------------------------------------------------------------- /man/figures/lifecycle-superseded.svg: -------------------------------------------------------------------------------- 1 | lifecyclelifecyclesupersededsuperseded -------------------------------------------------------------------------------- /man/mcboost-package.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/zzz.R 3 | \docType{package} 4 | \name{mcboost-package} 5 | \alias{mcboost} 6 | \alias{mcboost-package} 7 | \title{mcboost: Multi-Calibration Boosting} 8 | \description{ 9 | Implements 'Multi-Calibration Boosting' (2018) \url{https://proceedings.mlr.press/v80/hebert-johnson18a.html} and 'Multi-Accuracy Boosting' (2019) \doi{10.48550/arXiv.1805.12317} for the multi-calibration of a machine learning model's prediction. 'MCBoost' updates predictions for sub-groups in an iterative fashion in order to mitigate biases like poor calibration or large accuracy differences across subgroups. Multi-Calibration works best in scenarios where the underlying data & labels are unbiased, but resulting models are. This is often the case, e.g. when an algorithm fits a majority population while ignoring or under-fitting minority populations. 10 | } 11 | \references{ 12 | Kim et al., 2019: Multiaccuracy: Black-Box Post-Processing for Fairness in Classification. 13 | Hebert-Johnson et al., 2018: Multicalibration: Calibration for the ({C}omputationally-Identifiable) Masses. 14 | Pfisterer F, Kern C, Dandl S, Sun M, Kim M, Bischl B (2021). 15 | \dQuote{mcboost: Multi-Calibration Boosting for R.} 16 | \emph{Journal of Open Source Software}, \bold{6}(64), 3453. 17 | \doi{10.21105/joss.03453}, \url{https://joss.theoj.org/papers/10.21105/joss.03453}. 18 | } 19 | \seealso{ 20 | Useful links: 21 | \itemize{ 22 | \item \url{https://github.com/mlr-org/mcboost} 23 | \item Report bugs at \url{https://github.com/mlr-org/mcboost/issues} 24 | } 25 | 26 | } 27 | \author{ 28 | \strong{Maintainer}: Sebastian Fischer \email{sebf.fischer@gmail.com} [contributor] 29 | 30 | Authors: 31 | \itemize{ 32 | \item Florian Pfisterer \email{pfistererf@googlemail.com} (\href{https://orcid.org/0000-0001-8867-762X}{ORCID}) 33 | } 34 | 35 | Other contributors: 36 | \itemize{ 37 | \item Susanne Dandl \email{susanne.dandl@stat.uni-muenchen.de} (\href{https://orcid.org/0000-0003-4324-4163}{ORCID}) [contributor] 38 | \item Christoph Kern \email{c.kern@uni-mannheim.de} (\href{https://orcid.org/0000-0001-7363-4299}{ORCID}) [contributor] 39 | \item Carolin Becker [contributor] 40 | \item Bernd Bischl \email{bernd_bischl@gmx.net} (\href{https://orcid.org/0000-0001-6002-6980}{ORCID}) [contributor] 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /man/mlr3_init_predictor.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/helpers.R 3 | \name{mlr3_init_predictor} 4 | \alias{mlr3_init_predictor} 5 | \title{Create an initial predictor function from a trained mlr3 learner} 6 | \usage{ 7 | mlr3_init_predictor(learner) 8 | } 9 | \arguments{ 10 | \item{learner}{\code{\link[mlr3:Learner]{mlr3::Learner}} 11 | A trained learner used for initialization.} 12 | } 13 | \value{ 14 | \code{\link{function}} 15 | } 16 | \description{ 17 | Create an initial predictor function from a trained mlr3 learner 18 | } 19 | \examples{ 20 | \dontrun{ 21 | library("mlr3") 22 | l = lrn("classif.featureless")$train(tsk("sonar")) 23 | mlr3_init_predictor(l) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /man/mlr_pipeops_mcboost.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipeOpLearnerPred.R, R/PipeOpMCBoost.R 3 | \name{mlr_pipeops_mcboost} 4 | \alias{mlr_pipeops_mcboost} 5 | \alias{PipeOpLearnerPred} 6 | \alias{PipeOpMCBoost} 7 | \title{Multi-Calibrate a Learner's Prediction} 8 | \format{ 9 | \code{\link{R6Class}} inheriting from \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}}. 10 | 11 | \code{\link{R6Class}} inheriting from \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}}. 12 | } 13 | \description{ 14 | \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}} that trains a \code{\link[mlr3:Learner]{Learner}} and passes its predictions forward during training and prediction. 15 | 16 | Post-process a learner prediction using multi-calibration. 17 | For more details, please refer to \url{https://arxiv.org/pdf/1805.12317.pdf} (Kim et al. 2018) 18 | or the help for \code{\link{MCBoost}}. 19 | If no \code{init_predictor} is provided, the preceding learner's predictions 20 | corresponding to the \code{prediction} slot are used as an initial predictor for \code{MCBoost}. 21 | } 22 | \section{Construction}{ 23 | 24 | 25 | \if{html}{\out{
}}\preformatted{PipeOpLearnerPred$new(learner, id = NULL, param_vals = list()) 26 | 27 | * `learner` :: [`Learner`][mlr3::Learner] \\cr 28 | [`Learner`][mlr3::Learner] to prediction, or a string identifying a 29 | [`Learner`][mlr3::Learner] in the [`mlr3::mlr_learners`] [`Dictionary`][mlr3misc::Dictionary]. 30 | * `id` :: `character(1)` 31 | Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped. 32 | * `param_vals` :: named `list`\\cr 33 | List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`. 34 | 35 | 36 | [mlr3::Learner]: R:mlr3::Learner 37 | [mlr3::Learner]: R:mlr3::Learner 38 | [mlr3::Learner]: R:mlr3::Learner 39 | [`mlr3::mlr_learners`]: R:\%60mlr3::mlr_learners\%60 40 | [mlr3misc::Dictionary]: R:mlr3misc::Dictionary 41 | [mlr3::Learner]: R:mlr3::Learner 42 | }\if{html}{\out{
}} 43 | 44 | 45 | 46 | \if{html}{\out{
}}\preformatted{PipeOpMCBoost$new(id = "mcboost", param_vals = list()) 47 | }\if{html}{\out{
}} 48 | \itemize{ 49 | \item \code{id} :: \code{character(1)} 50 | Identifier of the resulting object, default \code{"threshold"}. 51 | \item \code{param_vals} :: named \code{list}\cr 52 | List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. 53 | See \code{MCBoost} for a comprehensive description of all hyperparameters. 54 | } 55 | } 56 | 57 | \section{Input and Output Channels}{ 58 | 59 | \code{\link{PipeOpLearnerPred}} has one input channel named \code{"input"}, taking a \code{\link[mlr3:Task]{Task}} specific to the \code{\link[mlr3:Learner]{Learner}} 60 | type given to \code{learner} during construction; both during training and prediction. 61 | 62 | \code{\link{PipeOpLearnerPred}} has one output channel named \code{"output"}, producing a \code{\link[mlr3:Task]{Task}} specific to the \code{\link[mlr3:Learner]{Learner}} 63 | type given to \code{learner} during construction; both during training and prediction. 64 | 65 | 66 | During training, the input and output are \code{"data"} and \code{"prediction"}, two \code{\link[mlr3:TaskClassif]{TaskClassif}}. 67 | A \code{\link[mlr3:PredictionClassif]{PredictionClassif}} is required as input and returned as output during prediction. 68 | } 69 | 70 | \section{State}{ 71 | 72 | 73 | 74 | The \verb{$state} is a \code{MCBoost} Object as obtained from \code{MCBoost$new()}. 75 | } 76 | 77 | \section{Parameters}{ 78 | 79 | The \verb{$state} is set to the \verb{$state} slot of the \code{\link[mlr3:Learner]{Learner}} object, together with the \verb{$state} elements inherited from 80 | \code{\link[mlr3pipelines:PipeOpTaskPreproc]{mlr3pipelines::PipeOpTaskPreproc}}. It is a named \code{list} with the inherited members, as well as: 81 | \itemize{ 82 | \item \code{model} :: \code{any}\cr 83 | Model created by the \code{\link[mlr3:Learner]{Learner}}'s \verb{$.train()} function. 84 | \item \code{train_log} :: \code{\link{data.table}} with columns \code{class} (\code{character}), \code{msg} (\code{character})\cr 85 | Errors logged during training. 86 | \item \code{train_time} :: \code{numeric(1)}\cr 87 | Training time, in seconds. 88 | \item \code{predict_log} :: \code{NULL} | \code{\link{data.table}} with columns \code{class} (\code{character}), \code{msg} (\code{character})\cr 89 | Errors logged during prediction. 90 | \item \code{predict_time} :: \code{NULL} | \code{numeric(1)} 91 | Prediction time, in seconds. 92 | } 93 | 94 | 95 | \itemize{ 96 | \item \code{max_iter} :: \code{integer}\cr 97 | A integer specifying the number of multi-calibration rounds. Defaults to 5. 98 | } 99 | } 100 | 101 | \section{Fields}{ 102 | 103 | Fields inherited from \code{\link{PipeOp}}, as well as: 104 | \itemize{ 105 | \item \code{learner} :: \code{\link[mlr3:Learner]{Learner}}\cr 106 | \code{\link[mlr3:Learner]{Learner}} that is being wrapped. Read-only. 107 | \item \code{learner_model} :: \code{\link[mlr3:Learner]{Learner}}\cr 108 | \code{\link[mlr3:Learner]{Learner}} that is being wrapped. This learner contains the model if the \code{PipeOp} is trained. Read-only. 109 | } 110 | 111 | 112 | Only fields inherited from \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}}. 113 | } 114 | 115 | \section{Methods}{ 116 | 117 | Methods inherited from \code{\link[mlr3pipelines:PipeOpTaskPreproc]{mlr3pipelines::PipeOpTaskPreproc}}/\code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}}. 118 | 119 | 120 | Only methods inherited from \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}}. 121 | } 122 | 123 | \examples{ 124 | \dontrun{ 125 | gr = gunion(list( 126 | "data" = po("nop"), 127 | "prediction" = po("learner_cv", lrn("classif.rpart")) 128 | )) \%>>\% 129 | PipeOpMCBoost$new() 130 | tsk = tsk("sonar") 131 | tid = sample(1:208, 108) 132 | gr$train(tsk$clone()$filter(tid)) 133 | gr$predict(tsk$clone()$filter(setdiff(1:208, tid))) 134 | } 135 | } 136 | \seealso{ 137 | https://mlr3book.mlr-org.com/list-pipeops.html 138 | 139 | https://mlr3book.mlr-org.com/list-pipeops.html 140 | } 141 | \concept{PipeOps} 142 | \section{Super classes}{ 143 | \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}} -> \code{\link[mlr3pipelines:PipeOpTaskPreproc]{mlr3pipelines::PipeOpTaskPreproc}} -> \code{PipeOpLearnerPred} 144 | } 145 | \section{Active bindings}{ 146 | \if{html}{\out{
}} 147 | \describe{ 148 | \item{\code{learner}}{The wrapped learner.} 149 | 150 | \item{\code{learner_model}}{The wrapped learner's model(s).} 151 | } 152 | \if{html}{\out{
}} 153 | } 154 | \section{Methods}{ 155 | \subsection{Public methods}{ 156 | \itemize{ 157 | \item \href{#method-PipeOpLearnerPred-new}{\code{PipeOpLearnerPred$new()}} 158 | \item \href{#method-PipeOpLearnerPred-clone}{\code{PipeOpLearnerPred$clone()}} 159 | } 160 | } 161 | \if{html}{\out{ 162 |
Inherited methods 163 | 169 |
170 | }} 171 | \if{html}{\out{
}} 172 | \if{html}{\out{}} 173 | \if{latex}{\out{\hypertarget{method-PipeOpLearnerPred-new}{}}} 174 | \subsection{Method \code{new()}}{ 175 | Initialize a Learner Predictor PipeOp. Can be used to wrap trained or untrainted 176 | mlr3 learners. 177 | \subsection{Usage}{ 178 | \if{html}{\out{
}}\preformatted{PipeOpLearnerPred$new(learner, id = NULL, param_vals = list())}\if{html}{\out{
}} 179 | } 180 | 181 | \subsection{Arguments}{ 182 | \if{html}{\out{
}} 183 | \describe{ 184 | \item{\code{learner}}{\code{\link{Learner}}\cr 185 | The learner that should be wrapped.} 186 | 187 | \item{\code{id}}{\code{\link{character}} \cr 188 | The \code{PipeOp}'s id. Defaults to "mcboost".} 189 | 190 | \item{\code{param_vals}}{\code{\link{list}} \cr 191 | List of hyperparameters for the \code{PipeOp}.} 192 | } 193 | \if{html}{\out{
}} 194 | } 195 | } 196 | \if{html}{\out{
}} 197 | \if{html}{\out{}} 198 | \if{latex}{\out{\hypertarget{method-PipeOpLearnerPred-clone}{}}} 199 | \subsection{Method \code{clone()}}{ 200 | The objects of this class are cloneable with this method. 201 | \subsection{Usage}{ 202 | \if{html}{\out{
}}\preformatted{PipeOpLearnerPred$clone(deep = FALSE)}\if{html}{\out{
}} 203 | } 204 | 205 | \subsection{Arguments}{ 206 | \if{html}{\out{
}} 207 | \describe{ 208 | \item{\code{deep}}{Whether to make a deep clone.} 209 | } 210 | \if{html}{\out{
}} 211 | } 212 | } 213 | } 214 | \section{Super class}{ 215 | \code{\link[mlr3pipelines:PipeOp]{mlr3pipelines::PipeOp}} -> \code{PipeOpMCBoost} 216 | } 217 | \section{Active bindings}{ 218 | \if{html}{\out{
}} 219 | \describe{ 220 | \item{\code{predict_type}}{Predict type of the PipeOp.} 221 | } 222 | \if{html}{\out{
}} 223 | } 224 | \section{Methods}{ 225 | \subsection{Public methods}{ 226 | \itemize{ 227 | \item \href{#method-PipeOpMCBoost-new}{\code{PipeOpMCBoost$new()}} 228 | \item \href{#method-PipeOpMCBoost-clone}{\code{PipeOpMCBoost$clone()}} 229 | } 230 | } 231 | \if{html}{\out{ 232 |
Inherited methods 233 | 239 |
240 | }} 241 | \if{html}{\out{
}} 242 | \if{html}{\out{}} 243 | \if{latex}{\out{\hypertarget{method-PipeOpMCBoost-new}{}}} 244 | \subsection{Method \code{new()}}{ 245 | Initialize a Multi-Calibration PipeOp. 246 | \subsection{Usage}{ 247 | \if{html}{\out{
}}\preformatted{PipeOpMCBoost$new(id = "mcboost", param_vals = list())}\if{html}{\out{
}} 248 | } 249 | 250 | \subsection{Arguments}{ 251 | \if{html}{\out{
}} 252 | \describe{ 253 | \item{\code{id}}{\code{\link{character}} \cr 254 | The \code{PipeOp}'s id. Defaults to "mcboost".} 255 | 256 | \item{\code{param_vals}}{\code{\link{list}} \cr 257 | List of hyperparameters for the \code{PipeOp}.} 258 | } 259 | \if{html}{\out{
}} 260 | } 261 | } 262 | \if{html}{\out{
}} 263 | \if{html}{\out{}} 264 | \if{latex}{\out{\hypertarget{method-PipeOpMCBoost-clone}{}}} 265 | \subsection{Method \code{clone()}}{ 266 | The objects of this class are cloneable with this method. 267 | \subsection{Usage}{ 268 | \if{html}{\out{
}}\preformatted{PipeOpMCBoost$clone(deep = FALSE)}\if{html}{\out{
}} 269 | } 270 | 271 | \subsection{Arguments}{ 272 | \if{html}{\out{
}} 273 | \describe{ 274 | \item{\code{deep}}{Whether to make a deep clone.} 275 | } 276 | \if{html}{\out{
}} 277 | } 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /man/one_hot.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/helpers.R 3 | \name{one_hot} 4 | \alias{one_hot} 5 | \title{One-hot encode a factor variable} 6 | \usage{ 7 | one_hot(labels) 8 | } 9 | \arguments{ 10 | \item{labels}{\code{\link{factor}}\cr 11 | Factor to encode.} 12 | } 13 | \value{ 14 | \code{\link{integer}}\cr 15 | Integer vector of encoded labels. 16 | } 17 | \description{ 18 | One-hot encode a factor variable 19 | } 20 | \examples{ 21 | \dontrun{ 22 | one_hot(factor(c("a", "b", "a"))) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /man/ppl_mcboost.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/PipelineMCBoost.R 3 | \name{ppl_mcboost} 4 | \alias{ppl_mcboost} 5 | \title{Multi-calibration pipeline} 6 | \usage{ 7 | ppl_mcboost(learner = lrn("classif.featureless"), param_vals = list()) 8 | } 9 | \arguments{ 10 | \item{learner}{(mlr3)\code{\link[mlr3:Learner]{mlr3::Learner}}\cr 11 | Initial learner. Internally wrapped into a \code{PipeOpLearnerCV} 12 | with \code{resampling.method = "insample"} as a default. 13 | All parameters can be adjusted through the resulting Graph's \code{param_set}. 14 | Defaults to \code{lrn("classif.featureless")}. 15 | Note: An initial predictor can also be supplied via the \code{init_predictor} parameter.} 16 | 17 | \item{param_vals}{\code{list} \cr 18 | List of parameter values passed on to \code{MCBoost$new}.} 19 | } 20 | \value{ 21 | (mlr3pipelines) \code{\link{Graph}} 22 | } 23 | \description{ 24 | Wraps MCBoost in a Pipeline to be used with \code{mlr3pipelines}. 25 | For now this assumes training on the same dataset that is later used 26 | for multi-calibration. 27 | } 28 | \examples{ 29 | \dontrun{ 30 | library("mlr3pipelines") 31 | gr = ppl_mcboost() 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /mcboost.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: XeLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | 18 | BuildType: Package 19 | PackageUseDevtools: Yes 20 | PackageInstallArgs: --no-multiarch --with-keep.source 21 | -------------------------------------------------------------------------------- /paper/MCBoost.drawio: -------------------------------------------------------------------------------- 1 | 3Vjbbts4EP0aA+1DCt1tP8ZO0izQpAWMot2nBS2OZaKU6FL0bb9+hxJliZJiO429XWwMxOLwqnNmzgw98Kfp7qMkq+WToMAHnkN3A/9u4HljP8D/2rAvDaHvloZEMlqaGoYZ+xuM0THWNaOQWwOVEFyxlW2MRZZBrCwbkVJs7WELwe1dVySBjmEWE961fmNULUvrKHRq+yOwZFnt7DqmJyXVYGPIl4SKbcPk3w/8qRRClU/pbgpcY1fhUs57eKH3cDAJmTpnwvqTN/v61/f97PkzfHn8tgmfR4sbLyqX2RC+Nm9sTqv2FQS4DKKNjcl2yRTMViTWPVvkG21LlXJsufhI8lVJwYLtAHedLESmHkjKuCZ/KtaSgcS1nwFhmORKih8wFVzIYh/f9fQHe7pvZl52A1LBrmEyb/oRRApK7nGI6fX9sJxi3C4Ymva2JtH1DTPLJoGVkRjHSQ5r19jig4H3NVCHp6GWYp1RDd2dcxru16H7UPxdBt0giCx0wwq0bTNEuuB6V8P2DC/O6K3WA2zFnOQ5i200YcfU98bzn5qED6Fp3e0MJ0VjbxovQpkjGzEcOfCwHAfUUp8u4BI4UWxji1EffGbqF8HwKAei3JFjExW0GCgPamY1xaO1kDc+sZAiMgHVWahg8/A+v06w30NwRFIdEdk8XxVcRFyZqLCYj36uRdVxkxcJ5hYHjFe7uk9nBhLb4+2IagyNEv2dxthTTHCephMhctw1yGD77n11EnzR8jDlhMEhETScEvn/ROaYMv2JBDwcmRddTktYOUsy7bjoZ3gif6JDlWGSujUdKaNUTzyuCbrTZNgxNrneeCIkBdnQiftIf3qd+1jomfxrzl9nvTNE5Y0+fhPYM8RikcNVnLCbK2/XlCmWJWi9I4q8Uc8Z5w0iaAgjGvRJ+cib+1HUZvsR+Aa0V1wogwZ2Bg3HXY13vZ4M2laGi4n8sAP/HxmiT/REU3f+f+APWnrrDj/0VDC9+IdXwn/Up8GF0lG2qVTuaY21+c0UFWsuMapFhjMKdSyipBw+l7UmHqSyscYFy6ILs4L3mRYtZwZFdK2gcPtKnw4rDZ06n4TuKjwRkqll+i8R+Z8KyLYeBuHv1kMvOFkU6SIlSHVIxiYiQdcnLVYQAWVDb0OciQxafBjT+YVJH9e2NxwtXaRQpZrUY99Sybxejt1WNuy58bjuqMv+9a48fXL8yjvPi8Bc+15ygO+t95IwDI8vdO17yVnyaydF9PhDwXJ11VwsIIrjPtWkw/Hc6QTehVUzDEenE6YzvEjCxGb9a1pJcP2TpH//Dw== -------------------------------------------------------------------------------- /paper/MCBoost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlr-org/mcboost/5c80fc18d839791a4ffe183b515d5f41c3de25cd/paper/MCBoost.png -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | % This file was created with JabRef 2.10. 2 | % Encoding: ISO8859_1 3 | 4 | 5 | @Article{Barda2020covid, 6 | Title = {Developing a COVID-19 mortality risk prediction model when individual-level data are not available}, 7 | Author = {Barda, Noam and Riesel, Dan and Akriv, Amichay and Levy, Joseph and Finkel, Uriah and Yona, Gal and Greenfeld, Daniel and Sheiba, Shimon and Somer, Jonathan and Bachmat, Eitan and Rothblum, Guy and Shalit, Uri and Netzer, Doron and Balicer, Ran and Dagan, Noa}, 8 | Journal = {Nature communications}, 9 | Year = {2020}, 10 | 11 | Month = {09}, 12 | Pages = {4439}, 13 | Volume = {11}, 14 | 15 | Doi = {10.1038/s41467-020-18297-9} 16 | } 17 | 18 | @Article{Barda2020bias, 19 | Title = {Addressing bias in prediction models by improving subpopulation calibration}, 20 | Author = {Barda, Noam and Yona, Gal and Rothblum, Guy N and Greenland, Philip and Leibowitz, Morton and Balicer, Ran and Bachmat, Eitan and Dagan, Noa}, 21 | Journal = {Journal of the American Medical Informatics Association}, 22 | Year = {2020}, 23 | 24 | Month = {11}, 25 | Number = {3}, 26 | Pages = {549-558}, 27 | Volume = {28}, 28 | 29 | Doi = {10.1093/jamia/ocaa283}, 30 | Eprint = {https://academic.oup.com/jamia/article-pdf/28/3/549/36428833/ocaa283.pdf}, 31 | ISSN = {1527-974X}, 32 | Url = {https://doi.org/10.1093/jamia/ocaa283} 33 | } 34 | 35 | @Misc{dwork-oi, 36 | Title = {Outcome Indistinguishability}, 37 | 38 | Author = {Cynthia Dwork and Michael P. Kim and Omer Reingold and Guy N. Rothblum and Gal Yona}, 39 | Year = {2020}, 40 | 41 | Archiveprefix = {arXiv}, 42 | Eprint = {2011.13426}, 43 | Primaryclass = {cs.LG}, 44 | Url = {https://arxiv.org/abs/2011.13426} 45 | } 46 | 47 | @InProceedings{dwork-rankings, 48 | Title = {Learning from Outcomes: Evidence-Based Rankings}, 49 | Author = {Dwork, Cynthia and Kim, Michael P. and Reingold, Omer and Rothblum, Guy N. and Yona, Gal}, 50 | Booktitle = {2019 IEEE 60th Annual Symposium on Foundations of Computer Science (FOCS)}, 51 | Year = {2019}, 52 | Pages = {106--125}, 53 | 54 | Doi = {10.1109/FOCS.2019.00016} 55 | } 56 | 57 | @InProceedings{hebert-johnson2018, 58 | Title = {Multicalibration: Calibration for the ({C}omputationally-Identifiable) Masses}, 59 | Author = {Hebert-Johnson, Ursula and Kim, Michael and Reingold, Omer and Rothblum, Guy}, 60 | Booktitle = {Proceedings of the 35th International Conference on Machine Learning}, 61 | Year = {2018}, 62 | 63 | Address = {Stockholmsmässan, Stockholm Sweden}, 64 | Editor = {Jennifer Dy and Andreas Krause}, 65 | Month = {10--15 Jul}, 66 | Pages = {1939--1948}, 67 | Publisher = {PMLR}, 68 | Series = {Proceedings of Machine Learning Research}, 69 | Volume = {80} 70 | } 71 | 72 | @InProceedings{kim2019, 73 | Title = {Multiaccuracy: Black-Box Post-Processing for Fairness in Classification}, 74 | Author = {Kim, Michael P. and Ghorbani, Amirata and Zou, James}, 75 | Booktitle = {Proceedings of the 2019 AAAI/ACM Conference on AI, Ethics, and Society}, 76 | Year = {2019}, 77 | 78 | Address = {New York, NY, USA}, 79 | Pages = {247–254}, 80 | Publisher = {Association for Computing Machinery}, 81 | Series = {AIES '19}, 82 | 83 | Doi = {10.1145/3306618.3314287}, 84 | ISBN = {9781450363242}, 85 | Keywords = {fairness, discrimination, post-processing, machine learning}, 86 | Location = {Honolulu, HI, USA}, 87 | Numpages = {8}, 88 | Url = {https://doi.org/10.1145/3306618.3314287} 89 | } 90 | 91 | @Misc{kimkern2021, 92 | Title = {Universal Generalization versus Propensity Scoring}, 93 | 94 | Author = {Michael P. Kim and Christoph Kern and Shafi Goldwasser and Frauke Kreuter and Omer Reingold}, 95 | HowPublished = {Manuscript submitted for publication}, 96 | Year = {2021} 97 | } 98 | 99 | @Article{mlr3, 100 | Title = {{mlr3}: A modern object-oriented machine learning framework in {R}}, 101 | Author = {Michel Lang and Martin Binder and Jakob Richter and Patrick Schratz and Florian Pfisterer and Stefan Coors and Quay Au and Giuseppe Casalicchio and Lars Kotthoff and Bernd Bischl}, 102 | Journal = {Journal of Open Source Software}, 103 | Year = {2019}, 104 | 105 | Month = {dec}, 106 | 107 | Doi = {10.21105/joss.01903}, 108 | Url = {https://joss.theoj.org/papers/10.21105/joss.01903} 109 | } 110 | 111 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'mcboost: Multi-Calibration Boosting for R' 3 | tags: 4 | - R 5 | - Multi-Calibration 6 | - Multi-Accuracy 7 | - Boosting 8 | - Post-Processing 9 | - Fair ML 10 | authors: 11 | - name: Florian Pfisterer^[Corresponding author] 12 | orcid: 0000-0001-8867-762X 13 | affiliation: 1 14 | - name: Christoph Kern 15 | orcid: 0000-0001-7363-4299 16 | affiliation: 2 17 | - name: Susanne Dandl 18 | orcid: 0000-0003-4324-4163 19 | affiliation: 1 20 | - name: Matthew Sun 21 | affiliation: 3 22 | - name: Michael P. Kim 23 | affiliation: 4 24 | - name: Bernd Bischl 25 | orcid: 0000-0001-6002-6980 26 | affiliation: 1 27 | affiliations: 28 | - name: Ludwig Maximilian University of Munich 29 | index: 1 30 | - name: University of Mannheim 31 | index: 2 32 | - name: Princeton University 33 | index: 3 34 | - name: UC Berkeley 35 | index: 4 36 | date: 01 June 2021 37 | bibliography: paper.bib 38 | --- 39 | 40 | # Summary 41 | 42 | Given the increasing usage of automated prediction systems in the context of high-stakes decisions, a growing body of research focuses on methods for detecting and mitigating biases in algorithmic decision-making. 43 | One important framework to audit for and mitigate biases in predictions is that of Multi-Calibration, introduced by @hebert-johnson2018. 44 | The underlying fairness notion, Multi-Calibration, promotes the idea of multi-group fairness and requires calibrated predictions not only for marginal populations, but also for subpopulations that may be defined by complex intersections of many attributes. 45 | A simpler variant of Multi-Calibration, referred to as Multi-Accuracy, requires unbiased predictions for large collections of subpopulations. 46 | @hebert-johnson2018 proposed a boosting-style algorithm for learning multi-calibrated predictors. 47 | @kim2019 demonstrated how to turn this algorithm into a post-processing strategy to achieve multi-accuracy, demonstrating empirical effectiveness across various domains. 48 | This package provides a stable implementation of the multi-calibration algorithm, called MCBoost. 49 | In contrast to other Fair ML approaches, MCBoost does not harm the overall utility of a prediction model, but rather aims at improving calibration and accuracy for large sets of subpopulations post-training. 50 | MCBoost comes with strong theoretical guarantees, which have been explored formally in @hebert-johnson2018, @kim2019, @dwork-rankings, @dwork-oi and @kimkern2021. 51 | 52 | `mcboost` implements Multi-Calibration Boosting for R. 53 | `mcboost` is model agnostic and allows the user to post-process any supervised machine learning model. 54 | It accepts initial models that fit binary outcomes or continuous outcomes with predictions that are in (or scaled to) the range [0, 1]. 55 | For convenience and ease of use, `mcboost` tightly integrates with the **mlr3** [@mlr3] machine learning eco-system in R by allowing to calibrate regression or classification models fitted either within or outside of mlr3. 56 | Post-processing with `mcboost` starts with an initial prediction model that is passed on to an auditing algorithm that runs Multi-Calibration-Boosting on a labeled auditing dataset (Fig. 1). The resulting model can be used for obtaining multi-calibrated predictions. 57 | `mcboost` includes two pre-defined learners for auditing (ridge regression and decision trees), and allows to easily adjust the learner and its parameters for Multi-Calibration Boosting. 58 | Users may also specify a fixed set of subgroups, instead of a learner, on which predictions should be audited. 59 | Furthermore, `mcboost` includes utilities to guard against overfitting to the auditing dataset during post-processing. 60 | 61 | ![Fig 1. Conceptual illustration of Multi-Calibration Boosting with `mcboost`.\label{fig:overview}](MCBoost.png) 62 | 63 | # Statement of need 64 | 65 | Given the ubiquitous use of machine learning models in crucial areas and growing concerns of biased predictions for minority subpopulations, Multi-Calibration Boosting should be widely accessible in the form of a free and open-source software package. 66 | Prior to the development of `mcboost`, Multi-Calibration Boosting has not been released as a software package for R. 67 | 68 | The results in @kim2019 highlight that MCBoost can improve classification accuracy for subpopulations in various settings, including gender detection with image data, income classification with survey data and disease prediction using biomedical data. 69 | @Barda2020bias show that post-processing for Multi-Calibration can greatly improve calibration metrics of two medical risk assessment models when evaluated in subpopulations defined by intersections of age, sex, ethnicity, socioeconomic status and immigration history. 70 | @Barda2020covid demonstrate that Multi-Calibration can also be used to adjust an initial classifier for a new task. They re-calibrate a baseline model for predicting the risk of severe respiratory infection with data on COVID-19 fatality rates in subpopulations, resulting in an accurate and calibrated COVID-19 mortality prediction model. 71 | 72 | 73 | We hope that `mcboost` lets Multi-Calibration Boosting be utilized by a wide community of developers and data scientists to audit and post-process prediction models, and helps to promote fairness in machine learning and statistical estimation applications. 74 | 75 | # Acknowledgements 76 | 77 | We thank Matthew Sun for developing an initial Python implementation of MCBoost. 78 | This work has been partially supported by the German Federal Ministry of Education and Research (BMBF) under Grant No. 01IS18036A. The authors of this work take full responsibilities for its content. 79 | 80 | # References 81 | -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | if (requireNamespace("testthat", quietly = TRUE)) { 2 | library(checkmate) 3 | library(testthat) 4 | library(mlr3) 5 | library(mcboost) 6 | library(mlr3pipelines) 7 | 8 | test_check("mcboost") 9 | } 10 | -------------------------------------------------------------------------------- /tests/testthat/setup.R: -------------------------------------------------------------------------------- 1 | if (requireNamespace("lgr")) { 2 | lg = lgr::get_logger("mlr3") 3 | old_threshold = lg$threshold 4 | lg$set_threshold("warn") 5 | } 6 | -------------------------------------------------------------------------------- /tests/testthat/teardown.R: -------------------------------------------------------------------------------- 1 | lg$set_threshold(old_threshold) 2 | -------------------------------------------------------------------------------- /tests/testthat/test_auditor_fitters.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("AuditorFitters work", { 4 | rf = AuditorFitter$new() 5 | expect_class(rf, "AuditorFitter") 6 | expect_error(rf$fit(1, 1), "Not implemented") 7 | }) 8 | 9 | test_that("LearnerAuditorFitters work", { 10 | skip_on_cran() 11 | rf = LearnerAuditorFitter$new(lrn("regr.featureless")) 12 | out = rf$fit(iris[, 1:4], runif(150)) 13 | expect_number(out[[1]]) 14 | expect_class(out[[2]], "LearnerPredictor") 15 | expect_true(out[[2]]$is_fitted) 16 | }) 17 | 18 | test_that("TreeAuditorFitters work", { 19 | skip_if_not_installed("mlr3learners") 20 | skip_if_not_installed("rpart") 21 | rf = TreeAuditorFitter$new() 22 | out = rf$fit(iris[, 1:4], runif(150)) 23 | expect_number(out[[1]]) 24 | expect_class(out[[2]], "LearnerPredictor") 25 | expect_true(out[[2]]$is_fitted) 26 | }) 27 | 28 | test_that("RidgeAuditorFitters work", { 29 | skip_on_cran() 30 | skip_on_os("mac") 31 | skip_on_os("solaris") 32 | skip_if_not_installed("mlr3learners") 33 | skip_if_not_installed("glmnet") 34 | 35 | rf = RidgeAuditorFitter$new() 36 | out = rf$fit(iris[, 1:4], runif(150)) 37 | expect_number(out[[1]]) 38 | expect_class(out[[2]], "LearnerPredictor") 39 | expect_true(out[[2]]$is_fitted) 40 | }) 41 | 42 | test_that("SubPopFitter work", { 43 | skip_on_cran() 44 | data = data.table( 45 | "AGE_NA" = c(0, 0, 0, 0, 0), 46 | "AGE_0_10" = c(1, 1, 0, 0, 0), 47 | "AGE_11_20" = c(0, 0, 1, 0, 0), 48 | "AGE_21_31" = c(0, 0, 0, 1, 1), 49 | "X1" = runif(5), 50 | "X2" = runif(5) 51 | ) 52 | label = c(1,0,0,1,1) 53 | 54 | pops = list("AGE_NA", "AGE_0_10", "AGE_11_20", "AGE_21_31", function(x) {x[["X1" > 0.5]]}) 55 | rf = SubpopAuditorFitter$new(subpops = pops) 56 | out = rf$fit(data, label - 0.5) 57 | expect_list(out) 58 | expect_number(out[[1]], lower = 0.2, upper = 0.2) 59 | expect_class(out[[2]], "SubpopPredictor") 60 | 61 | pops = list("AGE_NA") 62 | rf = SubpopAuditorFitter$new(subpops = pops) 63 | out = rf$fit(data, label - 0.5) 64 | expect_list(out) 65 | expect_number(out[[1]], lower = 0, upper = 0) 66 | expect_class(out[[2]], "SubpopPredictor") 67 | }) 68 | 69 | test_that("SubGroupFitter work", { 70 | skip_on_cran() 71 | data = data.table( 72 | "AGE_0_10" = c(1, 1, 0, 0, 0), 73 | "AGE_11_20" = c(0, 0, 1, 0, 0), 74 | "AGE_21_31" = c(0, 0, 0, 1, 1), 75 | "X1" = runif(5), 76 | "X2" = runif(5) 77 | ) 78 | label = c(1,0,0,1,1) 79 | 80 | masks = list( 81 | "M1" = c(1L, 0L, 1L, 1L, 0L), 82 | "M2" = c(1L, 0L, 0L, 0L, 1L) 83 | ) 84 | rf = SubgroupAuditorFitter$new(masks) 85 | out = rf$fit(data, label - 0.5, rep(1, length(label))) 86 | expect_list(out) 87 | expect_number(out[[1]], lower = 0, upper = 1) 88 | expect_class(out[[2]], "SubgroupModel") 89 | }) 90 | 91 | test_that("SubPopFitter iterates through all columns", { 92 | skip_on_cran() 93 | data = data.table( 94 | "AGE_NA" = c(0, 0, 0, 0, 0), 95 | "AGE_0_10" = c(1, 1, 0, 0, 0), 96 | "AGE_11_20" = c(0, 0, 1, 0, 0), 97 | "AGE_21_31" = c(0, 0, 0, 1, 1), 98 | "X1" = runif(5), 99 | "X2" = runif(5) 100 | ) 101 | label = c(1,0,0,1,1) 102 | 103 | pops = list("AGE_21_31", "AGE_11_20") 104 | rf = SubpopAuditorFitter$new(subpops = pops) 105 | out = rf$fit(data, label - 0.5) 106 | expect_list(out) 107 | expect_number(out[[1]], lower = 0.2, upper = 0.2) 108 | expect_class(out[[2]], "SubpopPredictor") 109 | 110 | pops = rev(list("AGE_21_31", "AGE_11_20")) 111 | rf = SubpopAuditorFitter$new(subpops = pops) 112 | out = rf$fit(data, label - 0.5) 113 | expect_list(out) 114 | expect_number(out[[1]], lower = 0.2, upper = 0.2) 115 | expect_class(out[[2]], "SubpopPredictor") 116 | }) 117 | 118 | test_that("SubPopFitter throws proper error if not binary or wrong length", { 119 | skip_on_cran() 120 | data = data.frame(X1 = rnorm(n = 10L), X2 = rnorm(n = 10L)) 121 | rs = c(1, rep(0, 9)) 122 | 123 | # 0/1 chracters are fine 124 | masks = list( 125 | rep(c("1", "0"), 5) 126 | ) 127 | sf = SubgroupAuditorFitter$new(masks) 128 | 129 | mean1 = sf$fit(data = data, resid = rs,rep(1, length(rs))) 130 | sm = SubgroupModel$new(masks) 131 | mean2 = sm$fit(data = data, labels = rs) # should not be 1! 132 | 133 | # wrong type 134 | masks = list( 135 | c("ab", "cc") 136 | ) 137 | expect_error(SubgroupAuditorFitter$new(masks), "subgroup_masks must be a list of integers") 138 | 139 | # wrong length 140 | masks = list( 141 | rep(c(1, 0), 10) 142 | ) 143 | sf = SubgroupAuditorFitter$new(masks) 144 | expect_error(sf$fit(data = data, resid = rs, mask = rep(1, 20)), "Length of subgroup masks must match length of data") 145 | 146 | # not binary 147 | masks = list( 148 | rep(c(1, 3, 0, 4)) 149 | ) 150 | expect_error(SubgroupAuditorFitter$new(masks), "subgroup_masks must be binary vectors") 151 | }) 152 | 153 | test_that("Bug in SubgroupAuditorFitter #16", { 154 | skip_on_cran() 155 | data = data.frame(X1 = rnorm(n = 10L), X2 = rnorm(n = 10L)) 156 | masks = list( 157 | rep(c(1, 0), 5) 158 | ) 159 | sf = SubgroupAuditorFitter$new(masks) 160 | resid = c(1, rep(0, 9)) 161 | sm = SubgroupModel$new(masks) 162 | mn = sm$fit(data = data, labels = resid) 163 | expect_equal(mn[[1]], mean(masks[[1]] * resid) / mean(masks[[1]])) 164 | }) 165 | -------------------------------------------------------------------------------- /tests/testthat/test_cv_predictors.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("TreeAuditorFitters work", { 4 | skip_on_cran() 5 | skip_if_not_installed("mlr3learners") 6 | skip_if_not_installed("rpart") 7 | rf = CVTreeAuditorFitter$new() 8 | out = rf$fit(iris[, 1:4], runif(150)) 9 | expect_number(out[[1]]) 10 | expect_is(out[[2]], "CVLearnerPredictor") 11 | expect_true(out[[2]]$is_fitted) 12 | out = out[[2]]$predict(iris[,1:4]) 13 | expect_numeric(out) 14 | }) 15 | 16 | test_that("MCBoost multicalibrate and predict_probs - CV Predictor", { 17 | skip_on_cran() 18 | skip_if_not_installed("mlr3learners") 19 | skip_if_not_installed("rpart") 20 | # Sonar task 21 | tsk = tsk("sonar") 22 | data = tsk$data(cols = tsk$feature_names) 23 | labels = tsk$data(cols = tsk$target_names)[[1]] 24 | set.seed(123L) 25 | mc = MCBoost$new(auditor_fitter = "CVTreeAuditorFitter") 26 | mc$multicalibrate(data, labels) 27 | 28 | expect_list(mc$iter_models, types = "CVLearnerPredictor") 29 | expect_list(mc$iter_partitions, types = "ProbRange") 30 | 31 | prds = mc$predict_probs(data) 32 | expect_numeric(prds, lower = 0, upper = 1, len = nrow(data)) 33 | }) 34 | 35 | test_that("Creating own CV Predictor works with different folds", { 36 | skip_on_cran() 37 | skip_if_not_installed("mlr3learners") 38 | skip_if_not_installed("rpart") 39 | # Sonar task 40 | tsk = tsk("sonar") 41 | data = tsk$data(cols = tsk$feature_names) 42 | labels = tsk$data(cols = tsk$target_names)[[1]] 43 | ln = lrn("regr.rpart") 44 | cvfit = CVLearnerAuditorFitter$new(ln, folds = 2L) 45 | set.seed(123L) 46 | mc = MCBoost$new(auditor_fitter = cvfit) 47 | mc$multicalibrate(data, labels) 48 | 49 | expect_equal(cvfit$learner$pipeop$param_set$values$resampling.folds, 2L) 50 | expect_list(mc$iter_models, types = "CVLearnerPredictor") 51 | expect_list(mc$iter_partitions, types = "ProbRange") 52 | 53 | prds = mc$predict_probs(data) 54 | expect_numeric(prds, lower = 0, upper = 1, len = nrow(data)) 55 | }) 56 | -------------------------------------------------------------------------------- /tests/testthat/test_mcboost_low_degree.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("MCBoost multicalibrate with subpops d = 2", { 4 | skip_on_os("solaris") 5 | skip_on_cran() 6 | # Sonar task 7 | tsk = tsk("sonar") 8 | data = tsk$data(cols = tsk$feature_names) 9 | labels = tsk$data(cols = tsk$target_names)[[1]] 10 | 11 | # Add group indicators for subpops 12 | data[, AGE_LE := sample(c(0,1), nrow(data), TRUE)] 13 | data[, G_1 := c(rep(0, 100), rep(1, 108))] 14 | data[, G_2 := c(rep(1, 50), rep(0, 158))] 15 | subpops = list("AGE_LE", "G_1", "G_2", function(x) x[["V1"]] > quantile(data$V1,.9)) 16 | 17 | # Fit initial model 18 | lp = LearnerPredictor$new(lrn("classif.rpart", maxdepth = 1L, predict_type = "prob")) 19 | lp$fit(data, labels) 20 | 21 | mc = MCBoost$new(default_model_class = lp, subpops = subpops, alpha = 0, weight_degree = 2L) 22 | mc$multicalibrate(data, labels) 23 | expect_is(mc$auditor_fitter, "SubpopAuditorFitter") 24 | expect_list(mc$iter_models, types = "SubpopPredictor", len = mc$max_iter) 25 | expect_list(mc$iter_partitions, types = "ProbRange", len = mc$max_iter) 26 | expect_numeric(mc$predict_probs(data), lower = 0, upper = 1, len = nrow(data)) 27 | }) 28 | -------------------------------------------------------------------------------- /tests/testthat/test_pipeop_mcboost.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("MCBoost class instantiation", { 4 | skip_on_cran() 5 | skip_on_os("solaris") 6 | skip_if_not_installed("mlr3") 7 | skip_if_not_installed("mlr3learners") 8 | skip_if_not_installed("mlr3pipelines") 9 | gr = gunion(list( 10 | "data" = po("nop"), 11 | "prediction" = po("learner_cv", lrn("classif.rpart")) 12 | )) %>>% 13 | PipeOpMCBoost$new(param_vals = list(multiplicative = FALSE)) 14 | expect_is(gr, "Graph") 15 | tsk = tsk("sonar") 16 | tid = sample(1:208, 108) 17 | train_out = gr$train(tsk$clone()$filter(tid)) 18 | expect_is(gr$state$mcboost$mc, "MCBoost") 19 | expect_list(gr$state$mcboost$mc$iter_models, types = "LearnerPredictor") 20 | expect_true(!gr$state$mcboost$mc$multiplicative) 21 | pr = gr$predict(tsk$clone()$filter(setdiff(1:208, tid))) 22 | expect_is(pr[[1]], "Prediction") 23 | }) 24 | 25 | test_that("pipeop instantiation", { 26 | skip_on_cran() 27 | skip_on_os("solaris") 28 | skip_if_not_installed("mlr3") 29 | skip_if_not_installed("mlr3pipelines") 30 | pop = po("mcboost") 31 | expect_is(pop, "PipeOpMCBoost") 32 | expect_is(pop, "PipeOp") 33 | expect_list(pop$param_set$values, len = 0L) 34 | expect_true(pop$predict_type == "prob") 35 | }) 36 | 37 | test_that("MCBoost ppl", { 38 | skip_on_cran() 39 | skip_on_os("solaris") 40 | skip_if_not_installed("mlr3") 41 | skip_if_not_installed("mlr3learners") 42 | skip_if_not_installed("mlr3pipelines") 43 | 44 | l = lrn("classif.featureless")$train(tsk("sonar")) 45 | pp = ppl_mcboost() 46 | expect_is(pp, "Graph") 47 | pp$param_set$values$mcboost.init_predictor = l 48 | pp$train(tsk("sonar")) 49 | expect_true(!is.null(pp$state)) 50 | prd = pp$predict(tsk("sonar")) 51 | expect_is(prd[[1]], "PredictionClassif") 52 | }) 53 | 54 | test_that("MCBoostSurv ppl", { 55 | skip_on_cran() 56 | skip_on_os("solaris") 57 | skip_if_not_installed("mlr3") 58 | skip_if_not_installed("mlr3learners") 59 | skip_if_not_installed("mlr3pipelines") 60 | skip_on_os("solaris") 61 | gr = ppl_mcboost(lrn("classif.rpart")) 62 | expect_is(gr, "Graph") 63 | }) -------------------------------------------------------------------------------- /tests/testthat/test_predictor.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("Predictor class instantiation", { 4 | skip_on_cran() 5 | prd = Predictor$new() 6 | expect_class(prd, "Predictor") 7 | expect_error(prd$fit(), fixed = "Abstract base class") 8 | expect_error(prd$predict(), fixed = "Abstract base class") 9 | }) 10 | 11 | test_that("ConstantPredictor", { 12 | skip_on_cran() 13 | prd = ConstantPredictor$new(0.7) 14 | expect_class(prd, "ConstantPredictor") 15 | expect_true(prd$is_fitted) 16 | p = prd$predict(data.frame(x = 1:3)) 17 | expect_equal(p, rep(0.7, 3)) 18 | }) 19 | 20 | test_that("LearnerPredictor - response", { 21 | skip_on_cran() 22 | skip_if_not_installed("mlr3learners") 23 | prd = LearnerPredictor$new(lrn("classif.rpart")) 24 | expect_class(prd, "LearnerPredictor") 25 | prd$fit(iris[1:100,1:4], factor(iris$Species[1:100])) 26 | expect_true(prd$is_fitted) 27 | p = prd$predict(iris[,1:4]) 28 | expect_numeric(p, len = 150L, lower = 0, upper = 1) 29 | }) 30 | 31 | test_that("LearnerPredictor - probs", { 32 | skip_on_cran() 33 | skip_if_not_installed("mlr3learners") 34 | prd = LearnerPredictor$new(lrn("classif.rpart", predict_type = "prob")) 35 | expect_class(prd, "LearnerPredictor") 36 | prd$fit(iris[51:150,1:4], factor(iris$Species[51:150])) 37 | expect_true(prd$is_fitted) 38 | p = prd$predict(iris[,1:4]) 39 | expect_numeric(p, len = 150L, lower = 0, upper = 1) 40 | }) 41 | -------------------------------------------------------------------------------- /tests/testthat/test_probrange.R: -------------------------------------------------------------------------------- 1 | test_that("ProbRange works", { 2 | skip_on_cran() 3 | pr = ProbRange$new(0.1, 0.55) 4 | prs = list( 5 | pr2 = ProbRange$new(0.1, 0.55), 6 | pr3 = ProbRange$new(0, 1), 7 | pr4 = ProbRange$new(0,0.4) 8 | ) 9 | 10 | expect_class(pr, "ProbRange") 11 | expect_equal(pr$lower, 0.1) 12 | expect_equal(pr$upper, 0.55) 13 | 14 | values = c(TRUE, FALSE, FALSE) 15 | out = mlr3misc::map_lgl(prs, function(x) { 16 | pr$is_equal(x) 17 | }) 18 | expect_equal(out, values, check.attributes = FALSE) 19 | 20 | out = mlr3misc::map_lgl(prs, function(x) { 21 | pr$is_not_equal(x) 22 | }) 23 | expect_equal(!out, values, check.attributes = FALSE) 24 | 25 | prs = c(0.09, 0.1, 0.4, 0.55, 0.7) 26 | expect_equal(pr$in_range_mask(prs), c(FALSE, TRUE, TRUE, FALSE, FALSE)) 27 | 28 | expect_false(pr$is_equal(5)) 29 | expect_true(pr$is_not_equal(5)) 30 | 31 | expect_output(print(pr), "ProbRange") 32 | }) 33 | -------------------------------------------------------------------------------- /tests/testthat/test_sonar_usecase.R: -------------------------------------------------------------------------------- 1 | skip_if_not_installed("mlr3") 2 | 3 | test_that("TreeAuditorFitters work", { 4 | skip_on_cran() 5 | skip_if_not_installed("mlr3learners") 6 | skip_if_not_installed("rpart") 7 | tsk = tsk("sonar") 8 | data = tsk$data()[, Class := as.integer(Class) - 1L] 9 | mod = glm(data = data, formula = Class ~ .) 10 | init_predictor = function(data) {predict(mod, data)} 11 | d = data[, -1] 12 | l = data$Class 13 | mc = MCBoost$new(init_predictor = init_predictor) 14 | mc$multicalibrate(d[1:200,], l[1:200]) 15 | expect_list(mc$iter_models, len = 5) 16 | out = mc$predict_probs(d[201:208,]) 17 | expect_numeric(out) 18 | }) 19 | -------------------------------------------------------------------------------- /tic.R: -------------------------------------------------------------------------------- 1 | # installs dependencies, runs R CMD check, runs covr::codecov() 2 | do_package_checks() 3 | 4 | if (ci_on_ghactions() && ci_has_env("BUILD_PKGDOWN")) { 5 | # creates pkgdown site and pushes to gh-pages branch 6 | # only for the runner with the "BUILD_PKGDOWN" env var set 7 | do_pkgdown() 8 | } 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/mcboost_example.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "MCBoost - Health Survey Example" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{MCBoost - Health Survey Example} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | %\VignetteDepends{} 9 | --- 10 | 11 | ```{r, echo = FALSE} 12 | # We do not build vignettes on CRAN or when the depends only flag is active. 13 | NOT_CRAN = identical(tolower(Sys.getenv("NOT_CRAN")), "true") 14 | HAS_DEPS = identical(tolower(Sys.getenv("_R_CHECK_DEPENDS_ONLY_")), "true") 15 | knitr::opts_chunk$set( 16 | collapse = TRUE, 17 | comment = "#>", 18 | purl = (NOT_CRAN & !HAS_DEPS), 19 | eval = (NOT_CRAN & !HAS_DEPS) 20 | ) 21 | ``` 22 | 23 | ```{r setup, message = FALSE} 24 | library(tidyverse) 25 | library(PracTools) 26 | library(ranger) 27 | library(neuralnet) 28 | library(formattable) 29 | library(mlr3) 30 | library(mlr3learners) 31 | library(mcboost) 32 | ``` 33 | 34 | ## Data and Setup 35 | 36 | This vignette presents two typical use cases of MCBoost with data from a health survey. The goal is to post-process two initial prediction models for multi-accuracy using different flavors of MCBoost, and to eventually compare the naive and post-processed predictors overall and for subpopulations. The first scenario starts with a neural net and, as an example, evaluates the initial and post-processed predictors with a focus on subgroup accuracy after running MCBoost. The second scenario uses a random forest and evaluates the initial and post-processed predictors with respect to subgroup calibration. 37 | 38 | We use data derived from the National Health Interview Survey (NHIS 2003), which includes demographic and health-related variables for 21,588 individuals. This data can directly be included from the `PracTools` package. 39 | 40 | ```{r, eval = TRUE} 41 | data(nhis.large) 42 | ``` 43 | We can obtain more information using: 44 | 45 | ```{r, eval = FALSE} 46 | ?nhis.large 47 | ``` 48 | 49 | 50 | In the following, our outcome of interest is whether an individual is covered by any type of health insurance (`notcov`, 1 = not covered, 0 = covered). We additionally prepare two sets of variables: 51 | 52 | - Predictor variables (age, parents in household, education, income, employment status, physical or other limitations) 53 | - Subpopulation variables (sex, hispanic ethnicity, race) 54 | 55 | The second set of variables will not be used for training the initial prediction models, but will be our focus when it comes to evaluating prediction performance for subgroups. 56 | 57 | Before we training an initial model, we preprocess the data: 58 | 59 | * We encode categorical features as `factor`. 60 | * We explicitly assign `NA`s in categorical features to a dedicated factor level 61 | * We drop `NA`s in the outcome variable `notcov` 62 | * We encode `notcov` as a factor variable instead of a dummy variable (1 = `notcov`, 0 = `cov`) 63 | * We create a new feature `inv_wt` as the inverse of survey weights `svwyt` 64 | 65 | ```{r} 66 | categorical <- c("age.grp", "parents", "educ", "inc.grp", "doing.lw", 67 | "limited", "sex", "hisp", "race") 68 | 69 | nhis <- nhis.large %>% 70 | mutate_at(categorical, as.factor) %>% 71 | mutate_at(categorical, fct_explicit_na) %>% 72 | drop_na(notcov) %>% 73 | select(all_of(categorical), notcov, svywt, ID) 74 | 75 | nhis$notcov <- factor(ifelse(nhis$notcov == 1, "notcov", "cov")) 76 | 77 | nhis_enc <- data.frame(model.matrix(notcov ~ ., data = nhis)[,-1]) 78 | nhis_enc$notcov <- nhis$notcov 79 | nhis_enc$sex <- nhis$sex 80 | nhis_enc$hisp <- nhis$hisp 81 | nhis_enc$race <- nhis$race 82 | nhis_enc$inv_wt <- (1 / nhis$svywt) 83 | ``` 84 | 85 | The pre-processed NHIS data will be split into three datasets: 86 | 87 | - A training set `train` for training the initial prediction models (55 \% of data) 88 | - An auditing set `post` for post-processing the initial models with MCBoost (20 \%) 89 | - A test set `test`for model evaluation (25 \%) 90 | 91 | To increase the difficulty of the prediction task, we sample from the NHIS data such that the prevalence of demographic subgroups in the test data differs from their prevalence in the training and auditing data. This is achieved by employing weighted sampling from NHIS (variable `inv_wt` from above). 92 | 93 | ```{r} 94 | set.seed(2953) 95 | 96 | test <- nhis_enc %>% slice_sample(prop = 0.25, weight_by = inv_wt) 97 | 98 | nontest_g <- nhis_enc %>% anti_join(test, by = "ID") 99 | 100 | train_g <- nontest_g %>% slice_sample(prop = 0.75) 101 | 102 | post <- nontest_g %>% anti_join(train_g, by = "ID") %>% select(-ID, -svywt, -inv_wt, -c(sex:race)) 103 | 104 | train <- train_g %>% select(-ID, -svywt, -inv_wt, -c(sex:race), -c(sex2:race3)) 105 | ``` 106 | 107 | As a result, non-hispanic white individuals (`hisp2`) are overrepresented and hispanic individuals are underrepresented in both the training and auditing set, compared to their prevalence in the test set. 108 | 109 | ```{r} 110 | train_g %>% summarise_at(vars(sex2:race3), mean) 111 | # hispanic individuals 112 | 1 - sum(train_g %>% summarise_at(vars(hisp2:hisp4), mean)) 113 | ``` 114 | 115 | ```{r} 116 | post %>% summarise_at(vars(sex2:race3), mean) 117 | # hispanic individuals 118 | 1 - sum(post %>% summarise_at(vars(hisp2:hisp4), mean)) 119 | ``` 120 | 121 | ```{r} 122 | test %>% summarise_at(vars(sex2:race3), mean) 123 | # hispanic individuals 124 | 1 - sum(test %>% summarise_at(vars(hisp2:hisp4), mean)) 125 | ``` 126 | 127 | ## Scenario 1: Improve Subgroup Accuracy 128 | 129 | We train an initial model for predicting healthcare coverage with the training set. Here, we use a neural network with one hidden layer, rather naively with little tweaking. 130 | 131 | ```{r, message = FALSE} 132 | nnet <- neuralnet(notcov ~ ., 133 | hidden = 5, 134 | linear.output = FALSE, 135 | err.fct = 'ce', 136 | threshold = 0.5, 137 | lifesign = 'full', 138 | data = train 139 | ) 140 | ``` 141 | 142 | ### MCBoost Auditing 143 | 144 | We prepare a function that allows us to pass the predictions of the model to MCBoost for post-processing. 145 | 146 | ```{r} 147 | init_nnet = function(data) { 148 | predict(nnet, data)[, 2] 149 | } 150 | ``` 151 | 152 | To showcase different use cases of MCBoost, we prepare two post-processing data sets based on the auditing set. The first set includes only the predictor variables that were used by the initial models, whereas the second set will allow post-processing based on our demographic subgroups of interest (sex, hispanic ethnicity, race). 153 | 154 | ```{r} 155 | d1 <- select(post, -c(notcov, sex2:race3)) 156 | d2 <- select(post, -notcov) 157 | l <- 1 - one_hot(post$notcov) 158 | ``` 159 | 160 | We initialize two custom auditors for MCBoost: Ridge regression with a small penalty on model complexity, and a `SubpopAuditorFitter` with a fixed set of subpopulations. 161 | 162 | ```{r} 163 | ridge = LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, lambda = 2 / nrow(post))) 164 | 165 | pops = SubpopAuditorFitter$new(list("sex2", "hisp2", "hisp3", "hisp4", "race2", "race3")) 166 | ``` 167 | 168 | The ridge regression will only be given access to the initial predictor variables when post-processing the neural net predictions with the auditing data. In contrast, we guide the subpop-fitter to audit the initial predictions explicitly on the outlined subpopulations (sex, hispanic ethnicity, race). In summary, we have: 169 | 170 | - `nnet`: Initial neural net 171 | - `nnet_mc_ridge`: Neural net, post-processed with ridge regression and the initial set of predictor variables 172 | - `nnet_mc_subpop`: Neural net, post-processed with a fixed set of subpopulations 173 | 174 | ```{r} 175 | nnet_mc_ridge = MCBoost$new(init_predictor = init_nnet, 176 | auditor_fitter = ridge, 177 | multiplicative = TRUE, 178 | partition = TRUE, 179 | max_iter = 15) 180 | nnet_mc_ridge$multicalibrate(d1, l) 181 | 182 | nnet_mc_subpop = MCBoost$new(init_predictor = init_nnet, 183 | auditor_fitter = pops, 184 | partition = TRUE, 185 | max_iter = 15) 186 | nnet_mc_subpop$multicalibrate(d2, l) 187 | ``` 188 | 189 | ### Model Evaluation 190 | 191 | Next, we use the initial and post-processed models to predict the outcome in the test data. We compute predicted probabilities and class predictions. 192 | 193 | ```{r} 194 | test$nnet <- predict(nnet, newdata = test)[, 2] 195 | test$nnet_mc_ridge <- nnet_mc_ridge$predict_probs(test) 196 | test$nnet_mc_subpop <- nnet_mc_subpop$predict_probs(test) 197 | 198 | test$c_nnet <- round(test$nnet) 199 | test$c_nnet_mc_ridge <- round(test$nnet_mc_ridge) 200 | test$c_nnet_mc_subpop <- round(test$nnet_mc_subpop) 201 | test$label <- 1 - one_hot(test$notcov) 202 | ``` 203 | 204 | Here we compare the overall accuracy of the initial and post-processed models. Overall, we observe little differences in performance. 205 | 206 | ```{r} 207 | mean(test$c_nnet == test$label) 208 | mean(test$c_nnet_mc_ridge == test$label) 209 | mean(test$c_nnet_mc_subpop == test$label) 210 | ``` 211 | 212 | However, we might be concerned with model performance for smaller subpopulations. In the following, we focus on subgroups defined by 2-way conjunctions of sex, hispanic ethnicity, and race. 213 | 214 | ```{r, warning = FALSE} 215 | test <- test %>% 216 | group_by(sex, hisp) %>% 217 | mutate(sex_hisp = cur_group_id()) %>% 218 | group_by(sex, race) %>% 219 | mutate(sex_race = cur_group_id()) %>% 220 | group_by(hisp, race) %>% 221 | mutate(hisp_race = cur_group_id()) %>% 222 | ungroup() 223 | 224 | grouping_vars <- c("sex", "hisp", "race", "sex_hisp", "sex_race", "hisp_race") 225 | 226 | eval <- map(grouping_vars, group_by_at, .tbl = test) %>% 227 | map(summarise, 228 | 'accuracy_nnet' = mean(c_nnet == label), 229 | 'accuracy_nnet_mc_ridge' = mean(c_nnet_mc_ridge == label), 230 | 'accuracy_nnet_mc_subpop' = mean(c_nnet_mc_subpop == label), 231 | 'size' = n()) %>% 232 | bind_rows() 233 | ``` 234 | 235 | We evaluate classification accuracy on these subpopulations, and order the results according to the size of the selected subgroups (`size`). Subgroup accuracy varies between methods, with MCBoost-Ridge (`nnet_mc_ridge`) and MCBoost-Subpop (`nnet_mc_subpop`) stabilizing subgroup performance when compared to the initial model, respectively. 236 | 237 | ```{r} 238 | eval %>% 239 | arrange(desc(size)) %>% 240 | select(size, accuracy_nnet:accuracy_nnet_mc_subpop) %>% 241 | round(., digits = 3) %>% 242 | formattable(., lapply(1:nrow(eval), function(row) { 243 | area(row, col = 2:4) ~ color_tile("transparent", "lightgreen") 244 | })) 245 | ``` 246 | 247 | ## Scenario 2: Improve Subgroup Calibration 248 | 249 | In this scenario, we use a random forest with the default settings of the ranger package as the initial predictor. 250 | 251 | ```{r} 252 | rf <- ranger(notcov ~ ., data = train, probability = TRUE) 253 | ``` 254 | 255 | ### MCBoost Auditing 256 | 257 | We again prepare a function to pass the predictions to MCBoost for post-processing. 258 | 259 | ```{r} 260 | init_rf = function(data) { 261 | predict(rf, data)$prediction[, 2] 262 | } 263 | ``` 264 | 265 | We use two custom auditors for MCBoost, i.e., ridge and lasso regression with different penalties on model complexity. 266 | 267 | ```{r} 268 | ridge = LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, lambda = 2 / nrow(post))) 269 | 270 | lasso = LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 1, lambda = 40 / nrow(post))) 271 | ``` 272 | 273 | The ridge regression will only be given access to the initial predictor variables when post-processing the random forest predictions. In contrast, we allow the lasso regression to audit the initial predictions both with the initial predictors and the subpopulations (sex, hispanic ethnicity, race). In summary, we have: 274 | 275 | - `rf`: Initial random forest 276 | - `rf_mc_ridge`: Random forest, post-processed with ridge regression and the initial set of predictor variables 277 | - `rf_mc_lasso`: Random forest, post-processed with lasso regression and the extended set of predictors 278 | 279 | ```{r} 280 | rf_mc_ridge = MCBoost$new(init_predictor = init_rf, 281 | auditor_fitter = ridge, 282 | multiplicative = TRUE, 283 | partition = TRUE, 284 | max_iter = 15) 285 | rf_mc_ridge$multicalibrate(d1, l) 286 | 287 | rf_mc_lasso = MCBoost$new(init_predictor = init_rf, 288 | auditor_fitter = lasso, 289 | multiplicative = TRUE, 290 | partition = TRUE, 291 | max_iter = 15) 292 | rf_mc_lasso$multicalibrate(d2, l) 293 | ``` 294 | 295 | ### Model Evaluation 296 | 297 | We again compute predicted probabilities and class predictions using the initial and post-processed models. 298 | 299 | ```{r} 300 | test$rf <- predict(rf, test)$prediction[, 2] 301 | test$rf_mc_ridge <- rf_mc_ridge$predict_probs(test) 302 | test$rf_mc_lasso <- rf_mc_lasso$predict_probs(test) 303 | 304 | test$c_rf <- round(test$rf) 305 | test$c_rf_mc_ridge <- round(test$rf_mc_ridge) 306 | test$c_rf_mc_lasso <- round(test$rf_mc_lasso) 307 | ``` 308 | 309 | Here we compare the overall accuracy of the initial and post-processed models. As before, we observe small differences in overall performance. 310 | 311 | ```{r} 312 | mean(test$c_rf == test$label) 313 | mean(test$c_rf_mc_ridge == test$label) 314 | mean(test$c_rf_mc_lasso == test$label) 315 | ``` 316 | 317 | However, we might be concerned with calibration in subpopulations. In the following we focus on subgroups defined by 2-way conjunctions of sex, hispanic ethnicity, and race. 318 | 319 | ```{r} 320 | eval <- map(grouping_vars, group_by_at, .tbl = test) %>% 321 | map(summarise, 322 | 'bias_rf' = abs(mean(rf) - mean(label))*100, 323 | 'bias_rf_mc_ridge' = abs(mean(rf_mc_ridge) - mean(label))*100, 324 | 'bias_rf_mc_lasso' = abs(mean(rf_mc_lasso) - mean(label))*100, 325 | 'size' = n()) %>% 326 | bind_rows() 327 | ``` 328 | 329 | This evaluation focuses on the difference between the average predicted risk of healthcare non-coverage and the observed proportion of non-coverage in the test data for subgroups. Considering the MCBoost-Ridge (`rf_mc_ridge`) and MCBoost-Lasso (`rf_mc_lasso`) results, post-processing with MCBoost reduces bias for many subpopulations. 330 | 331 | ```{r} 332 | eval %>% 333 | arrange(desc(size)) %>% 334 | select(size, bias_rf:bias_rf_mc_lasso) %>% 335 | round(., digits = 3) %>% 336 | formattable(., lapply(1:nrow(eval), function(row) { 337 | area(row, col = 2:4) ~ color_tile("lightgreen", "transparent") 338 | })) 339 | ``` 340 | --------------------------------------------------------------------------------