├── .github └── workflows │ ├── publish_documentation.yml │ ├── publish_pypi.yml │ └── unit_tests.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bokbokbok ├── __init__.py ├── eval_metrics │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ ├── binary_eval_metrics.py │ │ └── multiclass_eval_metrics.py │ └── regression │ │ ├── __init__.py │ │ └── regression_eval_metrics.py ├── loss_functions │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ └── classification_loss_functions.py │ └── regression │ │ ├── __init__.py │ │ └── regression_loss_functions.py └── utils │ ├── __init__.py │ └── functions.py ├── docs ├── derivations │ ├── focal.md │ ├── log_cosh.md │ ├── note.md │ └── wce.md ├── getting_started │ └── install.md ├── img │ └── bokbokbok.png ├── index.md ├── reference │ ├── eval_metrics_binary.md │ ├── eval_metrics_multiclass.md │ ├── eval_metrics_regression.md │ ├── loss_functions_classification.md │ └── loss_functions_regression.md └── tutorials │ ├── F1_score.ipynb │ ├── RMSPE.ipynb │ ├── focal_loss.ipynb │ ├── log_cosh_loss.ipynb │ ├── quadratic_weighted_kappa.ipynb │ └── weighted_cross_entropy.ipynb ├── mkdocs.yml ├── setup.py └── tests ├── test_focal.py ├── test_utils.py └── test_wce.py /.github/workflows/publish_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Publish Documentation 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [created] 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Setup conda 15 | uses: conda-incubator/setup-miniconda@v2 16 | with: 17 | miniconda-version: "latest" 18 | python-version: "3.7" 19 | activate-environment: deploydocs 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install --upgrade setuptools wheel twine 25 | pip install ".[all]" 26 | # In our docs, we need to output static images 27 | # That requires additional setup 28 | conda install --yes -c anaconda psutil 29 | conda install --yes -c plotly plotly-orca 30 | - name: Deploy mkdocs site 31 | run: | 32 | mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install --upgrade setuptools wheel twine 20 | - name: Make sure unit tests succeed 21 | run: | 22 | pip3 install ".[all]" 23 | pytest 24 | - name: Build package & publish to PyPi 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/unit_tests.yml: -------------------------------------------------------------------------------- 1 | name: Development 2 | on: 3 | # Trigger the workflow on push or pull request, 4 | # but only for the main branch 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | jobs: 10 | run: 11 | name: Run unit tests 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | build: [ubuntu] 16 | include: 17 | - build: ubuntu 18 | os: ubuntu-latest 19 | python-version: [3.6, 3.7, 3.8] 20 | steps: 21 | - uses: actions/checkout@master 22 | - name: Setup Python 23 | uses: actions/setup-python@master 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | pip3 install --upgrade setuptools pip 29 | pip3 install ".[all]" 30 | - name: Run unit tests and linters 31 | run: | 32 | pytest tests/ 33 | - name: Static code checking with pyflakes 34 | run: | 35 | pip3 install pyflakes 36 | pyflakes bokbokbok 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /notebooks/private/* 2 | notebooks/* 3 | .idea 4 | *.log 5 | tmd/csv_files/ 6 | tmd/fig_private/* 7 | notebooks/private/ 8 | .gitignore.swp 9 | # Created by https://www.gitignore.io/api/macos,python 10 | # Edit at https://www.gitignore.io/?templates=macos,python 11 | 12 | ### DATA ### 13 | # to not push utils 14 | ### macOS ### 15 | # General 16 | .DS_Store 17 | .AppleDouble 18 | .LSOverride 19 | .README.md.swp 20 | # Icon must end with two \r 21 | Icon 22 | # Thumbnails 23 | ._* 24 | # Files that might appear in the root of a volume 25 | .DocumentRevisions-V100 26 | .fseventsd 27 | .Spotlight-V100 28 | .TemporaryItems 29 | .Trashes 30 | .VolumeIcon.icns 31 | .com.apple.timemachine.donotpresent 32 | # Directories potentially created on remote AFP share 33 | .AppleDB 34 | .AppleDesktop 35 | Network Trash Folder 36 | Temporary Items 37 | .apdisk 38 | ### Python ### 39 | # Byte-compiled / optimized / DLL files 40 | __pycache__/ 41 | *.py[cod] 42 | *.py.swp 43 | *$py.class 44 | # C extensions 45 | *.so 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | pip-wheel-metadata/ 61 | share/python-wheels/ 62 | *.egg-info/ 63 | .installed.cfg 64 | *.egg 65 | MANIFEST 66 | # PyInstaller 67 | # Usually these files are written by a python script from a template 68 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 69 | *.manifest 70 | *.spec 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | # Unit test / coverage reports 75 | htmlcov/ 76 | .tox/ 77 | .nox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | # Translations 87 | *.mo 88 | *.pot 89 | # Django stuff: 90 | local_settings.py 91 | db.sqlite3 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | # Scrapy stuff: 96 | .scrapy 97 | # Sphinx documentation 98 | docs/_build/ 99 | # PyBuilder 100 | target/ 101 | # Jupyter Notebook 102 | .ipynb_checkpoints 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | # pyenv 107 | .python-version 108 | # celery beat schedule file 109 | celerybeat-schedule 110 | # SageMath parsed files 111 | *.sage.py 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | # Rope project settings 124 | .ropeproject 125 | # mkdocs documentation 126 | /site 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | # Pyre type checker 132 | .pyre/ 133 | ### Python Patch ### 134 | .venv/ 135 | # End of https://www.gitignore.io/api/macos,python 136 | notebooks/demo/.ipynb_checkpoints/ 137 | #TMD documentation 138 | *.ipynb_checkpoints/ 139 | site/ 140 | \!docs/** 141 | # Created by https://www.toptal.com/developers/gitignore/api/r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks 142 | # Edit at https://www.toptal.com/developers/gitignore?templates=r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks 143 | ### JupyterNotebooks ### 144 | # gitignore template for Jupyter Notebooks 145 | # website: http://jupyter.org/ 146 | */.ipynb_checkpoints/* 147 | # Remove previous ipynb_checkpoints 148 | # git rm -r .ipynb_checkpoints/ 149 | Icon 150 | ### PyCharm ### 151 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 152 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 153 | # User-specific stuff 154 | .idea/**/workspace.xml 155 | .idea/**/tasks.xml 156 | .idea/**/usage.statistics.xml 157 | .idea/**/dictionaries 158 | .idea/**/shelf 159 | # Generated files 160 | .idea/**/contentModel.xml 161 | # Sensitive or high-churn files 162 | .idea/**/dataSources/ 163 | .idea/**/dataSources.ids 164 | .idea/**/dataSources.local.xml 165 | .idea/**/sqlDataSources.xml 166 | .idea/**/dynamic.xml 167 | .idea/**/uiDesigner.xml 168 | .idea/**/dbnavigator.xml 169 | # Gradle 170 | .idea/**/gradle.xml 171 | .idea/**/libraries 172 | # Gradle and Maven with auto-import 173 | # When using Gradle or Maven with auto-import, you should exclude module files, 174 | # since they will be recreated, and may cause churn. Uncomment if using 175 | # auto-import. 176 | # .idea/artifacts 177 | # .idea/compiler.xml 178 | # .idea/jarRepositories.xml 179 | # .idea/modules.xml 180 | # .idea/*.iml 181 | # .idea/modules 182 | # *.iml 183 | # *.ipr 184 | # CMake 185 | cmake-build-*/ 186 | # Mongo Explorer plugin 187 | .idea/**/mongoSettings.xml 188 | # File-based project format 189 | *.iws 190 | # IntelliJ 191 | out/ 192 | # mpeltonen/sbt-idea plugin 193 | .idea_modules/ 194 | # JIRA plugin 195 | atlassian-ide-plugin.xml 196 | # Cursive Clojure plugin 197 | .idea/replstate.xml 198 | # Crashlytics plugin (for Android Studio and IntelliJ) 199 | com_crashlytics_export_strings.xml 200 | crashlytics.properties 201 | crashlytics-build.properties 202 | fabric.properties 203 | # Editor-based Rest Client 204 | .idea/httpRequests 205 | # Android studio 3.1+ serialized cache file 206 | .idea/caches/build_file_checksums.ser 207 | ### PyCharm Patch ### 208 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 209 | # modules.xml 210 | # .idea/misc.xml 211 | # Sonarlint plugin 212 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 213 | .idea/**/sonarlint/ 214 | # SonarQube Plugin 215 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 216 | .idea/**/sonarIssues.xml 217 | # Markdown Navigator plugin 218 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 219 | .idea/**/markdown-navigator.xml 220 | .idea/**/markdown-navigator-enh.xml 221 | .idea/**/markdown-navigator/ 222 | # Cache file creation bug 223 | # See https://youtrack.jetbrains.com/issue/JBR-2257 224 | .idea/$CACHE_FILE$ 225 | # CodeStream plugin 226 | # https://plugins.jetbrains.com/plugin/12206-codestream 227 | .idea/codestream.xml 228 | *.py,cover 229 | pytestdebug.log 230 | db.sqlite3-journal 231 | doc/_build/ 232 | # pipenv 233 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 234 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 235 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 236 | # install all needed dependencies. 237 | #Pipfile.lock 238 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 239 | __pypackages__/ 240 | # Celery stuff 241 | celerybeat.pid 242 | # pytype static type analyzer 243 | .pytype/ 244 | ### R ### 245 | # History files 246 | .Rhistory 247 | .Rapp.history 248 | # Session Data files 249 | .RData 250 | # User-specific files 251 | .Ruserdata 252 | # Example code in package build process 253 | *-Ex.R 254 | # Output files from R CMD build 255 | /*.tar.gz 256 | # Output files from R CMD check 257 | /*.Rcheck/ 258 | # RStudio files 259 | .Rproj.user/ 260 | # produced vignettes 261 | vignettes/*.html 262 | vignettes/*.pdf 263 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 264 | .httr-oauth 265 | # knitr and R markdown default cache directories 266 | *_cache/ 267 | /cache/ 268 | # Temporary files created by R markdown 269 | *.utf8.md 270 | *.knit.md 271 | # R Environment Variables 272 | .Renviron 273 | ### R.Bookdown Stack ### 274 | # R package: bookdown caching files 275 | /*_files/ 276 | ### VisualStudioCode ### 277 | .vscode/* 278 | *.code-workspace 279 | ### VisualStudioCode Patch ### 280 | # Ignore all local history of files 281 | .history 282 | ### Windows ### 283 | # Windows thumbnail cache files 284 | Thumbs.db 285 | Thumbs.db:encryptable 286 | ehthumbs.db 287 | ehthumbs_vista.db 288 | # Dump file 289 | *.stackdump 290 | # Folder config file 291 | [Dd]esktop.ini 292 | # Recycle Bin used on file shares 293 | $RECYCLE.BIN/ 294 | # Windows Installer files 295 | *.cab 296 | *.msi 297 | *.msix 298 | *.msm 299 | *.msp 300 | # Windows shortcuts 301 | *.lnk 302 | ### VisualStudio ### 303 | ## Ignore Visual Studio temporary files, build results, and 304 | ## files generated by popular Visual Studio add-ons. 305 | ## 306 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 307 | *.rsuser 308 | *.suo 309 | *.user 310 | *.userosscache 311 | *.sln.docstates 312 | # User-specific files (MonoDevelop/Xamarin Studio) 313 | *.userprefs 314 | # Mono auto generated files 315 | mono_crash.* 316 | # Build results 317 | [Dd]ebug/ 318 | [Dd]ebugPublic/ 319 | [Rr]elease/ 320 | [Rr]eleases/ 321 | x64/ 322 | x86/ 323 | [Aa][Rr][Mm]/ 324 | [Aa][Rr][Mm]64/ 325 | bld/ 326 | [Bb]in/ 327 | [Oo]bj/ 328 | [Ll]og/ 329 | [Ll]ogs/ 330 | # Visual Studio 2015/2017 cache/options directory 331 | .vs/ 332 | # Uncomment if you have tasks that create the project's static files in wwwroot 333 | #wwwroot/ 334 | # Visual Studio 2017 auto generated files 335 | Generated\ Files/ 336 | # MSTest test Results 337 | [Tt]est[Rr]esult*/ 338 | [Bb]uild[Ll]og.* 339 | # NUnit 340 | *.VisualState.xml 341 | TestResult.xml 342 | nunit-*.xml 343 | # Build Results of an ATL Project 344 | [Dd]ebugPS/ 345 | [Rr]eleasePS/ 346 | dlldata.c 347 | # Benchmark Results 348 | BenchmarkDotNet.Artifacts/ 349 | # .NET Core 350 | project.lock.json 351 | project.fragment.lock.json 352 | artifacts/ 353 | # StyleCop 354 | StyleCopReport.xml 355 | # Files built by Visual Studio 356 | *_i.c 357 | *_p.c 358 | *_h.h 359 | *.ilk 360 | *.meta 361 | *.obj 362 | *.iobj 363 | *.pch 364 | *.pdb 365 | *.ipdb 366 | *.pgc 367 | *.pgd 368 | *.rsp 369 | *.sbr 370 | *.tlb 371 | *.tli 372 | *.tlh 373 | *.tmp 374 | *.tmp_proj 375 | *_wpftmp.csproj 376 | *.vspscc 377 | *.vssscc 378 | .builds 379 | *.pidb 380 | *.svclog 381 | *.scc 382 | # Chutzpah Test files 383 | _Chutzpah* 384 | # Visual C++ cache files 385 | ipch/ 386 | *.aps 387 | *.ncb 388 | *.opendb 389 | *.opensdf 390 | *.sdf 391 | *.cachefile 392 | *.VC.db 393 | *.VC.VC.opendb 394 | # Visual Studio profiler 395 | *.psess 396 | *.vsp 397 | *.vspx 398 | *.sap 399 | # Visual Studio Trace Files 400 | *.e2e 401 | # TFS 2012 Local Workspace 402 | $tf/ 403 | # Guidance Automation Toolkit 404 | *.gpState 405 | # ReSharper is a .NET coding add-in 406 | _ReSharper*/ 407 | *.[Rr]e[Ss]harper 408 | *.DotSettings.user 409 | # TeamCity is a build add-in 410 | _TeamCity* 411 | # DotCover is a Code Coverage Tool 412 | *.dotCover 413 | # AxoCover is a Code Coverage Tool 414 | .axoCover/* 415 | !.axoCover/settings.json 416 | # Coverlet is a free, cross platform Code Coverage Tool 417 | coverage*[.json, .xml, .info] 418 | # Visual Studio code coverage results 419 | *.coverage 420 | *.coveragexml 421 | # NCrunch 422 | _NCrunch_* 423 | .*crunch*.local.xml 424 | nCrunchTemp_* 425 | # MightyMoose 426 | *.mm.* 427 | AutoTest.Net/ 428 | # Web workbench (sass) 429 | .sass-cache/ 430 | # Installshield output folder 431 | [Ee]xpress/ 432 | # DocProject is a documentation generator add-in 433 | DocProject/buildhelp/ 434 | DocProject/Help/*.HxT 435 | DocProject/Help/*.HxC 436 | DocProject/Help/*.hhc 437 | DocProject/Help/*.hhk 438 | DocProject/Help/*.hhp 439 | DocProject/Help/Html2 440 | DocProject/Help/html 441 | # Click-Once directory 442 | publish/ 443 | # Publish Web Output 444 | *.[Pp]ublish.xml 445 | *.azurePubxml 446 | # Note: Comment the next line if you want to checkin your web deploy settings, 447 | # but database connection strings (with potential passwords) will be unencrypted 448 | *.pubxml 449 | *.publishproj 450 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 451 | # checkin your Azure Web App publish settings, but sensitive information contained 452 | # in these scripts will be unencrypted 453 | PublishScripts/ 454 | # NuGet Packages 455 | *.nupkg 456 | # NuGet Symbol Packages 457 | *.snupkg 458 | # The packages folder can be ignored because of Package Restore 459 | **/[Pp]ackages/* 460 | # except build/, which is used as an MSBuild target. 461 | !**/[Pp]ackages/build/ 462 | # Uncomment if necessary however generally it will be regenerated when needed 463 | #!**/[Pp]ackages/repositories.config 464 | # NuGet v3's project.json files produces more ignorable files 465 | *.nuget.props 466 | *.nuget.targets 467 | # Microsoft Azure Build Output 468 | csx/ 469 | *.build.csdef 470 | # Microsoft Azure Emulator 471 | ecf/ 472 | rcf/ 473 | # Windows Store app package directories and files 474 | AppPackages/ 475 | BundleArtifacts/ 476 | Package.StoreAssociation.xml 477 | _pkginfo.txt 478 | *.appx 479 | *.appxbundle 480 | *.appxupload 481 | # Visual Studio cache files 482 | # files ending in .cache can be ignored 483 | *.[Cc]ache 484 | # but keep track of directories ending in .cache 485 | !?*.[Cc]ache/ 486 | # Others 487 | ClientBin/ 488 | ~$* 489 | *~ 490 | *.dbmdl 491 | *.dbproj.schemaview 492 | *.jfm 493 | *.pfx 494 | *.publishsettings 495 | orleans.codegen.cs 496 | # Including strong name files can present a security risk 497 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 498 | #*.snk 499 | # Since there are multiple workflows, uncomment next line to ignore bower_components 500 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 501 | #bower_components/ 502 | # RIA/Silverlight projects 503 | Generated_Code/ 504 | # Backup & report files from converting an old project file 505 | # to a newer Visual Studio version. Backup files are not needed, 506 | # because we have git ;-) 507 | _UpgradeReport_Files/ 508 | Backup*/ 509 | UpgradeLog*.XML 510 | UpgradeLog*.htm 511 | ServiceFabricBackup/ 512 | *.rptproj.bak 513 | # SQL Server files 514 | *.mdf 515 | *.ldf 516 | *.ndf 517 | # Business Intelligence projects 518 | *.rdl.data 519 | *.bim.layout 520 | *.bim_*.settings 521 | *.rptproj.rsuser 522 | *- [Bb]ackup.rdl 523 | *- [Bb]ackup ([0-9]).rdl 524 | *- [Bb]ackup ([0-9][0-9]).rdl 525 | # Microsoft Fakes 526 | FakesAssemblies/ 527 | # GhostDoc plugin setting file 528 | *.GhostDoc.xml 529 | # Node.js Tools for Visual Studio 530 | .ntvs_analysis.dat 531 | node_modules/ 532 | # Visual Studio 6 build log 533 | *.plg 534 | # Visual Studio 6 workspace options file 535 | *.opt 536 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 537 | *.vbw 538 | # Visual Studio LightSwitch build output 539 | **/*.HTMLClient/GeneratedArtifacts 540 | **/*.DesktopClient/GeneratedArtifacts 541 | **/*.DesktopClient/ModelManifest.xml 542 | **/*.Server/GeneratedArtifacts 543 | **/*.Server/ModelManifest.xml 544 | _Pvt_Extensions 545 | # Paket dependency manager 546 | .paket/paket.exe 547 | paket-files/ 548 | # FAKE - F# Make 549 | .fake/ 550 | # CodeRush personal settings 551 | .cr/personal 552 | # Python Tools for Visual Studio (PTVS) 553 | *.pyc 554 | # Cake - Uncomment if you are using it 555 | # tools/** 556 | # !tools/packages.config 557 | # Tabs Studio 558 | *.tss 559 | # Telerik's JustMock configuration file 560 | *.jmconfig 561 | # BizTalk build output 562 | *.btp.cs 563 | *.btm.cs 564 | *.odx.cs 565 | *.xsd.cs 566 | # OpenCover UI analysis results 567 | OpenCover/ 568 | # Azure Stream Analytics local run output 569 | ASALocalRun/ 570 | # MSBuild Binary and Structured Log 571 | *.binlog 572 | # NVidia Nsight GPU debugger configuration file 573 | *.nvuser 574 | # MFractors (Xamarin productivity tool) working folder 575 | .mfractor/ 576 | # Local History for Visual Studio 577 | .localhistory/ 578 | # BeatPulse healthcheck temp database 579 | healthchecksdb 580 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 581 | MigrationBackup/ 582 | # Ionide (cross platform F# VS Code tools) working folder 583 | .ionide/ 584 | # End of https://www.toptal.com/developers/gitignore/api/r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks 585 | 586 | *.ipynb 587 | *.pptx 588 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | `bokbokbok` aims to integrate custom loss functions to LightGBM, XGBoost and CatBoost. 4 | To add a loss function / eval metric / to contibute in general please follow these steps: 5 | 6 | - Discuss the feature you want to add on Github before you write a PR for it. On disagreements, maintainer(s) will have the final word. 7 | - If you’re going to add a loss function, please contribute the derivations of gradients and Hessians. 8 | - When issues or pull requests are not going to be resolved or merged, they should be closed as soon as possible. 9 | This is kinder than deciding this after a long period. Our issue tracker should reflect work to be done. 10 | 11 | That said, there are many ways to contribute to bokbokbok, including: 12 | 13 | - Contribution to code 14 | - Improving the documentation 15 | - Reviewing merge requests 16 | - Investigating bugs 17 | - Reporting issues 18 | 19 | Starting out with open source? See the guide [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/) and have a look at [our issues labelled *good first issue*](https://github.com/orchardbirds/bokbokbok/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). 20 | 21 | ## Setup 22 | 23 | Development install: 24 | 25 | ```shell 26 | pip install -e '.[all]' 27 | ``` 28 | 29 | Run unit tests with 30 | 31 | ```shell 32 | pytest 33 | ``` 34 | 35 | ## Standards 36 | 37 | - Python 3.6+ 38 | - Follow [PEP8](http://pep8.org/) as closely as possible (except line length) 39 | - [google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/) 40 | - Git: Include a short description of *what* and *why* was done, *how* can be seen in the code. Use present tense, imperative mood 41 | - Git: limit the length of the first line to 72 chars. You can use multiple messages to specify a second (longer) line: `git commit -m "Patch load function" -m "This is a much longer explanation of what was done"` 42 | 43 | 44 | ### Derivations 45 | 46 | We use [Code cogs](https://www.codecogs.com/latex/eqneditor.php) to generate equations that are compatible with Git and markdown. 47 | To use an equation, choose svg format and HTML embedding and copy the link at the bottom of the page. 48 | 49 | ## Releases and versioning 50 | 51 | We use [semver](https://semver.org/) for versioning. When we are ready for a release, the maintainer runs: 52 | 53 | ```shell 54 | git tag -a v0.1 -m "bokbokbok v0.1" && git push origin v0.1 55 | ``` 56 | 57 | When we create a new github release a [github action](https://github.com/ing-bank/skorecard/blob/main/.github/workflows/publish_pypi.yml) is triggered that: 58 | 59 | - a new version will be deployed to pypi 60 | - the docs will be re-built and deployed 61 | 62 | 63 | ### Documentation 64 | 65 | Documentation is a very crucial part of the project, because it ensures usability of the package. We develop the docs in the following way: 66 | 67 | * We use [mkdocs](https://www.mkdocs.org/) with [mkdocs-material](https://squidfunk.github.io/mkdocs-material/) theme. The `docs/` folder contains all the relevant documentation. 68 | * We use `mkdocs serve` to view the documentation locally. Use it to test the documentation every time you make any changes. 69 | * Maintainers can deploy the docs using `mkdocs gh-deploy`. The documentation is deployed to `https://orchardbirds.github.io/`. 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dan Timbrell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![PyPi Version](https://img.shields.io/pypi/pyversions/bokbokbok)](#) 4 | [![PyPI](https://img.shields.io/pypi/v/bokbokbok)](#) 5 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/bokbokbok)](#) 6 | ![GitHub contributors](https://img.shields.io/github/contributors/orchardbirds/bokbokbok) 7 | 8 | 9 | # bokbokbok 10 | 11 | ## Overview 12 | 13 | **bokbokbok** is a python package that enables you to use Custom Loss Functions and Evaluation Metrics for XGBoost and LightGBM. 14 | Main features: 15 | 16 | - [Weighted Cross Entropy](https://orchardbirds.github.io/bokbokbok/tutorials/weighted_cross_entropy.html) 17 | - [Weighted Focal Loss](https://orchardbirds.github.io/bokbokbok/tutorials/focal_loss.html) 18 | - [Log Cosh Loss](https://orchardbirds.github.io/bokbokbok/tutorials/log_cosh_loss.html) 19 | - [Root Mean Squared Percentage Error](https://orchardbirds.github.io/bokbokbok/tutorials/RMSPE.html) 20 | - [F1 score](https://orchardbirds.github.io/bokbokbok/tutorials/F1_score.html) 21 | - [Quadratic Weighted Kappa](https://orchardbirds.github.io/bokbokbok/tutorials/quadratic_weighted_kappa.html) 22 | 23 | ## Installation 24 | 25 | ```bash 26 | pip install bokbokbok 27 | ``` 28 | 29 | ## Documentation 30 | 31 | The documentation can [be found here.](https://orchardbirds.github.io/bokbokbok/) 32 | 33 | ## Contributing 34 | 35 | To learn more about making a contribution to bokbokbok, please see [CONTRIBUTING.md.](https://github.com/orchardbirds/bokbokbok/blob/main/CONTRIBUTING.md) 36 | -------------------------------------------------------------------------------- /bokbokbok/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/__init__.py -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/eval_metrics/__init__.py -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """Import required metrics.""" 2 | 3 | 4 | from .binary_eval_metrics import ( 5 | WeightedCrossEntropyMetric, 6 | WeightedFocalMetric, 7 | F1_Score_Binary, 8 | ) 9 | 10 | from .multiclass_eval_metrics import ( 11 | QuadraticWeightedKappaMetric, 12 | ) 13 | 14 | __all__ = [ 15 | "WeightedCrossEntropyMetric", 16 | "WeightedFocalMetric", 17 | "F1_Score_Binary", 18 | "QuadraticWeightedKappaMetric", 19 | ] -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/classification/binary_eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import f1_score 3 | from bokbokbok.utils import clip_sigmoid 4 | 5 | 6 | def WeightedCrossEntropyMetric(alpha=0.5, XGBoost=False): 7 | """ 8 | Calculates the Weighted Cross Entropy Metric by applying a weighting factor alpha, allowing one to 9 | trade off recall and precision by up- or down-weighting the cost of a positive error relative to a 10 | negative error. 11 | 12 | A value alpha > 1 decreases the false negative count, hence increasing the recall. 13 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision. 14 | 15 | Args: 16 | alpha (float): The scale to be applied. 17 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 18 | Note that you should also set `maximize=False` in the XGBoost train function 19 | 20 | """ 21 | 22 | 23 | def weighted_cross_entropy_metric(yhat, dtrain, alpha=alpha, XGBoost=XGBoost): 24 | """ 25 | Weighted Cross Entropy Metric. 26 | 27 | Args: 28 | yhat: Predictions 29 | dtrain: The XGBoost / LightGBM dataset 30 | alpha (float): Scale applied 31 | XGBoost (Bool): If XGBoost is to be implemented 32 | 33 | Returns: 34 | Name of the eval metric, Eval score, Bool to minimise function 35 | 36 | """ 37 | y = dtrain.get_label() 38 | yhat = clip_sigmoid(yhat) 39 | elements = - alpha * y * np.log(yhat) - (1 - y) * np.log(1 - yhat) 40 | if XGBoost: 41 | return f'WCE_alpha{alpha}', (np.sum(elements) / len(y)) 42 | else: 43 | return f'WCE_alpha{alpha}', (np.sum(elements) / len(y)), False 44 | 45 | return weighted_cross_entropy_metric 46 | 47 | 48 | def WeightedFocalMetric(alpha=1.0, gamma=2.0, XGBoost=False): 49 | """ 50 | Implements [alpha-weighted Focal Loss](https://arxiv.org/pdf/1708.02002.pdf) 51 | 52 | The more gamma is increased, the more the model is focussed on the hard, misclassified examples. 53 | 54 | A value alpha > 1 decreases the false negative count, hence increasing the recall. 55 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision. 56 | 57 | Args: 58 | alpha (float): The scale to be applied. 59 | gamma (float): The focusing parameter to be applied 60 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 61 | Note that you should also set `maximize=False` in the XGBoost train function 62 | """ 63 | 64 | def focal_metric(yhat, dtrain, alpha=alpha, gamma=gamma, XGBoost=XGBoost): 65 | """ 66 | Weighted Focal Loss Metric. 67 | 68 | Args: 69 | yhat: Predictions 70 | dtrain: The XGBoost / LightGBM dataset 71 | alpha (float): Scale applied 72 | gamma (float): Focusing parameter 73 | XGBoost (Bool): If XGBoost is to be implemented 74 | 75 | Returns: 76 | Name of the eval metric, Eval score, Bool to minimise function 77 | 78 | """ 79 | y = dtrain.get_label() 80 | yhat = clip_sigmoid(yhat) 81 | 82 | elements = (- alpha * y * np.log(yhat) * np.power(1 - yhat, gamma) - 83 | (1 - y) * np.log(1 - yhat) * np.power(yhat, gamma)) 84 | 85 | if XGBoost: 86 | return f'Focal_alpha{alpha}_gamma{gamma}', (np.sum(elements) / len(y)) 87 | else: 88 | return f'Focal_alpha{alpha}_gamma{gamma}', (np.sum(elements) / len(y)), False 89 | 90 | return focal_metric 91 | 92 | 93 | def F1_Score_Binary(XGBoost=False, *args, **kwargs): 94 | """ 95 | Implements the f1_score metric 96 | [from scikit learn](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn-metrics-f1-score) 97 | 98 | Args: 99 | *args: The arguments to be fed into the scikit learn metric. 100 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 101 | Note that you should also set `maximize=True` in the XGBoost train function 102 | 103 | """ 104 | def binary_f1_score(yhat, data, XGBoost=XGBoost): 105 | """ 106 | F1 Score. 107 | 108 | Args: 109 | yhat: Predictions 110 | dtrain: The XGBoost / LightGBM dataset 111 | XGBoost (Bool): If XGBoost is to be implemented 112 | 113 | Returns: 114 | Name of the eval metric, Eval score, Bool to maximise function 115 | """ 116 | y_true = data.get_label() 117 | yhat = np.round(yhat) 118 | if XGBoost: 119 | return 'F1', f1_score(y_true, yhat, *args, **kwargs) 120 | else: 121 | return 'F1', f1_score(y_true, yhat, *args, **kwargs), True 122 | 123 | return binary_f1_score 124 | -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/classification/multiclass_eval_metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import cohen_kappa_score 2 | import numpy as np 3 | 4 | 5 | def QuadraticWeightedKappaMetric(XGBoost=False): 6 | """ 7 | Calculates the Weighted Cross Entropy Metric by applying a weighting factor alpha, allowing one to 8 | trade off recall and precision by up- or down-weighting the cost of a positive error relative to a 9 | negative error. 10 | 11 | A value alpha > 1 decreases the false negative count, hence increasing the recall. 12 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision. 13 | 14 | Args: 15 | alpha (float): The scale to be applied. 16 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 17 | Note that you should also set `maximize=False` in the XGBoost train function 18 | 19 | """ 20 | 21 | 22 | def quadratic_weighted_kappa_metric(yhat, dtrain, XGBoost=XGBoost): 23 | """ 24 | Weighted Cross Entropy Metric. 25 | 26 | Args: 27 | yhat: Predictions 28 | dtrain: The XGBoost / LightGBM dataset 29 | XGBoost (Bool): If XGBoost is to be implemented 30 | 31 | Returns: 32 | Name of the eval metric, Eval score, Bool to maximise function 33 | 34 | """ 35 | y = dtrain.get_label() 36 | num_class = len(np.unique(dtrain.get_label())) 37 | 38 | if XGBoost == False: 39 | # LightGBM needs extra reshaping 40 | yhat = yhat.reshape(num_class, len(y)).T 41 | yhat = yhat.argmax(axis=1) 42 | 43 | qwk = cohen_kappa_score(y, yhat, weights="quadratic") 44 | 45 | if XGBoost: 46 | return 'QWK', qwk 47 | else: 48 | return 'QWK', qwk, True 49 | 50 | return quadratic_weighted_kappa_metric 51 | -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/regression/__init__.py: -------------------------------------------------------------------------------- 1 | """Import required metrics.""" 2 | 3 | 4 | from .regression_eval_metrics import ( 5 | LogCoshMetric, 6 | RMSPEMetric, 7 | ) 8 | 9 | __all__ = [ 10 | "LogCoshMetric", 11 | "RMSPEMetric", 12 | ] 13 | -------------------------------------------------------------------------------- /bokbokbok/eval_metrics/regression/regression_eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def LogCoshMetric(XGBoost=False): 5 | """ 6 | Calculates the [Log Cosh Error](https://openreview.net/pdf?id=rkglvsC9Ym) as an alternative to 7 | Mean Absolute Error. 8 | Args: 9 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 10 | Note that you should also set `maximize=False` in the XGBoost train function 11 | 12 | """ 13 | def log_cosh_error(yhat, dtrain, XGBoost=XGBoost): 14 | """ 15 | Root Mean Squared Log Error. 16 | All input labels are required to be greater than -1. 17 | 18 | yhat: Predictions 19 | dtrain: The XGBoost / LightGBM dataset 20 | XGBoost (Bool): If XGBoost is to be implemented 21 | """ 22 | 23 | y = dtrain.get_label() 24 | elements = np.log(np.cosh(yhat - y)) 25 | if XGBoost: 26 | return 'LogCosh', float(np.sum(elements) / len(y)) 27 | else: 28 | return 'LogCosh', float(np.sum(elements) / len(y)), False 29 | 30 | return log_cosh_error 31 | 32 | 33 | def RMSPEMetric(XGBoost=False): 34 | """ 35 | Calculates the Root Mean Squared Percentage Error: 36 | https://www.kaggle.com/c/optiver-realized-volatility-prediction/overview/evaluation 37 | 38 | The corresponding Loss function is Squared Percentage Error. 39 | Args: 40 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use. 41 | Note that you should also set `maximize=False` in the XGBoost train function 42 | 43 | """ 44 | def RMSPE(yhat, dtrain, XGBoost=XGBoost): 45 | """ 46 | Root Mean Squared Log Error. 47 | All input labels are required to be greater than -1. 48 | 49 | yhat: Predictions 50 | dtrain: The XGBoost / LightGBM dataset 51 | XGBoost (Bool): If XGBoost is to be implemented 52 | """ 53 | 54 | y = dtrain.get_label() 55 | elements = ((y - yhat) / y) ** 2 56 | if XGBoost: 57 | return 'RMSPE', float(np.sqrt(np.sum(elements) / len(y))) 58 | else: 59 | return 'RMSPE', float(np.sqrt(np.sum(elements) / len(y))), False 60 | 61 | return RMSPE -------------------------------------------------------------------------------- /bokbokbok/loss_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/loss_functions/__init__.py -------------------------------------------------------------------------------- /bokbokbok/loss_functions/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """Import required losses.""" 2 | 3 | 4 | from .classification_loss_functions import ( 5 | WeightedCrossEntropyLoss, 6 | WeightedFocalLoss, 7 | ) 8 | 9 | __all__ = [ 10 | "WeightedCrossEntropyLoss", 11 | "WeightedFocalLoss" 12 | ] -------------------------------------------------------------------------------- /bokbokbok/loss_functions/classification/classification_loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bokbokbok.utils import clip_sigmoid 3 | 4 | 5 | def WeightedCrossEntropyLoss(alpha=0.5): 6 | """ 7 | Calculates the Weighted Cross-Entropy Loss, which applies a factor alpha, allowing one to 8 | trade off recall and precision by up- or down-weighting the cost of a positive error relative 9 | to a negative error. 10 | 11 | A value alpha > 1 decreases the false negative count, hence increasing the recall. 12 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision. 13 | """ 14 | 15 | def _gradient(yhat, dtrain, alpha): 16 | """Compute the weighted cross-entropy gradient. 17 | 18 | Args: 19 | yhat (np.array): Margin predictions 20 | dtrain: The XGBoost / LightGBM dataset 21 | alpha (float): Scale applied 22 | 23 | Returns: 24 | grad: Weighted cross-entropy gradient 25 | """ 26 | y = dtrain.get_label() 27 | 28 | yhat = clip_sigmoid(yhat) 29 | 30 | grad = (y * yhat * (alpha - 1)) + yhat - (alpha * y) 31 | 32 | return grad 33 | 34 | def _hessian(yhat, dtrain, alpha): 35 | """Compute the weighted cross-entropy hessian. 36 | 37 | Args: 38 | yhat (np.array): Margin predictions 39 | dtrain: The XGBoost / LightGBM dataset 40 | alpha (float): Scale applied 41 | 42 | Returns: 43 | hess: Weighted cross-entropy Hessian 44 | """ 45 | y = dtrain.get_label() 46 | yhat = clip_sigmoid(yhat) 47 | 48 | hess = (y * (alpha - 1) + 1) * yhat * (1 - yhat) 49 | 50 | return hess 51 | 52 | def weighted_cross_entropy( 53 | yhat, 54 | dtrain, 55 | alpha=alpha 56 | ): 57 | """ 58 | Calculate gradient and hessian for weight cross-entropy, 59 | 60 | Args: 61 | yhat (np.array): Predictions 62 | dtrain: The XGBoost / LightGBM dataset 63 | alpha (float): Scale applied 64 | 65 | Returns: 66 | grad: Weighted cross-entropy gradient 67 | hess: Weighted cross-entropy Hessian 68 | """ 69 | grad = _gradient(yhat, dtrain, alpha=alpha) 70 | 71 | hess = _hessian(yhat, dtrain, alpha=alpha) 72 | 73 | return grad, hess 74 | 75 | return weighted_cross_entropy 76 | 77 | 78 | def WeightedFocalLoss(alpha=1.0, gamma=2.0): 79 | """ 80 | Calculates the [Weighted Focal Loss.](https://arxiv.org/pdf/1708.02002.pdf) 81 | 82 | Note that if using alpha = 1 and gamma = 0, 83 | this is the same as using regular Cross Entropy. 84 | 85 | The more gamma is increased, the more the model is focussed on the hard, misclassified examples. 86 | 87 | A value alpha > 1 decreases the false negative count, hence increasing the recall. 88 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision. 89 | 90 | """ 91 | 92 | def _gradient(yhat, dtrain, alpha, gamma): 93 | """Compute the weighted focal gradient. 94 | 95 | Args: 96 | yhat (np.array): Margin predictions 97 | dtrain: The XGBoost / LightGBM dataset 98 | alpha (float): Scale applied 99 | gamma (float): Focusing parameter 100 | 101 | Returns: 102 | grad: Weighted Focal Loss gradient 103 | """ 104 | y = dtrain.get_label() 105 | 106 | yhat = clip_sigmoid(yhat) 107 | 108 | grad = ( 109 | alpha * y * np.power(1 - yhat, gamma) * (gamma * yhat * np.log(yhat) + yhat - 1) + 110 | (1 - y) * np.power(yhat, gamma) * (yhat - gamma * np.log(1 - yhat) * (1 - yhat)) 111 | ) 112 | 113 | return grad 114 | 115 | def _hessian(yhat, dtrain, alpha, gamma): 116 | """Compute the weighted focal hessian. 117 | 118 | Args: 119 | yhat (np.array): Margin predictions 120 | dtrain: The XGBoost / LightGBM dataset 121 | alpha (float): Scale applied 122 | gamma (float): Focusing parameter 123 | 124 | Returns: 125 | hess: Weighted Focal Loss Hessian 126 | """ 127 | y = dtrain.get_label() 128 | 129 | yhat = clip_sigmoid(yhat) 130 | 131 | hess = ( 132 | alpha * y * yhat * np.power(1 - y, 133 | gamma) * (gamma * (1 - yhat) * np.log(yhat) + 2 * gamma * (1 - yhat) - 134 | np.power(gamma, 2) * yhat * np.log(yhat) + 1 - yhat) + 135 | (1 - y) * np.power(yhat, gamma + 1) * (1 - yhat) * (2 * gamma + gamma * (np.log(1 - yhat)) + 1) 136 | ) 137 | 138 | return hess 139 | 140 | def focal_loss( 141 | yhat, 142 | dtrain, 143 | alpha=alpha, 144 | gamma=gamma): 145 | """ 146 | Calculate gradient and hessian for Focal Loss, 147 | 148 | Args: 149 | yhat (np.array): Margin predictions 150 | dtrain: The XGBoost / LightGBM dataset 151 | alpha (float): Scale applied 152 | gamma (float): Focusing parameter 153 | 154 | Returns: 155 | grad: Focal Loss gradient 156 | hess: Focal Loss Hessian 157 | """ 158 | 159 | grad = _gradient(yhat, dtrain, alpha=alpha, gamma=gamma) 160 | 161 | hess = _hessian(yhat, dtrain, alpha=alpha, gamma=gamma) 162 | 163 | return grad, hess 164 | 165 | return focal_loss 166 | -------------------------------------------------------------------------------- /bokbokbok/loss_functions/regression/__init__.py: -------------------------------------------------------------------------------- 1 | """Import required losses.""" 2 | 3 | 4 | from .regression_loss_functions import ( 5 | LogCoshLoss, 6 | SPELoss, 7 | ) 8 | 9 | __all__ = [ 10 | "LogCoshLoss", 11 | "SPELoss", 12 | ] -------------------------------------------------------------------------------- /bokbokbok/loss_functions/regression/regression_loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def LogCoshLoss(): 5 | """ 6 | [Log Cosh Loss](https://openreview.net/pdf?id=rkglvsC9Ym) is an alternative to Mean Absolute Error. 7 | """ 8 | 9 | def _gradient(yhat, dtrain): 10 | """Compute the log cosh gradient. 11 | 12 | Args: 13 | yhat (np.array): Predictions 14 | dtrain: The XGBoost / LightGBM dataset 15 | 16 | Returns: 17 | log cosh gradient 18 | """ 19 | 20 | y = dtrain.get_label() 21 | return -np.tanh(y - yhat) 22 | 23 | def _hessian(yhat, dtrain): 24 | """Compute the log cosh hessian. 25 | 26 | Args: 27 | yhat (np.array): Predictions 28 | dtrain: The XGBoost / LightGBM dataset 29 | 30 | Returns: 31 | log cosh Hessian 32 | """ 33 | 34 | y = dtrain.get_label() 35 | return 1. / np.power(np.cosh(y - yhat), 2) 36 | 37 | def log_cosh_loss( 38 | yhat, 39 | dtrain 40 | ): 41 | """ 42 | Calculate gradient and hessian for log cosh loss. 43 | 44 | Args: 45 | yhat (np.array): Predictions 46 | dtrain: The XGBoost / LightGBM dataset 47 | 48 | Returns: 49 | grad: log cosh loss gradient 50 | hess: log cosh loss Hessian 51 | """ 52 | grad = _gradient(yhat, dtrain) 53 | 54 | hess = _hessian(yhat, dtrain) 55 | 56 | return grad, hess 57 | 58 | return log_cosh_loss 59 | 60 | 61 | def SPELoss(): 62 | """ 63 | Squared Percentage Error loss 64 | """ 65 | 66 | def _gradient(yhat, dtrain): 67 | """ 68 | Compute the gradient squared percentage error. 69 | Args: 70 | yhat (np.array): Predictions 71 | dtrain: The XGBoost / LightGBM dataset 72 | 73 | Returns: 74 | SPE Gradient 75 | """ 76 | y = dtrain.get_label() 77 | return -2*(y-yhat)/(y**2) 78 | 79 | def _hessian(yhat, dtrain): 80 | """ 81 | Compute the hessian for squared percentage error. 82 | Args: 83 | yhat (np.array): Predictions 84 | dtrain: The XGBoost / LightGBM dataset 85 | 86 | Returns: 87 | SPE Hessian 88 | """ 89 | y = dtrain.get_label() 90 | return 2/(y**2) 91 | 92 | def squared_percentage(yhat, dtrain): 93 | """ 94 | Calculate gradient and hessian for squared percentage error. 95 | 96 | Args: 97 | yhat (np.array): Predictions 98 | dtrain: The XGBoost / LightGBM dataset 99 | 100 | Returns: 101 | grad: SPE loss gradient 102 | hess: SPE loss Hessian 103 | """ 104 | #yhat[yhat < -1] = -1 + 1e-6 105 | grad = _gradient(yhat, dtrain) 106 | 107 | hess = _hessian(yhat, dtrain) 108 | 109 | return grad, hess 110 | 111 | return squared_percentage -------------------------------------------------------------------------------- /bokbokbok/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Import required functions.""" 2 | 3 | 4 | from .functions import ( 5 | clip_sigmoid 6 | ) 7 | 8 | __all__ = [ 9 | "clip_sigmoid", 10 | ] 11 | -------------------------------------------------------------------------------- /bokbokbok/utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def clip_sigmoid(yhat): 5 | """ 6 | Applies the sigmoid function and ensures that the values lie in the range 7 | 1e-15 < yhat < 1 - 1e-15. 8 | We clip to avoid dividing by zero in the loss functions. 9 | 10 | Args: 11 | yhat: The margin probabilities yet to be put into a sigmoid function 12 | 13 | Returns: 14 | yhat: The clipped probabilities 15 | """ 16 | yhat = 1. / (1. + np.exp(-yhat)) 17 | yhat[yhat >= 1] = 1 - 1e-15 18 | yhat[yhat <= 0] = 1e-15 19 | return yhat 20 | -------------------------------------------------------------------------------- /docs/derivations/focal.md: -------------------------------------------------------------------------------- 1 | ## Weighted Focal Loss 2 | 3 | Weighted Focal Loss applies a scaling parameter *alpha* and a focusing parameter *gamma* to [Binary Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression) 4 | 5 | We take the definition of the Focal Loss from [this paper](https://arxiv.org/pdf/1708.02002.pdf): 6 | 7 | 8 | 9 | where: 10 | 11 | 12 | 13 | This is equivalent to writing: 14 | 15 | 16 | 17 | 18 | We calculate the Gradient: 19 | 20 | 21 | 22 | 23 | We also need to calculate the Hessian: 24 | 25 | 26 | 27 | By setting *alpha* = 1 and *gamma* = 0 we obtain the Gradient and Hessian for Binary Cross Entropy Loss, as expected. -------------------------------------------------------------------------------- /docs/derivations/log_cosh.md: -------------------------------------------------------------------------------- 1 | ## Log Cosh Error 2 | 3 | The equation for Log Cosh Error is: 4 | 5 | 6 | 7 | We calculate the Gradient: 8 | 9 | 10 | 11 | We also need to calculate the Hessian: 12 | 13 | -------------------------------------------------------------------------------- /docs/derivations/note.md: -------------------------------------------------------------------------------- 1 | For the gradient boosting packages we [have to calculate the gradient of the Loss function with respect to the marginal probabilites](https://github.com/Microsoft/LightGBM/blob/master/examples/python-guide/advanced_example.py). 2 | 3 | In this case, we must calculate 4 | 5 | 6 | 7 | 8 | 9 | The Hessian is similarly calculated: 10 | 11 | 12 | 13 | 14 | **Where y-hat is the sigmoid function, unless stated otherwise**: 15 | 16 | 17 | 18 | We will make use of the following property for the calculations of the Gradients and Hessians: 19 | 20 | 21 | 22 | Note to avoid divide-by-zero errors, we clip the values of the sigmoid such that the output of the sigmoid is bound by 10-15 from below and 1 - 10-15 from above. -------------------------------------------------------------------------------- /docs/derivations/wce.md: -------------------------------------------------------------------------------- 1 | ## Weighted Cross Entropy Loss 2 | 3 | Weighted Cross Entropy applies a scaling parameter *alpha* to [Binary Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression), 4 | allowing us to penalise false positives or false negatives more harshly. If you want false 5 | positives to be penalised more than false negatives, *alpha* must be greater than 1. Otherwise, 6 | it must be less than 1. 7 | 8 | The equations for Binary and Weighted Cross Entropy Loss are the following: 9 | 10 | 11 | 12 | 13 | 14 | We calculate the Gradient: 15 | 16 | 17 | 18 | 19 | We also need to calculate the Hessian: 20 | 21 | 22 | 23 | 24 | By setting *alpha* = 1 we obtain the Gradient and Hessian for Binary Cross Entropy Loss, as expected. -------------------------------------------------------------------------------- /docs/getting_started/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | In order to install bokbokbok you need to use Python 3.7 or higher. 4 | 5 | Install `bokbokbok` via pip with: 6 | 7 | ```bash 8 | pip install bokbokbok 9 | ``` 10 | 11 | Alternatively you can fork/clone and run: 12 | 13 | ```bash 14 | git clone https://gitlab.com/orchardbirds/bokbokbok.git 15 | cd bokbokbok 16 | pip install . 17 | ``` -------------------------------------------------------------------------------- /docs/img/bokbokbok.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/docs/img/bokbokbok.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to the bokbokbok doks! 2 | 3 | 4 | 5 | **bokbokbok** is a Python library that lets us easily implement custom loss functions and eval metrics in LightGBM and XGBoost. 6 | 7 | ## Example Usage - Weighted Cross Entropy 8 | 9 | ```python 10 | clf = lgb.train(params=params, 11 | train_set=train, 12 | valid_sets=[train, valid], 13 | valid_names=['train','valid'], 14 | fobj=WeightedCrossEntropyLoss(alpha=alpha), 15 | feval=WeightedCrossEntropyMetric(alpha=alpha), 16 | early_stopping_rounds=100) 17 | ``` 18 | ## Licence 19 | bokbokbok is created under the MIT License, see more in the LICENSE file 20 | 21 | -------------------------------------------------------------------------------- /docs/reference/eval_metrics_binary.md: -------------------------------------------------------------------------------- 1 | ::: bokbokbok.eval_metrics.classification.binary_eval_metrics -------------------------------------------------------------------------------- /docs/reference/eval_metrics_multiclass.md: -------------------------------------------------------------------------------- 1 | ::: bokbokbok.eval_metrics.classification.multiclass_eval_metrics -------------------------------------------------------------------------------- /docs/reference/eval_metrics_regression.md: -------------------------------------------------------------------------------- 1 | ::: bokbokbok.eval_metrics.regression.regression_eval_metrics -------------------------------------------------------------------------------- /docs/reference/loss_functions_classification.md: -------------------------------------------------------------------------------- 1 | ::: bokbokbok.loss_functions.classification.classification_loss_functions -------------------------------------------------------------------------------- /docs/reference/loss_functions_regression.md: -------------------------------------------------------------------------------- 1 | ::: bokbokbok.loss_functions.regression.regression_loss_functions -------------------------------------------------------------------------------- /docs/tutorials/F1_score.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### When to use F-Score\n", 8 | "\n", 9 | "The F-score or F-measure is a measure of a test's accuracy. It is calculated from the precision and recall of the test, where the precision is the number of true positive results divided by the number of all positive results, including those not identified correctly, and the recall is the number of true positive results divided by the number of all samples that should have been identified as positive. Precision is also known as positive predictive value, and recall is also known as sensitivity in diagnostic binary classification.\n", 10 | "\n", 11 | "The highest possible value of an F-score is 1.0, indicating perfect precision and recall, and the lowest possible value is 0, if either the precision or the recall is zero. " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from sklearn.datasets import make_classification\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "from sklearn.metrics import roc_auc_score\n", 23 | "from bokbokbok.eval_metrics.classification import F1_Score_Binary\n", 24 | "from bokbokbok.utils import clip_sigmoid\n", 25 | "\n", 26 | "X, y = make_classification(n_samples=1000, \n", 27 | " n_features=10, \n", 28 | " random_state=41114)\n", 29 | "\n", 30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 31 | " y, \n", 32 | " test_size=0.25, \n", 33 | " random_state=41114)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Usage in LightGBM" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import lightgbm as lgb\n", 50 | "\n", 51 | "train = lgb.Dataset(X_train, y_train)\n", 52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 53 | "params = {\n", 54 | " 'n_estimators': 300,\n", 55 | " 'objective': 'binary',\n", 56 | " 'seed': 41114,\n", 57 | " 'n_jobs': 8,\n", 58 | " 'learning_rate': 0.1,\n", 59 | " }\n", 60 | "\n", 61 | "clf = lgb.train(params=params,\n", 62 | " train_set=train,\n", 63 | " valid_sets=[train, valid],\n", 64 | " valid_names=['train','valid'],\n", 65 | " feval=F1_Score_Binary(average='micro'),\n", 66 | " early_stopping_rounds=100)\n", 67 | "\n", 68 | "roc_auc_score(y_valid, clf.predict(X_valid))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Usage in XGBoost" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "import xgboost as xgb\n", 85 | "\n", 86 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 87 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 88 | "\n", 89 | "params = {\n", 90 | " 'seed': 41114,\n", 91 | " 'objective':'binary:logistic',\n", 92 | " 'learning_rate': 0.1,\n", 93 | " 'disable_default_eval_metric': 1\n", 94 | " }\n", 95 | "\n", 96 | "bst = xgb.train(params,\n", 97 | " dtrain=dtrain,\n", 98 | " num_boost_round=300,\n", 99 | " early_stopping_rounds=10,\n", 100 | " verbose_eval=10,\n", 101 | " maximize=True,\n", 102 | " feval=F1_Score_Binary(average='micro', XGBoost=True),\n", 103 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n", 104 | "\n", 105 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))" 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python [conda env:skorecard_py37]", 112 | "language": "python", 113 | "name": "conda-env-skorecard_py37-py" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.7.7" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 2 130 | } 131 | -------------------------------------------------------------------------------- /docs/tutorials/RMSPE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### When to use (Root Mean) Squared Percentage Error?\n", 8 | "\n", 9 | "This function is defined according to [this Kaggle competition](https://www.kaggle.com/c/optiver-realized-volatility-prediction/overview/evaluation) for volatility calculation. \n", 10 | "\n", 11 | "RMSPE cannot be used as a Loss function - the gradient is constant and hence the Hessian is 0. Nevertheless, it can still be used as an evaluation metric as the model trains. To use the loss function, we simply remove the square for a non-zero Hessian." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from sklearn.datasets import make_regression\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "from sklearn.metrics import mean_absolute_error\n", 23 | "from bokbokbok.eval_metrics.regression import RMSPEMetric\n", 24 | "from bokbokbok.loss_functions.regression import SPELoss\n", 25 | "\n", 26 | "X, y = make_regression(n_samples=10000, \n", 27 | " n_features=10, \n", 28 | " random_state=41114)\n", 29 | "\n", 30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 31 | " y, \n", 32 | " test_size=0.25, \n", 33 | " random_state=41114)\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Usage in LightGBM" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import lightgbm as lgb\n", 50 | "\n", 51 | "train = lgb.Dataset(X_train, y_train)\n", 52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 53 | "params = {\n", 54 | " 'n_estimators': 3000,\n", 55 | " 'seed': 41114,\n", 56 | " 'n_jobs': 8,\n", 57 | " 'max_leaves':10,\n", 58 | " }\n", 59 | "\n", 60 | "clf = lgb.train(params=params,\n", 61 | " train_set=train,\n", 62 | " valid_sets=[train, valid],\n", 63 | " valid_names=['train','valid'],\n", 64 | " fobj=SPELoss(),\n", 65 | " feval=RMSPEMetric(),\n", 66 | " early_stopping_rounds=100)\n", 67 | "\n", 68 | "mean_absolute_error(y_valid, clf.predict(X_valid))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Usage in XGBoost" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "import xgboost as xgb\n", 85 | "\n", 86 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 87 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 88 | "\n", 89 | "params = {\n", 90 | " 'seed': 41114,\n", 91 | " 'learning_rate': 0.1,\n", 92 | " 'disable_default_eval_metric': 1\n", 93 | " }\n", 94 | "\n", 95 | "bst = xgb.train(params,\n", 96 | " dtrain=dtrain,\n", 97 | " num_boost_round=3000,\n", 98 | " early_stopping_rounds=100,\n", 99 | " verbose_eval=100,\n", 100 | " obj=SPELoss(),\n", 101 | " maximize=False,\n", 102 | " feval=RMSPEMetric(XGBoost=True),\n", 103 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n", 104 | "\n", 105 | "mean_absolute_error(y_valid, bst.predict(dvalid))" 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python [conda env:skorecard_py37]", 112 | "language": "python", 113 | "name": "conda-env-skorecard_py37-py" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.7.7" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 4 130 | } 131 | -------------------------------------------------------------------------------- /docs/tutorials/focal_loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### When to use Focal Loss?\n", 8 | "\n", 9 | "Focal Loss addresses class imbalance in tasks such as object detection. Focal loss applies a modulating term to the Cross Entropy loss in order to focus learning on hard negative examples. It is a dynamically scaled Cross Entropy loss, where the scaling factor decays to zero as confidence in the correct class increases. Intuitively, this scaling factor can automatically down-weight the contribution of easy examples during training and rapidly focus the model on hard examples. This scaling factor is *gamma*. The more *gamma* is increased, the more the model is focussed on the hard, misclassified examples.\n", 10 | "\n", 11 | "We employ Weighted Focal Loss, which further allows us to reduce false positives or false negatives depending on our value of *alpha*:\n", 12 | "\n", 13 | "A value *alpha* > 1 decreases the false negative count, hence increasing the recall. Conversely, setting *alpha* < 1 decreases the false positive count and increases the precision. " 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from sklearn.datasets import make_classification\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "from sklearn.metrics import roc_auc_score\n", 25 | "from bokbokbok.loss_functions.classification import WeightedFocalLoss\n", 26 | "from bokbokbok.eval_metrics.classification import WeightedFocalMetric\n", 27 | "from bokbokbok.utils import clip_sigmoid\n", 28 | "\n", 29 | "X, y = make_classification(n_samples=1000, \n", 30 | " n_features=10, \n", 31 | " random_state=41114)\n", 32 | "\n", 33 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 34 | " y, \n", 35 | " test_size=0.25, \n", 36 | " random_state=41114)\n", 37 | "\n", 38 | "alpha = 0.7 # Reduce False Positives\n", 39 | "gamma = 2 # Focus on misclassified examples more strictly" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### Usage in LightGBM" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import lightgbm as lgb\n", 56 | "\n", 57 | "train = lgb.Dataset(X_train, y_train)\n", 58 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 59 | "params = {\n", 60 | " 'n_estimators': 300,\n", 61 | " 'seed': 41114,\n", 62 | " 'n_jobs': 8,\n", 63 | " 'learning_rate': 0.1,\n", 64 | " }\n", 65 | "\n", 66 | "clf = lgb.train(params=params,\n", 67 | " train_set=train,\n", 68 | " valid_sets=[train, valid],\n", 69 | " valid_names=['train','valid'],\n", 70 | " fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma),\n", 71 | " feval=WeightedFocalMetric(alpha=alpha, gamma=gamma),\n", 72 | " early_stopping_rounds=100)\n", 73 | "\n", 74 | "roc_auc_score(y_valid, clip_sigmoid(clf.predict(X_valid)))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Usage in XGBoost" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "import xgboost as xgb\n", 91 | "\n", 92 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 93 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 94 | "\n", 95 | "params = {\n", 96 | " 'seed': 41114,\n", 97 | " 'learning_rate': 0.1,\n", 98 | " 'disable_default_eval_metric': 1\n", 99 | " }\n", 100 | "\n", 101 | "bst = xgb.train(params,\n", 102 | " dtrain=dtrain,\n", 103 | " num_boost_round=300,\n", 104 | " early_stopping_rounds=10,\n", 105 | " verbose_eval=10,\n", 106 | " obj=WeightedFocalLoss(alpha=alpha, gamma=gamma),\n", 107 | " maximize=False,\n", 108 | " feval=WeightedFocalMetric(alpha=alpha, gamma=gamma, XGBoost=True),\n", 109 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n", 110 | "\n", 111 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python [conda env:skorecard_py37]", 118 | "language": "python", 119 | "name": "conda-env-skorecard_py37-py" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.7.7" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 2 136 | } 137 | -------------------------------------------------------------------------------- /docs/tutorials/log_cosh_loss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### When to use Log Cosh Loss?\n", 8 | "\n", 9 | "Log Cosh Loss addresses the small number of problems that can arise from using Mean Absolute Error due to its sharpness. Log(cosh(x)) is a way to very closely approximate Mean Absolute Error while retaining a 'smooth' function.\n", 10 | "\n", 11 | "Do note that large y-values can cause issues here, which is why the y-values are scaled below" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from sklearn.datasets import make_regression\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "from sklearn.metrics import mean_absolute_error\n", 23 | "from bokbokbok.eval_metrics.regression import LogCoshMetric\n", 24 | "from bokbokbok.loss_functions.regression import LogCoshLoss\n", 25 | "\n", 26 | "X, y = make_regression(n_samples=1000, \n", 27 | " n_features=10, \n", 28 | " random_state=41114)\n", 29 | "\n", 30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 31 | " y/100, \n", 32 | " test_size=0.25, \n", 33 | " random_state=41114)\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### Usage in LightGBM" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import lightgbm as lgb\n", 50 | "\n", 51 | "train = lgb.Dataset(X_train, y_train)\n", 52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 53 | "params = {\n", 54 | " 'n_estimators': 3000,\n", 55 | " 'seed': 41114,\n", 56 | " 'n_jobs': 8,\n", 57 | " 'learning_rate': 0.1,\n", 58 | " 'verbose': 100,\n", 59 | " }\n", 60 | "\n", 61 | "clf = lgb.train(params=params,\n", 62 | " train_set=train,\n", 63 | " valid_sets=[train, valid],\n", 64 | " valid_names=['train','valid'],\n", 65 | " fobj=LogCoshLoss(),\n", 66 | " feval=LogCoshMetric(),\n", 67 | " early_stopping_rounds=100,\n", 68 | " verbose_eval=100)\n", 69 | "\n", 70 | "mean_absolute_error(y_valid, clf.predict(X_valid))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### Usage in XGBoost" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "import xgboost as xgb\n", 87 | "\n", 88 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 89 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 90 | "\n", 91 | "params = {\n", 92 | " 'seed': 41114,\n", 93 | " 'learning_rate': 0.1,\n", 94 | " 'disable_default_eval_metric': 1\n", 95 | " }\n", 96 | "\n", 97 | "bst = xgb.train(params,\n", 98 | " dtrain=dtrain,\n", 99 | " num_boost_round=3000,\n", 100 | " early_stopping_rounds=10,\n", 101 | " verbose_eval=100,\n", 102 | " obj=LogCoshLoss(),\n", 103 | " maximize=False,\n", 104 | " feval=LogCoshMetric(XGBoost=True),\n", 105 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n", 106 | "\n", 107 | "mean_absolute_error(y_valid, bst.predict(dvalid))" 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python [conda env:skorecard_py37]", 114 | "language": "python", 115 | "name": "conda-env-skorecard_py37-py" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.7.7" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 2 132 | } 133 | -------------------------------------------------------------------------------- /docs/tutorials/quadratic_weighted_kappa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from sklearn.datasets import make_multilabel_classification\n", 10 | "from sklearn.model_selection import train_test_split\n", 11 | "from sklearn.metrics import cohen_kappa_score\n", 12 | "from bokbokbok.eval_metrics.classification import QuadraticWeightedKappaMetric\n", 13 | "from bokbokbok.utils import clip_sigmoid\n", 14 | "\n", 15 | "X, y = make_multilabel_classification(n_samples=1000, \n", 16 | " n_features=10,\n", 17 | " n_classes=2,\n", 18 | " n_labels=1,\n", 19 | " random_state=41114)\n", 20 | "y = y.sum(axis=1)\n", 21 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 22 | " y, \n", 23 | " test_size=0.25, \n", 24 | " random_state=41114)\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### Usage in LightGBM" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import lightgbm as lgb\n", 41 | "\n", 42 | "train = lgb.Dataset(X_train, y_train)\n", 43 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 44 | "params = {\n", 45 | " 'n_estimators': 300,\n", 46 | " 'seed': 41114,\n", 47 | " 'n_jobs': 8,\n", 48 | " 'learning_rate': 0.1,\n", 49 | " 'objective':'multiclass',\n", 50 | " 'num_class': 3\n", 51 | " }\n", 52 | "\n", 53 | "clf = lgb.train(params=params,\n", 54 | " train_set=train,\n", 55 | " valid_sets=[train, valid],\n", 56 | " valid_names=['train','valid'],\n", 57 | " feval=QuadraticWeightedKappaMetric(),\n", 58 | " early_stopping_rounds=100)\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "### Usage in XGBoost" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "import xgboost as xgb\n", 75 | "\n", 76 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 77 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 78 | "\n", 79 | "params = {\n", 80 | " 'seed': 41114,\n", 81 | " 'learning_rate': 0.1,\n", 82 | " 'disable_default_eval_metric': 1,\n", 83 | " 'objective': 'multi:softprob',\n", 84 | " 'num_class': 3\n", 85 | " }\n", 86 | "\n", 87 | "bst = xgb.train(params,\n", 88 | " dtrain=dtrain,\n", 89 | " num_boost_round=300,\n", 90 | " early_stopping_rounds=100,\n", 91 | " verbose_eval=10,\n", 92 | " feval=QuadraticWeightedKappaMetric(XGBoost=True),\n", 93 | " maximize=True,\n", 94 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Python [conda env:skorecard_py37]", 101 | "language": "python", 102 | "name": "conda-env-skorecard_py37-py" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.7.7" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 5 119 | } 120 | -------------------------------------------------------------------------------- /docs/tutorials/weighted_cross_entropy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### When to use Weighted Cross Entropy?\n", 8 | "\n", 9 | "A factor *alpha* is added in to Cross Entropy, allowing one to trade off recall and precision by up- or down-weighting the cost of a positive error relative to a negative error.\n", 10 | "\n", 11 | "A value *alpha* > 1 decreases the false negative count, hence increasing the recall. Conversely, setting *alpha* < 1 decreases the false positive count and increases the precision. " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from sklearn.datasets import make_classification\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "from sklearn.metrics import roc_auc_score\n", 23 | "from bokbokbok.loss_functions.classification import WeightedCrossEntropyLoss\n", 24 | "from bokbokbok.eval_metrics.classification import WeightedCrossEntropyMetric\n", 25 | "from bokbokbok.utils import clip_sigmoid\n", 26 | "\n", 27 | "X, y = make_classification(n_samples=1000, \n", 28 | " n_features=10, \n", 29 | " random_state=41114)\n", 30 | "\n", 31 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n", 32 | " y, \n", 33 | " test_size=0.25, \n", 34 | " random_state=41114)\n", 35 | "\n", 36 | "alpha = 0.7 # Reduce False Positives" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### Usage in LightGBM" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import lightgbm as lgb\n", 53 | "\n", 54 | "train = lgb.Dataset(X_train, y_train)\n", 55 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n", 56 | "params = {\n", 57 | " 'n_estimators': 300,\n", 58 | " 'seed': 41114,\n", 59 | " 'n_jobs': 8,\n", 60 | " 'learning_rate': 0.1,\n", 61 | " }\n", 62 | "\n", 63 | "clf = lgb.train(params=params,\n", 64 | " train_set=train,\n", 65 | " valid_sets=[train, valid],\n", 66 | " valid_names=['train','valid'],\n", 67 | " fobj=WeightedCrossEntropyLoss(alpha=alpha),\n", 68 | " feval=WeightedCrossEntropyMetric(alpha=alpha),\n", 69 | " early_stopping_rounds=100)\n", 70 | "\n", 71 | "roc_auc_score(y_valid, clip_sigmoid(clf.predict(X_valid)))" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "### Usage in XGBoost" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import xgboost as xgb\n", 88 | "\n", 89 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 90 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n", 91 | "\n", 92 | "params = {\n", 93 | " 'seed': 41114,\n", 94 | " 'learning_rate': 0.1,\n", 95 | " 'disable_default_eval_metric': 1\n", 96 | " }\n", 97 | "\n", 98 | "bst = xgb.train(params,\n", 99 | " dtrain=dtrain,\n", 100 | " num_boost_round=300,\n", 101 | " early_stopping_rounds=10,\n", 102 | " verbose_eval=10,\n", 103 | " obj=WeightedCrossEntropyLoss(alpha=alpha),\n", 104 | " maximize=False,\n", 105 | " feval=WeightedCrossEntropyMetric(alpha=alpha, XGBoost=True),\n", 106 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n", 107 | "\n", 108 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "Python [conda env:skorecard_py37]", 115 | "language": "python", 116 | "name": "conda-env-skorecard_py37-py" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.7.7" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: bokbokbok doks 2 | 3 | repo_url: https://github.com/orchardbirds/bokbokbok/ 4 | site_url: https://orchardbirds.github.io/bokbokbok/ 5 | site_description: Implementing Custom Loss Functions and Eval Metrics in LightGBM and XGBoost 6 | site_author: Daniel Timbrell 7 | 8 | use_directory_urls: false 9 | 10 | nav: 11 | - Home: index.md 12 | - Getting started: 13 | - getting_started/install.md 14 | - How To: 15 | - Use Weighted Cross Entropy: tutorials/weighted_cross_entropy.ipynb 16 | - Use Weighted Focal Loss: tutorials/focal_loss.ipynb 17 | - Use F1 Score: tutorials/F1_score.ipynb 18 | - Use Log Cosh Score: tutorials/log_cosh_loss.ipynb 19 | - Use Root Mean Squared Percentage Error: tutorials/RMSPE.ipynb 20 | - Use Quadratic Weighted Kappa: tutorials/quadratic_weighted_kappa.ipynb 21 | - Derivations: 22 | - A Note About Gradients in Classification Problems: derivations/note.md 23 | - Weighted Cross Entropy: derivations/wce.md 24 | - Focal Loss: derivations/focal.md 25 | - Log Cosh Error: derivations/log_cosh.md 26 | - Reference: 27 | - Evaluation Metrics: 28 | - bokbokbok.eval_metrics.binary_classification: reference/eval_metrics_binary.md 29 | - bokbokbok.eval_metrics.multiclass_classification: reference/eval_metrics_multiclass.md 30 | - bokbokbok.eval_metrics.regression: reference/eval_metrics_regression.md 31 | - Loss Functions: 32 | - bokbokbok.loss_functions.classification: reference/loss_functions_classification.md 33 | - bokbokbok.loss_functions.regression: reference/loss_functions_regression.md 34 | 35 | plugins: 36 | - mkdocstrings: 37 | handlers: 38 | python: 39 | selection: 40 | inherited_members: true 41 | filters: 42 | - "!^Base" 43 | - "!^_" # exlude all members starting with _ 44 | - "^__init__$" # but always include __init__ modules and methods 45 | rendering: 46 | show_root_toc_entry: false 47 | watch: 48 | - bokbokbok 49 | - search 50 | - mknotebooks: 51 | enable_default_jupyter_cell_styling: true 52 | enable_default_pandas_dataframe_styling: true 53 | 54 | copyright: Copyright © 2020 55 | 56 | theme: 57 | name: material 58 | logo: img/bokbokbok.png 59 | favicon: img/bokbokbok.png 60 | font: 61 | text: Ubuntu 62 | code: Ubuntu Mono 63 | features: 64 | - navigation.tabs 65 | palette: 66 | scheme: default 67 | primary: teal 68 | accent: yellow 69 | 70 | 71 | markdown_extensions: 72 | - codehilite 73 | - pymdownx.highlight 74 | - pymdownx.inlinehilite 75 | - pymdownx.superfences 76 | - pymdownx.details 77 | - pymdownx.tabbed 78 | - pymdownx.snippets 79 | - pymdownx.highlight: 80 | use_pygments: true 81 | - toc: 82 | permalink: true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="UTF-8") as fh: 4 | long_description = fh.read() 5 | 6 | base_packages = [ 7 | "numpy>=1.19.2", 8 | "scikit-learn>=0.23.2", 9 | ] 10 | 11 | dev_dep = [ 12 | "flake8>=3.8.3", 13 | "black>=19.10b0", 14 | "pre-commit>=2.5.0", 15 | "mypy>=0.770", 16 | "flake8-docstrings>=1.4.0" "pytest>=6.0.0", 17 | "pytest-cov>=2.10.0", 18 | "lightgbm>=3.0.0", 19 | "xgboost>=1.3.3", 20 | ] 21 | 22 | docs_dep = [ 23 | "mkdocs-material>=6.1.0", 24 | "mkdocs-git-revision-date-localized-plugin>=0.7.2", 25 | "mkdocs-git-authors-plugin>=0.3.2", 26 | "mkdocs-table-reader-plugin>=0.4.1", 27 | "mkdocs-enumerate-headings-plugin>=0.4.3", 28 | "mkdocs-awesome-pages-plugin>=2.4.0", 29 | "mkdocs-minify-plugin>=0.3.0", 30 | "mknotebooks>=0.6.2", 31 | "mkdocstrings>=0.13.6", 32 | "mkdocs-print-site-plugin>=0.8.2", 33 | "mkdocs-markdownextradata-plugin>=0.1.9", 34 | ] 35 | 36 | setup( 37 | name="bokbokbok", 38 | version="0.6.1", 39 | description="Custom Losses and Metrics for XGBoost, LightGBM, CatBoost", 40 | long_description=long_description, 41 | long_description_content_type="text/markdown", 42 | author="Daniel Timbrell", 43 | author_email="dantimbrell@gmail.com", 44 | license="Open Source", 45 | python_requires=">=3.6", 46 | classifiers=[ 47 | "License :: OSI Approved :: MIT License", 48 | "Operating System :: OS Independent", 49 | "Programming Language :: Python", 50 | "Programming Language :: Python :: 3", 51 | "Programming Language :: Python :: 3.6", 52 | "Programming Language :: Python :: 3.7", 53 | "Programming Language :: Python :: 3.8", 54 | "Programming Language :: Python :: 3 :: Only", 55 | ], 56 | include_package_data=True, 57 | install_requires=base_packages, 58 | extras_require={ 59 | "base": base_packages, 60 | "all": base_packages + dev_dep + docs_dep 61 | }, 62 | url="https://github.com/orchardbirds/bokbokbok", 63 | packages=find_packages(".", exclude=["tests", "notebooks", "docs"]), 64 | ) 65 | -------------------------------------------------------------------------------- /tests/test_focal.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import make_classification 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.metrics import mean_absolute_error 4 | from bokbokbok.loss_functions.classification import WeightedFocalLoss, WeightedCrossEntropyLoss 5 | from bokbokbok.eval_metrics.classification import WeightedFocalMetric, WeightedCrossEntropyMetric 6 | from bokbokbok.utils import clip_sigmoid 7 | import lightgbm as lgb 8 | 9 | 10 | def test_focal_lgb_implementation(): 11 | """ 12 | Assert that there is no difference between running Focal with alpha=1 and gamma=0 13 | and LightGBM's internal CE loss. 14 | """ 15 | X, y = make_classification(n_samples=1000, 16 | n_features=10, 17 | random_state=41114) 18 | 19 | X_train, X_valid, y_train, y_valid = train_test_split(X, 20 | y, 21 | test_size=0.25, 22 | random_state=41114) 23 | 24 | alpha = 1.0 25 | gamma = 0 26 | 27 | train = lgb.Dataset(X_train, y_train) 28 | valid = lgb.Dataset(X_valid, y_valid, reference=train) 29 | 30 | params_wfl = { 31 | 'n_estimators': 300, 32 | 'seed': 41114, 33 | 'n_jobs': 8, 34 | 'learning_rate': 0.1, 35 | } 36 | 37 | wfl_clf = lgb.train(params=params_wfl, 38 | train_set=train, 39 | valid_sets=[train, valid], 40 | valid_names=['train','valid'], 41 | fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma), 42 | feval=WeightedFocalMetric(alpha=alpha, gamma=gamma), 43 | early_stopping_rounds=100) 44 | 45 | 46 | params = { 47 | 'n_estimators': 300, 48 | 'objective': 'cross_entropy', 49 | 'seed': 41114, 50 | 'n_jobs': 8, 51 | 'metric': 'cross_entropy', 52 | 'learning_rate': 0.1, 53 | 'boost_from_average': False 54 | } 55 | 56 | clf = lgb.train(params=params, 57 | train_set=train, 58 | valid_sets=[train, valid], 59 | valid_names=['train','valid'], 60 | early_stopping_rounds=100) 61 | 62 | wfl_preds = clip_sigmoid(wfl_clf.predict(X_valid)) 63 | preds = clf.predict(X_valid) 64 | assert mean_absolute_error(wfl_preds, preds) == 0.0 65 | 66 | 67 | def test_focal_wce_comparison(): 68 | """ 69 | Assert that there is no difference between running Focal with alpha=3 and gamma=0 70 | and running WCE with alpha=3. 71 | """ 72 | X, y = make_classification(n_samples=1000, 73 | n_features=10, 74 | random_state=41114) 75 | 76 | X_train, X_valid, y_train, y_valid = train_test_split(X, 77 | y, 78 | test_size=0.25, 79 | random_state=41114) 80 | 81 | alpha = 3.0 82 | gamma = 0 83 | 84 | train = lgb.Dataset(X_train, y_train) 85 | valid = lgb.Dataset(X_valid, y_valid, reference=train) 86 | 87 | params_wfl = { 88 | 'n_estimators': 300, 89 | 'seed': 41114, 90 | 'n_jobs': 8, 91 | 'learning_rate': 0.1, 92 | } 93 | 94 | wfl_clf = lgb.train(params=params_wfl, 95 | train_set=train, 96 | valid_sets=[train, valid], 97 | valid_names=['train','valid'], 98 | fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma), 99 | feval=WeightedFocalMetric(alpha=alpha, gamma=gamma), 100 | early_stopping_rounds=100) 101 | 102 | 103 | params_wce = { 104 | 'n_estimators': 300, 105 | 'seed': 41114, 106 | 'n_jobs': 8, 107 | 'learning_rate': 0.1, 108 | } 109 | 110 | wce_clf = lgb.train(params=params_wce, 111 | train_set=train, 112 | valid_sets=[train, valid], 113 | valid_names=['train','valid'], 114 | fobj=WeightedCrossEntropyLoss(alpha=alpha), 115 | feval=WeightedCrossEntropyMetric(alpha=alpha), 116 | early_stopping_rounds=100) 117 | 118 | wfl_preds = clip_sigmoid(wfl_clf.predict(X_valid)) 119 | wce_preds = clip_sigmoid(wce_clf.predict(X_valid)) 120 | assert mean_absolute_error(wfl_preds, wce_preds) == 0.0 121 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from bokbokbok.utils import clip_sigmoid 4 | 5 | def test_clip_sigmoid(): 6 | assert np.allclose(a=clip_sigmoid(np.array([100, 0, -100])), 7 | b=[1 - 1e-15, 0.5, 1e-15]) -------------------------------------------------------------------------------- /tests/test_wce.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import make_classification 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.metrics import mean_absolute_error 4 | from bokbokbok.loss_functions.classification import WeightedCrossEntropyLoss 5 | from bokbokbok.eval_metrics.classification import WeightedCrossEntropyMetric 6 | from bokbokbok.utils import clip_sigmoid 7 | import lightgbm as lgb 8 | 9 | 10 | def test_wce_lgb_implementation(): 11 | """ 12 | Assert that there is no difference between running WCE with alpha=1 13 | and LightGBM's internal CE loss. 14 | """ 15 | X, y = make_classification(n_samples=1000, 16 | n_features=10, 17 | random_state=41114) 18 | 19 | X_train, X_valid, y_train, y_valid = train_test_split(X, 20 | y, 21 | test_size=0.25, 22 | random_state=41114) 23 | 24 | alpha = 1.0 25 | 26 | train = lgb.Dataset(X_train, y_train) 27 | valid = lgb.Dataset(X_valid, y_valid, reference=train) 28 | 29 | params_wce = { 30 | 'n_estimators': 300, 31 | 'seed': 41114, 32 | 'n_jobs': 8, 33 | 'learning_rate': 0.1, 34 | } 35 | 36 | wce_clf = lgb.train(params=params_wce, 37 | train_set=train, 38 | valid_sets=[train, valid], 39 | valid_names=['train','valid'], 40 | fobj=WeightedCrossEntropyLoss(alpha=1.0), 41 | feval=WeightedCrossEntropyMetric(alpha=1.0), 42 | early_stopping_rounds=100) 43 | 44 | 45 | params = { 46 | 'n_estimators': 300, 47 | 'objective': 'cross_entropy', 48 | 'seed': 41114, 49 | 'n_jobs': 8, 50 | 'metric': 'cross_entropy', 51 | 'learning_rate': 0.1, 52 | 'boost_from_average': False 53 | } 54 | 55 | clf = lgb.train(params=params, 56 | train_set=train, 57 | valid_sets=[train, valid], 58 | valid_names=['train','valid'], 59 | early_stopping_rounds=100) 60 | 61 | wce_preds = clip_sigmoid(wce_clf.predict(X_valid)) 62 | preds = clf.predict(X_valid) 63 | assert mean_absolute_error(wce_preds, preds) == 0.0 --------------------------------------------------------------------------------