├── .Rbuildignore ├── .github ├── .gitignore └── workflows │ ├── R-CMD-check.yaml │ ├── pkgdown.yaml │ └── test-coverage.yaml ├── .gitignore ├── CRAN-SUBMISSION ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── NEWS.md ├── R ├── additive_shap.R ├── kernelshap.R ├── methods.R ├── permshap.R ├── pred_fun.R ├── utils.R ├── utils_kernelshap.R └── utils_permshap.R ├── README.md ├── backlog ├── 2023-11-11 Permutation-SHAP.R ├── compare_with_python.R ├── plot_settings.R ├── test_additive_shap.R └── test_ranger.R ├── cran-comments.md ├── logo.png ├── man ├── additive_shap.Rd ├── figures │ ├── README-gam-dep.svg │ ├── README-gam-imp.svg │ ├── README-nn-dep.svg │ ├── README-nn-imp.svg │ ├── README-prob-dep.svg │ ├── README-prob-imp.svg │ ├── README-rf-dep.svg │ ├── README-rf-imp.svg │ └── logo.png ├── is.kernelshap.Rd ├── kernelshap.Rd ├── permshap.Rd ├── print.kernelshap.Rd └── summary.kernelshap.Rd ├── packaging.R ├── pkgdown ├── _pkgdown.yml └── favicon │ ├── apple-touch-icon-120x120.png │ ├── apple-touch-icon-152x152.png │ ├── apple-touch-icon-180x180.png │ ├── apple-touch-icon-60x60.png │ ├── apple-touch-icon-76x76.png │ ├── apple-touch-icon.png │ ├── favicon-16x16.png │ ├── favicon-32x32.png │ └── favicon.ico ├── revdep ├── .gitignore ├── README.md ├── cran.md ├── email.yml ├── failures.md └── problems.md └── tests ├── testthat.R └── testthat ├── test-additive_shap.R ├── test-basic.R ├── test-kernelshap-utils.R ├── test-methods.R ├── test-multioutput.R ├── test-permshap-utils.R ├── test-utils.R └── test-weights.R /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^LICENSE\.md$ 2 | ^packaging.R$ 3 | [.]Rproj$ 4 | ^cran-comments.md$ 5 | ^logo.png$ 6 | ^cran-comments\.md$ 7 | ^.*\.Rproj$ 8 | ^\.Rproj\.user$ 9 | ^backlog$ 10 | ^CRAN-SUBMISSION$ 11 | ^_pkgdown\.yml$ 12 | ^docs$ 13 | ^pkgdown$ 14 | ^pkgdown/_pkgdown\.yml$ 15 | ^\.github$ 16 | ^revdep$ 17 | ^compare_with_python.R$ 18 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | 9 | name: R-CMD-check 10 | 11 | jobs: 12 | R-CMD-check: 13 | runs-on: ${{ matrix.config.os }} 14 | 15 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | config: 21 | - {os: macos-latest, r: 'release'} 22 | - {os: windows-latest, r: 'release'} 23 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 24 | - {os: ubuntu-latest, r: 'release'} 25 | - {os: ubuntu-latest, r: 'oldrel-1'} 26 | 27 | env: 28 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 29 | R_KEEP_PKG_SOURCE: yes 30 | 31 | steps: 32 | - uses: actions/checkout@v3 33 | 34 | - uses: r-lib/actions/setup-pandoc@v2 35 | 36 | - uses: r-lib/actions/setup-r@v2 37 | with: 38 | r-version: ${{ matrix.config.r }} 39 | http-user-agent: ${{ matrix.config.http-user-agent }} 40 | use-public-rspm: true 41 | 42 | - uses: r-lib/actions/setup-r-dependencies@v2 43 | with: 44 | extra-packages: any::rcmdcheck 45 | needs: check 46 | 47 | - uses: r-lib/actions/check-r-package@v2 48 | with: 49 | upload-snapshots: true 50 | -------------------------------------------------------------------------------- /.github/workflows/pkgdown.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | branches: [main, master] 8 | release: 9 | types: [published] 10 | workflow_dispatch: 11 | 12 | name: pkgdown 13 | 14 | jobs: 15 | pkgdown: 16 | runs-on: ubuntu-latest 17 | # Only restrict concurrency for non-PR jobs 18 | concurrency: 19 | group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} 20 | env: 21 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 22 | steps: 23 | - uses: actions/checkout@v3 24 | 25 | - uses: r-lib/actions/setup-pandoc@v2 26 | 27 | - uses: r-lib/actions/setup-r@v2 28 | with: 29 | use-public-rspm: true 30 | 31 | - uses: r-lib/actions/setup-r-dependencies@v2 32 | with: 33 | extra-packages: any::pkgdown, local::. 34 | needs: website 35 | 36 | - name: Install DrWhy theme 37 | run: | 38 | install.packages("remotes") 39 | remotes::install_deps(dependencies = TRUE) 40 | install.packages("future.apply") 41 | remotes::install_github("ModelOriented/DrWhyTemplate") 42 | shell: Rscript {0} 43 | 44 | - name: Build site 45 | run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE) 46 | shell: Rscript {0} 47 | 48 | - name: Deploy to GitHub pages 🚀 49 | if: github.event_name != 'pull_request' 50 | uses: JamesIves/github-pages-deploy-action@v4.4.1 51 | with: 52 | clean: false 53 | branch: gh-pages 54 | folder: docs 55 | -------------------------------------------------------------------------------- /.github/workflows/test-coverage.yaml: -------------------------------------------------------------------------------- 1 | # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples 2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help 3 | on: 4 | push: 5 | branches: [main, master] 6 | pull_request: 7 | 8 | name: test-coverage.yaml 9 | 10 | permissions: read-all 11 | 12 | jobs: 13 | test-coverage: 14 | runs-on: ubuntu-latest 15 | env: 16 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - uses: r-lib/actions/setup-r@v2 22 | with: 23 | use-public-rspm: true 24 | 25 | - uses: r-lib/actions/setup-r-dependencies@v2 26 | with: 27 | extra-packages: any::covr, any::xml2 28 | needs: coverage 29 | 30 | - name: Test coverage 31 | run: | 32 | cov <- covr::package_coverage( 33 | quiet = FALSE, 34 | clean = FALSE, 35 | install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package"), 36 | function_exclusions = c( 37 | "kernelshap\\.ranger", 38 | "permshap\\.ranger", 39 | "pred_ranger" 40 | ) 41 | ) 42 | print(cov) 43 | covr::to_cobertura(cov) 44 | shell: Rscript {0} 45 | 46 | - uses: codecov/codecov-action@v5 47 | with: 48 | # Fail if error if not on PR, or if on PR and token is given 49 | fail_ci_if_error: ${{ github.event_name != 'pull_request' || secrets.CODECOV_TOKEN }} 50 | files: ./cobertura.xml 51 | plugins: noop 52 | disable_search: true 53 | token: ${{ secrets.CODECOV_TOKEN }} 54 | 55 | - name: Show testthat output 56 | if: always() 57 | run: | 58 | ## -------------------------------------------------------------------- 59 | find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true 60 | shell: bash 61 | 62 | - name: Upload test results 63 | if: failure() 64 | uses: actions/upload-artifact@v4 65 | with: 66 | name: coverage-test-failures 67 | path: ${{ runner.temp }}/package 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # History files 2 | .Rhistory 3 | .Rapp.history 4 | 5 | # Session Data files 6 | .RData 7 | 8 | # User-specific files 9 | .Ruserdata 10 | 11 | # Example code in package build process 12 | *-Ex.R 13 | 14 | # Output files from R CMD build 15 | /*.tar.gz 16 | 17 | # Output files from R CMD check 18 | /*.Rcheck/ 19 | 20 | # RStudio files 21 | .Rproj.user/ 22 | 23 | # produced vignettes 24 | vignettes/*.html 25 | vignettes/*.pdf 26 | 27 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 28 | .httr-oauth 29 | 30 | # knitr and R markdown default cache directories 31 | *_cache/ 32 | /cache/ 33 | 34 | # Temporary files created by R markdown 35 | *.utf8.md 36 | *.knit.md 37 | 38 | # R Environment Variables 39 | .Renviron 40 | 41 | # other 42 | stuff.R 43 | *.Rproj 44 | 45 | # pkgdown 46 | docs 47 | -------------------------------------------------------------------------------- /CRAN-SUBMISSION: -------------------------------------------------------------------------------- 1 | Version: 0.7.0 2 | Date: 2024-08-17 15:37:44 UTC 3 | SHA: bb03d4c0b075dabdc7aee816c2a786829a5e92fc 4 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: kernelshap 2 | Title: Kernel SHAP 3 | Version: 0.7.1 4 | Authors@R: c( 5 | person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"), 6 | comment = c(ORCID = "0009-0007-2540-9629")), 7 | person("David", "Watson", , "david.s.watson11@gmail.com", role = "aut", 8 | comment = c(ORCID = "0000-0001-9632-2159")), 9 | person("Przemyslaw", "Biecek", , "przemyslaw.biecek@gmail.com", role = "ctb", 10 | comment = c(ORCID = "0000-0001-8423-1823")) 11 | ) 12 | Description: Efficient implementation of Kernel SHAP, see Lundberg and Lee 13 | (2017), and Covert and Lee (2021) 14 | . Furthermore, for up to 15 | 14 features, exact permutation SHAP values can be calculated. The 16 | package plays well together with meta-learning packages like 17 | 'tidymodels', 'caret' or 'mlr3'. Visualizations can be done using the 18 | R package 'shapviz'. 19 | License: GPL (>= 2) 20 | Depends: 21 | R (>= 3.2.0) 22 | Encoding: UTF-8 23 | Roxygen: list(markdown = TRUE) 24 | RoxygenNote: 7.3.2 25 | Imports: 26 | foreach, 27 | MASS, 28 | stats, 29 | utils 30 | Suggests: 31 | doFuture, 32 | testthat (>= 3.0.0) 33 | Config/testthat/edition: 3 34 | URL: https://github.com/ModelOriented/kernelshap 35 | BugReports: https://github.com/ModelOriented/kernelshap/issues 36 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | GNU General Public License 2 | ========================== 3 | 4 | _Version 2, June 1991_ 5 | _Copyright © 1989, 1991 Free Software Foundation, Inc.,_ 6 | _51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA_ 7 | 8 | Everyone is permitted to copy and distribute verbatim copies 9 | of this license document, but changing it is not allowed. 10 | 11 | ### Preamble 12 | 13 | The licenses for most software are designed to take away your 14 | freedom to share and change it. By contrast, the GNU General Public 15 | License is intended to guarantee your freedom to share and change free 16 | software--to make sure the software is free for all its users. This 17 | General Public License applies to most of the Free Software 18 | Foundation's software and to any other program whose authors commit to 19 | using it. (Some other Free Software Foundation software is covered by 20 | the GNU Lesser General Public License instead.) You can apply it to 21 | your programs, too. 22 | 23 | When we speak of free software, we are referring to freedom, not 24 | price. Our General Public Licenses are designed to make sure that you 25 | have the freedom to distribute copies of free software (and charge for 26 | this service if you wish), that you receive source code or can get it 27 | if you want it, that you can change the software or use pieces of it 28 | in new free programs; and that you know you can do these things. 29 | 30 | To protect your rights, we need to make restrictions that forbid 31 | anyone to deny you these rights or to ask you to surrender the rights. 32 | These restrictions translate to certain responsibilities for you if you 33 | distribute copies of the software, or if you modify it. 34 | 35 | For example, if you distribute copies of such a program, whether 36 | gratis or for a fee, you must give the recipients all the rights that 37 | you have. You must make sure that they, too, receive or can get the 38 | source code. And you must show them these terms so they know their 39 | rights. 40 | 41 | We protect your rights with two steps: **(1)** copyright the software, and 42 | **(2)** offer you this license which gives you legal permission to copy, 43 | distribute and/or modify the software. 44 | 45 | Also, for each author's protection and ours, we want to make certain 46 | that everyone understands that there is no warranty for this free 47 | software. If the software is modified by someone else and passed on, we 48 | want its recipients to know that what they have is not the original, so 49 | that any problems introduced by others will not reflect on the original 50 | authors' reputations. 51 | 52 | Finally, any free program is threatened constantly by software 53 | patents. We wish to avoid the danger that redistributors of a free 54 | program will individually obtain patent licenses, in effect making the 55 | program proprietary. To prevent this, we have made it clear that any 56 | patent must be licensed for everyone's free use or not licensed at all. 57 | 58 | The precise terms and conditions for copying, distribution and 59 | modification follow. 60 | 61 | ### TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 62 | 63 | **0.** This License applies to any program or other work which contains 64 | a notice placed by the copyright holder saying it may be distributed 65 | under the terms of this General Public License. The “Program”, below, 66 | refers to any such program or work, and a “work based on the Program” 67 | means either the Program or any derivative work under copyright law: 68 | that is to say, a work containing the Program or a portion of it, 69 | either verbatim or with modifications and/or translated into another 70 | language. (Hereinafter, translation is included without limitation in 71 | the term “modification”.) Each licensee is addressed as “you”. 72 | 73 | Activities other than copying, distribution and modification are not 74 | covered by this License; they are outside its scope. The act of 75 | running the Program is not restricted, and the output from the Program 76 | is covered only if its contents constitute a work based on the 77 | Program (independent of having been made by running the Program). 78 | Whether that is true depends on what the Program does. 79 | 80 | **1.** You may copy and distribute verbatim copies of the Program's 81 | source code as you receive it, in any medium, provided that you 82 | conspicuously and appropriately publish on each copy an appropriate 83 | copyright notice and disclaimer of warranty; keep intact all the 84 | notices that refer to this License and to the absence of any warranty; 85 | and give any other recipients of the Program a copy of this License 86 | along with the Program. 87 | 88 | You may charge a fee for the physical act of transferring a copy, and 89 | you may at your option offer warranty protection in exchange for a fee. 90 | 91 | **2.** You may modify your copy or copies of the Program or any portion 92 | of it, thus forming a work based on the Program, and copy and 93 | distribute such modifications or work under the terms of Section 1 94 | above, provided that you also meet all of these conditions: 95 | 96 | * **a)** You must cause the modified files to carry prominent notices 97 | stating that you changed the files and the date of any change. 98 | * **b)** You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | * **c)** If the modified program normally reads commands interactively 103 | when run, you must cause it, when started running for such 104 | interactive use in the most ordinary way, to print or display an 105 | announcement including an appropriate copyright notice and a 106 | notice that there is no warranty (or else, saying that you provide 107 | a warranty) and that users may redistribute the program under 108 | these conditions, and telling the user how to view a copy of this 109 | License. (Exception: if the Program itself is interactive but 110 | does not normally print such an announcement, your work based on 111 | the Program is not required to print an announcement.) 112 | 113 | These requirements apply to the modified work as a whole. If 114 | identifiable sections of that work are not derived from the Program, 115 | and can be reasonably considered independent and separate works in 116 | themselves, then this License, and its terms, do not apply to those 117 | sections when you distribute them as separate works. But when you 118 | distribute the same sections as part of a whole which is a work based 119 | on the Program, the distribution of the whole must be on the terms of 120 | this License, whose permissions for other licensees extend to the 121 | entire whole, and thus to each and every part regardless of who wrote it. 122 | 123 | Thus, it is not the intent of this section to claim rights or contest 124 | your rights to work written entirely by you; rather, the intent is to 125 | exercise the right to control the distribution of derivative or 126 | collective works based on the Program. 127 | 128 | In addition, mere aggregation of another work not based on the Program 129 | with the Program (or with a work based on the Program) on a volume of 130 | a storage or distribution medium does not bring the other work under 131 | the scope of this License. 132 | 133 | **3.** You may copy and distribute the Program (or a work based on it, 134 | under Section 2) in object code or executable form under the terms of 135 | Sections 1 and 2 above provided that you also do one of the following: 136 | 137 | * **a)** Accompany it with the complete corresponding machine-readable 138 | source code, which must be distributed under the terms of Sections 139 | 1 and 2 above on a medium customarily used for software interchange; or, 140 | * **b)** Accompany it with a written offer, valid for at least three 141 | years, to give any third party, for a charge no more than your 142 | cost of physically performing source distribution, a complete 143 | machine-readable copy of the corresponding source code, to be 144 | distributed under the terms of Sections 1 and 2 above on a medium 145 | customarily used for software interchange; or, 146 | * **c)** Accompany it with the information you received as to the offer 147 | to distribute corresponding source code. (This alternative is 148 | allowed only for noncommercial distribution and only if you 149 | received the program in object code or executable form with such 150 | an offer, in accord with Subsection b above.) 151 | 152 | The source code for a work means the preferred form of the work for 153 | making modifications to it. For an executable work, complete source 154 | code means all the source code for all modules it contains, plus any 155 | associated interface definition files, plus the scripts used to 156 | control compilation and installation of the executable. However, as a 157 | special exception, the source code distributed need not include 158 | anything that is normally distributed (in either source or binary 159 | form) with the major components (compiler, kernel, and so on) of the 160 | operating system on which the executable runs, unless that component 161 | itself accompanies the executable. 162 | 163 | If distribution of executable or object code is made by offering 164 | access to copy from a designated place, then offering equivalent 165 | access to copy the source code from the same place counts as 166 | distribution of the source code, even though third parties are not 167 | compelled to copy the source along with the object code. 168 | 169 | **4.** You may not copy, modify, sublicense, or distribute the Program 170 | except as expressly provided under this License. Any attempt 171 | otherwise to copy, modify, sublicense or distribute the Program is 172 | void, and will automatically terminate your rights under this License. 173 | However, parties who have received copies, or rights, from you under 174 | this License will not have their licenses terminated so long as such 175 | parties remain in full compliance. 176 | 177 | **5.** You are not required to accept this License, since you have not 178 | signed it. However, nothing else grants you permission to modify or 179 | distribute the Program or its derivative works. These actions are 180 | prohibited by law if you do not accept this License. Therefore, by 181 | modifying or distributing the Program (or any work based on the 182 | Program), you indicate your acceptance of this License to do so, and 183 | all its terms and conditions for copying, distributing or modifying 184 | the Program or works based on it. 185 | 186 | **6.** Each time you redistribute the Program (or any work based on the 187 | Program), the recipient automatically receives a license from the 188 | original licensor to copy, distribute or modify the Program subject to 189 | these terms and conditions. You may not impose any further 190 | restrictions on the recipients' exercise of the rights granted herein. 191 | You are not responsible for enforcing compliance by third parties to 192 | this License. 193 | 194 | **7.** If, as a consequence of a court judgment or allegation of patent 195 | infringement or for any other reason (not limited to patent issues), 196 | conditions are imposed on you (whether by court order, agreement or 197 | otherwise) that contradict the conditions of this License, they do not 198 | excuse you from the conditions of this License. If you cannot 199 | distribute so as to satisfy simultaneously your obligations under this 200 | License and any other pertinent obligations, then as a consequence you 201 | may not distribute the Program at all. For example, if a patent 202 | license would not permit royalty-free redistribution of the Program by 203 | all those who receive copies directly or indirectly through you, then 204 | the only way you could satisfy both it and this License would be to 205 | refrain entirely from distribution of the Program. 206 | 207 | If any portion of this section is held invalid or unenforceable under 208 | any particular circumstance, the balance of the section is intended to 209 | apply and the section as a whole is intended to apply in other 210 | circumstances. 211 | 212 | It is not the purpose of this section to induce you to infringe any 213 | patents or other property right claims or to contest validity of any 214 | such claims; this section has the sole purpose of protecting the 215 | integrity of the free software distribution system, which is 216 | implemented by public license practices. Many people have made 217 | generous contributions to the wide range of software distributed 218 | through that system in reliance on consistent application of that 219 | system; it is up to the author/donor to decide if he or she is willing 220 | to distribute software through any other system and a licensee cannot 221 | impose that choice. 222 | 223 | This section is intended to make thoroughly clear what is believed to 224 | be a consequence of the rest of this License. 225 | 226 | **8.** If the distribution and/or use of the Program is restricted in 227 | certain countries either by patents or by copyrighted interfaces, the 228 | original copyright holder who places the Program under this License 229 | may add an explicit geographical distribution limitation excluding 230 | those countries, so that distribution is permitted only in or among 231 | countries not thus excluded. In such case, this License incorporates 232 | the limitation as if written in the body of this License. 233 | 234 | **9.** The Free Software Foundation may publish revised and/or new versions 235 | of the General Public License from time to time. Such new versions will 236 | be similar in spirit to the present version, but may differ in detail to 237 | address new problems or concerns. 238 | 239 | Each version is given a distinguishing version number. If the Program 240 | specifies a version number of this License which applies to it and “any 241 | later version”, you have the option of following the terms and conditions 242 | either of that version or of any later version published by the Free 243 | Software Foundation. If the Program does not specify a version number of 244 | this License, you may choose any version ever published by the Free Software 245 | Foundation. 246 | 247 | **10.** If you wish to incorporate parts of the Program into other free 248 | programs whose distribution conditions are different, write to the author 249 | to ask for permission. For software which is copyrighted by the Free 250 | Software Foundation, write to the Free Software Foundation; we sometimes 251 | make exceptions for this. Our decision will be guided by the two goals 252 | of preserving the free status of all derivatives of our free software and 253 | of promoting the sharing and reuse of software generally. 254 | 255 | ### NO WARRANTY 256 | 257 | **11.** BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 258 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 259 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 260 | PROVIDE THE PROGRAM “AS IS” WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 261 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 262 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 263 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 264 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 265 | REPAIR OR CORRECTION. 266 | 267 | **12.** IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 268 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 269 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 270 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 271 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 272 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 273 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 274 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 275 | POSSIBILITY OF SUCH DAMAGES. 276 | 277 | END OF TERMS AND CONDITIONS 278 | 279 | ### How to Apply These Terms to Your New Programs 280 | 281 | If you develop a new program, and you want it to be of the greatest 282 | possible use to the public, the best way to achieve this is to make it 283 | free software which everyone can redistribute and change under these terms. 284 | 285 | To do so, attach the following notices to the program. It is safest 286 | to attach them to the start of each source file to most effectively 287 | convey the exclusion of warranty; and each file should have at least 288 | the “copyright” line and a pointer to where the full notice is found. 289 | 290 | 291 | Copyright (C) 292 | 293 | This program is free software; you can redistribute it and/or modify 294 | it under the terms of the GNU General Public License as published by 295 | the Free Software Foundation; either version 2 of the License, or 296 | (at your option) any later version. 297 | 298 | This program is distributed in the hope that it will be useful, 299 | but WITHOUT ANY WARRANTY; without even the implied warranty of 300 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 301 | GNU General Public License for more details. 302 | 303 | You should have received a copy of the GNU General Public License along 304 | with this program; if not, write to the Free Software Foundation, Inc., 305 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 306 | 307 | Also add information on how to contact you by electronic and paper mail. 308 | 309 | If the program is interactive, make it output a short notice like this 310 | when it starts in an interactive mode: 311 | 312 | Gnomovision version 69, Copyright (C) year name of author 313 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 314 | This is free software, and you are welcome to redistribute it 315 | under certain conditions; type `show c' for details. 316 | 317 | The hypothetical commands `show w` and `show c` should show the appropriate 318 | parts of the General Public License. Of course, the commands you use may 319 | be called something other than `show w` and `show c`; they could even be 320 | mouse-clicks or menu items--whatever suits your program. 321 | 322 | You should also get your employer (if you work as a programmer) or your 323 | school, if any, to sign a “copyright disclaimer” for the program, if 324 | necessary. Here is a sample; alter the names: 325 | 326 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 327 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 328 | 329 | , 1 April 1989 330 | Ty Coon, President of Vice 331 | 332 | This General Public License does not permit incorporating your program into 333 | proprietary programs. If your program is a subroutine library, you may 334 | consider it more useful to permit linking proprietary applications with the 335 | library. If this is what you want to do, use the GNU Lesser General 336 | Public License instead of this License. 337 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(kernelshap,default) 4 | S3method(kernelshap,ranger) 5 | S3method(permshap,default) 6 | S3method(permshap,ranger) 7 | S3method(print,kernelshap) 8 | S3method(summary,kernelshap) 9 | export(additive_shap) 10 | export(is.kernelshap) 11 | export(kernelshap) 12 | export(permshap) 13 | importFrom(foreach,"%dopar%") 14 | -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | # kernelshap 0.7.1 2 | 3 | ## Documentation 4 | 5 | - More compact README. 6 | - Updated function description. 7 | 8 | ## Maintenance 9 | 10 | - Update code coverage version [#150](https://github.com/ModelOriented/kernelshap/pull/150). 11 | 12 | # kernelshap 0.7.0 13 | 14 | This release is intended to be the last before stable version 1.0.0. 15 | 16 | ## Major change 17 | 18 | Passing a background dataset `bg_X` is now optional. 19 | 20 | If the explanation data `X` is sufficiently large (>= 50 rows), `bg_X` is derived as a random sample of `bg_n = 200` rows from `X`. If `X` has less than `bg_n` rows, then simply 21 | `bg_X = X`. If `X` has too few rows (< 50), you will have to pass an explicit `bg_X`. 22 | 23 | ## Minor changes 24 | 25 | - `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `kernelshap()` and `permshap()` to distinguish cumulative hazards (default) and survival probabilities per time point. 26 | - The resulting object of `kernelshap()` and `permshap()` now contain `bg_X` and `bg_w` used to calculate the SHAP values. 27 | 28 | # kernelshap 0.6.0 29 | 30 | ## Major changes 31 | 32 | - Factor-valued predictions are not supported anymore. 33 | 34 | ## Maintenance 35 | 36 | - Fix CRAN note about unavailable link to `gam::gam()`. 37 | - Added dependency to {MASS} for calculating Moore-Penrose generalized matrix inverse. 38 | 39 | # kernelshap 0.5.0 40 | 41 | ## New features 42 | 43 | New additive explainer `additive_shap()` that works for models fitted via 44 | 45 | - `lm()`, 46 | - `glm()`, 47 | - `mgcv::gam()`, 48 | - `mgcv::bam()`, 49 | - `gam::gam()`, 50 | - `survival::coxph()`, 51 | - `survival::survreg()`. 52 | 53 | The explainer uses `predict(..., type = "terms")`, a beautiful trick 54 | used in `fastshap::explain.lm()`. The result will be identical to those returned by `kernelshap()` and `permshap()` but exponentially faster. Thanks David Watson for the great idea discussed in [#130](https://github.com/ModelOriented/kernelshap/issues/130). 55 | 56 | ## User visible changes 57 | 58 | - `permshap()` now returns an object of class "kernelshap" to reduce the number of redundant methods. 59 | - To distinguish which algorithm has generated the "kernelshap" object, the outputs of `kernelshap()`, `permshap()` (and `additive_shap()`) got an element "algorithm". 60 | - `is.permshap()` has been removed. 61 | 62 | # kernelshap 0.4.2 63 | 64 | ## API 65 | 66 | - {mlr3}: Non-probabilistic classification now works. 67 | - {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`. 68 | 69 | ## Documentation 70 | 71 | - The README has received an {mlr3} and {caret} example. 72 | 73 | # kernelshap 0.4.1 74 | 75 | ## Performance improvements 76 | 77 | - Significant speed-up for pure data.frames, i.e., no data.tables or tibbles. 78 | - Some small performance improvements, e.g., for factor predictions and univariate predictions. 79 | - Slight speed-up of `permshap()` by caching calculations for the two special permutations of all 0 and all 1. Consequently, the `m_exact` component in the output is reduced by 2. 80 | 81 | ## Documentation 82 | 83 | - Rewrote many examples in the README. 84 | - Added reference to Erik Strumbelj and Ivan Kononeko (2014). 85 | 86 | # kernelshap 0.4.0 87 | 88 | ## Major changes 89 | 90 | - Added `permshap()` to calculate exact permutation SHAP values. The function currently works for up to 14 features. 91 | - Factor-valued predictions are now supported. Each level is represented by its dummy variable. 92 | 93 | ## Other changes 94 | 95 | - Slight speed-up. 96 | - Integer valued case weights are now turned into doubles to avoid integer overflow. 97 | 98 | # kernelshap 0.3.8 99 | 100 | ## API improvements 101 | 102 | - Multi-output case: column names of predictions are now used as list names of the resulting `S` and `SE` lists. 103 | 104 | ## Bug fixes 105 | 106 | - {mlr3} probabilistic classification would not work out-of-the-box. This has been fixed (with corresponding example in the README) in https://github.com/ModelOriented/kernelshap/pull/100 107 | - The progress bar was initialized at 1 instead of 0. This is fixed. 108 | 109 | ## Maintenance 110 | 111 | - Added explanation of sampling Kernel SHAP to help file. 112 | - In internal calculations, use explicit `feature_names` as dimnames (https://github.com/ModelOriented/kernelshap/issues/96). 113 | 114 | # kernelshap 0.3.7 115 | 116 | ## Maintenance 117 | 118 | - Fixed problem in Latex math for MacOS. 119 | 120 | # kernelshap 0.3.6 121 | 122 | ## Maintenance 123 | 124 | - Improved help files and README 125 | 126 | # kernelshap 0.3.5 127 | 128 | ## Maintenance 129 | 130 | - New contributor: Przemyslaw Biecek - welcome on board! 131 | - My new cozy home: https://github.com/ModelOriented/kernelshap 132 | - Webpage created with "pkgdown" 133 | - Introduced Github workflows 134 | - More unit tests 135 | 136 | ## Small visible changes 137 | 138 | - Removed the `ks_extract()` function. It was designed to extract objects like the matrix `S` of SHAP values from the resulting "kernelshap" object `x`. We feel that the standard extraction options (`x$S`, `x[["S"]]`, or `getElement(x, "S")`) are sufficient. 139 | - Adding $(n \times K)$ matrix of predictions to the output, where $n$ is the number of rows in the explainer data `X`, and $K$ is the dimension of a single prediction (usually 1). 140 | - Setting `verbose = FALSE` now does not suppress the warning on too large background data anymore. Use `suppressWarnings()` instead. 141 | 142 | # kernelshap 0.3.4 143 | 144 | ## Documentation 145 | 146 | - New logo 147 | - Better package description 148 | - Better README 149 | 150 | # kernelshap 0.3.3 151 | 152 | ## Less dependencies 153 | 154 | - Removed dependency "dorng". This might have an impact on the seeding if in parallel mode. 155 | - Removed dependency "MASS" 156 | 157 | # kernelshap 0.3.2 158 | 159 | ## Documentation 160 | 161 | - Rewritten README and examples to better show the role of the background data. 162 | 163 | ## Bug fixes 164 | 165 | - When `bg_X` contained more columns than `X`, unflexible prediction functions could fail when being applied to `bg_X`. 166 | 167 | # kernelshap 0.3.1 168 | 169 | ## Changes 170 | 171 | - New argument `feature_names` allows to specify the features to calculate SHAP values for. The default equals to `colnames(X)`. This should be changed only in situations when `X` (the dataset to be explained) contains non-feature columns. 172 | - The background dataset can now consist of a single row only. This is useful in situations with natural "off" value such as for image data or for models that can naturally deal with missing values. 173 | 174 | 175 | # kernelshap 0.3.0 176 | 177 | ## Major improvements 178 | 179 | ### Exact calculations 180 | 181 | Thanks to David Watson, exact calculations are now also possible for $p>5$ features. By default, the algorithm uses exact calculations for $p \le 8$ and a hybrid strategy otherwise, see the next section. At the same time, the exact algorithm became much more efficient. 182 | 183 | A word of caution: Exact calculations mean to create $2^p-2$ on-off vectors $z$ (cheap step) and evaluating the model on a whopping $(2^p-2)N$ rows, where $N$ is the number of rows of the background data (expensive step). As this explodes with large $p$, we do not recommend the exact strategy for $p > 10$. 184 | 185 | ### Hybrid strategy 186 | 187 | The iterative Kernel SHAP sampling algorithm of Covert and Lee (2021) [1] works by randomly sample $m$ on-off vectors $z$ so that their sum follows the SHAP Kernel weight distribution (renormalized to the range from $1$ to $p-1$). Based on these vectors, many predictions are formed. Then, Kernel SHAP values are derived as the solution of a constrained linear regression, see [1] for details. This is done multiple times until convergence. 188 | 189 | A drawback of this strategy is that many (at least 75%) of the $z$ vectors will have $\sum z \in \{1, p-1\}$, producing many duplicates. Similarly, at least 92% of the mass will be used for the $p(p+1)$ possible vectors with $\sum z \in \{1, 2, p-1, p-2\}$ etc. This inefficiency can be fixed by a hybrid strategy, combining exact calculations with sampling. 190 | The hybrid algorithm has two steps: 191 | 192 | 1. Step 1 (exact part): There are $2p$ different on-off vectors $z$ with $\sum z \in \{1, p-1\}$, covering a large proportion of the Kernel SHAP distribution. The degree 1 hybrid will list those vectors and use them according to their weights in the upcoming calculations. Depending on $p$, we can also go a step further to a degree 2 hybrid by adding all $p(p-1)$ vectors with $\sum z \in \{2, p-2\}$ to the process etc. The necessary predictions are obtained along with other calculations similar to those in [1]. 193 | 2. Step 2 (sampling part): The remaining weight is filled by sampling vectors $z$ according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. Together with the results from Step 1 - correctly weighted - this now forms a complete iteration as in Covert and Lee (2021). The difference is that most mass is covered by exact calculations. Afterwards, the algorithm iterates until convergence. The output of Step 1 is reused in every iteration, leading to an extremely efficient strategy. 194 | 195 | The default behaviour of `kernelshap()` is as follows: 196 | 197 | - $p \le 8$: Exact Kernel SHAP (with respect to the background data) 198 | - $9 \le p \le 16$: Degree 2 hybrid 199 | - $p > 16$: Degree 1 hybrid 200 | - $p = 1$: Exact Shapley values 201 | 202 | It is also possible to use a pure sampling strategy, see Section "User visible changes" below. While this is usually not advisable compared to a hybrid approach, the options of `kernelshap()` allow to study different properties of Kernel SHAP and doing empirical research on the topic. 203 | 204 | Kernel SHAP in the Python implementation "shap" uses a quite similar hybrid strategy, but without iterating. The new logic in the R package thus combines the efficiency of the Python implementation with the convergence monitoring of [1]. 205 | 206 | [1] Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021. 207 | 208 | ## User visible changes 209 | 210 | - The default value of `m` is reduced from $8p$ to $2p$ except when `hybrid_degree = 0` (pure sampling). 211 | - The default value of `exact` is now `TRUE` for $p \le 8$ instead of $p \le 5$. 212 | - A new argument `hybrid_degree` is introduced to control the exact part of the hybrid algorithm. The default is 2 for $4 \le p \le 16$ and degree 1 otherwise. Set to 0 to force a pure sampling strategy (not recommended but useful to demonstrate superiority of hybrid approaches). 213 | - The default value of `tol` was reduced from 0.01 to 0.005. 214 | - The default of `max_iter` was reduced from 250 to 100. 215 | - The order of some of the arguments behind the first four has been changed. 216 | - Paired sampling no longer duplicates `m`. 217 | - Thanks to Mathias Ambuehl, the random sampling of z vectors is now fully vectorized. 218 | - The output of `print()` is now more slim. 219 | - A new `summary()` function shows more infos. 220 | 221 | ## Other changes 222 | 223 | - The resulting object now contains `m_exact` (the number of on-off vectors used for the exact part), `prop_exact` (proportion of mass treated in exact fashion), `exact` flag, and `txt` (the info message when starting the algorithm). 224 | 225 | ## Bug fixes 226 | 227 | - Predictions of `mgcv::gam()` would cause an error in `check_pred()` (they are 1D-arrays). 228 | - Fixed small mistakes in the examples of the README (mlr3 and mgcv). 229 | 230 | # kernelshap 0.2.0 231 | 232 | ## Breaking change 233 | 234 | The interface of `kernelshap()` has been revised. Instead of specifying a prediction function, it suffices now to pass the fitted model object. The default `pred_fun` is now `stats::predict`, which works in most cases. Some other cases are catched via model class ("ranger" and mlr3 "Learner"). The `pred_fun` can be overwritten by a function of the form `function(object, X, ...)`. Additional arguments to the prediction function are passed via `...` of `kernelshap()`. 235 | 236 | Some examples: 237 | 238 | - Logistic regression (logit scale): `kernelshap(fit, X, bg_X)` 239 | - Logistic regression (probabilities): `kernelshap(fit, X, bg_X, type = "response")` 240 | - Linear regression with logarithmic response, but evaluated on original scale: Here, the default predict function needs to be overwritten: `kernelshap(fit, X, bg_X, pred_fun = function(m, X) exp(predict(m, X)))` 241 | 242 | ## Major improvements 243 | 244 | - `kernelshap()` has received a more intuitive interface, see breaking change above. 245 | - The package now supports multidimensional predictions. Hurray! 246 | - Thanks to David Watson, parallel computing is now supported. The user needs to set up the parallel backend before calling `kernelshap()`, e.g., using the "doFuture" package, and then set `parallel = TRUE`. Especially on Windows, sometimes not all global variables or packages are loaded in the parallel instances. These can be specified by `parallel_args`, a list of arguments passed to `foreach()`. 247 | - Even without parallel computing, `kernelshap()` has become much faster. 248 | - For $2 \le p \le 5$ features, the algorithm now returns exact Kernel SHAP values with respect to the given background data. (For $p = 1$, exact *Shapley values* are returned.) 249 | - Direct handling of "tidymodels" models. 250 | 251 | ## User visible changes 252 | 253 | - Besides `matrix`, `data.frame`s, and `tibble`s, the package now also accepts `data.table`s (if the prediction function can deal with them). 254 | - `kernelshap()` is less picky regarding the output structure of `pred_fun()`. 255 | - `kernelshap()` is less picky about the column structure of the background data `bg_X`. It should simply contain the columns of `X` (but can have more or in different order). The old behaviour was to launch an error if `colnames(X) != colnames(bg_X)`. 256 | - The default `m = "auto"` has been changed from `trunc(20 * sqrt(p))` to `max(trunc(20 * sqrt(p)), 5 * p`. This will have an effect for cases where the number of features $p > 16$. The change will imply more robust results for large p. 257 | - There were too many "ks_*()" functions to extract elements of a "kernelshap" object. They are now all deprecated and replaced by `ks_extract(, what = "S")`. 258 | - Added "MASS", "doRNG", and "foreach" to dependencies. 259 | 260 | ## Bug fixes 261 | 262 | - Depending on $m$ and $p$, the matrix inversion required in the constrained least-squares solution could fail. It is now replaced by `MASS::ginv()`, the Moore-Penrose pseudoinverse using `svd()`. 263 | 264 | ## New contributor 265 | 266 | - David Watson 267 | 268 | # kernelshap 0.1.0 269 | 270 | This is the initial release. 271 | -------------------------------------------------------------------------------- /R/additive_shap.R: -------------------------------------------------------------------------------- 1 | #' Additive SHAP 2 | #' 3 | #' Exact additive SHAP assuming feature independence. The implementation 4 | #' works for models fitted via 5 | #' - [lm()], 6 | #' - [glm()], 7 | #' - [mgcv::gam()], 8 | #' - [mgcv::bam()], 9 | #' - `gam::gam()`, 10 | #' - [survival::coxph()], and 11 | #' - [survival::survreg()]. 12 | #' 13 | #' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`, 14 | #' a logic adopted from `fastshap:::explain.lm(..., exact = TRUE)`. 15 | #' Models with interactions (specified via `:` or `*`), or with terms of 16 | #' multiple features like `log(x1/x2)` are not supported. 17 | #' 18 | #' Note that the SHAP values obtained by [additive_shap()] are expected to 19 | #' match those of [permshap()] and [kernelshap()] as long as their background 20 | #' data equals the full training data (which is typically not feasible). 21 | #' 22 | #' @param object Fitted additive model. 23 | #' @param X Dataframe with rows to be explained. Passed to 24 | #' `predict(object, newdata = X, type = "terms")`. 25 | #' @param verbose Set to `FALSE` to suppress messages. 26 | #' @param ... Currently unused. 27 | #' @returns 28 | #' An object of class "kernelshap" with the following components: 29 | #' - `S`: \eqn{(n \times p)} matrix with SHAP values. 30 | #' - `X`: Same as input argument `X`. 31 | #' - `baseline`: The baseline. 32 | #' - `exact`: `TRUE`. 33 | #' - `txt`: Summary text. 34 | #' - `predictions`: Vector with predictions of `X` on the scale of "terms". 35 | #' - `algorithm`: "additive_shap". 36 | #' @export 37 | #' @examples 38 | #' # MODEL ONE: Linear regression 39 | #' fit <- lm(Sepal.Length ~ ., data = iris) 40 | #' s <- additive_shap(fit, head(iris)) 41 | #' s 42 | #' 43 | #' # MODEL TWO: More complicated (but not very clever) formula 44 | #' fit <- lm( 45 | #' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width), 46 | #' data = iris 47 | #' ) 48 | #' s_add <- additive_shap(fit, head(iris)) 49 | #' s_add 50 | #' 51 | #' # Equals kernelshap()/permshap() when background data is full training data 52 | #' s_kernel <- kernelshap( 53 | #' fit, head(iris[c("Sepal.Width", "Petal.Length")]), bg_X = iris 54 | #' ) 55 | #' all.equal(s_add$S, s_kernel$S) 56 | additive_shap <- function(object, X, verbose = TRUE, ...) { 57 | stopifnot( 58 | inherits(object, c("lm", "glm", "gam", "bam", "Gam", "coxph", "survreg")) 59 | ) 60 | if (any(attr(stats::terms(object), "order") > 1)) { 61 | stop("Additive SHAP not appropriate for models with interactions.") 62 | } 63 | 64 | txt <- "Exact additive SHAP via predict(..., type = 'terms')" 65 | if (verbose) { 66 | message(txt) 67 | } 68 | 69 | S <- stats::predict(object, newdata = X, type = "terms") 70 | rownames(S) <- NULL 71 | 72 | # Baseline value 73 | b <- as.vector(attr(S, "constant")) 74 | if (is.null(b)) { 75 | b <- 0 76 | } 77 | 78 | # Which columns of X are used in each column of S? 79 | s_names <- colnames(S) 80 | cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z))) 81 | if (any(lengths(cols_used) > 1L)) { 82 | stop("The formula contains terms with multiple features (not supported).") 83 | } 84 | 85 | # Collapse all columns in S using the same column in X and rename accordingly 86 | mapping <- split( 87 | s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE 88 | ) 89 | S <- do.call( 90 | cbind, 91 | lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE)) 92 | ) 93 | 94 | structure( 95 | list( 96 | S = S, 97 | X = X, 98 | baseline = b, 99 | exact = TRUE, 100 | txt = txt, 101 | predictions = b + rowSums(S), 102 | algorithm = "additive_shap" 103 | ), 104 | class = "kernelshap" 105 | ) 106 | } 107 | -------------------------------------------------------------------------------- /R/kernelshap.R: -------------------------------------------------------------------------------- 1 | #' Kernel SHAP 2 | #' 3 | #' @description 4 | #' Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and 5 | #' Covert and Lee (2021), abbreviated by CL21. 6 | #' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding 7 | #' the selected background data. For larger \eqn{p}, an almost exact 8 | #' hybrid algorithm combining exact calculations and iterative sampling is used, 9 | #' see Details. 10 | #' 11 | #' Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP. 12 | #' Thus, for up to eight features, we recommend [permshap()]. For more features, 13 | #' [permshap()] is slow compared the optimized hybrid strategy of our Kernel SHAP 14 | #' implementation. 15 | #' 16 | #' @details 17 | #' The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: 18 | #' 19 | #' 1. A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} 20 | #' such that its sum follows the SHAP Kernel weight distribution 21 | #' (normalized to the range \eqn{\{1, \dots, p-1\}}). 22 | #' 2. For each \eqn{j} with \eqn{z_j = 1}, the \eqn{j}-th column of the 23 | #' original background data is replaced by the corresponding feature value \eqn{x_j} 24 | #' of the observation to be explained. 25 | #' 3. The average prediction \eqn{v_z} on the data of Step 2 is calculated, and the 26 | #' average prediction \eqn{v_0} on the background data is subtracted. 27 | #' 4. Steps 1 to 3 are repeated \eqn{m} times. This produces a binary \eqn{m \times p} 28 | #' matrix \eqn{Z} (each row equals one of the \eqn{z}) and a vector \eqn{v} of 29 | #' shifted predictions. 30 | #' 5. \eqn{v} is regressed onto \eqn{Z} under the constraint that the sum of the 31 | #' coefficients equals \eqn{v_1 - v_0}, where \eqn{v_1} is the prediction of the 32 | #' observation to be explained. The resulting coefficients are the Kernel SHAP values. 33 | #' 34 | #' This is repeated multiple times until convergence, see CL21 for details. 35 | #' 36 | #' A drawback of this strategy is that many (at least 75%) of the \eqn{z} vectors will 37 | #' have \eqn{\sum z \in \{1, p-1\}}, producing many duplicates. Similarly, at least 92% 38 | #' of the mass will be used for the \eqn{p(p+1)} possible vectors with 39 | #' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. 40 | #' This inefficiency can be fixed by a hybrid strategy, combining exact calculations 41 | #' with sampling. 42 | #' 43 | #' The hybrid algorithm has two steps: 44 | #' 1. Step 1 (exact part): There are \eqn{2p} different on-off vectors \eqn{z} with 45 | #' \eqn{\sum z \in \{1, p-1\}}, covering a large proportion of the Kernel SHAP 46 | #' distribution. The degree 1 hybrid will list those vectors and use them according 47 | #' to their weights in the upcoming calculations. Depending on \eqn{p}, we can also go 48 | #' a step further to a degree 2 hybrid by adding all \eqn{p(p-1)} vectors with 49 | #' \eqn{\sum z \in \{2, p-2\}} to the process etc. The necessary predictions are 50 | #' obtained along with other calculations similar to those described in CL21. 51 | #' 2. Step 2 (sampling part): The remaining weight is filled by sampling vectors z 52 | #' according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. 53 | #' Together with the results from Step 1 - correctly weighted - this now forms a 54 | #' complete iteration as in CL21. The difference is that most mass is covered by exact 55 | #' calculations. Afterwards, the algorithm iterates until convergence. 56 | #' The output of Step 1 is reused in every iteration, leading to an extremely 57 | #' efficient strategy. 58 | #' 59 | #' If \eqn{p} is sufficiently small, all possible \eqn{2^p-2} on-off vectors \eqn{z} can be 60 | #' evaluated. In this case, no sampling is required and the algorithm returns exact 61 | #' Kernel SHAP values with respect to the given background data. 62 | #' Since [kernelshap()] calculates predictions on data with \eqn{MN} rows 63 | #' (\eqn{N} is the background data size and \eqn{M} the number of \eqn{z} vectors), \eqn{p} 64 | #' should not be much higher than 10 for exact calculations. 65 | #' For similar reasons, degree 2 hybrids should not use \eqn{p} much larger than 40. 66 | #' 67 | #' @importFrom foreach %dopar% 68 | #' 69 | #' @param object Fitted model object. 70 | #' @param X \eqn{(n \times p)} matrix or `data.frame` with rows to be explained. 71 | #' The columns should only represent model features, not the response 72 | #' (but see `feature_names` on how to overrule this). 73 | #' @param bg_X Background data used to integrate out "switched off" features, 74 | #' often a subset of the training data (typically 50 to 500 rows). 75 | #' In cases with a natural "off" value (like MNIST digits), 76 | #' this can also be a single row with all values set to the off value. 77 | #' If no `bg_X` is passed (the default) and if `X` is sufficiently large, 78 | #' a random sample of `bg_n` rows from `X` serves as background data. 79 | #' @param pred_fun Prediction function of the form `function(object, X, ...)`, 80 | #' providing \eqn{K \ge 1} predictions per row. Its first argument 81 | #' represents the model `object`, its second argument a data structure like `X`. 82 | #' Additional (named) arguments are passed via `...`. 83 | #' The default, [stats::predict()], will work in most cases. 84 | #' @param feature_names Optional vector of column names in `X` used to calculate 85 | #' SHAP values. By default, this equals `colnames(X)`. Not supported if `X` 86 | #' is a matrix. 87 | #' @param bg_w Optional vector of case weights for each row of `bg_X`. 88 | #' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights. 89 | #' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`. 90 | #' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values 91 | #' with respect to the background data. In this case, the arguments `hybrid_degree`, 92 | #' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored. 93 | #' The default is `TRUE` up to eight features, and `FALSE` otherwise. 94 | #' @param hybrid_degree Integer controlling the exactness of the hybrid strategy. For 95 | #' \eqn{4 \le p \le 16}, the default is 2, otherwise it is 1. 96 | #' Ignored if `exact = TRUE`. 97 | #' - `0`: Pure sampling strategy not involving any exact part. It is strictly 98 | #' worse than the hybrid strategy and should therefore only be used for 99 | #' studying properties of the Kernel SHAP algorithm. 100 | #' - `1`: Uses all \eqn{2p} on-off vectors \eqn{z} with \eqn{\sum z \in \{1, p-1\}} 101 | #' for the exact part, which covers at least 75% of the mass of the Kernel weight 102 | #' distribution. The remaining mass is covered by random sampling. 103 | #' - `2`: Uses all \eqn{p(p+1)} on-off vectors \eqn{z} with 104 | #' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. This covers at least 92% of the mass of the 105 | #' Kernel weight distribution. The remaining mass is covered by sampling. 106 | #' Convergence usually happens in the minimal possible number of iterations of two. 107 | #' - `k>2`: Uses all on-off vectors with 108 | #' \eqn{\sum z \in \{1, \dots, k, p-k, \dots, p-1\}}. 109 | #' @param paired_sampling Logical flag indicating whether to do the sampling in a paired 110 | #' manner. This means that with every on-off vector \eqn{z}, also \eqn{1-z} is 111 | #' considered. CL21 shows its superiority compared to standard sampling, therefore the 112 | #' default (`TRUE`) should usually not be changed except for studying properties 113 | #' of Kernel SHAP algorithms. Ignored if `exact = TRUE`. 114 | #' @param m Even number of on-off vectors sampled during one iteration. 115 | #' The default is \eqn{2p}, except when `hybrid_degree == 0`. 116 | #' Then it is set to \eqn{8p}. Ignored if `exact = TRUE`. 117 | #' @param tol Tolerance determining when to stop. Following CL21, the algorithm keeps 118 | #' iterating until \eqn{\textrm{max}(\sigma_n)/(\textrm{max}(\beta_n) - \textrm{min}(\beta_n)) < \textrm{tol}}, 119 | #' where the \eqn{\beta_n} are the SHAP values of a given observation, 120 | #' and \eqn{\sigma_n} their standard errors. 121 | #' For multidimensional predictions, the criterion must be satisfied for each 122 | #' dimension separately. The stopping criterion uses the fact that standard errors 123 | #' and SHAP values are all on the same scale. Ignored if `exact = TRUE`. 124 | #' @param max_iter If the stopping criterion (see `tol`) is not reached after 125 | #' `max_iter` iterations, the algorithm stops. Ignored if `exact = TRUE`. 126 | #' @param parallel If `TRUE`, use parallel [foreach::foreach()] to loop over rows 127 | #' to be explained. Must register backend beforehand, e.g., via 'doFuture' package, 128 | #' see README for an example. Parallelization automatically disables the progress bar. 129 | #' @param parallel_args Named list of arguments passed to [foreach::foreach()]. 130 | #' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`. 131 | #' Example on Windows: if `object` is a GAM fitted with package 'mgcv', 132 | #' then one might need to set `parallel_args = list(.packages = "mgcv")`. 133 | #' @param verbose Set to `FALSE` to suppress messages and the progress bar. 134 | #' @param survival Should cumulative hazards ("chf", default) or survival 135 | #' probabilities ("prob") per time be predicted? Only in `ranger()` survival models. 136 | #' @param ... Additional arguments passed to `pred_fun(object, X, ...)`. 137 | #' @returns 138 | #' An object of class "kernelshap" with the following components: 139 | #' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has 140 | #' dimension \eqn{K > 1}, a list of \eqn{K} such matrices. 141 | #' - `X`: Same as input argument `X`. 142 | #' - `baseline`: Vector of length K representing the average prediction on the 143 | #' background data. 144 | #' - `bg_X`: The background data. 145 | #' - `bg_w`: The background case weights. 146 | #' - `SE`: Standard errors corresponding to `S` (and organized like `S`). 147 | #' - `n_iter`: Integer vector of length n providing the number of iterations 148 | #' per row of `X`. 149 | #' - `converged`: Logical vector of length n indicating convergence per row of `X`. 150 | #' - `m`: Integer providing the effective number of sampled on-off vectors used 151 | #' per iteration. 152 | #' - `m_exact`: Integer providing the effective number of exact on-off vectors used 153 | #' per iteration. 154 | #' - `prop_exact`: Proportion of the Kernel SHAP weight distribution covered by 155 | #' exact calculations. 156 | #' - `exact`: Logical flag indicating whether calculations are exact or not. 157 | #' - `txt`: Summary text. 158 | #' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`. 159 | #' - `algorithm`: "kernelshap". 160 | #' @references 161 | #' 1. Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model 162 | #' predictions. Proceedings of the 31st International Conference on Neural 163 | #' Information Processing Systems, 2017. 164 | #' 2. Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value 165 | #' Estimation Using Linear Regression. Proceedings of The 24th International 166 | #' Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021. 167 | #' @export 168 | #' @examples 169 | #' # MODEL ONE: Linear regression 170 | #' fit <- lm(Sepal.Length ~ ., data = iris) 171 | #' 172 | #' # Select rows to explain (only feature columns) 173 | #' X_explain <- iris[-1] 174 | #' 175 | #' # Calculate SHAP values 176 | #' s <- kernelshap(fit, X_explain) 177 | #' s 178 | #' 179 | #' # MODEL TWO: Multi-response linear regression 180 | #' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) 181 | #' s <- kernelshap(fit, iris[3:5]) 182 | #' s 183 | #' 184 | #' # Note 1: Feature columns can also be selected 'feature_names' 185 | #' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X 186 | #' s <- kernelshap( 187 | #' fit, 188 | #' iris[1:4, ], 189 | #' bg_X = iris, 190 | #' feature_names = c("Petal.Length", "Petal.Width", "Species") 191 | #' ) 192 | #' s 193 | kernelshap <- function(object, ...){ 194 | UseMethod("kernelshap") 195 | } 196 | 197 | #' @describeIn kernelshap Default Kernel SHAP method. 198 | #' @export 199 | kernelshap.default <- function( 200 | object, 201 | X, 202 | bg_X = NULL, 203 | pred_fun = stats::predict, 204 | feature_names = colnames(X), 205 | bg_w = NULL, 206 | bg_n = 200L, 207 | exact = length(feature_names) <= 8L, 208 | hybrid_degree = 1L + length(feature_names) %in% 4:16, 209 | paired_sampling = TRUE, 210 | m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)), 211 | tol = 0.005, 212 | max_iter = 100L, 213 | parallel = FALSE, 214 | parallel_args = NULL, 215 | verbose = TRUE, 216 | ... 217 | ) { 218 | p <- length(feature_names) 219 | basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun) 220 | stopifnot( 221 | exact %in% c(TRUE, FALSE), 222 | p == 1L || exact || hybrid_degree %in% 0:(p / 2), 223 | paired_sampling %in% c(TRUE, FALSE), 224 | "m must be even" = trunc(m / 2) == m / 2 225 | ) 226 | prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose) 227 | bg_X <- prep_bg$bg_X 228 | bg_w <- prep_bg$bg_w 229 | bg_n <- nrow(bg_X) 230 | n <- nrow(X) 231 | 232 | # Calculate v1 and v0 233 | bg_preds <- align_pred(pred_fun(object, bg_X, ...)) 234 | v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K 235 | v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K 236 | 237 | # For p = 1, exact Shapley values are returned 238 | if (p == 1L) { 239 | out <- case_p1( 240 | n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose 241 | ) 242 | return(out) 243 | } 244 | 245 | txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree) 246 | if (verbose) { 247 | message(txt) 248 | } 249 | 250 | # Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant 251 | # In what follows, predictions will never be applied directly to bg_X anymore 252 | if (!identical(colnames(bg_X), feature_names)) { 253 | bg_X <- bg_X[, feature_names, drop = FALSE] 254 | } 255 | 256 | # Precalculations that are identical for each row to be explained 257 | if (exact || hybrid_degree >= 1L) { 258 | if (exact) { 259 | precalc <- input_exact(p, feature_names = feature_names) 260 | } else { 261 | precalc <- input_partly_exact( 262 | p, deg = hybrid_degree, feature_names = feature_names 263 | ) 264 | } 265 | m_exact <- nrow(precalc[["Z"]]) 266 | prop_exact <- sum(precalc[["w"]]) 267 | precalc[["bg_X_exact"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)) 268 | } else { 269 | precalc <- list() 270 | m_exact <- 0L 271 | prop_exact <- 0 272 | } 273 | if (!exact) { 274 | precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m)) 275 | } 276 | 277 | if (max(m, m_exact) * bg_n > 2e5) { 278 | warning_burden(max(m, m_exact), bg_n = bg_n) 279 | } 280 | 281 | # Apply Kernel SHAP to each row of X 282 | if (isTRUE(parallel)) { 283 | parallel_args <- c(list(i = seq_len(n)), parallel_args) 284 | res <- do.call(foreach::foreach, parallel_args) %dopar% kernelshap_one( 285 | x = X[i, , drop = FALSE], 286 | v1 = v1[i, , drop = FALSE], 287 | object = object, 288 | pred_fun = pred_fun, 289 | feature_names = feature_names, 290 | bg_w = bg_w, 291 | exact = exact, 292 | deg = hybrid_degree, 293 | paired = paired_sampling, 294 | m = m, 295 | tol = tol, 296 | max_iter = max_iter, 297 | v0 = v0, 298 | precalc = precalc, 299 | ... 300 | ) 301 | } else { 302 | if (verbose && n >= 2L) { 303 | pb <- utils::txtProgressBar(max = n, style = 3) 304 | } 305 | res <- vector("list", n) 306 | for (i in seq_len(n)) { 307 | res[[i]] <- kernelshap_one( 308 | x = X[i, , drop = FALSE], 309 | v1 = v1[i, , drop = FALSE], 310 | object = object, 311 | pred_fun = pred_fun, 312 | feature_names = feature_names, 313 | bg_w = bg_w, 314 | exact = exact, 315 | deg = hybrid_degree, 316 | paired = paired_sampling, 317 | m = m, 318 | tol = tol, 319 | max_iter = max_iter, 320 | v0 = v0, 321 | precalc = precalc, 322 | ... 323 | ) 324 | if (verbose && n >= 2L) { 325 | utils::setTxtProgressBar(pb, i) 326 | } 327 | } 328 | } 329 | 330 | # Organize output 331 | converged <- vapply(res, `[[`, "converged", FUN.VALUE = logical(1L)) 332 | if (verbose && !all(converged)) { 333 | warning("\nNon-convergence for ", sum(!converged), " rows.") 334 | } 335 | 336 | if (verbose) { 337 | cat("\n") 338 | } 339 | 340 | out <- list( 341 | S = reorganize_list(lapply(res, `[[`, "beta")), 342 | X = X, 343 | baseline = as.vector(v0), 344 | bg_X = bg_X, 345 | bg_w = bg_w, 346 | SE = reorganize_list(lapply(res, `[[`, "sigma")), 347 | n_iter = vapply(res, `[[`, "n_iter", FUN.VALUE = integer(1L)), 348 | converged = converged, 349 | m = m, 350 | m_exact = m_exact, 351 | prop_exact = prop_exact, 352 | exact = exact || trunc(p / 2) == hybrid_degree, 353 | txt = txt, 354 | predictions = v1, 355 | algorithm = "kernelshap" 356 | ) 357 | class(out) <- "kernelshap" 358 | out 359 | } 360 | 361 | #' @describeIn kernelshap Kernel SHAP method for "ranger" models, see Readme for an example. 362 | #' @export 363 | kernelshap.ranger <- function( 364 | object, 365 | X, 366 | bg_X = NULL, 367 | pred_fun = NULL, 368 | feature_names = colnames(X), 369 | bg_w = NULL, 370 | bg_n = 200L, 371 | exact = length(feature_names) <= 8L, 372 | hybrid_degree = 1L + length(feature_names) %in% 4:16, 373 | paired_sampling = TRUE, 374 | m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)), 375 | tol = 0.005, 376 | max_iter = 100L, 377 | parallel = FALSE, 378 | parallel_args = NULL, 379 | verbose = TRUE, 380 | survival = c("chf", "prob"), 381 | ... 382 | ) { 383 | 384 | if (is.null(pred_fun)) { 385 | pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) 386 | } 387 | 388 | kernelshap.default( 389 | object = object, 390 | X = X, 391 | bg_X = bg_X, 392 | pred_fun = pred_fun, 393 | feature_names = feature_names, 394 | bg_w = bg_w, 395 | bg_n = bg_n, 396 | exact = exact, 397 | hybrid_degree = hybrid_degree, 398 | paired_sampling = paired_sampling, 399 | m = m, 400 | tol = tol, 401 | max_iter = max_iter, 402 | parallel = parallel, 403 | parallel_args = parallel_args, 404 | verbose = verbose, 405 | ... 406 | ) 407 | } 408 | 409 | -------------------------------------------------------------------------------- /R/methods.R: -------------------------------------------------------------------------------- 1 | #' Prints "kernelshap" Object 2 | #' 3 | #' @param x An object of class "kernelshap". 4 | #' @param n Maximum number of rows of SHAP values to print. 5 | #' @param ... Further arguments passed from other methods. 6 | #' @returns Invisibly, the input is returned. 7 | #' @export 8 | #' @examples 9 | #' fit <- lm(Sepal.Length ~ ., data = iris) 10 | #' s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[, -1]) 11 | #' s 12 | #' @seealso [kernelshap()] 13 | print.kernelshap <- function(x, n = 2L, ...) { 14 | cat("SHAP values of first observations:\n") 15 | print(head_list(getElement(x, "S"), n = n)) 16 | invisible(x) 17 | } 18 | 19 | #' Summarizes "kernelshap" Object 20 | #' 21 | #' @param object An object of class "kernelshap". 22 | #' @param compact Set to `TRUE` for a more compact summary. 23 | #' @param n Maximum number of rows of SHAP values etc. to print. 24 | #' @param ... Further arguments passed from other methods. 25 | #' @returns Invisibly, the input is returned. 26 | #' @export 27 | #' @examples 28 | #' fit <- lm(Sepal.Length ~ ., data = iris) 29 | #' s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[, -1]) 30 | #' summary(s) 31 | #' @seealso [kernelshap()] 32 | summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) { 33 | cat(getElement(object, "txt")) 34 | 35 | S <- getElement(object, "S") 36 | if (!is.list(S)) { 37 | n <- min(n, nrow(S)) 38 | cat(paste("\n - SHAP matrix of dim", nrow(S), "x", ncol(S))) 39 | } else { 40 | n <- min(n, nrow(S[[1L]])) 41 | cat( 42 | "\n -", length(S), "SHAP matrices of dim", nrow(S[[1L]]), "x", ncol(S[[1L]]) 43 | ) 44 | } 45 | cat("\n - baseline:", getElement(object, "baseline")) 46 | ex <- getElement(object, "exact") 47 | if (!ex) { 48 | cat( 49 | "\n - average number of iterations:", mean(getElement(object, "n_iter")), 50 | "\n - rows not converged:", sum(!getElement(object, "converged")), 51 | "\n - proportion exact:", getElement(object, "prop_exact"), 52 | "\n - m/iter:", getElement(object, "m") 53 | ) 54 | } 55 | m_exact <- getElement(object, "m_exact") 56 | if (!is.null(m_exact)) { 57 | cat("\n - m_exact:", m_exact) 58 | } 59 | if (!compact) { 60 | cat("\n\nSHAP values of first observations:\n") 61 | print(head_list(S, n = n)) 62 | if (!ex) { 63 | cat("\nCorresponding standard errors:\n") 64 | print(head_list(getElement(object, "SE"), n = n)) 65 | } 66 | } 67 | invisible(object) 68 | } 69 | 70 | #' Check for kernelshap 71 | #' 72 | #' Is object of class "kernelshap"? 73 | #' 74 | #' @param object An R object. 75 | #' @returns `TRUE` if `object` is of class "kernelshap", and `FALSE` otherwise. 76 | #' @export 77 | #' @examples 78 | #' fit <- lm(Sepal.Length ~ ., data = iris) 79 | #' s <- kernelshap(fit, iris[1:2, -1], bg_X = iris[, -1]) 80 | #' is.kernelshap(s) 81 | #' is.kernelshap("a") 82 | #' @seealso [kernelshap()] 83 | is.kernelshap <- function(object){ 84 | inherits(object, "kernelshap") 85 | } 86 | -------------------------------------------------------------------------------- /R/permshap.R: -------------------------------------------------------------------------------- 1 | #' Permutation SHAP 2 | #' 3 | #' Exact permutation SHAP algorithm with respect to a background dataset, 4 | #' see Strumbelj and Kononenko. The function works for up to 14 features. 5 | #' For more than eight features, we recommend [kernelshap()] due to its higher speed. 6 | #' 7 | #' @inheritParams kernelshap 8 | #' @returns 9 | #' An object of class "kernelshap" with the following components: 10 | #' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has 11 | #' dimension \eqn{K > 1}, a list of \eqn{K} such matrices. 12 | #' - `X`: Same as input argument `X`. 13 | #' - `baseline`: Vector of length K representing the average prediction on the 14 | #' background data. 15 | #' - `bg_X`: The background data. 16 | #' - `bg_w`: The background case weights. 17 | #' - `m_exact`: Integer providing the effective number of exact on-off vectors used. 18 | #' - `exact`: Logical flag indicating whether calculations are exact or not 19 | #' (currently always `TRUE`). 20 | #' - `txt`: Summary text. 21 | #' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`. 22 | #' - `algorithm`: "permshap". 23 | #' @references 24 | #' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual 25 | #' predictions with feature contributions. Knowledge and Information Systems 41, 2014. 26 | #' @export 27 | #' @examples 28 | #' # MODEL ONE: Linear regression 29 | #' fit <- lm(Sepal.Length ~ ., data = iris) 30 | #' 31 | #' # Select rows to explain (only feature columns) 32 | #' X_explain <- iris[-1] 33 | #' 34 | #' # Calculate SHAP values 35 | #' s <- permshap(fit, X_explain) 36 | #' s 37 | #' 38 | #' # MODEL TWO: Multi-response linear regression 39 | #' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) 40 | #' s <- permshap(fit, iris[3:5]) 41 | #' s 42 | #' 43 | #' # Note 1: Feature columns can also be selected 'feature_names' 44 | #' # Note 2: Especially when X is small, pass a sufficiently large background data bg_X 45 | #' s <- permshap( 46 | #' fit, 47 | #' iris[1:4, ], 48 | #' bg_X = iris, 49 | #' feature_names = c("Petal.Length", "Petal.Width", "Species") 50 | #' ) 51 | #' s 52 | permshap <- function(object, ...) { 53 | UseMethod("permshap") 54 | } 55 | 56 | #' @describeIn permshap Default permutation SHAP method. 57 | #' @export 58 | permshap.default <- function( 59 | object, 60 | X, 61 | bg_X = NULL, 62 | pred_fun = stats::predict, 63 | feature_names = colnames(X), 64 | bg_w = NULL, 65 | bg_n = 200L, 66 | parallel = FALSE, 67 | parallel_args = NULL, 68 | verbose = TRUE, 69 | ... 70 | ) { 71 | p <- length(feature_names) 72 | if (p <= 1L) { 73 | stop("Case p = 1 not implemented. Use kernelshap() instead.") 74 | } 75 | if (p > 14L) { 76 | stop("Permutation SHAP only supported for up to 14 features") 77 | } 78 | 79 | txt <- "Exact permutation SHAP" 80 | if (verbose) { 81 | message(txt) 82 | } 83 | 84 | basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun) 85 | prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose) 86 | bg_X <- prep_bg$bg_X 87 | bg_w <- prep_bg$bg_w 88 | bg_n <- nrow(bg_X) 89 | n <- nrow(X) 90 | 91 | # Baseline and predictions on explanation data 92 | bg_preds <- align_pred(pred_fun(object, bg_X, ...)) 93 | v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K 94 | v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K 95 | 96 | # Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant 97 | # Predictions will never be applied directly to bg_X anymore 98 | if (!identical(colnames(bg_X), feature_names)) { 99 | bg_X <- bg_X[, feature_names, drop = FALSE] 100 | } 101 | 102 | # Precalculations that are identical for each row to be explained 103 | Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE) 104 | m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row 105 | precalc <- list( 106 | Z = Z, 107 | Z_code = rowpaste(Z), 108 | bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)) 109 | ) 110 | 111 | if (m_exact * bg_n > 2e5) { 112 | warning_burden(m_exact, bg_n = bg_n) 113 | } 114 | 115 | # Apply permutation SHAP to each row of X 116 | if (isTRUE(parallel)) { 117 | parallel_args <- c(list(i = seq_len(n)), parallel_args) 118 | res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one( 119 | x = X[i, , drop = FALSE], 120 | v1 = v1[i, , drop = FALSE], 121 | object = object, 122 | pred_fun = pred_fun, 123 | bg_w = bg_w, 124 | v0 = v0, 125 | precalc = precalc, 126 | ... 127 | ) 128 | } else { 129 | if (verbose && n >= 2L) { 130 | pb <- utils::txtProgressBar(max = n, style = 3) 131 | } 132 | res <- vector("list", n) 133 | for (i in seq_len(n)) { 134 | res[[i]] <- permshap_one( 135 | x = X[i, , drop = FALSE], 136 | v1 = v1[i, , drop = FALSE], 137 | object = object, 138 | pred_fun = pred_fun, 139 | bg_w = bg_w, 140 | v0 = v0, 141 | precalc = precalc, 142 | ... 143 | ) 144 | if (verbose && n >= 2L) { 145 | utils::setTxtProgressBar(pb, i) 146 | } 147 | } 148 | } 149 | if (verbose) { 150 | cat("\n") 151 | } 152 | out <- list( 153 | S = reorganize_list(res), 154 | X = X, 155 | baseline = as.vector(v0), 156 | bg_X = bg_X, 157 | bg_w = bg_w, 158 | m_exact = m_exact, 159 | exact = TRUE, 160 | txt = txt, 161 | predictions = v1, 162 | algorithm = "permshap" 163 | ) 164 | class(out) <- "kernelshap" 165 | out 166 | } 167 | 168 | #' @describeIn permshap Permutation SHAP method for "ranger" models, see Readme for an example. 169 | #' @export 170 | permshap.ranger <- function( 171 | object, 172 | X, 173 | bg_X = NULL, 174 | pred_fun = NULL, 175 | feature_names = colnames(X), 176 | bg_w = NULL, 177 | bg_n = 200L, 178 | parallel = FALSE, 179 | parallel_args = NULL, 180 | verbose = TRUE, 181 | survival = c("chf", "prob"), 182 | ... 183 | ) { 184 | 185 | if (is.null(pred_fun)) { 186 | pred_fun <- create_ranger_pred_fun(object$treetype, survival = match.arg(survival)) 187 | } 188 | 189 | permshap.default( 190 | object = object, 191 | X = X, 192 | bg_X = bg_X, 193 | pred_fun = pred_fun, 194 | feature_names = feature_names, 195 | bg_w = bg_w, 196 | bg_n = bg_n, 197 | parallel = parallel, 198 | parallel_args = parallel_args, 199 | verbose = verbose, 200 | ... 201 | ) 202 | } 203 | 204 | -------------------------------------------------------------------------------- /R/pred_fun.R: -------------------------------------------------------------------------------- 1 | #' Predict Function for Ranger 2 | #' 3 | #' Returns prediction function for different modes of ranger. 4 | #' 5 | #' @noRd 6 | #' @keywords internal 7 | #' @param treetype The value of `fit$treetype` in a fitted ranger model. 8 | #' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time. 9 | #' 10 | #' @returns A function with signature f(model, newdata, ...). 11 | create_ranger_pred_fun <- function(treetype, survival = c("chf", "prob")) { 12 | survival <- match.arg(survival) 13 | 14 | if (treetype != "Survival") { 15 | pred_fun <- function(model, newdata, ...) { 16 | stats::predict(model, newdata, ...)$predictions 17 | } 18 | return(pred_fun) 19 | } 20 | 21 | if (survival == "prob") { 22 | survival <- "survival" 23 | } 24 | 25 | pred_fun <- function(model, newdata, ...) { 26 | pred <- stats::predict(model, newdata, ...) 27 | out <- pred[[survival]] 28 | colnames(out) <- paste0("t", pred$unique.death.times) 29 | return(out) 30 | } 31 | return(pred_fun) 32 | } 33 | 34 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | #' Fast Row Subsetting 2 | #' 3 | #' Internal function used to row-subset data.frames. 4 | #' Brings a massive speed-up for data.frames. All other classes (tibble, data.table, 5 | #' matrix) are subsetted in the usual way. 6 | #' 7 | #' @noRd 8 | #' @keywords internal 9 | #' 10 | #' @param x A matrix-like object. 11 | #' @param i Logical or integer vector of rows to pick. 12 | #' @returns Subsetted version of `x`. 13 | rep_rows <- function(x, i) { 14 | if (!(all(class(x) == "data.frame"))) { 15 | return(x[i, , drop = FALSE]) # matrix, tibble, data.table, ... 16 | } 17 | # data.frame 18 | out <- lapply(x, function(z) if (length(dim(z)) != 2L) z[i] else z[i, , drop = FALSE]) 19 | attr(out, "row.names") <- .set_row_names(length(i)) 20 | class(out) <- "data.frame" 21 | out 22 | } 23 | 24 | #' Weighted Version of colMeans() 25 | #' 26 | #' Internal function used to calculate column-wise weighted means. 27 | #' 28 | #' @noRd 29 | #' @keywords internal 30 | #' 31 | #' @param x A matrix-like object. 32 | #' @param w Optional case weights. 33 | #' @returns A (1 x ncol(x)) matrix of column means. 34 | wcolMeans <- function(x, w = NULL, ...) { 35 | x <- as.matrix(x) 36 | out <- if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w) 37 | t.default(out) 38 | } 39 | 40 | #' All on-off Vectors 41 | #' 42 | #' Internal function that creates matrix of all on-off vectors of length `p`. 43 | #' 44 | #' @noRd 45 | #' @keywords internal 46 | #' 47 | #' @param p Number of features. 48 | #' @param feature_names Feature names. 49 | #' @param keep_extremes Should extremes be kept? Defaults to `FALSE` (for kernelshap). 50 | #' @returns An integer matrix of all on-off vectors of length `p`. 51 | exact_Z <- function(p, feature_names, keep_extremes = FALSE) { 52 | Z <- as.matrix(do.call(expand.grid, replicate(p, 0:1, simplify = FALSE))) 53 | colnames(Z) <- feature_names 54 | if (keep_extremes) Z else Z[2:(nrow(Z) - 1L), , drop = FALSE] 55 | } 56 | 57 | #' Masker 58 | #' 59 | #' Internal function. 60 | #' For each on-off vector (rows in `Z`), the (weighted) average prediction is returned. 61 | #' 62 | #' @noRd 63 | #' @keywords internal 64 | #' 65 | #' @inheritParams kernelshap 66 | #' @param X Row to be explained stacked m*n_bg times. 67 | #' @param bg Background data stacked m times. 68 | #' @param Z A (m x p) matrix with on-off values. 69 | #' @param w A vector with case weights (of the same length as the unstacked 70 | #' background data). 71 | #' @returns A (m x K) matrix with vz values. 72 | get_vz <- function(X, bg, Z, object, pred_fun, w, ...) { 73 | m <- nrow(Z) 74 | not_Z <- !Z 75 | n_bg <- nrow(bg) / m # because bg was replicated m times 76 | 77 | # Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p) 78 | g <- rep_each(m, each = n_bg) 79 | not_Z <- not_Z[g, , drop = FALSE] 80 | 81 | if (is.matrix(X)) { 82 | # Remember that columns of X and bg are perfectly aligned in this case 83 | X[not_Z] <- bg[not_Z] 84 | } else { 85 | for (v in colnames(Z)) { 86 | s <- not_Z[, v] 87 | X[[v]][s] <- bg[[v]][s] 88 | } 89 | } 90 | preds <- align_pred(pred_fun(object, X, ...)) 91 | 92 | # Aggregate (distinguishing fast 1-dim case) 93 | if (ncol(preds) == 1L) { 94 | return(wrowmean_vector(preds, ngroups = m, w = w)) 95 | } 96 | if (is.null(w)) { 97 | return(rowsum(preds, group = g, reorder = FALSE) / n_bg) 98 | } 99 | rowsum(preds * w, group = g, reorder = FALSE) / sum(w) 100 | } 101 | 102 | #' Combine Matrices 103 | #' 104 | #' Binds list of matrices along new first axis. 105 | #' 106 | #' @noRd 107 | #' @keywords internal 108 | #' 109 | #' @param a List of n (p x K) matrices. 110 | #' @returns A (n x p x K) array. 111 | abind1 <- function(a) { 112 | out <- array( 113 | dim = c(length(a), dim(a[[1L]])), 114 | dimnames = c(list(NULL), dimnames(a[[1L]])) 115 | ) 116 | for (i in seq_along(a)) { 117 | out[i, , ] <- a[[i]] 118 | } 119 | out 120 | } 121 | 122 | #' Reorganize List 123 | #' 124 | #' Internal function that turns list of n (p x K) matrices into list of K (n x p) 125 | #' matrices. Reduce if K = 1. 126 | #' 127 | #' @noRd 128 | #' @keywords internal 129 | #' 130 | #' @param alist List of n (p x K) matrices. 131 | #' @returns List of K (n x p) matrices. 132 | reorganize_list <- function(alist) { 133 | if (!is.list(alist)) { 134 | stop("alist must be a list") 135 | } 136 | out <- asplit(abind1(alist), MARGIN = 3L) 137 | if (length(out) == 1L) { 138 | return(as.matrix(out[[1L]])) 139 | } 140 | lapply(out, as.matrix) 141 | } 142 | 143 | #' Aligns Predictions 144 | #' 145 | #' Turns predictions into matrix. 146 | #' 147 | #' @noRd 148 | #' @keywords internal 149 | #' 150 | #' @param x Object representing model predictions. 151 | #' @returns Like `x`, but converted to matrix. 152 | align_pred <- function(x) { 153 | if (is.data.frame(x) && ncol(x) == 1L) { 154 | x <- x[[1L]] 155 | } 156 | if (!is.matrix(x)) { 157 | x <- as.matrix(x) 158 | } 159 | if (!is.numeric(x) && !is.logical(x)) { 160 | stop("Predictions must be numeric!") 161 | } 162 | return(x) 163 | } 164 | 165 | #' Head of List Elements 166 | #' 167 | #' Internal function that returns the top n rows of each element in the input list. 168 | #' 169 | #' @noRd 170 | #' @keywords internal 171 | #' 172 | #' @param x A list or a matrix-like. 173 | #' @param n Number of rows to show. 174 | #' @returns List of first rows of each element in the input. 175 | head_list <- function(x, n = 6L) { 176 | if (!is.list(x)) utils::head(x, n) else lapply(x, utils::head, n) 177 | } 178 | 179 | # Summarize details about the chosen algorithm (exact, hybrid, sampling) 180 | summarize_strategy <- function(p, exact, deg) { 181 | if (exact || trunc(p / 2) == deg) { 182 | txt <- "Exact Kernel SHAP values" 183 | if (!exact) { 184 | txt <- paste(txt, "by the hybrid approach") 185 | } 186 | return(txt) 187 | } 188 | if (deg == 0L) { 189 | return("Kernel SHAP values by iterative sampling") 190 | } 191 | paste("Kernel SHAP values by the hybrid strategy of degree", deg) 192 | } 193 | 194 | # Case p = 1 returns exact Shapley values 195 | case_p1 <- function(n, feature_names, v0, v1, X, verbose) { 196 | txt <- "Exact Shapley values (p = 1)" 197 | if (verbose) { 198 | message(txt) 199 | } 200 | S <- v1 - v0[rep(1L, n), , drop = FALSE] # (n x K) 201 | SE <- matrix(numeric(n), dimnames = list(NULL, feature_names)) # (n x 1) 202 | if (ncol(v1) > 1L) { 203 | SE <- replicate(ncol(v1), SE, simplify = FALSE) 204 | S <- lapply( 205 | asplit(S, MARGIN = 2L), function(M) 206 | as.matrix(M, dimnames = list(NULL, feature_names)) 207 | ) 208 | } else { 209 | colnames(S) <- feature_names 210 | } 211 | out <- list( 212 | S = S, 213 | X = X, 214 | baseline = as.vector(v0), 215 | bg_X = NULL, 216 | bg_w = NULL, 217 | SE = SE, 218 | n_iter = integer(n), 219 | converged = rep(TRUE, n), 220 | m = 0L, 221 | m_exact = 0L, 222 | prop_exact = 1, 223 | exact = TRUE, 224 | txt = txt, 225 | predictions = v1, 226 | algorithm = "kernelshap" 227 | ) 228 | class(out) <- "kernelshap" 229 | out 230 | } 231 | 232 | #' Fast Index Generation (from {hstats}) 233 | #' 234 | #' For not too small m, much faster than `rep(seq_len(m), each = each)`. 235 | #' 236 | #' @noRd 237 | #' @keywords internal 238 | #' 239 | #' @param m Integer. See `each`. 240 | #' @param each Integer. How many times should each value in `1:m` be repeated? 241 | #' @returns Like `x`, but converted to matrix. 242 | #' @examples 243 | #' rep_each(10, 2) 244 | #' rep(1:10, each = 2) # Dito 245 | rep_each <- function(m, each) { 246 | out <- .col(dim = c(each, m)) 247 | dim(out) <- NULL 248 | out 249 | } 250 | 251 | #' Grouped Means for Single-Column Matrices (adapted from {hstats}) 252 | #' 253 | #' Grouped means for matrix with single column over fixed-length groups. 254 | #' 255 | #' @noRd 256 | #' @keywords internal 257 | #' 258 | #' @param x Matrix with one column. 259 | #' @param ngroups Number of subsequent, equals sized groups. 260 | #' @param w Optional vector of case weights of length `NROW(x) / ngroups`. 261 | #' @returns Matrix with one column. 262 | wrowmean_vector <- function(x, ngroups = 1L, w = NULL) { 263 | if (ncol(x) != 1L) { 264 | stop("x must have a single column") 265 | } 266 | nm <- colnames(x) 267 | dim(x) <- c(length(x) %/% ngroups, ngroups) 268 | out <- if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w) 269 | dim(out) <- c(ngroups, 1L) 270 | if (!is.null(nm)) { 271 | colnames(out) <- nm 272 | } 273 | out 274 | } 275 | 276 | #' Basic Input Checks 277 | #' 278 | #' @noRd 279 | #' @keywords internal 280 | #' 281 | #' @inheritParams kernelshap 282 | #' 283 | #' @returns TRUE or an error 284 | basic_checks <- function(X, feature_names, pred_fun) { 285 | stopifnot( 286 | is.matrix(X) || is.data.frame(X), 287 | dim(X) >= 1L, 288 | length(feature_names) >= 1L, 289 | all(feature_names %in% colnames(X)), 290 | "If X is a matrix, feature_names must equal colnames(X)" = 291 | !is.matrix(X) || identical(colnames(X), feature_names), 292 | is.function(pred_fun) 293 | ) 294 | TRUE 295 | } 296 | 297 | #' Prepare Background Data 298 | #' 299 | #' @noRd 300 | #' @keywords internal 301 | #' 302 | #' @inheritParams kernelshap 303 | #' 304 | #' @returns List with bg_X and bg_w. 305 | prepare_bg <- function(X, bg_X, bg_n, bg_w, verbose) { 306 | n <- nrow(X) 307 | if (is.null(bg_X)) { 308 | if (n <= bg_n) { # No subsampling required 309 | if (n < min(20L, bg_n)) { 310 | stop("X is too small to act as background data. Please specify 'bg_X'.") 311 | } 312 | if (n < min(50L, bg_n)) { 313 | warning("X is quite small to act as background data. Consider specifying a larger 'bg_X'.") 314 | } 315 | bg_X <- X 316 | } else { # Subsampling 317 | if (verbose) { 318 | message("Sampling ", bg_n, " rows from X as background data.") 319 | } 320 | ix <- sample(n, bg_n) 321 | bg_X <- X[ix, , drop = FALSE] 322 | if (!is.null(bg_w)) { 323 | stopifnot(length(bg_w) == n) 324 | bg_w <- bg_w[ix] 325 | } 326 | } 327 | } else { 328 | stopifnot( 329 | is.matrix(bg_X) || is.data.frame(bg_X), 330 | is.matrix(X) == is.matrix(bg_X), 331 | nrow(bg_X) >= 1L, 332 | all(colnames(X) %in% colnames(bg_X)) 333 | ) 334 | bg_X <- bg_X[, colnames(X), drop = FALSE] 335 | } 336 | 337 | if (!is.null(bg_w)) { 338 | bg_w <- prep_w(bg_w, bg_n = nrow(bg_X)) 339 | } 340 | 341 | return(list(bg_X = bg_X, bg_w = bg_w)) 342 | } 343 | 344 | #' Warning on Slow Computations 345 | #' 346 | #' @noRd 347 | #' @keywords internal 348 | #' 349 | #' @param m Number of on-off vectors. 350 | #' @param bg_n Number of rows in the background data. 351 | #' 352 | #' @returns TRUE. 353 | warning_burden <- function(m, bg_n) { 354 | warning("\nPredictions on large data sets with ", m, "x", bg_n, 355 | " observations are being done.\n", 356 | "Consider reducing the computational burden (e.g. use smaller X_bg)") 357 | TRUE 358 | } 359 | 360 | #' Prepare Case Weights 361 | #' 362 | #' @noRd 363 | #' @keywords internal 364 | #' 365 | #' @param w Vector of case weights. 366 | #' @param bg_n Number of rows in the background data. 367 | #' 368 | #' @returns TRUE or an error. 369 | prep_w <- function(w, bg_n) { 370 | stopifnot( 371 | length(w) == bg_n, 372 | all(w >= 0), 373 | !all(w == 0) 374 | ) 375 | if (!is.double(w)) as.double(w) else w 376 | } 377 | 378 | -------------------------------------------------------------------------------- /R/utils_kernelshap.R: -------------------------------------------------------------------------------- 1 | # Kernel SHAP algorithm for a single row x 2 | # If exact, a single call to predict() is necessary. 3 | # If sampling is involved, we need at least two additional calls to predict(). 4 | kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact, deg, 5 | paired, m, tol, max_iter, v0, precalc, ...) { 6 | p <- length(feature_names) 7 | 8 | # Calculate A_exact and b_exact 9 | if (exact || deg >= 1L) { 10 | A_exact <- precalc[["A"]] # (p x p) 11 | bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p) 12 | Z <- precalc[["Z"]] # (m_ex x p) 13 | m_exact <- nrow(Z) 14 | v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K) 15 | 16 | # Most expensive part 17 | vz <- get_vz( # (m_ex x K) 18 | X = rep_rows(x, rep.int(1L, nrow(bg_X_exact))), # (m_ex*n_bg x p) 19 | bg = bg_X_exact, # (m_ex*n_bg x p) 20 | Z = Z, # (m_ex x p) 21 | object = object, 22 | pred_fun = pred_fun, 23 | w = bg_w, 24 | ... 25 | ) 26 | # Note: w is correctly replicated along columns of (vz - v0_m_exact) 27 | b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K) 28 | 29 | # Some of the hybrid cases are exact as well 30 | if (exact || trunc(p / 2) == deg) { 31 | beta <- solver(A_exact, b_exact, constraint = v1 - v0) # (p x K) 32 | return(list(beta = beta, sigma = 0 * beta, n_iter = 1L, converged = TRUE)) 33 | } 34 | } 35 | 36 | # Iterative sampling part, always using A_exact and b_exact to fill up the weights 37 | bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p) 38 | X <- rep_rows(x, rep.int(1L, nrow(bg_X_m))) # (m*n_bg x p) 39 | v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K) 40 | 41 | est_m = list() 42 | converged <- FALSE 43 | n_iter <- 0L 44 | A_sum <- matrix( # (p x p) 45 | 0, nrow = p, ncol = p, dimnames = list(feature_names, feature_names) 46 | ) 47 | b_sum <- matrix( # (p x K) 48 | 0, nrow = p, ncol = ncol(v0), dimnames = list(feature_names, colnames(v1)) 49 | ) 50 | if (deg == 0L) { 51 | A_exact <- A_sum 52 | b_exact <- b_sum 53 | } 54 | 55 | while(!isTRUE(converged) && n_iter < max_iter) { 56 | n_iter <- n_iter + 1L 57 | input <- input_sampling( 58 | p = p, m = m, deg = deg, paired = paired, feature_names = feature_names 59 | ) 60 | Z <- input[["Z"]] 61 | 62 | # Expensive # (m x K) 63 | vz <- get_vz( 64 | X = X, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ... 65 | ) 66 | 67 | # The sum of weights of A_exact and input[["A"]] is 1, same for b 68 | A_temp <- A_exact + input[["A"]] # (p x p) 69 | b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K) 70 | A_sum <- A_sum + A_temp # (p x p) 71 | b_sum <- b_sum + b_temp # (p x K) 72 | 73 | # Least-squares with constraint that beta_1 + ... + beta_p = v_1 - v_0. 74 | # The additional constraint beta_0 = v_0 is dealt via offset 75 | est_m[[n_iter]] <- solver(A_temp, b_temp, constraint = v1 - v0) # (p x K) 76 | 77 | # Covariance calculation would fail in the first iteration 78 | if (n_iter >= 2L) { 79 | beta_n <- solver(A_sum / n_iter, b_sum / n_iter, constraint = v1 - v0) # (p x K) 80 | sigma_n <- get_sigma(est_m, iter = n_iter) # (p x K) 81 | converged <- all(conv_crit(sigma_n, beta_n) < tol) 82 | } 83 | } 84 | list(beta = beta_n, sigma = sigma_n, n_iter = n_iter, converged = converged) 85 | } 86 | 87 | # Regression coefficients given sum(beta) = constraint 88 | # A: (p x p), b: (p x k), constraint: (1 x K) 89 | solver <- function(A, b, constraint) { 90 | p <- ncol(A) 91 | Ainv <- MASS::ginv(A) 92 | dimnames(Ainv) <- dimnames(A) 93 | s <- (matrix(colSums(Ainv %*% b), nrow = 1L) - constraint) / sum(Ainv) # (1 x K) 94 | Ainv %*% (b - s[rep.int(1L, p), , drop = FALSE]) # (p x K) 95 | } 96 | 97 | # Draw m binary vectors z of length p with sum(z) distributed according 98 | # to Kernel SHAP weights -> (m x p) matrix. 99 | # The argument S can be used to restrict the range of sum(z). 100 | sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) { 101 | # First draw s = sum(z) according to Kernel weights (renormalized to sum 1) 102 | probs <- kernel_weights(p, S = S) 103 | N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)] 104 | 105 | # Then, conditional on that number, set random positions of z to 1 106 | # Original, unvectorized code 107 | # out <- vapply( 108 | # N, 109 | # function(z) {out <- numeric(p); out[sample(1:p, z)] <- 1; out}, 110 | # FUN.VALUE = numeric(p) 111 | # ) 112 | # t(out) 113 | 114 | # Vectorized by Mathias Ambuehl 115 | out <- rep(rep.int(0:1, m), as.vector(rbind(p - N, N))) 116 | dim(out) <- c(p, m) 117 | ord <- order(col(out), sample.int(m * p)) 118 | out[] <- out[ord] 119 | rownames(out) <- feature_names 120 | t(out) 121 | } 122 | 123 | # Calculate standard error from list of m estimates 124 | get_sigma <- function(est, iter) { 125 | apply(abind1(est), 3L, FUN = function(Y) sqrt(diag(stats::cov(Y)) / iter)) 126 | } 127 | 128 | # Convergence criterion 129 | conv_crit <- function(sig, bet) { 130 | if (any(dim(sig) != dim(bet))) { 131 | stop("sig must have same dimension as bet") 132 | } 133 | apply(sig, 2L, FUN = max) / apply(bet, 2L, FUN = function(z) diff(range(z))) 134 | } 135 | 136 | # Provides random input for SHAP sampling: 137 | # - Z: Matrix with m on-off vectors z with sum(z) following Kernel weight distribution. 138 | # - w: Vector (1/m, 1/m, ...) of length m (if pure sampling) 139 | # - A: Matrix A = Z'wZ 140 | # The weights are constant (Kernel weights have been used to draw the z vectors). 141 | # 142 | # If deg > 0, vectors z with sum(z) restricted to [deg+1, p-deg-1] are sampled. 143 | # This case is used in combination with input_partly_hybrid(). Consequently, sum(w) < 1. 144 | input_sampling <- function(p, m, deg, paired, feature_names) { 145 | if (p < 2L * deg + 2L) { 146 | stop("p must be >=2*deg + 2") 147 | } 148 | S <- (deg + 1L):(p - deg - 1L) 149 | Z <- sample_Z( 150 | p = p, m = if (paired) m / 2 else m, feature_names = feature_names, S = S 151 | ) 152 | if (paired) { 153 | Z <- rbind(Z, 1 - Z) 154 | } 155 | w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)]) 156 | w <- w_total / m 157 | list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w) 158 | } 159 | 160 | # Functions required only for handling (partly) exact cases 161 | 162 | # Provides fixed input for the exact case: 163 | # - Z: Matrix with all 2^p-2 on-off vectors z 164 | # - w: Vector with row weights of Z ensuring that the distribution of sum(z) matches 165 | # the SHAP kernel distribution 166 | # - A: Exact matrix A = Z'wZ 167 | input_exact <- function(p, feature_names) { 168 | Z <- exact_Z(p, feature_names = feature_names, keep_extremes = FALSE) 169 | # Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j 170 | w <- kernel_weights(p) / choose(p, 1:(p - 1L)) 171 | list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names)) 172 | } 173 | 174 | #' Exact Matrix A 175 | #' 176 | #' Internal function that calculates exact A. 177 | #' Notice the difference to the off-diagnonals in the Supplement of 178 | #' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula, 179 | #' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22 180 | #' 181 | #' @noRd 182 | #' @keywords internal 183 | #' 184 | #' @param p Number of features. 185 | #' @param feature_names Feature names. 186 | #' @returns A (p x p) matrix. 187 | exact_A <- function(p, feature_names) { 188 | S <- 1:(p - 1L) 189 | c_pr <- S * (S - 1) / p / (p - 1) 190 | off_diag <- sum(kernel_weights(p) * c_pr) 191 | A <- matrix( 192 | off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names) 193 | ) 194 | diag(A) <- 0.5 195 | A 196 | } 197 | 198 | # List all length p vectors z with sum(z) in {k, p - k} 199 | partly_exact_Z <- function(p, k, feature_names) { 200 | if (k < 1L) { 201 | stop("k must be at least 1") 202 | } 203 | if (p < 2L * k) { 204 | stop("p must be >=2*k") 205 | } 206 | if (k == 1L) { 207 | Z <- diag(p) 208 | } else { 209 | Z <- t( 210 | utils::combn(seq_len(p), k, FUN = function(z) {x <- numeric(p); x[z] <- 1; x}) 211 | ) 212 | } 213 | if (p != 2L * k) { 214 | Z <- rbind(Z, 1 - Z) 215 | } 216 | colnames(Z) <- feature_names 217 | Z 218 | } 219 | 220 | # Create Z, w, A for vectors z with sum(z) in {k, p-k} for k in {1, ..., deg}. 221 | # The total weights do not sum to one, except in the special (exact) case deg=p-deg. 222 | # (The remaining weight will be added via input_sampling(p, deg=deg)). 223 | # Note that for a given k, the weights are constant. 224 | input_partly_exact <- function(p, deg, feature_names) { 225 | if (deg < 1L) { 226 | stop("deg must be at least 1") 227 | } 228 | if (p < 2L * deg) { 229 | stop("p must be >=2*deg") 230 | } 231 | 232 | kw <- kernel_weights(p) 233 | Z <- w <- vector("list", deg) 234 | 235 | for (k in seq_len(deg)) { 236 | Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names) 237 | n <- nrow(Z[[k]]) 238 | w_tot <- kw[k] * (2 - (p == 2L * k)) 239 | w[[k]] <- rep.int(w_tot / n, n) 240 | } 241 | w <- unlist(w, recursive = FALSE, use.names = FALSE) 242 | Z <- do.call(rbind, Z) 243 | 244 | list(Z = Z, w = w, A = crossprod(Z, w * Z)) 245 | } 246 | 247 | # Kernel weights normalized to a non-empty subset S of {1, ..., p-1} 248 | kernel_weights <- function(p, S = seq_len(p - 1L)) { 249 | probs <- (p - 1L) / (choose(p, S) * S * (p - S)) 250 | probs / sum(probs) 251 | } 252 | -------------------------------------------------------------------------------- /R/utils_permshap.R: -------------------------------------------------------------------------------- 1 | #' Shapley Weights 2 | #' 3 | #' Weights used in Shapley's formula. Vectorized over `p` and/or `ell`. 4 | #' 5 | #' @noRd 6 | #' @keywords internal 7 | #' 8 | #' @param p Number of features. 9 | #' @param ell Size of subset (i.e., sum of on-off vector z). 10 | #' @returns Shapley weights. 11 | shapley_weights <- function(p, ell) { 12 | 1 / choose(p, ell) / (p - ell) 13 | } 14 | 15 | #' SHAP values for one row 16 | #' 17 | #' Calculates permutation SHAP values for a single row. 18 | #' 19 | #' @noRd 20 | #' @keywords internal 21 | #' 22 | #' @inheritParams permshap 23 | #' @param v1 Prediction of `x`. 24 | #' @param v0 Average prediction on background data. 25 | #' @param x A single row to be explained. 26 | #' @param precalc A list with precalculated values that are identical for all rows. 27 | #' @return A (p x K) matrix of SHAP values. 28 | permshap_one <- function(x, v1, object, pred_fun, bg_w, v0, precalc, ...) { 29 | Z <- precalc[["Z"]] # ((m_ex+2) x K) 30 | vz <- get_vz( # (m_ex x K) 31 | X = rep_rows(x, rep.int(1L, times = nrow(precalc[["bg_X_rep"]]))), # (m_ex*n_bg x p) 32 | bg = precalc[["bg_X_rep"]], # (m_ex*n_bg x p) 33 | Z = Z[2:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p) 34 | object = object, 35 | pred_fun = pred_fun, 36 | w = bg_w, 37 | ... 38 | ) 39 | vz <- rbind(v0, vz, v1) # we add the cheaply calculated v0 and v1 40 | rownames(vz) <- precalc[["Z_code"]] 41 | shapley_formula(Z, vz = vz) 42 | } 43 | 44 | #' Shapley's formula 45 | #' 46 | #' Evaluates Shapley's formula for each feature. 47 | #' 48 | #' @noRd 49 | #' @keywords internal 50 | #' 51 | #' @param Z Matrix of on-off row vectors. 52 | #' @param vz Named vector of vz values. 53 | #' @returns SHAP values organized as (p x K) matrix. 54 | shapley_formula <- function(Z, vz) { 55 | p <- ncol(Z) 56 | out <- matrix(nrow = p, ncol = ncol(vz), dimnames = list(colnames(Z), colnames(vz))) 57 | for (j in seq_len(p)) { 58 | s1 <- Z[, j] == 1L 59 | vz1 <- vz[s1, , drop = FALSE] 60 | L <- rowSums(Z[s1, -j, drop = FALSE]) # how many players are playing with j? 61 | s0 <- rownames(vz1) 62 | substr(s0, j, j) <- "0" 63 | vz0 <- vz[s0, , drop = FALSE] 64 | w <- shapley_weights(p, L) 65 | out[j, ] <- wcolMeans(vz1 - vz0, w = w) 66 | } 67 | out 68 | } 69 | 70 | #' Rowwise Paste 71 | #' 72 | #' Fast version of `apply(Z, 1L, FUN = paste0, collapse = "")`. 73 | #' 74 | #' @noRd 75 | #' @keywords internal 76 | #' 77 | #' @param Z A (n x p) matrix. 78 | #' @returns A length n vector. 79 | rowpaste <- function(Z) { 80 | do.call(paste0, asplit(Z, 2L)) 81 | } 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kernelshap 2 | 3 | 4 | 5 | [![R-CMD-check](https://github.com/ModelOriented/kernelshap/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/ModelOriented/kernelshap/actions/workflows/R-CMD-check.yaml) 6 | [![Codecov test coverage](https://codecov.io/gh/ModelOriented/kernelshap/graph/badge.svg)](https://app.codecov.io/gh/ModelOriented/kernelshap) 7 | [![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/kernelshap)](https://cran.r-project.org/package=kernelshap) 8 | 9 | [![](https://cranlogs.r-pkg.org/badges/kernelshap)](https://cran.r-project.org/package=kernelshap) 10 | [![](https://cranlogs.r-pkg.org/badges/grand-total/kernelshap?color=orange)](https://cran.r-project.org/package=kernelshap) 11 | 12 | 13 | 14 | ## Overview 15 | 16 | The package contains three functions to crunch SHAP values: 17 | 18 | - **`permshap()`**: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features. 19 | - **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features. 20 | - **`additive_shap()`**: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible. 21 | 22 | To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values. 23 | 24 | **Remarks to `permshap()` and `kernelshap()`** 25 | 26 | - Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`. 27 | - Exact Kernel SHAP is an approximation to exact permutation SHAP. Since exact calculations are usually sufficiently fast for up to eight features, we recommend `permshap()` in this case. With more features, `kernelshap()` switches to a comparably fast, almost exact algorithm. That is why we recommend `kernelshap()` in this case. 28 | - For models with interactions of order up to two, SHAP values of exact permutation SHAP and exact Kernel SHAP agree. 29 | - `permshap()` and `kernelshap()` give the same results as `additive_shap` as long as the full training data would be used as background data. 30 | 31 | ## Installation 32 | 33 | ```r 34 | # From CRAN 35 | install.packages("kernelshap") 36 | 37 | # Or the development version: 38 | devtools::install_github("ModelOriented/kernelshap") 39 | ``` 40 | 41 | ## Basic Usage 42 | 43 | Let's model diamond prices with a random forest. As an alternative, you could use the {treeshap} package in this situation. 44 | 45 | ```r 46 | library(kernelshap) 47 | library(ggplot2) 48 | library(ranger) 49 | library(shapviz) 50 | 51 | diamonds <- transform( 52 | diamonds, 53 | log_price = log(price), 54 | log_carat = log(carat) 55 | ) 56 | 57 | xvars <- c("log_carat", "clarity", "color", "cut") 58 | 59 | fit <- ranger( 60 | log_price ~ log_carat + clarity + color + cut, 61 | data = diamonds, 62 | num.trees = 100, 63 | seed = 20 64 | ) 65 | fit # OOB R-squared 0.989 66 | 67 | # 1) Sample rows to be explained 68 | set.seed(10) 69 | X <- diamonds[sample(nrow(diamonds), 1000), xvars] 70 | 71 | # 2) Optional: Select background data. If unspecified, 200 rows from X are used 72 | bg_X <- diamonds[sample(nrow(diamonds), 200), ] 73 | 74 | # 3) Crunch SHAP values (22 seconds) 75 | # Note: Since the number of features is small, we use permshap() 76 | system.time( 77 | ps <- permshap(fit, X, bg_X = bg_X) 78 | ) 79 | ps 80 | 81 | # SHAP values of first observations: 82 | log_carat clarity color cut 83 | [1,] 1.1913247 0.09005467 -0.13430720 0.000682593 84 | [2,] -0.4931989 -0.11724773 0.09868921 0.028563613 85 | 86 | # Kernel SHAP gives almost the same: 87 | system.time( # 22 s 88 | ks <- kernelshap(fit, X, bg_X = bg_X) 89 | ) 90 | ks 91 | # log_carat clarity color cut 92 | # [1,] 1.1911791 0.0900462 -0.13531648 0.001845958 93 | # [2,] -0.4927482 -0.1168517 0.09815062 0.028255442 94 | 95 | # 4) Analyze with {shapviz} 96 | ps <- shapviz(ps) 97 | sv_importance(ps) 98 | sv_dependence(ps, xvars) 99 | ``` 100 | 101 | ![](man/figures/README-rf-imp.svg) 102 | 103 | ![](man/figures/README-rf-dep.svg) 104 | 105 | ## More Examples 106 | 107 | {kernelshap} can deal with almost any situation. We will show some of the flexibility here. The first two examples require you to run at least up to Step 2 of the "Basic Usage" code. 108 | 109 | ### Parallel computing 110 | 111 | Parallel computing for `permshap()` and `kernelshap()` is supported via {foreach}. Note that this does not work for all models. 112 | 113 | On Windows, sometimes not all packages or global objects are passed to the parallel sessions. Often, this can be fixed via `parallel_args`, see this example: 114 | 115 | ```r 116 | library(doFuture) 117 | library(mgcv) 118 | 119 | registerDoFuture() 120 | plan(multisession, workers = 4) # Windows 121 | # plan(multicore, workers = 4) # Linux, macOS, Solaris 122 | 123 | # GAM with interactions - we cannot use additive_shap() 124 | fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds) 125 | 126 | system.time( # 4 seconds in parallel 127 | ps <- permshap( 128 | fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv") 129 | ) 130 | ) 131 | ps 132 | 133 | # SHAP values of first observations: 134 | # log_carat clarity color cut 135 | # [1,] 1.26801 0.1023518 -0.09223291 0.004512402 136 | # [2,] -0.51546 -0.1174766 0.11122775 0.030243973 137 | 138 | # Because there are no interactions of order above 2, Kernel SHAP gives the same: 139 | system.time( # 13 s non-parallel 140 | ks <- kernelshap(fit, X, bg_X = bg_X) 141 | ) 142 | all.equal(ps$S, ks$S) 143 | # [1] TRUE 144 | 145 | # Now the usual plots: 146 | sv <- shapviz(ps) 147 | sv_importance(sv, kind = "bee") 148 | sv_dependence(sv, xvars) 149 | ``` 150 | 151 | ![](man/figures/README-gam-imp.svg) 152 | 153 | ![](man/figures/README-gam-dep.svg) 154 | 155 | ### Taylored predict() 156 | 157 | In this {keras} example, we show how to use a tailored `predict()` function that complies with 158 | 159 | - the Keras API, 160 | - uses sufficiently large batches, and 161 | - turns off the Keras progress bar. 162 | 163 | (The results are not fully reproducible.) 164 | 165 | ```r 166 | library(keras) 167 | 168 | nn <- keras_model_sequential() 169 | nn |> 170 | layer_dense(units = 30, activation = "relu", input_shape = 4) |> 171 | layer_dense(units = 15, activation = "relu") |> 172 | layer_dense(units = 1) 173 | 174 | nn |> 175 | compile(optimizer = optimizer_adam(0.001), loss = "mse") 176 | 177 | cb <- list( 178 | callback_early_stopping(patience = 20), 179 | callback_reduce_lr_on_plateau(patience = 5) 180 | ) 181 | 182 | nn |> 183 | fit( 184 | x = data.matrix(diamonds[xvars]), 185 | y = diamonds$log_price, 186 | epochs = 100, 187 | batch_size = 400, 188 | validation_split = 0.2, 189 | callbacks = cb 190 | ) 191 | 192 | pred_fun <- function(mod, X) 193 | predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE, workers = 4) 194 | 195 | system.time( # 50 s 196 | ps <- permshap(nn, X, bg_X = bg_X, pred_fun = pred_fun) 197 | ) 198 | 199 | ps <- shapviz(ps) 200 | sv_importance(ps, show_numbers = TRUE) 201 | sv_dependence(ps, xvars) 202 | ``` 203 | 204 | ![](man/figures/README-nn-imp.svg) 205 | 206 | ![](man/figures/README-nn-dep.svg) 207 | 208 | ### Additive SHAP 209 | 210 | The additive explainer extracts the additive contribution of each feature from a model of suitable class. 211 | 212 | ```r 213 | fit <- lm(log(price) ~ log(carat) + color + clarity + cut, data = diamonds) 214 | shap_values <- additive_shap(fit, diamonds) |> 215 | shapviz() 216 | sv_importance(shap_values) 217 | sv_dependence(shap_values, v = "carat", color_var = NULL) 218 | ``` 219 | 220 | ### Multi-output models 221 | 222 | {kernelshap} supports multivariate predictions like: 223 | 224 | - probabilistic classification, 225 | - regression with multivariate response, and 226 | - predictions found by applying multiple regression models. 227 | 228 | Here, we use the `iris` data (no need to run code from above). 229 | 230 | ```r 231 | library(kernelshap) 232 | library(ranger) 233 | library(shapviz) 234 | 235 | set.seed(1) 236 | 237 | # Probabilistic classification 238 | fit_prob <- ranger(Species ~ ., data = iris, probability = TRUE) 239 | ps_prob <- permshap(fit_prob, X = iris[-5]) |> 240 | shapviz() 241 | sv_importance(ps_prob) 242 | sv_dependence(ps_prob, "Petal.Length") 243 | ``` 244 | 245 | ![](man/figures/README-prob-imp.svg) 246 | 247 | ![](man/figures/README-prob-dep.svg) 248 | 249 | ### Meta-learners 250 | 251 | Meta-learning packages like {tidymodels}, {caret} or {mlr3} are straightforward to use. The following examples additionally shows that the `...` arguments of `permshap()` and `kernelshap()` are passed to `predict()`. 252 | 253 | #### Tidymodels 254 | 255 | ```r 256 | library(kernelshap) 257 | library(tidymodels) 258 | 259 | set.seed(1) 260 | 261 | iris_recipe <- iris |> 262 | recipe(Species ~ .) 263 | 264 | mod <- rand_forest(trees = 100) |> 265 | set_engine("ranger") |> 266 | set_mode("classification") 267 | 268 | iris_wf <- workflow() |> 269 | add_recipe(iris_recipe) |> 270 | add_model(mod) 271 | 272 | fit <- iris_wf |> 273 | fit(iris) 274 | 275 | system.time( # 3s 276 | ps <- permshap(fit, iris[-5], type = "prob") 277 | ) 278 | ps 279 | 280 | # Some values 281 | $.pred_setosa 282 | Sepal.Length Sepal.Width Petal.Length Petal.Width 283 | [1,] 0.02186111 0.012137778 0.3658278 0.2667667 284 | [2,] 0.02628333 0.001315556 0.3683833 0.2706111 285 | ``` 286 | 287 | #### caret 288 | 289 | ```r 290 | library(kernelshap) 291 | library(caret) 292 | 293 | fit <- train( 294 | Sepal.Length ~ ., 295 | data = iris, 296 | method = "lm", 297 | tuneGrid = data.frame(intercept = TRUE), 298 | trControl = trainControl(method = "none") 299 | ) 300 | 301 | ps <- permshap(fit, iris[-1]) 302 | ``` 303 | 304 | #### mlr3 305 | 306 | ```r 307 | library(kernelshap) 308 | library(mlr3) 309 | library(mlr3learners) 310 | 311 | set.seed(1) 312 | 313 | task_classif <- TaskClassif$new(id = "1", backend = iris, target = "Species") 314 | learner_classif <- lrn("classif.rpart", predict_type = "prob") 315 | learner_classif$train(task_classif) 316 | 317 | x <- learner_classif$selected_features() 318 | 319 | # Don't forget to pass predict_type = "prob" to mlr3's predict() 320 | ps <- permshap( 321 | learner_classif, X = iris, feature_names = x, predict_type = "prob" 322 | ) 323 | ps 324 | # $setosa 325 | # Petal.Length Petal.Width 326 | # [1,] 0.6666667 0 327 | # [2,] 0.6666667 0 328 | ``` 329 | 330 | ## References 331 | 332 | [1] Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and Information Systems 41, 2014. 333 | 334 | [2] Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017. 335 | 336 | [3] Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021. 337 | -------------------------------------------------------------------------------- /backlog/2023-11-11 Permutation-SHAP.R: -------------------------------------------------------------------------------- 1 | library(kernelshap) 2 | library(ranger) 3 | 4 | differences <- numeric(4) 5 | 6 | system.time({ 7 | set.seed(1) 8 | for (depth in 1:4) { 9 | print(depth) 10 | fit <- ranger( 11 | Sepal.Length ~ ., 12 | mtry = 3, 13 | data = iris, 14 | max.depth = depth 15 | ) 16 | ps <- permshap(fit, iris[2:5], bg_X = iris, verbose = FALSE) 17 | ks <- kernelshap(fit, iris[2:5], bg_X = iris, verbose = FALSE) 18 | differences[depth] <- mean(abs(ks$S - ps$S)) 19 | } 20 | }) 21 | 22 | differences # for tree depth 1, 2, 3, 4 23 | # 5.053249e-17 9.046443e-17 2.387905e-04 4.403375e-04 24 | 25 | # SHAP values of first two rows with tree depth 4 26 | ps 27 | # Sepal.Width Petal.Length Petal.Width Species 28 | # [1,] 0.11377616 -0.7130647 -0.1956012 -0.004437022 29 | # [2,] -0.06852539 -0.7596562 -0.2259017 -0.006575266 30 | 31 | ks 32 | # Sepal.Width Petal.Length Petal.Width Species 33 | # [1,] 0.11463191 -0.7125194 -0.1951810 -0.006258208 34 | # [2,] -0.06828866 -0.7597391 -0.2259833 -0.006647530 35 | 36 | 37 | # larger data, more features 38 | library(xgboost) 39 | library(shapviz) 40 | 41 | colnames(miami) <- tolower(colnames(miami)) 42 | miami$log_ocean <- log(miami$ocean_dist) 43 | x <- c("log_ocean", "tot_lvg_area", "lnd_sqfoot", "structure_quality", "age", "month_sold") 44 | 45 | # Train/valid split 46 | set.seed(1) 47 | ix <- sample(nrow(miami), 0.8 * nrow(miami)) 48 | 49 | y_train <- log(miami$sale_prc[ix]) 50 | y_valid <- log(miami$sale_prc[-ix]) 51 | X_train <- data.matrix(miami[ix, x]) 52 | X_valid <- data.matrix(miami[-ix, x]) 53 | 54 | dtrain <- xgb.DMatrix(X_train, label = y_train) 55 | dvalid <- xgb.DMatrix(X_valid, label = y_valid) 56 | 57 | # Fit via early stopping (depth 1 to 3) 58 | differences <- numeric(3) 59 | 60 | for (i in 1:3) { 61 | fit <- xgb.train( 62 | params = list(learning_rate = 0.15, objective = "reg:squarederror", max_depth = i), 63 | data = dtrain, 64 | watchlist = list(valid = dvalid), 65 | early_stopping_rounds = 20, 66 | nrounds = 1000, 67 | callbacks = list(cb.print.evaluation(period = 100)) 68 | ) 69 | ps <- permshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500)) 70 | ks <- kernelshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500)) 71 | differences[i] <- mean(abs(ks$S - ps$S)) 72 | } 73 | differences 74 | # 2.904010e-09 5.158383e-09 6.586577e-04 75 | 76 | ps 77 | # SHAP values of first observations: 78 | # log_ocean tot_lvg_area lnd_sqfoot structure_quality age month_sold 79 | # 0.2224359 0.04941044 0.1266136 0.1360166 0.01036866 0.005557032 80 | # 0.3674484 0.01045079 0.1192187 0.1180312 0.01426247 0.005465283 81 | 82 | ks 83 | # SHAP values of first observations: 84 | # log_ocean tot_lvg_area lnd_sqfoot structure_quality age month_sold 85 | # 0.2245202 0.049520308 0.1266020 0.1349770 0.01142703 0.003355770 86 | # 0.3697167 0.009575195 0.1198201 0.1168738 0.01544061 0.003450425 87 | -------------------------------------------------------------------------------- /backlog/compare_with_python.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(kernelshap) 3 | 4 | # Turn ordinal factors into unordered 5 | ord <- c("clarity", "color", "cut") 6 | diamonds[, ord] <- lapply(diamonds[ord], factor, ordered = FALSE) 7 | 8 | # Fit model 9 | fit <- lm(log(price) ~ log(carat) * (clarity + color + cut), data = diamonds) 10 | 11 | # Subset of 120 diamonds used as background data 12 | bg_X <- diamonds[seq(1, nrow(diamonds), 450), ] 13 | 14 | # Subset of 1018 diamonds to explain 15 | X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)] 16 | 17 | # Exact KernelSHAP (2s) 18 | system.time( 19 | ks <- kernelshap(fit, X_small, bg_X = bg_X) 20 | ) 21 | ks 22 | 23 | # SHAP values of first 2 observations: 24 | # carat clarity color cut 25 | # [1,] -2.050074 -0.28048747 0.1281222 0.01587382 26 | # [2,] -2.085838 0.04050415 0.1283010 0.03731644 27 | 28 | # Pure sampling version takes a bit longer (7 seconds) 29 | system.time( 30 | ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0) 31 | ) 32 | ks2 33 | 34 | # SHAP values of first 2 observations: 35 | # carat clarity color cut 36 | # [1,] -2.050074 -0.28048747 0.1281222 0.01587382 37 | # [2,] -2.085838 0.04050415 0.1283010 0.03731644 38 | 39 | 40 | library(shapviz) 41 | 42 | sv <- shapviz(ks) 43 | sv_dependence(sv, "carat") 44 | 45 | 46 | # More features (but non-sensical model) 47 | # Fit model 48 | fit <- lm( 49 | log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth, 50 | data = diamonds 51 | ) 52 | 53 | # Subset of 1018 diamonds to explain 54 | X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")] 55 | 56 | # Exact KernelSHAP on X_small, using X_small as background data 57 | # (39s for exact, 15s for hybrid deg 2, 8s for hybrid deg 1, 16s for sampling) 58 | system.time( 59 | ks <- kernelshap(fit, X_small, bg_X = bg_X) 60 | ) 61 | ks 62 | 63 | # SHAP values of first 2 observations: 64 | # carat cut color clarity depth table x y z 65 | # [1,] -1.842799 0.01424231 0.1266108 -0.27033874 -0.0007084443 0.0017787647 -0.1720782 0.001330275 -0.006445693 66 | # [2,] -1.876709 0.03856957 0.1266546 0.03932912 -0.0004202636 -0.0004871776 -0.1739880 0.001397792 -0.006560624 67 | 68 | #======================== 69 | # The same in Python 70 | #======================== 71 | 72 | import numpy as np 73 | import pandas as pd 74 | from plotnine.data import diamonds 75 | from statsmodels.formula.api import ols 76 | from shap import KernelExplainer 77 | 78 | # Turn categoricals into integers because, inconveniently, kernel SHAP 79 | # requires numpy array as input 80 | ord = ["clarity", "color", "cut"] 81 | # x = ["carat"] + ord + ["table", "depth", "x", "y", "z"] 82 | x = ["carat"] + ord 83 | diamonds[ord] = diamonds[ord].apply(lambda x: x.cat.codes) 84 | X = diamonds[x].to_numpy() 85 | 86 | # Fit model with interactions and dummy variables 87 | fit = ols( 88 | "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth", 89 | data=diamonds 90 | ).fit() 91 | 92 | # Background data (120 rows) 93 | bg_X = X[0:len(X):450] 94 | 95 | # Define subset of 1018 diamonds to explain 96 | X_small = X[0:len(X):53] 97 | 98 | # Calculate KernelSHAP values 99 | ks = KernelExplainer( 100 | model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), 101 | data = bg_X 102 | ) 103 | sv = ks.shap_values(X_small) # 11 minutes 104 | sv[0:2] 105 | 106 | # Output for four features (exact calculations, 1 minute) 107 | # array([[-2.05007406, -0.28048747, 0.12812216, 0.01587382], 108 | # [-2.0858379 , 0.04050415, 0.12830103, 0.03731644]]) 109 | 110 | 111 | # Output for nine features (exact calculations, 13 minutes) 112 | # array([[-1.84279897e+00, -2.70338744e-01, 1.26610769e-01, 113 | # 1.42423108e-02, 1.77876470e-03, -7.08444295e-04, 114 | # -1.72078182e-01, 1.33027467e-03, -6.44569296e-03], 115 | # [-1.87670887e+00, 3.93291219e-02, 1.26654599e-01, 116 | # 3.85695742e-02, -4.87177593e-04, -4.20263565e-04, 117 | # -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]]) 118 | -------------------------------------------------------------------------------- /backlog/plot_settings.R: -------------------------------------------------------------------------------- 1 | ggsave("man/figures/README-rf-imp.svg", scale = 2) 2 | ggsave("man/figures/README-rf-dep.svg", width = 8.5, height = 6) 3 | 4 | ggsave("man/figures/README-gam-imp.svg", scale = 2) 5 | ggsave("man/figures/README-gam-dep.svg", width = 8.5, height = 6) -------------------------------------------------------------------------------- /backlog/test_additive_shap.R: -------------------------------------------------------------------------------- 1 | # Some tests that need contributed packages 2 | 3 | library(mgcv) 4 | library(gam) 5 | library(survival) 6 | library(splines) 7 | library(testthat) 8 | 9 | formulas_ok <- list( 10 | Sepal.Length ~ Sepal.Width + Petal.Width + Species, 11 | Sepal.Length ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2), 12 | Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2) 13 | ) 14 | 15 | formulas_bad <- list( 16 | Sepal.Length ~ Species * Petal.Length, 17 | Sepal.Length ~ Species + Petal.Length + Species:Petal.Length, 18 | Sepal.Length ~ log(Petal.Length / Petal.Width) 19 | ) 20 | 21 | models <- list(mgcv::gam, mgcv::bam, gam::gam) 22 | 23 | X <- head(iris) 24 | for (formula in formulas_ok) { 25 | for (model in models) { 26 | fit <- model(formula, data = iris) 27 | s <- additive_shap(fit, X = X, verbose = FALSE) 28 | expect_equal(s$predictions, as.vector(predict(fit, newdata = X))) 29 | } 30 | } 31 | 32 | for (formula in formulas_bad) { 33 | for (model in models) { 34 | fit <- model(formula, data = iris) 35 | expect_error(s <- additive_shap(fit, X = X, verbose = FALSE)) 36 | } 37 | } 38 | 39 | # Survival 40 | iris$s <- rep(1, nrow(iris)) 41 | formulas_ok <- list( 42 | Surv(Sepal.Length, s) ~ Sepal.Width + Petal.Width + Species, 43 | Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2), 44 | Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Sepal.Width, 2) 45 | ) 46 | 47 | formulas_bad <- list( 48 | Surv(Sepal.Length, s) ~ Species * Petal.Length, 49 | Surv(Sepal.Length, s) ~ Species + Petal.Length + Species:Petal.Length, 50 | Surv(Sepal.Length, s) ~ log(Petal.Length / Petal.Width) 51 | ) 52 | 53 | models <- list(survival::coxph, survival::survreg) 54 | 55 | for (formula in formulas_ok) { 56 | for (model in models) { 57 | fit <- model(formula, data = iris) 58 | s <- additive_shap(fit, X = X, verbose = FALSE) 59 | } 60 | } 61 | 62 | for (formula in formulas_bad) { 63 | for (model in models) { 64 | fit <- model(formula, data = iris) 65 | expect_error(s <- additive_shap(fit, X = X, verbose = FALSE)) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /backlog/test_ranger.R: -------------------------------------------------------------------------------- 1 | library(ranger) 2 | library(survival) 3 | library(kernelshap) 4 | 5 | set.seed(1) 6 | 7 | fit <- ranger(Surv(time, status) ~ ., data = veteran, num.trees = 20) 8 | fit2 <- ranger(time ~ . - status, data = veteran, num.trees = 20) 9 | fit3 <- ranger(time ~ . - status, data = veteran, quantreg = TRUE, num.trees = 20) 10 | fit4 <- ranger(status ~ . - time, data = veteran, probability = TRUE, num.trees = 20) 11 | 12 | xvars <- setdiff(colnames(veteran), c("time", "status")) 13 | 14 | kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran) 15 | kernelshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob") 16 | kernelshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran) 17 | kernelshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles") 18 | kernelshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran) 19 | 20 | permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran) 21 | permshap(fit, head(veteran), feature_names = xvars, bg_X = veteran, survival = "prob") 22 | permshap(fit2, head(veteran), feature_names = xvars, bg_X = veteran) 23 | permshap(fit3, head(veteran), feature_names = xvars, bg_X = veteran, type = "quantiles") 24 | permshap(fit4, head(veteran), feature_names = xvars, bg_X = veteran) 25 | 26 | -------------------------------------------------------------------------------- /cran-comments.md: -------------------------------------------------------------------------------- 1 | # kernelshap 0.7.0 2 | 3 | Hello CRAN team 4 | 5 | This update comes with a major convenience improvement: Background data is automatically sampled from the explanation data, given that is sufficiently large. 6 | 7 | ## Checks 8 | 9 | ### Local check 10 | 11 | 0 errors | 0 warnings | 0 notes 12 | 13 | ### `check_win_devel()` 14 | 15 | Status: OK 16 | 17 | ### Revdep 18 | 19 | survex 1.2.0 20 | - OK: 1 21 | - BROKEN: 0 22 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/logo.png -------------------------------------------------------------------------------- /man/additive_shap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/additive_shap.R 3 | \name{additive_shap} 4 | \alias{additive_shap} 5 | \title{Additive SHAP} 6 | \usage{ 7 | additive_shap(object, X, verbose = TRUE, ...) 8 | } 9 | \arguments{ 10 | \item{object}{Fitted additive model.} 11 | 12 | \item{X}{Dataframe with rows to be explained. Passed to 13 | \code{predict(object, newdata = X, type = "terms")}.} 14 | 15 | \item{verbose}{Set to \code{FALSE} to suppress messages.} 16 | 17 | \item{...}{Currently unused.} 18 | } 19 | \value{ 20 | An object of class "kernelshap" with the following components: 21 | \itemize{ 22 | \item \code{S}: \eqn{(n \times p)} matrix with SHAP values. 23 | \item \code{X}: Same as input argument \code{X}. 24 | \item \code{baseline}: The baseline. 25 | \item \code{exact}: \code{TRUE}. 26 | \item \code{txt}: Summary text. 27 | \item \code{predictions}: Vector with predictions of \code{X} on the scale of "terms". 28 | \item \code{algorithm}: "additive_shap". 29 | } 30 | } 31 | \description{ 32 | Exact additive SHAP assuming feature independence. The implementation 33 | works for models fitted via 34 | \itemize{ 35 | \item \code{\link[=lm]{lm()}}, 36 | \item \code{\link[=glm]{glm()}}, 37 | \item \code{\link[mgcv:gam]{mgcv::gam()}}, 38 | \item \code{\link[mgcv:bam]{mgcv::bam()}}, 39 | \item \code{gam::gam()}, 40 | \item \code{\link[survival:coxph]{survival::coxph()}}, and 41 | \item \code{\link[survival:survreg]{survival::survreg()}}. 42 | } 43 | } 44 | \details{ 45 | The SHAP values are extracted via \code{predict(object, newdata = X, type = "terms")}, 46 | a logic adopted from \code{fastshap:::explain.lm(..., exact = TRUE)}. 47 | Models with interactions (specified via \code{:} or \code{*}), or with terms of 48 | multiple features like \code{log(x1/x2)} are not supported. 49 | 50 | Note that the SHAP values obtained by \code{\link[=additive_shap]{additive_shap()}} are expected to 51 | match those of \code{\link[=permshap]{permshap()}} and \code{\link[=kernelshap]{kernelshap()}} as long as their background 52 | data equals the full training data (which is typically not feasible). 53 | } 54 | \examples{ 55 | # MODEL ONE: Linear regression 56 | fit <- lm(Sepal.Length ~ ., data = iris) 57 | s <- additive_shap(fit, head(iris)) 58 | s 59 | 60 | # MODEL TWO: More complicated (but not very clever) formula 61 | fit <- lm( 62 | Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width), 63 | data = iris 64 | ) 65 | s_add <- additive_shap(fit, head(iris)) 66 | s_add 67 | 68 | # Equals kernelshap()/permshap() when background data is full training data 69 | s_kernel <- kernelshap( 70 | fit, head(iris[c("Sepal.Width", "Petal.Length")]), bg_X = iris 71 | ) 72 | all.equal(s_add$S, s_kernel$S) 73 | } 74 | -------------------------------------------------------------------------------- /man/figures/README-nn-imp.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 0.967 50 | 0.160 51 | 0.110 52 | 0.026 53 | 54 | 55 | cut 56 | color 57 | clarity 58 | log_carat 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 0.0 68 | 0.3 69 | 0.6 70 | 0.9 71 | mean(|SHAP value|) 72 | 73 | 74 | -------------------------------------------------------------------------------- /man/figures/README-prob-imp.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | Sepal.Width 61 | Sepal.Length 62 | Petal.Width 63 | Petal.Length 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 0.00 74 | 0.05 75 | 0.10 76 | 0.15 77 | 0.20 78 | mean(|SHAP value|) 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | setosa 87 | versicolor 88 | virginica 89 | 90 | 91 | -------------------------------------------------------------------------------- /man/figures/README-rf-imp.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | cut 52 | color 53 | clarity 54 | log_carat 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 0.00 64 | 0.25 65 | 0.50 66 | 0.75 67 | mean(|SHAP value|) 68 | 69 | 70 | -------------------------------------------------------------------------------- /man/figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/man/figures/logo.png -------------------------------------------------------------------------------- /man/is.kernelshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{is.kernelshap} 4 | \alias{is.kernelshap} 5 | \title{Check for kernelshap} 6 | \usage{ 7 | is.kernelshap(object) 8 | } 9 | \arguments{ 10 | \item{object}{An R object.} 11 | } 12 | \value{ 13 | \code{TRUE} if \code{object} is of class "kernelshap", and \code{FALSE} otherwise. 14 | } 15 | \description{ 16 | Is object of class "kernelshap"? 17 | } 18 | \examples{ 19 | fit <- lm(Sepal.Length ~ ., data = iris) 20 | s <- kernelshap(fit, iris[1:2, -1], bg_X = iris[, -1]) 21 | is.kernelshap(s) 22 | is.kernelshap("a") 23 | } 24 | \seealso{ 25 | \code{\link[=kernelshap]{kernelshap()}} 26 | } 27 | -------------------------------------------------------------------------------- /man/kernelshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/kernelshap.R 3 | \name{kernelshap} 4 | \alias{kernelshap} 5 | \alias{kernelshap.default} 6 | \alias{kernelshap.ranger} 7 | \title{Kernel SHAP} 8 | \usage{ 9 | kernelshap(object, ...) 10 | 11 | \method{kernelshap}{default}( 12 | object, 13 | X, 14 | bg_X = NULL, 15 | pred_fun = stats::predict, 16 | feature_names = colnames(X), 17 | bg_w = NULL, 18 | bg_n = 200L, 19 | exact = length(feature_names) <= 8L, 20 | hybrid_degree = 1L + length(feature_names) \%in\% 4:16, 21 | paired_sampling = TRUE, 22 | m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)), 23 | tol = 0.005, 24 | max_iter = 100L, 25 | parallel = FALSE, 26 | parallel_args = NULL, 27 | verbose = TRUE, 28 | ... 29 | ) 30 | 31 | \method{kernelshap}{ranger}( 32 | object, 33 | X, 34 | bg_X = NULL, 35 | pred_fun = NULL, 36 | feature_names = colnames(X), 37 | bg_w = NULL, 38 | bg_n = 200L, 39 | exact = length(feature_names) <= 8L, 40 | hybrid_degree = 1L + length(feature_names) \%in\% 4:16, 41 | paired_sampling = TRUE, 42 | m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)), 43 | tol = 0.005, 44 | max_iter = 100L, 45 | parallel = FALSE, 46 | parallel_args = NULL, 47 | verbose = TRUE, 48 | survival = c("chf", "prob"), 49 | ... 50 | ) 51 | } 52 | \arguments{ 53 | \item{object}{Fitted model object.} 54 | 55 | \item{...}{Additional arguments passed to \code{pred_fun(object, X, ...)}.} 56 | 57 | \item{X}{\eqn{(n \times p)} matrix or \code{data.frame} with rows to be explained. 58 | The columns should only represent model features, not the response 59 | (but see \code{feature_names} on how to overrule this).} 60 | 61 | \item{bg_X}{Background data used to integrate out "switched off" features, 62 | often a subset of the training data (typically 50 to 500 rows). 63 | In cases with a natural "off" value (like MNIST digits), 64 | this can also be a single row with all values set to the off value. 65 | If no \code{bg_X} is passed (the default) and if \code{X} is sufficiently large, 66 | a random sample of \code{bg_n} rows from \code{X} serves as background data.} 67 | 68 | \item{pred_fun}{Prediction function of the form \verb{function(object, X, ...)}, 69 | providing \eqn{K \ge 1} predictions per row. Its first argument 70 | represents the model \code{object}, its second argument a data structure like \code{X}. 71 | Additional (named) arguments are passed via \code{...}. 72 | The default, \code{\link[stats:predict]{stats::predict()}}, will work in most cases.} 73 | 74 | \item{feature_names}{Optional vector of column names in \code{X} used to calculate 75 | SHAP values. By default, this equals \code{colnames(X)}. Not supported if \code{X} 76 | is a matrix.} 77 | 78 | \item{bg_w}{Optional vector of case weights for each row of \code{bg_X}. 79 | If \code{bg_X = NULL}, must be of same length as \code{X}. Set to \code{NULL} for no weights.} 80 | 81 | \item{bg_n}{If \code{bg_X = NULL}: Size of background data to be sampled from \code{X}.} 82 | 83 | \item{exact}{If \code{TRUE}, the algorithm will produce exact Kernel SHAP values 84 | with respect to the background data. In this case, the arguments \code{hybrid_degree}, 85 | \code{m}, \code{paired_sampling}, \code{tol}, and \code{max_iter} are ignored. 86 | The default is \code{TRUE} up to eight features, and \code{FALSE} otherwise.} 87 | 88 | \item{hybrid_degree}{Integer controlling the exactness of the hybrid strategy. For 89 | \eqn{4 \le p \le 16}, the default is 2, otherwise it is 1. 90 | Ignored if \code{exact = TRUE}. 91 | \itemize{ 92 | \item \code{0}: Pure sampling strategy not involving any exact part. It is strictly 93 | worse than the hybrid strategy and should therefore only be used for 94 | studying properties of the Kernel SHAP algorithm. 95 | \item \code{1}: Uses all \eqn{2p} on-off vectors \eqn{z} with \eqn{\sum z \in \{1, p-1\}} 96 | for the exact part, which covers at least 75\% of the mass of the Kernel weight 97 | distribution. The remaining mass is covered by random sampling. 98 | \item \code{2}: Uses all \eqn{p(p+1)} on-off vectors \eqn{z} with 99 | \eqn{\sum z \in \{1, 2, p-2, p-1\}}. This covers at least 92\% of the mass of the 100 | Kernel weight distribution. The remaining mass is covered by sampling. 101 | Convergence usually happens in the minimal possible number of iterations of two. 102 | \item \code{k>2}: Uses all on-off vectors with 103 | \eqn{\sum z \in \{1, \dots, k, p-k, \dots, p-1\}}. 104 | }} 105 | 106 | \item{paired_sampling}{Logical flag indicating whether to do the sampling in a paired 107 | manner. This means that with every on-off vector \eqn{z}, also \eqn{1-z} is 108 | considered. CL21 shows its superiority compared to standard sampling, therefore the 109 | default (\code{TRUE}) should usually not be changed except for studying properties 110 | of Kernel SHAP algorithms. Ignored if \code{exact = TRUE}.} 111 | 112 | \item{m}{Even number of on-off vectors sampled during one iteration. 113 | The default is \eqn{2p}, except when \code{hybrid_degree == 0}. 114 | Then it is set to \eqn{8p}. Ignored if \code{exact = TRUE}.} 115 | 116 | \item{tol}{Tolerance determining when to stop. Following CL21, the algorithm keeps 117 | iterating until \eqn{\textrm{max}(\sigma_n)/(\textrm{max}(\beta_n) - \textrm{min}(\beta_n)) < \textrm{tol}}, 118 | where the \eqn{\beta_n} are the SHAP values of a given observation, 119 | and \eqn{\sigma_n} their standard errors. 120 | For multidimensional predictions, the criterion must be satisfied for each 121 | dimension separately. The stopping criterion uses the fact that standard errors 122 | and SHAP values are all on the same scale. Ignored if \code{exact = TRUE}.} 123 | 124 | \item{max_iter}{If the stopping criterion (see \code{tol}) is not reached after 125 | \code{max_iter} iterations, the algorithm stops. Ignored if \code{exact = TRUE}.} 126 | 127 | \item{parallel}{If \code{TRUE}, use parallel \code{\link[foreach:foreach]{foreach::foreach()}} to loop over rows 128 | to be explained. Must register backend beforehand, e.g., via 'doFuture' package, 129 | see README for an example. Parallelization automatically disables the progress bar.} 130 | 131 | \item{parallel_args}{Named list of arguments passed to \code{\link[foreach:foreach]{foreach::foreach()}}. 132 | Ideally, this is \code{NULL} (default). Only relevant if \code{parallel = TRUE}. 133 | Example on Windows: if \code{object} is a GAM fitted with package 'mgcv', 134 | then one might need to set \code{parallel_args = list(.packages = "mgcv")}.} 135 | 136 | \item{verbose}{Set to \code{FALSE} to suppress messages and the progress bar.} 137 | 138 | \item{survival}{Should cumulative hazards ("chf", default) or survival 139 | probabilities ("prob") per time be predicted? Only in \code{ranger()} survival models.} 140 | } 141 | \value{ 142 | An object of class "kernelshap" with the following components: 143 | \itemize{ 144 | \item \code{S}: \eqn{(n \times p)} matrix with SHAP values or, if the model output has 145 | dimension \eqn{K > 1}, a list of \eqn{K} such matrices. 146 | \item \code{X}: Same as input argument \code{X}. 147 | \item \code{baseline}: Vector of length K representing the average prediction on the 148 | background data. 149 | \item \code{bg_X}: The background data. 150 | \item \code{bg_w}: The background case weights. 151 | \item \code{SE}: Standard errors corresponding to \code{S} (and organized like \code{S}). 152 | \item \code{n_iter}: Integer vector of length n providing the number of iterations 153 | per row of \code{X}. 154 | \item \code{converged}: Logical vector of length n indicating convergence per row of \code{X}. 155 | \item \code{m}: Integer providing the effective number of sampled on-off vectors used 156 | per iteration. 157 | \item \code{m_exact}: Integer providing the effective number of exact on-off vectors used 158 | per iteration. 159 | \item \code{prop_exact}: Proportion of the Kernel SHAP weight distribution covered by 160 | exact calculations. 161 | \item \code{exact}: Logical flag indicating whether calculations are exact or not. 162 | \item \code{txt}: Summary text. 163 | \item \code{predictions}: \eqn{(n \times K)} matrix with predictions of \code{X}. 164 | \item \code{algorithm}: "kernelshap". 165 | } 166 | } 167 | \description{ 168 | Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and 169 | Covert and Lee (2021), abbreviated by CL21. 170 | For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding 171 | the selected background data. For larger \eqn{p}, an almost exact 172 | hybrid algorithm combining exact calculations and iterative sampling is used, 173 | see Details. 174 | 175 | Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP. 176 | Thus, for up to eight features, we recommend \code{\link[=permshap]{permshap()}}. For more features, 177 | \code{\link[=permshap]{permshap()}} is slow compared the optimized hybrid strategy of our Kernel SHAP 178 | implementation. 179 | } 180 | \details{ 181 | The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: 182 | \enumerate{ 183 | \item A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} 184 | such that its sum follows the SHAP Kernel weight distribution 185 | (normalized to the range \eqn{\{1, \dots, p-1\}}). 186 | \item For each \eqn{j} with \eqn{z_j = 1}, the \eqn{j}-th column of the 187 | original background data is replaced by the corresponding feature value \eqn{x_j} 188 | of the observation to be explained. 189 | \item The average prediction \eqn{v_z} on the data of Step 2 is calculated, and the 190 | average prediction \eqn{v_0} on the background data is subtracted. 191 | \item Steps 1 to 3 are repeated \eqn{m} times. This produces a binary \eqn{m \times p} 192 | matrix \eqn{Z} (each row equals one of the \eqn{z}) and a vector \eqn{v} of 193 | shifted predictions. 194 | \item \eqn{v} is regressed onto \eqn{Z} under the constraint that the sum of the 195 | coefficients equals \eqn{v_1 - v_0}, where \eqn{v_1} is the prediction of the 196 | observation to be explained. The resulting coefficients are the Kernel SHAP values. 197 | } 198 | 199 | This is repeated multiple times until convergence, see CL21 for details. 200 | 201 | A drawback of this strategy is that many (at least 75\%) of the \eqn{z} vectors will 202 | have \eqn{\sum z \in \{1, p-1\}}, producing many duplicates. Similarly, at least 92\% 203 | of the mass will be used for the \eqn{p(p+1)} possible vectors with 204 | \eqn{\sum z \in \{1, 2, p-2, p-1\}}. 205 | This inefficiency can be fixed by a hybrid strategy, combining exact calculations 206 | with sampling. 207 | 208 | The hybrid algorithm has two steps: 209 | \enumerate{ 210 | \item Step 1 (exact part): There are \eqn{2p} different on-off vectors \eqn{z} with 211 | \eqn{\sum z \in \{1, p-1\}}, covering a large proportion of the Kernel SHAP 212 | distribution. The degree 1 hybrid will list those vectors and use them according 213 | to their weights in the upcoming calculations. Depending on \eqn{p}, we can also go 214 | a step further to a degree 2 hybrid by adding all \eqn{p(p-1)} vectors with 215 | \eqn{\sum z \in \{2, p-2\}} to the process etc. The necessary predictions are 216 | obtained along with other calculations similar to those described in CL21. 217 | \item Step 2 (sampling part): The remaining weight is filled by sampling vectors z 218 | according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. 219 | Together with the results from Step 1 - correctly weighted - this now forms a 220 | complete iteration as in CL21. The difference is that most mass is covered by exact 221 | calculations. Afterwards, the algorithm iterates until convergence. 222 | The output of Step 1 is reused in every iteration, leading to an extremely 223 | efficient strategy. 224 | } 225 | 226 | If \eqn{p} is sufficiently small, all possible \eqn{2^p-2} on-off vectors \eqn{z} can be 227 | evaluated. In this case, no sampling is required and the algorithm returns exact 228 | Kernel SHAP values with respect to the given background data. 229 | Since \code{\link[=kernelshap]{kernelshap()}} calculates predictions on data with \eqn{MN} rows 230 | (\eqn{N} is the background data size and \eqn{M} the number of \eqn{z} vectors), \eqn{p} 231 | should not be much higher than 10 for exact calculations. 232 | For similar reasons, degree 2 hybrids should not use \eqn{p} much larger than 40. 233 | } 234 | \section{Methods (by class)}{ 235 | \itemize{ 236 | \item \code{kernelshap(default)}: Default Kernel SHAP method. 237 | 238 | \item \code{kernelshap(ranger)}: Kernel SHAP method for "ranger" models, see Readme for an example. 239 | 240 | }} 241 | \examples{ 242 | # MODEL ONE: Linear regression 243 | fit <- lm(Sepal.Length ~ ., data = iris) 244 | 245 | # Select rows to explain (only feature columns) 246 | X_explain <- iris[-1] 247 | 248 | # Calculate SHAP values 249 | s <- kernelshap(fit, X_explain) 250 | s 251 | 252 | # MODEL TWO: Multi-response linear regression 253 | fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) 254 | s <- kernelshap(fit, iris[3:5]) 255 | s 256 | 257 | # Note 1: Feature columns can also be selected 'feature_names' 258 | # Note 2: Especially when X is small, pass a sufficiently large background data bg_X 259 | s <- kernelshap( 260 | fit, 261 | iris[1:4, ], 262 | bg_X = iris, 263 | feature_names = c("Petal.Length", "Petal.Width", "Species") 264 | ) 265 | s 266 | } 267 | \references{ 268 | \enumerate{ 269 | \item Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model 270 | predictions. Proceedings of the 31st International Conference on Neural 271 | Information Processing Systems, 2017. 272 | \item Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value 273 | Estimation Using Linear Regression. Proceedings of The 24th International 274 | Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021. 275 | } 276 | } 277 | -------------------------------------------------------------------------------- /man/permshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/permshap.R 3 | \name{permshap} 4 | \alias{permshap} 5 | \alias{permshap.default} 6 | \alias{permshap.ranger} 7 | \title{Permutation SHAP} 8 | \usage{ 9 | permshap(object, ...) 10 | 11 | \method{permshap}{default}( 12 | object, 13 | X, 14 | bg_X = NULL, 15 | pred_fun = stats::predict, 16 | feature_names = colnames(X), 17 | bg_w = NULL, 18 | bg_n = 200L, 19 | parallel = FALSE, 20 | parallel_args = NULL, 21 | verbose = TRUE, 22 | ... 23 | ) 24 | 25 | \method{permshap}{ranger}( 26 | object, 27 | X, 28 | bg_X = NULL, 29 | pred_fun = NULL, 30 | feature_names = colnames(X), 31 | bg_w = NULL, 32 | bg_n = 200L, 33 | parallel = FALSE, 34 | parallel_args = NULL, 35 | verbose = TRUE, 36 | survival = c("chf", "prob"), 37 | ... 38 | ) 39 | } 40 | \arguments{ 41 | \item{object}{Fitted model object.} 42 | 43 | \item{...}{Additional arguments passed to \code{pred_fun(object, X, ...)}.} 44 | 45 | \item{X}{\eqn{(n \times p)} matrix or \code{data.frame} with rows to be explained. 46 | The columns should only represent model features, not the response 47 | (but see \code{feature_names} on how to overrule this).} 48 | 49 | \item{bg_X}{Background data used to integrate out "switched off" features, 50 | often a subset of the training data (typically 50 to 500 rows). 51 | In cases with a natural "off" value (like MNIST digits), 52 | this can also be a single row with all values set to the off value. 53 | If no \code{bg_X} is passed (the default) and if \code{X} is sufficiently large, 54 | a random sample of \code{bg_n} rows from \code{X} serves as background data.} 55 | 56 | \item{pred_fun}{Prediction function of the form \verb{function(object, X, ...)}, 57 | providing \eqn{K \ge 1} predictions per row. Its first argument 58 | represents the model \code{object}, its second argument a data structure like \code{X}. 59 | Additional (named) arguments are passed via \code{...}. 60 | The default, \code{\link[stats:predict]{stats::predict()}}, will work in most cases.} 61 | 62 | \item{feature_names}{Optional vector of column names in \code{X} used to calculate 63 | SHAP values. By default, this equals \code{colnames(X)}. Not supported if \code{X} 64 | is a matrix.} 65 | 66 | \item{bg_w}{Optional vector of case weights for each row of \code{bg_X}. 67 | If \code{bg_X = NULL}, must be of same length as \code{X}. Set to \code{NULL} for no weights.} 68 | 69 | \item{bg_n}{If \code{bg_X = NULL}: Size of background data to be sampled from \code{X}.} 70 | 71 | \item{parallel}{If \code{TRUE}, use parallel \code{\link[foreach:foreach]{foreach::foreach()}} to loop over rows 72 | to be explained. Must register backend beforehand, e.g., via 'doFuture' package, 73 | see README for an example. Parallelization automatically disables the progress bar.} 74 | 75 | \item{parallel_args}{Named list of arguments passed to \code{\link[foreach:foreach]{foreach::foreach()}}. 76 | Ideally, this is \code{NULL} (default). Only relevant if \code{parallel = TRUE}. 77 | Example on Windows: if \code{object} is a GAM fitted with package 'mgcv', 78 | then one might need to set \code{parallel_args = list(.packages = "mgcv")}.} 79 | 80 | \item{verbose}{Set to \code{FALSE} to suppress messages and the progress bar.} 81 | 82 | \item{survival}{Should cumulative hazards ("chf", default) or survival 83 | probabilities ("prob") per time be predicted? Only in \code{ranger()} survival models.} 84 | } 85 | \value{ 86 | An object of class "kernelshap" with the following components: 87 | \itemize{ 88 | \item \code{S}: \eqn{(n \times p)} matrix with SHAP values or, if the model output has 89 | dimension \eqn{K > 1}, a list of \eqn{K} such matrices. 90 | \item \code{X}: Same as input argument \code{X}. 91 | \item \code{baseline}: Vector of length K representing the average prediction on the 92 | background data. 93 | \item \code{bg_X}: The background data. 94 | \item \code{bg_w}: The background case weights. 95 | \item \code{m_exact}: Integer providing the effective number of exact on-off vectors used. 96 | \item \code{exact}: Logical flag indicating whether calculations are exact or not 97 | (currently always \code{TRUE}). 98 | \item \code{txt}: Summary text. 99 | \item \code{predictions}: \eqn{(n \times K)} matrix with predictions of \code{X}. 100 | \item \code{algorithm}: "permshap". 101 | } 102 | } 103 | \description{ 104 | Exact permutation SHAP algorithm with respect to a background dataset, 105 | see Strumbelj and Kononenko. The function works for up to 14 features. 106 | For more than eight features, we recommend \code{\link[=kernelshap]{kernelshap()}} due to its higher speed. 107 | } 108 | \section{Methods (by class)}{ 109 | \itemize{ 110 | \item \code{permshap(default)}: Default permutation SHAP method. 111 | 112 | \item \code{permshap(ranger)}: Permutation SHAP method for "ranger" models, see Readme for an example. 113 | 114 | }} 115 | \examples{ 116 | # MODEL ONE: Linear regression 117 | fit <- lm(Sepal.Length ~ ., data = iris) 118 | 119 | # Select rows to explain (only feature columns) 120 | X_explain <- iris[-1] 121 | 122 | # Calculate SHAP values 123 | s <- permshap(fit, X_explain) 124 | s 125 | 126 | # MODEL TWO: Multi-response linear regression 127 | fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris) 128 | s <- permshap(fit, iris[3:5]) 129 | s 130 | 131 | # Note 1: Feature columns can also be selected 'feature_names' 132 | # Note 2: Especially when X is small, pass a sufficiently large background data bg_X 133 | s <- permshap( 134 | fit, 135 | iris[1:4, ], 136 | bg_X = iris, 137 | feature_names = c("Petal.Length", "Petal.Width", "Species") 138 | ) 139 | s 140 | } 141 | \references{ 142 | \enumerate{ 143 | \item Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual 144 | predictions with feature contributions. Knowledge and Information Systems 41, 2014. 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /man/print.kernelshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{print.kernelshap} 4 | \alias{print.kernelshap} 5 | \title{Prints "kernelshap" Object} 6 | \usage{ 7 | \method{print}{kernelshap}(x, n = 2L, ...) 8 | } 9 | \arguments{ 10 | \item{x}{An object of class "kernelshap".} 11 | 12 | \item{n}{Maximum number of rows of SHAP values to print.} 13 | 14 | \item{...}{Further arguments passed from other methods.} 15 | } 16 | \value{ 17 | Invisibly, the input is returned. 18 | } 19 | \description{ 20 | Prints "kernelshap" Object 21 | } 22 | \examples{ 23 | fit <- lm(Sepal.Length ~ ., data = iris) 24 | s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[, -1]) 25 | s 26 | } 27 | \seealso{ 28 | \code{\link[=kernelshap]{kernelshap()}} 29 | } 30 | -------------------------------------------------------------------------------- /man/summary.kernelshap.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/methods.R 3 | \name{summary.kernelshap} 4 | \alias{summary.kernelshap} 5 | \title{Summarizes "kernelshap" Object} 6 | \usage{ 7 | \method{summary}{kernelshap}(object, compact = FALSE, n = 2L, ...) 8 | } 9 | \arguments{ 10 | \item{object}{An object of class "kernelshap".} 11 | 12 | \item{compact}{Set to \code{TRUE} for a more compact summary.} 13 | 14 | \item{n}{Maximum number of rows of SHAP values etc. to print.} 15 | 16 | \item{...}{Further arguments passed from other methods.} 17 | } 18 | \value{ 19 | Invisibly, the input is returned. 20 | } 21 | \description{ 22 | Summarizes "kernelshap" Object 23 | } 24 | \examples{ 25 | fit <- lm(Sepal.Length ~ ., data = iris) 26 | s <- kernelshap(fit, iris[1:3, -1], bg_X = iris[, -1]) 27 | summary(s) 28 | } 29 | \seealso{ 30 | \code{\link[=kernelshap]{kernelshap()}} 31 | } 32 | -------------------------------------------------------------------------------- /packaging.R: -------------------------------------------------------------------------------- 1 | #============================================================================= 2 | # Put together the package 3 | #============================================================================= 4 | 5 | # WORKFLOW: UPDATE EXISTING PACKAGE 6 | # 1) Modify package content and documentation. 7 | # 2) Increase package number in "use_description" below. 8 | # 3) Go through this script and carefully answer "no" if a "use_*" function 9 | # asks to overwrite the existing files. Don't skip that function call. 10 | # devtools::load_all() 11 | 12 | library(usethis) 13 | 14 | # Sketch of description file 15 | use_description( 16 | fields = list( 17 | Title = "Kernel SHAP", 18 | Version = "0.7.1", 19 | Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), 20 | and Covert and Lee (2021) . 21 | Furthermore, for up to 14 features, exact permutation SHAP values can be calculated. 22 | The package plays well together with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'. 23 | Visualizations can be done using the R package 'shapviz'.", 24 | `Authors@R` = 25 | "c(person('Michael', family='Mayer', role=c('aut', 'cre'), email='mayermichael79@gmail.com', comment=c(ORCID='0009-0007-2540-9629')), 26 | person('David', family='Watson', role='aut', email='david.s.watson11@gmail.com', comment=c(ORCID='0000-0001-9632-2159')), 27 | person('Przemyslaw', family='Biecek', email='przemyslaw.biecek@gmail.com', role='ctb', comment=c(ORCID='0000-0001-8423-1823')) 28 | )", 29 | Depends = "R (>= 3.2.0)", 30 | LazyData = NULL 31 | ), 32 | roxygen = TRUE 33 | ) 34 | 35 | use_package("foreach", "Imports") 36 | use_package("MASS", "Imports") 37 | use_package("stats", "Imports") 38 | use_package("utils", "Imports") 39 | 40 | use_package("doFuture", "Suggests") 41 | 42 | use_gpl_license(2) 43 | 44 | # Your files that do not belong to the package itself (others are added by "use_* function") 45 | use_build_ignore(c("^packaging.R$", "[.]Rproj$", "^compare_with_python.R$", 46 | "^cran-comments.md$", "^logo.png$"), escape = FALSE) 47 | 48 | # If your code uses the pipe operator %>% 49 | # use_pipe() 50 | 51 | # If your package contains data. Google how to document 52 | # use_data() 53 | 54 | # Add short docu in Markdown (without running R code) 55 | use_readme_md() 56 | 57 | # Longer docu in RMarkdown (with running R code). Often quite similar to readme. 58 | # use_vignette("kernelshap") 59 | 60 | # If you want to add unit tests 61 | use_testthat() 62 | # use_test("kernelshap.R") 63 | # use_test("methods.R") 64 | 65 | # On top of NEWS.md, describe changes made to the package 66 | use_news_md() 67 | 68 | # Add logo 69 | use_logo("logo.png") 70 | 71 | # If package goes to CRAN: infos (check results etc.) for CRAN 72 | use_cran_comments() 73 | 74 | use_github_links() # use this if this project is on github 75 | 76 | # Github actions 77 | use_github_action("check-standard") 78 | use_github_action("test-coverage") 79 | use_github_action("pkgdown") 80 | 81 | # Revdep 82 | use_revdep() 83 | 84 | #============================================================================= 85 | # Finish package building (can use fresh session) 86 | #============================================================================= 87 | 88 | library(devtools) 89 | 90 | document() 91 | test() 92 | check(manual = TRUE, cran = TRUE) 93 | build() 94 | # build(binary = TRUE) 95 | install(upgrade = FALSE) 96 | 97 | # Run only if package is public(!) and should go to CRAN 98 | if (FALSE) { 99 | check_win_devel() 100 | check_rhub() 101 | 102 | # Takes long 103 | revdepcheck::revdep_check(num_workers = 4L, bioc = FALSE) 104 | 105 | # Wait until above checks are passed without relevant notes/warnings 106 | # then submit to CRAN 107 | release() 108 | } 109 | -------------------------------------------------------------------------------- /pkgdown/_pkgdown.yml: -------------------------------------------------------------------------------- 1 | template: 2 | package: DrWhyTemplate 3 | default_assets: false 4 | params: 5 | ganalytics: UA-5650686-14 6 | noindex: true 7 | -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-120x120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon-120x120.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-152x152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon-152x152.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-180x180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon-180x180.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-60x60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon-60x60.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon-76x76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon-76x76.png -------------------------------------------------------------------------------- /pkgdown/favicon/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/apple-touch-icon.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/favicon-16x16.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/favicon-32x32.png -------------------------------------------------------------------------------- /pkgdown/favicon/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/kernelshap/585c02ff8187ca4cd6ff568350617c70a25d817f/pkgdown/favicon/favicon.ico -------------------------------------------------------------------------------- /revdep/.gitignore: -------------------------------------------------------------------------------- 1 | checks 2 | library 3 | checks.noindex 4 | library.noindex 5 | cloud.noindex 6 | data.sqlite 7 | *.html 8 | -------------------------------------------------------------------------------- /revdep/README.md: -------------------------------------------------------------------------------- 1 | # Platform 2 | 3 | |field |value | 4 | |:--------|:--------------------------------------------------------| 5 | |version |R version 4.4.1 (2024-06-14 ucrt) | 6 | |os |Windows 11 x64 (build 22631) | 7 | |system |x86_64, mingw32 | 8 | |ui |RStudio | 9 | |language |(EN) | 10 | |collate |German_Switzerland.utf8 | 11 | |ctype |German_Switzerland.utf8 | 12 | |tz |Europe/Zurich | 13 | |date |2024-08-09 | 14 | |rstudio |2024.04.2+764 Chocolate Cosmos (desktop) | 15 | |pandoc |3.1.6 @ C:\Users\Michael\AppData\Local\Pandoc\pandoc.exe | 16 | 17 | # Dependencies 18 | 19 | |package |old |new |Δ | 20 | |:----------|:------|:------|:--| 21 | |kernelshap |0.6.0 |0.7.0 |* | 22 | |foreach |1.5.2 |1.5.2 | | 23 | |iterators |1.0.14 |1.0.14 | | 24 | 25 | # Revdeps 26 | 27 | -------------------------------------------------------------------------------- /revdep/cran.md: -------------------------------------------------------------------------------- 1 | ## revdepcheck results 2 | 3 | We checked 1 reverse dependencies, comparing R CMD check results across CRAN and dev versions of this package. 4 | 5 | * We saw 0 new problems 6 | * We failed to check 0 packages 7 | 8 | -------------------------------------------------------------------------------- /revdep/email.yml: -------------------------------------------------------------------------------- 1 | release_date: ??? 2 | rel_release_date: ??? 3 | my_news_url: ??? 4 | release_version: ??? 5 | release_details: ??? 6 | -------------------------------------------------------------------------------- /revdep/failures.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /revdep/problems.md: -------------------------------------------------------------------------------- 1 | *Wow, no problems at all. :)* -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | # This file is part of the standard setup for testthat. 2 | # It is recommended that you do not modify it. 3 | # 4 | # Where should you do additional test configuration? 5 | # Learn more about the roles of various files in: 6 | # * https://r-pkgs.org/tests.html 7 | # * https://testthat.r-lib.org/reference/test_package.html#special-files 8 | 9 | library(testthat) 10 | library(kernelshap) 11 | 12 | test_check("kernelshap") 13 | -------------------------------------------------------------------------------- /tests/testthat/test-additive_shap.R: -------------------------------------------------------------------------------- 1 | test_that("Additive formulas give same as agnostic SHAP with full training data as bg data", { 2 | formulas <- list( 3 | Sepal.Length ~ ., 4 | Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2) + Petal.Length, 5 | form <- Sepal.Length ~ log(Sepal.Width) + Species + poly(Petal.Length, 2) 6 | ) 7 | xvars <- list( 8 | setdiff(colnames(iris), "Sepal.Length"), 9 | c("Sepal.Width", "Petal.Length"), 10 | xvars <- c("Sepal.Width", "Petal.Length", "Species") 11 | ) 12 | 13 | for (j in seq_along(formulas)) { 14 | fit <- list( 15 | lm = lm(formulas[[j]], data = iris), 16 | glm = glm(formulas[[j]], data = iris, family = quasipoisson) 17 | ) 18 | 19 | shap1 <- lapply(fit, additive_shap, head(iris), verbose = FALSE) 20 | shap2 <- lapply( 21 | fit, permshap, head(iris), bg_X = iris, verbose = FALSE, feature_names = xvars[[j]] 22 | ) 23 | shap3 <- lapply( 24 | fit, kernelshap, head(iris), bg_X = iris, verbose = FALSE, feature_names = xvars[[j]] 25 | ) 26 | 27 | for (i in seq_along(fit)) { 28 | expect_equal(shap1[[i]]$S, shap2[[i]]$S) 29 | expect_equal(shap1[[i]]$S, shap3[[i]]$S) 30 | } 31 | } 32 | }) 33 | 34 | test_that("formulas with more than one covariate per term fail", { 35 | formulas_bad <- list( 36 | Sepal.Length ~ Species * Petal.Length, 37 | Sepal.Length ~ Species + Petal.Length + Species:Petal.Length, 38 | Sepal.Length ~ log(Petal.Length / Petal.Width) 39 | ) 40 | 41 | for (formula in formulas_bad) { 42 | fit <- list( 43 | lm = lm(formula, data = iris), 44 | glm = glm(formula, data = iris, family = quasipoisson) 45 | ) 46 | for (f in fit) 47 | expect_error(additive_shap(f, head(iris), verbose = FALSE)) 48 | } 49 | }) 50 | 51 | -------------------------------------------------------------------------------- /tests/testthat/test-basic.R: -------------------------------------------------------------------------------- 1 | # Model with non-linearities and interactions 2 | fit <- lm( 3 | Sepal.Length ~ poly(Petal.Width, degree = 2L) * Species + Petal.Length, data = iris 4 | ) 5 | x <- c("Petal.Width", "Species", "Petal.Length") 6 | preds <- unname(predict(fit, iris)) 7 | J <- c(1L, 51L, 101L) 8 | 9 | shap <- list( 10 | kernelshap(fit, iris[x], bg_X = iris, verbose = FALSE), 11 | permshap(fit, iris[x], bg_X = iris, verbose = FALSE) 12 | ) 13 | 14 | test_that("baseline equals average prediction on background data", { 15 | for (s in shap) 16 | expect_equal(s$baseline, mean(iris$Sepal.Length)) 17 | }) 18 | 19 | test_that("SHAP + baseline = prediction for exact mode", { 20 | for (s in shap) 21 | expect_equal(rowSums(s$S) + s$baseline, preds) 22 | }) 23 | 24 | test_that("auto-selection of background data works", { 25 | # Here, the background data equals the full X 26 | shap2 <- list( 27 | kernelshap(fit, iris[x], verbose = FALSE), 28 | permshap(fit, iris[x], verbose = FALSE) 29 | ) 30 | 31 | for (i in 1:2) { 32 | expect_equal(shap$S, shap2$S) 33 | } 34 | }) 35 | 36 | test_that("missing bg_X gives error if X is very small", { 37 | for (algo in c(kernelshap, permshap)) 38 | expect_error(algo(fit, iris[1:10, x], verbose = FALSE)) 39 | 40 | }) 41 | 42 | test_that("missing bg_X gives warning if X is quite small", { 43 | for (algo in c(kernelshap, permshap)) 44 | expect_warning(algo(fit, iris[1:30, x], verbose = FALSE)) 45 | }) 46 | 47 | test_that("selection of bg_X can be controlled via bg_n", { 48 | for (algo in c(kernelshap, permshap)) { 49 | s <- algo(fit, iris[x], verbose = FALSE, bg_n = 20L) 50 | expect_equal(nrow(s$bg_X), 20L) 51 | } 52 | }) 53 | 54 | test_that("using foreach (non-parallel) gives the same as normal mode", { 55 | for (algo in c(kernelshap, permshap)) { 56 | s <- algo(fit, iris[J, x], bg_X = iris, verbose = FALSE) 57 | s2 <- suppressWarnings( 58 | algo(fit, iris[J, x], bg_X = iris, verbose = FALSE, parallel = TRUE) 59 | ) 60 | expect_equal(s, s2) 61 | } 62 | }) 63 | 64 | test_that("verbose is chatty", { 65 | for (algo in c(kernelshap, permshap)) { 66 | capture_output(expect_message(algo(fit, iris[J, x], bg_X = iris, verbose = TRUE))) 67 | } 68 | }) 69 | 70 | test_that("large background data cause warning", { 71 | # Takes a bit of time, thus only for one algo 72 | large_bg <- iris[rep(1:150, 230), ] 73 | expect_warning( 74 | kernelshap(fit, iris[1L, x], bg_X = large_bg, verbose = FALSE) 75 | ) 76 | }) 77 | 78 | test_that("Decomposing a single row works", { 79 | for (algo in c(kernelshap, permshap)) { 80 | s <- algo(fit, iris[1L, x], bg_X = iris, verbose = FALSE) 81 | expect_equal(s$baseline, mean(iris$Sepal.Length)) 82 | expect_equal(rowSums(s$S) + s$baseline, preds[1]) 83 | } 84 | }) 85 | 86 | test_that("Background data can contain additional columns", { 87 | for (algo in c(kernelshap, permshap)) { 88 | s <- algo(fit, iris[1L, x], bg_X = cbind(d = 1, iris), verbose = FALSE) 89 | expect_true(is.kernelshap(s)) 90 | } 91 | }) 92 | 93 | test_that("Background data can contain only one single row", { 94 | for (algo in c(kernelshap, permshap)) 95 | expect_no_error(algo(fit, iris[1L, x], bg_X = iris[150L, ], verbose = FALSE)) 96 | }) 97 | 98 | test_that("feature_names can drop columns from SHAP calculations", { 99 | for (algo in c(kernelshap, permshap)) { 100 | s <- algo(fit, iris[J, ], bg_X = iris, feature_names = x, verbose = FALSE) 101 | expect_equal(colnames(s$S), x) 102 | } 103 | }) 104 | 105 | test_that("feature_names can rearrange column names in result", { 106 | for (algo in c(kernelshap, permshap)) { 107 | s <- algo(fit, iris[J, ], bg_X = iris, feature_names = rev(x), verbose = FALSE) 108 | expect_equal(colnames(s$S), rev(x)) 109 | } 110 | }) 111 | 112 | test_that("feature_names must be in colnames(X) and colnames(bg_X)", { 113 | for (algo in c(kernelshap, permshap)) { 114 | expect_error(algo(fit, iris, bg_X = cbind(iris, a = 1), feature_names = "a")) 115 | expect_error(algo(fit, cbind(iris, a = 1), bg_X = iris, feature_names = "a")) 116 | } 117 | }) 118 | 119 | test_that("Matrix input is fine", { 120 | X <- data.matrix(iris) 121 | pred_fun <- function(m, X) { 122 | data <- as.data.frame(X) |> 123 | transform(Species = factor(Species, labels = levels(iris$Species))) 124 | predict(m, data) 125 | } 126 | 127 | for (algo in c(kernelshap, permshap)) { 128 | s <- algo(fit, X[J, x], pred_fun = pred_fun, bg_X = X, verbose = FALSE) 129 | 130 | expect_equal(s$baseline, mean(iris$Sepal.Length)) # baseline is mean of bg 131 | expect_equal(rowSums(s$S) + s$baseline, preds[J]) # sum shap = centered preds 132 | expect_no_error( # additional cols in bg are ok 133 | algo(fit, X[J, x], pred_fun = pred_fun, bg_X = cbind(d = 1, X), verbose = FALSE) 134 | ) 135 | expect_error( # feature_names are less flexible 136 | algo(fit, X[J, ], pred_fun = pred_fun, bg_X = X, 137 | verbose = FALSE, feature_names = "Sepal.Width") 138 | ) 139 | } 140 | }) 141 | 142 | test_that("Special case p = 1 works only for kernelshap()", { 143 | capture_output( 144 | expect_message( 145 | s <- kernelshap(fit, X = iris[J, ], bg_X = iris, feature_names = "Petal.Width") 146 | ) 147 | ) 148 | expect_equal(s$baseline, mean(iris$Sepal.Length)) 149 | expect_equal(unname(rowSums(s$S)) + s$baseline, preds[J]) 150 | expect_equal(s$SE[1L], 0) 151 | 152 | expect_error( # Not implemented 153 | permshap( 154 | fit, iris[J, ], bg_X = iris, verbose = FALSE, feature_names = "Petal.Width" 155 | ) 156 | ) 157 | }) 158 | 159 | test_that("exact hybrid kernelshap() is similar to exact (non-hybrid)", { 160 | s1 <- kernelshap( 161 | fit, iris[J, x], bg_X = iris, exact = FALSE, hybrid_degree = 1L, verbose = FALSE 162 | ) 163 | expect_equal(s1$S, shap[[1L]]$S[J, ]) 164 | }) 165 | 166 | test_that("baseline equals average prediction on background data in sampling mode", { 167 | s2 <- s_sampling <- kernelshap( 168 | fit, iris[J, x], bg_X = iris, hybrid_degree = 0L, verbose = FALSE, exact = FALSE 169 | ) 170 | expect_equal(s2$baseline, mean(iris$Sepal.Length)) 171 | }) 172 | 173 | test_that("SHAP + baseline = prediction for sampling mode", { 174 | s2 <- s_sampling <- kernelshap( 175 | fit, iris[J, x], bg_X = iris, hybrid_degree = 0L, verbose = FALSE, exact = FALSE 176 | ) 177 | expect_equal(rowSums(s2$S) + s2$baseline, preds[J]) 178 | }) 179 | 180 | test_that("kernelshap works for large p (hybrid case)", { 181 | set.seed(9L) 182 | X <- data.frame(matrix(rnorm(20000L), ncol = 100L)) 183 | y <- X[, 1L] * X[, 2L] * X[, 3L] 184 | fit <- lm(y ~ X1:X2:X3 + ., data = cbind(y = y, X)) 185 | s <- kernelshap(fit, X[1L, ], bg_X = X, verbose = FALSE) 186 | 187 | expect_equal(s$baseline, mean(y)) 188 | expect_equal(rowSums(s$S) + s$baseline, unname(predict(fit, X[1L, ]))) 189 | }) 190 | 191 | -------------------------------------------------------------------------------- /tests/testthat/test-kernelshap-utils.R: -------------------------------------------------------------------------------- 1 | test_that("sum of kernel weights is 1", { 2 | for (p in 2:10) { 3 | expect_equal(sum(kernel_weights(p)), 1.0) 4 | } 5 | }) 6 | 7 | test_that("Sum of kernel weights is 1, even for subset of domain", { 8 | expect_equal(sum(kernel_weights(10L, S = 2:5)), 1.0) 9 | }) 10 | 11 | p <- 10L 12 | m <- 100L 13 | 14 | test_that("Random z have right output dim and the sums are between 1 and p-1", { 15 | Z <- sample_Z(p, m = m, feature_names = LETTERS[1:p]) 16 | 17 | expect_equal(dim(Z), c(m, p)) 18 | expect_true(all(rowSums(Z) %in% 1:(p - 1L))) 19 | }) 20 | 21 | test_that("Random z have right output dim and the sums are in subset S", { 22 | S <- 2:3 23 | Z <- sample_Z(p, m = m, feature_names = LETTERS[1:p], S = S) 24 | 25 | expect_equal(dim(Z), c(m, p)) 26 | expect_true(all(rowSums(Z) %in% S)) 27 | }) 28 | 29 | test_that("Sampling input structure is ok (deg = 0)", { 30 | input <- input_sampling( 31 | p, m = m, deg = 0L, paired = TRUE, feature_names = LETTERS[1:p] 32 | ) 33 | 34 | expect_equal(dim(input$Z), c(m, p)) 35 | expect_equal(sum(input$w), 1.0) 36 | expect_equal(dim(input$A), c(p, p)) 37 | expect_equal(unname(diag(input$A)), rep(0.5, p)) 38 | }) 39 | 40 | test_that("Sampling input structure is ok (deg = 0, unpaired)", { 41 | input <- input_sampling( 42 | p, m = m, deg = 0L, paired = FALSE, feature_names = LETTERS[1:p] 43 | ) 44 | 45 | expect_equal(dim(input$Z), c(m, p)) 46 | expect_equal(sum(input$w), 1.0) 47 | expect_equal(dim(input$A), c(p, p)) 48 | # expect_equal(diag(input$A), rep(0.5, p)) # This is not TRUE 49 | }) 50 | 51 | test_that("Sampling input structure is ok (deg = 1)", { 52 | input <- input_sampling( 53 | p, m = m, deg = 1L, paired = TRUE, feature_names = LETTERS[1:p] 54 | ) 55 | 56 | expect_equal(dim(input$Z), c(m, p)) 57 | expect_true(sum(input$w) < 1.0) 58 | expect_equal(dim(input$A), c(p, p)) 59 | expect_true(all(diag(input$A) < 0.5)) 60 | }) 61 | 62 | test_that("Sampling input input structure ok (deg = 2)", { 63 | input <- input_sampling( 64 | p, m = m, deg = 2L, paired = TRUE, feature_names = LETTERS[1:p] 65 | ) 66 | 67 | expect_equal(dim(input$Z), c(m, p)) 68 | expect_true(sum(input$w) < 1.0) 69 | expect_equal(dim(input$A), c(p, p)) 70 | expect_true(all(diag(input$A) < 0.5)) 71 | }) 72 | 73 | test_that("Partly exact A, w, Z equal exact for sufficiently large deg", { 74 | for (p in 2:10) { 75 | pa <- input_partly_exact(p, deg = trunc(p / 2), feature_names = LETTERS[1:p]) 76 | ex <- input_exact(p, feature_names = LETTERS[1:p]) 77 | pa_rs <- rowSums(pa$Z) 78 | ex_rs <- rowSums(ex$Z) 79 | 80 | expect_equal(pa$A, ex$A) 81 | expect_equal(pa$w[order(pa_rs)], ex$w[order(ex_rs)]) 82 | expect_equal(tabulate(pa_rs), tabulate(ex_rs)) 83 | } 84 | }) 85 | 86 | test_that("hybrid weights sum to 1 for different p and degree 1", { 87 | deg <- 1L 88 | expect_error(input_sampling(2L, deg = deg, feature_names = LETTERS[1:p])) 89 | expect_error(input_sampling(3L, deg = deg, feature_names = LETTERS[1:p])) 90 | 91 | for (p in 4:20) { 92 | pa <- input_partly_exact(p, deg = deg, feature_names = LETTERS[1:p]) 93 | sa <- input_sampling( 94 | p, m = 100L, deg = deg, paired = TRUE, feature_names = LETTERS[1:p] 95 | ) 96 | expect_equal(sum(pa$w) + sum(sa$w), 1.0) 97 | } 98 | }) 99 | 100 | test_that("hybrid weights sum to 1 for different p and degree 2", { 101 | deg <- 2L 102 | expect_error(input_sampling(4L, deg = deg, feature_names = LETTERS[1:p])) 103 | expect_error(input_sampling(5L, deg = deg, feature_names = LETTERS[1:p])) 104 | 105 | for (p in 6:20) { 106 | pa <- input_partly_exact(p, deg = deg, feature_names = LETTERS[1:p]) 107 | sa <- input_sampling( 108 | p, m = 100L, deg = deg, paired = FALSE, feature_names = LETTERS[1:p] 109 | ) 110 | expect_equal(sum(pa$w) + sum(sa$w), 1L) 111 | } 112 | }) 113 | 114 | test_that("partly_exact_Z(p, k) fails for bad p or k", { 115 | expect_error(partly_exact_Z(0L, k = 1L, feature_names = LETTERS[1:p])) 116 | expect_error(partly_exact_Z(5L, k = 3L, feature_names = LETTERS[1:p])) 117 | expect_error(partly_exact_Z(5L, k = 0L, feature_names = LETTERS[1:p])) 118 | }) 119 | 120 | test_that("input_partly_exact(p, deg) fails for bad p or deg", { 121 | expect_error(input_partly_exact(2L, deg = 0L, feature_names = LETTERS[1:p])) 122 | expect_error(input_partly_exact(5L, deg = 3L, feature_names = LETTERS[1:p])) 123 | }) 124 | 125 | -------------------------------------------------------------------------------- /tests/testthat/test-methods.R: -------------------------------------------------------------------------------- 1 | fit <- lm(Sepal.Length ~ ., data = iris) 2 | 3 | set.seed(1) 4 | 5 | shap <- list( 6 | kernelshap( 7 | fit, iris[1:2, -1L], bg_X = iris, verbose = FALSE, exact = FALSE, hybrid_degree = 1 8 | ), 9 | permshap(fit, iris[1:2, -1L], bg_X = iris, verbose = FALSE), 10 | additive_shap(fit, iris, verbose = FALSE) 11 | ) 12 | 13 | test_that("is.kernelshap() works", { 14 | for (s in shap) { 15 | expect_true(is.kernelshap(s)) 16 | expect_false(is.kernelshap(1)) 17 | } 18 | }) 19 | 20 | test_that("print() and summary() do not give an error", { 21 | for (s in shap) { 22 | capture_output(expect_no_error(print(s))) 23 | capture_output(expect_no_error(summary(s))) 24 | capture_output(expect_no_error(summary(s, compact = TRUE))) 25 | } 26 | }) 27 | 28 | -------------------------------------------------------------------------------- /tests/testthat/test-multioutput.R: -------------------------------------------------------------------------------- 1 | # Model with non-linearities and interactions 2 | y <- iris$Sepal.Length 3 | Y <- as.matrix(iris[, c("Sepal.Length", "Sepal.Width")]) 4 | 5 | fit_y <- lm(y ~ poly(Petal.Width, degree = 2L) * Species, data = iris) 6 | fit_Y <- lm(Y ~ poly(Petal.Width, degree = 2L) * Species, data = iris) 7 | 8 | x <- c("Petal.Width", "Species") 9 | J <- c(1L, 51L, 101L) 10 | 11 | preds_y <- unname(predict(fit_y, iris)) 12 | preds_Y <- unname(predict(fit_Y, iris)) 13 | 14 | shap_y <- list( 15 | kernelshap(fit_y, iris[J, x], bg_X = iris, verbose = FALSE), 16 | permshap(fit_y, iris[J, x], bg_X = iris, verbose = FALSE) 17 | ) 18 | 19 | shap_Y <- list( 20 | kernelshap(fit_Y, iris[J, x], bg_X = iris, verbose = FALSE), 21 | permshap(fit_Y, iris[J, x], bg_X = iris, verbose = FALSE) 22 | ) 23 | 24 | test_that("Baseline equals average prediction on background data", { 25 | for (i in 1:2) { 26 | expect_equal(shap_Y[[i]]$baseline, unname(colMeans(Y))) 27 | } 28 | }) 29 | 30 | test_that("SHAP + baseline = prediction", { 31 | for (i in 1:2) { 32 | s <- shap_Y[[i]] 33 | expect_equal(rowSums(s$S[[1L]]) + s$baseline[1L], preds_Y[J, 1L]) 34 | expect_equal(rowSums(s$S[[2L]]) + s$baseline[2L], preds_Y[J, 2L]) 35 | } 36 | }) 37 | 38 | test_that("First dimension of multioutput model equals single output", { 39 | for (i in 1:2) { 40 | expect_equal(shap_Y[[i]]$baseline[1L], shap_y[[i]]$baseline) 41 | expect_equal(shap_Y[[i]]$S[[1L]], shap_y[[i]]$S) 42 | } 43 | }) 44 | 45 | 46 | -------------------------------------------------------------------------------- /tests/testthat/test-permshap-utils.R: -------------------------------------------------------------------------------- 1 | test_that("rowpaste() does what it should", { 2 | M <- cbind(c(0, 0), c(1, 0), c(1, 1)) 3 | expect_equal(rowpaste(M), c("011", "001")) 4 | }) 5 | 6 | test_that("shapley_weights() does what it should", { 7 | expect_equal(shapley_weights(5, 2), factorial(2) * factorial(5 - 2 - 1) / factorial(5)) 8 | }) 9 | 10 | -------------------------------------------------------------------------------- /tests/testthat/test-utils.R: -------------------------------------------------------------------------------- 1 | # Helper functions 2 | test_that("head_list(x) = head(x) for matrix x", { 3 | x <- cbind(1:10, 2:11) 4 | expect_equal(head_list(x), utils::head(x)) 5 | }) 6 | 7 | test_that("head_list(x)[[1L]] = head(x[[1L]]) for list of matries x", { 8 | x1 <- cbind(1:10, 2:11) 9 | x2 <- cbind(1:7, 2:8) 10 | x <- list(x1, x2) 11 | expect_equal(head_list(x)[[1L]], utils::head(x[[1L]])) 12 | }) 13 | 14 | test_that("reorganize_list() fails for non-list inputs", { 15 | expect_error(reorganize_list(alist = 1:10)) 16 | }) 17 | 18 | test_that("wcolMeans() gives the same as colMeans() in unweighted case", { 19 | X <- cbind(1:3, 2:4) 20 | expect_equal(c(wcolMeans(X)), colMeans(X)) 21 | expect_equal(c(wcolMeans(X, w = c(1, 1, 1))), colMeans(X)) 22 | expect_equal(c(wcolMeans(X, w = c(2, 2, 2))), colMeans(X)) 23 | }) 24 | 25 | test_that("exact_Z() works for both kernel- and permshap", { 26 | p <- 5 27 | nm <- LETTERS[1:p] 28 | r1 <- exact_Z(p, feature_names = nm, keep_extremes = TRUE) 29 | r2 <- exact_Z(p, feature_names = nm, keep_extremes = FALSE) 30 | expect_equal(r1[2:(nrow(r1) - 1L), ], r2) 31 | expect_equal(colnames(r1), nm) 32 | expect_equal(dim(r1), c(2^p, p)) 33 | }) 34 | 35 | test_that("rep_rows() gives the same as usual subsetting (except rownames)", { 36 | setrn <- function(x) {rownames(x) <- 1:nrow(x); x} 37 | 38 | expect_equal(rep_rows(iris, 1), iris[1, ]) 39 | expect_equal(rep_rows(iris, 2:1), setrn(iris[2:1, ])) 40 | expect_equal(rep_rows(iris, c(1, 1, 1)), setrn(iris[c(1, 1, 1), ])) 41 | 42 | ir <- iris[1, ] 43 | ir$y <- list(list(a = 1, b = 2)) 44 | expect_equal(rep_rows(ir, c(1, 1)), setrn(ir[c(1, 1), ])) 45 | }) 46 | 47 | test_that("rep_rows() gives the same as usual subsetting for matrices", { 48 | ir <- data.matrix(iris[1:4]) 49 | 50 | expect_equal(rep_rows(ir, c(1, 1, 2)), ir[c(1, 1, 2), ]) 51 | expect_equal(rep_rows(ir, 1), ir[1, , drop = FALSE]) 52 | }) 53 | 54 | 55 | # Unit tests copied from {hstats} 56 | 57 | test_that("rep_each() works", { 58 | expect_equal(rep_each(3, 10), rep_each(3L, 10L)) 59 | expect_equal(rep_each(3, 10), rep(1:3, each = 10)) 60 | expect_true(is.integer(rep_each(100, 100))) 61 | }) 62 | 63 | test_that("wrowmean_vector() works for 1D matrices", { 64 | x2 <- cbind(a = 6:1) 65 | out2 <- wrowmean_vector(x2, ngroups = 2L) 66 | expec <- rowsum(x2, group = rep(1:2, each = 3)) / 3 67 | rownames(expec) <- NULL 68 | 69 | expect_error(wrowmean_vector(matrix(1:4, ncol = 2L))) 70 | expect_equal(out2, expec) 71 | 72 | expect_equal(wrowmean_vector(x2, ngroups = 3L), cbind(a = c(5.5, 3.5, 1.5))) 73 | 74 | # Constant weights have no effect 75 | expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(1, 1, 1)), out2) 76 | expect_equal(wrowmean_vector(x2, ngroups = 2L, w = c(4, 4, 4)), out2) 77 | 78 | # Non-constant weights 79 | a <- weighted.mean(6:4, 1:3) 80 | b <- weighted.mean(3:1, 1:3) 81 | out2 <- wrowmean_vector(x2, ngroups = 2L, w = 1:3) 82 | expect_equal(out2, cbind(a = c(a, b))) 83 | }) 84 | 85 | test_that("align_pred() works", { 86 | expect_error(align_pred(c("A", "B"))) 87 | expect_error(align_pred(factor(c("A", "B")))) 88 | expect_equal(align_pred(1:4), as.matrix(1:4)) 89 | }) 90 | 91 | -------------------------------------------------------------------------------- /tests/testthat/test-weights.R: -------------------------------------------------------------------------------- 1 | # Model with non-linearities and interactions 2 | fit <- lm( 3 | Sepal.Length ~ poly(Petal.Width, degree = 2L) * Species, 4 | data = iris, 5 | weights = Petal.Length 6 | ) 7 | x <- c("Petal.Width", "Species") 8 | preds <- unname(predict(fit, iris)) 9 | J <- c(1L, 51L, 101L) 10 | w <- iris$Petal.Length 11 | 12 | shap <- list( 13 | kernelshap(fit, iris[x], bg_X = iris, bg_w = w, verbose = FALSE), 14 | permshap(fit, iris[x], bg_X = iris, bg_w = w, verbose = FALSE) 15 | ) 16 | 17 | test_that("constant weights gives same as no weights", { 18 | shap_unweighted <- list( 19 | kernelshap(fit, iris[x], bg_X = iris, verbose = FALSE), 20 | permshap(fit, iris[x], bg_X = iris, verbose = FALSE) 21 | ) 22 | 23 | w2 <- rep(3, nrow(iris)) 24 | shap2 <- list( 25 | kernelshap(fit, iris[x], bg_X = iris, bg_w = w2, verbose = FALSE), 26 | permshap(fit, iris[x], bg_X = iris, bg_w = w2, verbose = FALSE) 27 | ) 28 | 29 | for (i in seq_along(shap)) 30 | expect_equal(shap2[[i]]$S, shap_unweighted[[i]]$S) 31 | }) 32 | 33 | test_that("baseline equals average prediction on background data", { 34 | for (s in shap) 35 | expect_equal(s$baseline, weighted.mean(iris$Sepal.Length, w)) 36 | }) 37 | 38 | test_that("SHAP + baseline = prediction for exact mode", { 39 | for (s in shap) 40 | expect_equal(rowSums(s$S) + s$baseline, preds) 41 | }) 42 | 43 | test_that("Decomposing a single row works", { 44 | for (algo in c(kernelshap, permshap)) { 45 | s <- algo(fit, iris[1L, x], bg_X = iris, bg_w = w, verbose = FALSE) 46 | expect_equal(s$baseline, weighted.mean(iris$Sepal.Length, w)) 47 | expect_equal(rowSums(s$S) + s$baseline, preds[1]) 48 | } 49 | }) 50 | 51 | test_that("auto-selection of background data works", { 52 | # Here, the background data equals the full X 53 | shap2 <- list( 54 | kernelshap(fit, iris[x], bg_w = w, verbose = FALSE), 55 | permshap(fit, iris[x], bg_w = w, verbose = FALSE) 56 | ) 57 | 58 | for (i in 1:2) { 59 | expect_equal(shap$S, shap2$S) 60 | } 61 | }) 62 | 63 | test_that("selection of bg_X can be controlled via bg_n", { 64 | n <- 20L 65 | for (algo in c(kernelshap, permshap)) { 66 | s <- algo(fit, iris, bg_w = w, verbose = FALSE, bg_n = n) 67 | expect_equal(nrow(s$bg_X), n) 68 | } 69 | }) 70 | 71 | test_that("weights must have correct length", { 72 | for (algo in c(kernelshap, permshap)) { 73 | expect_error(algo(fit, iris[J, ], bg_X = iris, bg_w = 1:3, verbose = FALSE)) 74 | } 75 | }) 76 | 77 | test_that("weights can't be all 0", { 78 | for (algo in c(kernelshap, permshap)) { 79 | expect_error( 80 | algo(fit, iris[J, ], bg_X = iris, bg_w = rep(0, nrow(iris)), verbose = FALSE) 81 | ) 82 | } 83 | }) 84 | 85 | test_that("weights can't be negative", { 86 | for (algo in c(kernelshap, permshap)) { 87 | expect_error( 88 | algo(fit, iris[J, ], bg_X = iris, bg_w = rep(-1, nrow(iris)), verbose = FALSE) 89 | ) 90 | } 91 | }) 92 | 93 | --------------------------------------------------------------------------------