├── .github
└── workflows
│ ├── publish_documentation.yml
│ ├── publish_pypi.yml
│ └── unit_tests.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bokbokbok
├── __init__.py
├── eval_metrics
│ ├── __init__.py
│ ├── classification
│ │ ├── __init__.py
│ │ ├── binary_eval_metrics.py
│ │ └── multiclass_eval_metrics.py
│ └── regression
│ │ ├── __init__.py
│ │ └── regression_eval_metrics.py
├── loss_functions
│ ├── __init__.py
│ ├── classification
│ │ ├── __init__.py
│ │ └── classification_loss_functions.py
│ └── regression
│ │ ├── __init__.py
│ │ └── regression_loss_functions.py
└── utils
│ ├── __init__.py
│ └── functions.py
├── docs
├── derivations
│ ├── focal.md
│ ├── log_cosh.md
│ ├── note.md
│ └── wce.md
├── getting_started
│ └── install.md
├── img
│ └── bokbokbok.png
├── index.md
├── reference
│ ├── eval_metrics_binary.md
│ ├── eval_metrics_multiclass.md
│ ├── eval_metrics_regression.md
│ ├── loss_functions_classification.md
│ └── loss_functions_regression.md
└── tutorials
│ ├── F1_score.ipynb
│ ├── RMSPE.ipynb
│ ├── focal_loss.ipynb
│ ├── log_cosh_loss.ipynb
│ ├── quadratic_weighted_kappa.ipynb
│ └── weighted_cross_entropy.ipynb
├── mkdocs.yml
├── setup.py
└── tests
├── test_focal.py
├── test_utils.py
└── test_wce.py
/.github/workflows/publish_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Publish Documentation
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types: [created]
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 |
14 | - name: Setup conda
15 | uses: conda-incubator/setup-miniconda@v2
16 | with:
17 | miniconda-version: "latest"
18 | python-version: "3.7"
19 | activate-environment: deploydocs
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install --upgrade setuptools wheel twine
25 | pip install ".[all]"
26 | # In our docs, we need to output static images
27 | # That requires additional setup
28 | conda install --yes -c anaconda psutil
29 | conda install --yes -c plotly plotly-orca
30 | - name: Deploy mkdocs site
31 | run: |
32 | mkdocs gh-deploy --force
--------------------------------------------------------------------------------
/.github/workflows/publish_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python
13 | uses: actions/setup-python@v2
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install --upgrade setuptools wheel twine
20 | - name: Make sure unit tests succeed
21 | run: |
22 | pip3 install ".[all]"
23 | pytest
24 | - name: Build package & publish to PyPi
25 | env:
26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
28 | run: |
29 | python setup.py sdist bdist_wheel
30 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/unit_tests.yml:
--------------------------------------------------------------------------------
1 | name: Development
2 | on:
3 | # Trigger the workflow on push or pull request,
4 | # but only for the main branch
5 | push:
6 | branches:
7 | - main
8 | pull_request:
9 | jobs:
10 | run:
11 | name: Run unit tests
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | build: [ubuntu]
16 | include:
17 | - build: ubuntu
18 | os: ubuntu-latest
19 | python-version: [3.6, 3.7, 3.8]
20 | steps:
21 | - uses: actions/checkout@master
22 | - name: Setup Python
23 | uses: actions/setup-python@master
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | pip3 install --upgrade setuptools pip
29 | pip3 install ".[all]"
30 | - name: Run unit tests and linters
31 | run: |
32 | pytest tests/
33 | - name: Static code checking with pyflakes
34 | run: |
35 | pip3 install pyflakes
36 | pyflakes bokbokbok
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /notebooks/private/*
2 | notebooks/*
3 | .idea
4 | *.log
5 | tmd/csv_files/
6 | tmd/fig_private/*
7 | notebooks/private/
8 | .gitignore.swp
9 | # Created by https://www.gitignore.io/api/macos,python
10 | # Edit at https://www.gitignore.io/?templates=macos,python
11 |
12 | ### DATA ###
13 | # to not push utils
14 | ### macOS ###
15 | # General
16 | .DS_Store
17 | .AppleDouble
18 | .LSOverride
19 | .README.md.swp
20 | # Icon must end with two \r
21 | Icon
22 | # Thumbnails
23 | ._*
24 | # Files that might appear in the root of a volume
25 | .DocumentRevisions-V100
26 | .fseventsd
27 | .Spotlight-V100
28 | .TemporaryItems
29 | .Trashes
30 | .VolumeIcon.icns
31 | .com.apple.timemachine.donotpresent
32 | # Directories potentially created on remote AFP share
33 | .AppleDB
34 | .AppleDesktop
35 | Network Trash Folder
36 | Temporary Items
37 | .apdisk
38 | ### Python ###
39 | # Byte-compiled / optimized / DLL files
40 | __pycache__/
41 | *.py[cod]
42 | *.py.swp
43 | *$py.class
44 | # C extensions
45 | *.so
46 | # Distribution / packaging
47 | .Python
48 | build/
49 | develop-eggs/
50 | dist/
51 | downloads/
52 | eggs/
53 | .eggs/
54 | lib/
55 | lib64/
56 | parts/
57 | sdist/
58 | var/
59 | wheels/
60 | pip-wheel-metadata/
61 | share/python-wheels/
62 | *.egg-info/
63 | .installed.cfg
64 | *.egg
65 | MANIFEST
66 | # PyInstaller
67 | # Usually these files are written by a python script from a template
68 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
69 | *.manifest
70 | *.spec
71 | # Installer logs
72 | pip-log.txt
73 | pip-delete-this-directory.txt
74 | # Unit test / coverage reports
75 | htmlcov/
76 | .tox/
77 | .nox/
78 | .coverage
79 | .coverage.*
80 | .cache
81 | nosetests.xml
82 | coverage.xml
83 | *.cover
84 | .hypothesis/
85 | .pytest_cache/
86 | # Translations
87 | *.mo
88 | *.pot
89 | # Django stuff:
90 | local_settings.py
91 | db.sqlite3
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 | # Scrapy stuff:
96 | .scrapy
97 | # Sphinx documentation
98 | docs/_build/
99 | # PyBuilder
100 | target/
101 | # Jupyter Notebook
102 | .ipynb_checkpoints
103 | # IPython
104 | profile_default/
105 | ipython_config.py
106 | # pyenv
107 | .python-version
108 | # celery beat schedule file
109 | celerybeat-schedule
110 | # SageMath parsed files
111 | *.sage.py
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 | # Rope project settings
124 | .ropeproject
125 | # mkdocs documentation
126 | /site
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 | # Pyre type checker
132 | .pyre/
133 | ### Python Patch ###
134 | .venv/
135 | # End of https://www.gitignore.io/api/macos,python
136 | notebooks/demo/.ipynb_checkpoints/
137 | #TMD documentation
138 | *.ipynb_checkpoints/
139 | site/
140 | \!docs/**
141 | # Created by https://www.toptal.com/developers/gitignore/api/r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks
142 | # Edit at https://www.toptal.com/developers/gitignore?templates=r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks
143 | ### JupyterNotebooks ###
144 | # gitignore template for Jupyter Notebooks
145 | # website: http://jupyter.org/
146 | */.ipynb_checkpoints/*
147 | # Remove previous ipynb_checkpoints
148 | # git rm -r .ipynb_checkpoints/
149 | Icon
150 | ### PyCharm ###
151 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
152 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
153 | # User-specific stuff
154 | .idea/**/workspace.xml
155 | .idea/**/tasks.xml
156 | .idea/**/usage.statistics.xml
157 | .idea/**/dictionaries
158 | .idea/**/shelf
159 | # Generated files
160 | .idea/**/contentModel.xml
161 | # Sensitive or high-churn files
162 | .idea/**/dataSources/
163 | .idea/**/dataSources.ids
164 | .idea/**/dataSources.local.xml
165 | .idea/**/sqlDataSources.xml
166 | .idea/**/dynamic.xml
167 | .idea/**/uiDesigner.xml
168 | .idea/**/dbnavigator.xml
169 | # Gradle
170 | .idea/**/gradle.xml
171 | .idea/**/libraries
172 | # Gradle and Maven with auto-import
173 | # When using Gradle or Maven with auto-import, you should exclude module files,
174 | # since they will be recreated, and may cause churn. Uncomment if using
175 | # auto-import.
176 | # .idea/artifacts
177 | # .idea/compiler.xml
178 | # .idea/jarRepositories.xml
179 | # .idea/modules.xml
180 | # .idea/*.iml
181 | # .idea/modules
182 | # *.iml
183 | # *.ipr
184 | # CMake
185 | cmake-build-*/
186 | # Mongo Explorer plugin
187 | .idea/**/mongoSettings.xml
188 | # File-based project format
189 | *.iws
190 | # IntelliJ
191 | out/
192 | # mpeltonen/sbt-idea plugin
193 | .idea_modules/
194 | # JIRA plugin
195 | atlassian-ide-plugin.xml
196 | # Cursive Clojure plugin
197 | .idea/replstate.xml
198 | # Crashlytics plugin (for Android Studio and IntelliJ)
199 | com_crashlytics_export_strings.xml
200 | crashlytics.properties
201 | crashlytics-build.properties
202 | fabric.properties
203 | # Editor-based Rest Client
204 | .idea/httpRequests
205 | # Android studio 3.1+ serialized cache file
206 | .idea/caches/build_file_checksums.ser
207 | ### PyCharm Patch ###
208 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
209 | # modules.xml
210 | # .idea/misc.xml
211 | # Sonarlint plugin
212 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
213 | .idea/**/sonarlint/
214 | # SonarQube Plugin
215 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
216 | .idea/**/sonarIssues.xml
217 | # Markdown Navigator plugin
218 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
219 | .idea/**/markdown-navigator.xml
220 | .idea/**/markdown-navigator-enh.xml
221 | .idea/**/markdown-navigator/
222 | # Cache file creation bug
223 | # See https://youtrack.jetbrains.com/issue/JBR-2257
224 | .idea/$CACHE_FILE$
225 | # CodeStream plugin
226 | # https://plugins.jetbrains.com/plugin/12206-codestream
227 | .idea/codestream.xml
228 | *.py,cover
229 | pytestdebug.log
230 | db.sqlite3-journal
231 | doc/_build/
232 | # pipenv
233 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
234 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
235 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
236 | # install all needed dependencies.
237 | #Pipfile.lock
238 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
239 | __pypackages__/
240 | # Celery stuff
241 | celerybeat.pid
242 | # pytype static type analyzer
243 | .pytype/
244 | ### R ###
245 | # History files
246 | .Rhistory
247 | .Rapp.history
248 | # Session Data files
249 | .RData
250 | # User-specific files
251 | .Ruserdata
252 | # Example code in package build process
253 | *-Ex.R
254 | # Output files from R CMD build
255 | /*.tar.gz
256 | # Output files from R CMD check
257 | /*.Rcheck/
258 | # RStudio files
259 | .Rproj.user/
260 | # produced vignettes
261 | vignettes/*.html
262 | vignettes/*.pdf
263 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
264 | .httr-oauth
265 | # knitr and R markdown default cache directories
266 | *_cache/
267 | /cache/
268 | # Temporary files created by R markdown
269 | *.utf8.md
270 | *.knit.md
271 | # R Environment Variables
272 | .Renviron
273 | ### R.Bookdown Stack ###
274 | # R package: bookdown caching files
275 | /*_files/
276 | ### VisualStudioCode ###
277 | .vscode/*
278 | *.code-workspace
279 | ### VisualStudioCode Patch ###
280 | # Ignore all local history of files
281 | .history
282 | ### Windows ###
283 | # Windows thumbnail cache files
284 | Thumbs.db
285 | Thumbs.db:encryptable
286 | ehthumbs.db
287 | ehthumbs_vista.db
288 | # Dump file
289 | *.stackdump
290 | # Folder config file
291 | [Dd]esktop.ini
292 | # Recycle Bin used on file shares
293 | $RECYCLE.BIN/
294 | # Windows Installer files
295 | *.cab
296 | *.msi
297 | *.msix
298 | *.msm
299 | *.msp
300 | # Windows shortcuts
301 | *.lnk
302 | ### VisualStudio ###
303 | ## Ignore Visual Studio temporary files, build results, and
304 | ## files generated by popular Visual Studio add-ons.
305 | ##
306 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
307 | *.rsuser
308 | *.suo
309 | *.user
310 | *.userosscache
311 | *.sln.docstates
312 | # User-specific files (MonoDevelop/Xamarin Studio)
313 | *.userprefs
314 | # Mono auto generated files
315 | mono_crash.*
316 | # Build results
317 | [Dd]ebug/
318 | [Dd]ebugPublic/
319 | [Rr]elease/
320 | [Rr]eleases/
321 | x64/
322 | x86/
323 | [Aa][Rr][Mm]/
324 | [Aa][Rr][Mm]64/
325 | bld/
326 | [Bb]in/
327 | [Oo]bj/
328 | [Ll]og/
329 | [Ll]ogs/
330 | # Visual Studio 2015/2017 cache/options directory
331 | .vs/
332 | # Uncomment if you have tasks that create the project's static files in wwwroot
333 | #wwwroot/
334 | # Visual Studio 2017 auto generated files
335 | Generated\ Files/
336 | # MSTest test Results
337 | [Tt]est[Rr]esult*/
338 | [Bb]uild[Ll]og.*
339 | # NUnit
340 | *.VisualState.xml
341 | TestResult.xml
342 | nunit-*.xml
343 | # Build Results of an ATL Project
344 | [Dd]ebugPS/
345 | [Rr]eleasePS/
346 | dlldata.c
347 | # Benchmark Results
348 | BenchmarkDotNet.Artifacts/
349 | # .NET Core
350 | project.lock.json
351 | project.fragment.lock.json
352 | artifacts/
353 | # StyleCop
354 | StyleCopReport.xml
355 | # Files built by Visual Studio
356 | *_i.c
357 | *_p.c
358 | *_h.h
359 | *.ilk
360 | *.meta
361 | *.obj
362 | *.iobj
363 | *.pch
364 | *.pdb
365 | *.ipdb
366 | *.pgc
367 | *.pgd
368 | *.rsp
369 | *.sbr
370 | *.tlb
371 | *.tli
372 | *.tlh
373 | *.tmp
374 | *.tmp_proj
375 | *_wpftmp.csproj
376 | *.vspscc
377 | *.vssscc
378 | .builds
379 | *.pidb
380 | *.svclog
381 | *.scc
382 | # Chutzpah Test files
383 | _Chutzpah*
384 | # Visual C++ cache files
385 | ipch/
386 | *.aps
387 | *.ncb
388 | *.opendb
389 | *.opensdf
390 | *.sdf
391 | *.cachefile
392 | *.VC.db
393 | *.VC.VC.opendb
394 | # Visual Studio profiler
395 | *.psess
396 | *.vsp
397 | *.vspx
398 | *.sap
399 | # Visual Studio Trace Files
400 | *.e2e
401 | # TFS 2012 Local Workspace
402 | $tf/
403 | # Guidance Automation Toolkit
404 | *.gpState
405 | # ReSharper is a .NET coding add-in
406 | _ReSharper*/
407 | *.[Rr]e[Ss]harper
408 | *.DotSettings.user
409 | # TeamCity is a build add-in
410 | _TeamCity*
411 | # DotCover is a Code Coverage Tool
412 | *.dotCover
413 | # AxoCover is a Code Coverage Tool
414 | .axoCover/*
415 | !.axoCover/settings.json
416 | # Coverlet is a free, cross platform Code Coverage Tool
417 | coverage*[.json, .xml, .info]
418 | # Visual Studio code coverage results
419 | *.coverage
420 | *.coveragexml
421 | # NCrunch
422 | _NCrunch_*
423 | .*crunch*.local.xml
424 | nCrunchTemp_*
425 | # MightyMoose
426 | *.mm.*
427 | AutoTest.Net/
428 | # Web workbench (sass)
429 | .sass-cache/
430 | # Installshield output folder
431 | [Ee]xpress/
432 | # DocProject is a documentation generator add-in
433 | DocProject/buildhelp/
434 | DocProject/Help/*.HxT
435 | DocProject/Help/*.HxC
436 | DocProject/Help/*.hhc
437 | DocProject/Help/*.hhk
438 | DocProject/Help/*.hhp
439 | DocProject/Help/Html2
440 | DocProject/Help/html
441 | # Click-Once directory
442 | publish/
443 | # Publish Web Output
444 | *.[Pp]ublish.xml
445 | *.azurePubxml
446 | # Note: Comment the next line if you want to checkin your web deploy settings,
447 | # but database connection strings (with potential passwords) will be unencrypted
448 | *.pubxml
449 | *.publishproj
450 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
451 | # checkin your Azure Web App publish settings, but sensitive information contained
452 | # in these scripts will be unencrypted
453 | PublishScripts/
454 | # NuGet Packages
455 | *.nupkg
456 | # NuGet Symbol Packages
457 | *.snupkg
458 | # The packages folder can be ignored because of Package Restore
459 | **/[Pp]ackages/*
460 | # except build/, which is used as an MSBuild target.
461 | !**/[Pp]ackages/build/
462 | # Uncomment if necessary however generally it will be regenerated when needed
463 | #!**/[Pp]ackages/repositories.config
464 | # NuGet v3's project.json files produces more ignorable files
465 | *.nuget.props
466 | *.nuget.targets
467 | # Microsoft Azure Build Output
468 | csx/
469 | *.build.csdef
470 | # Microsoft Azure Emulator
471 | ecf/
472 | rcf/
473 | # Windows Store app package directories and files
474 | AppPackages/
475 | BundleArtifacts/
476 | Package.StoreAssociation.xml
477 | _pkginfo.txt
478 | *.appx
479 | *.appxbundle
480 | *.appxupload
481 | # Visual Studio cache files
482 | # files ending in .cache can be ignored
483 | *.[Cc]ache
484 | # but keep track of directories ending in .cache
485 | !?*.[Cc]ache/
486 | # Others
487 | ClientBin/
488 | ~$*
489 | *~
490 | *.dbmdl
491 | *.dbproj.schemaview
492 | *.jfm
493 | *.pfx
494 | *.publishsettings
495 | orleans.codegen.cs
496 | # Including strong name files can present a security risk
497 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
498 | #*.snk
499 | # Since there are multiple workflows, uncomment next line to ignore bower_components
500 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
501 | #bower_components/
502 | # RIA/Silverlight projects
503 | Generated_Code/
504 | # Backup & report files from converting an old project file
505 | # to a newer Visual Studio version. Backup files are not needed,
506 | # because we have git ;-)
507 | _UpgradeReport_Files/
508 | Backup*/
509 | UpgradeLog*.XML
510 | UpgradeLog*.htm
511 | ServiceFabricBackup/
512 | *.rptproj.bak
513 | # SQL Server files
514 | *.mdf
515 | *.ldf
516 | *.ndf
517 | # Business Intelligence projects
518 | *.rdl.data
519 | *.bim.layout
520 | *.bim_*.settings
521 | *.rptproj.rsuser
522 | *- [Bb]ackup.rdl
523 | *- [Bb]ackup ([0-9]).rdl
524 | *- [Bb]ackup ([0-9][0-9]).rdl
525 | # Microsoft Fakes
526 | FakesAssemblies/
527 | # GhostDoc plugin setting file
528 | *.GhostDoc.xml
529 | # Node.js Tools for Visual Studio
530 | .ntvs_analysis.dat
531 | node_modules/
532 | # Visual Studio 6 build log
533 | *.plg
534 | # Visual Studio 6 workspace options file
535 | *.opt
536 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
537 | *.vbw
538 | # Visual Studio LightSwitch build output
539 | **/*.HTMLClient/GeneratedArtifacts
540 | **/*.DesktopClient/GeneratedArtifacts
541 | **/*.DesktopClient/ModelManifest.xml
542 | **/*.Server/GeneratedArtifacts
543 | **/*.Server/ModelManifest.xml
544 | _Pvt_Extensions
545 | # Paket dependency manager
546 | .paket/paket.exe
547 | paket-files/
548 | # FAKE - F# Make
549 | .fake/
550 | # CodeRush personal settings
551 | .cr/personal
552 | # Python Tools for Visual Studio (PTVS)
553 | *.pyc
554 | # Cake - Uncomment if you are using it
555 | # tools/**
556 | # !tools/packages.config
557 | # Tabs Studio
558 | *.tss
559 | # Telerik's JustMock configuration file
560 | *.jmconfig
561 | # BizTalk build output
562 | *.btp.cs
563 | *.btm.cs
564 | *.odx.cs
565 | *.xsd.cs
566 | # OpenCover UI analysis results
567 | OpenCover/
568 | # Azure Stream Analytics local run output
569 | ASALocalRun/
570 | # MSBuild Binary and Structured Log
571 | *.binlog
572 | # NVidia Nsight GPU debugger configuration file
573 | *.nvuser
574 | # MFractors (Xamarin productivity tool) working folder
575 | .mfractor/
576 | # Local History for Visual Studio
577 | .localhistory/
578 | # BeatPulse healthcheck temp database
579 | healthchecksdb
580 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
581 | MigrationBackup/
582 | # Ionide (cross platform F# VS Code tools) working folder
583 | .ionide/
584 | # End of https://www.toptal.com/developers/gitignore/api/r,macos,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks
585 |
586 | *.ipynb
587 | *.pptx
588 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guide
2 |
3 | `bokbokbok` aims to integrate custom loss functions to LightGBM, XGBoost and CatBoost.
4 | To add a loss function / eval metric / to contibute in general please follow these steps:
5 |
6 | - Discuss the feature you want to add on Github before you write a PR for it. On disagreements, maintainer(s) will have the final word.
7 | - If you’re going to add a loss function, please contribute the derivations of gradients and Hessians.
8 | - When issues or pull requests are not going to be resolved or merged, they should be closed as soon as possible.
9 | This is kinder than deciding this after a long period. Our issue tracker should reflect work to be done.
10 |
11 | That said, there are many ways to contribute to bokbokbok, including:
12 |
13 | - Contribution to code
14 | - Improving the documentation
15 | - Reviewing merge requests
16 | - Investigating bugs
17 | - Reporting issues
18 |
19 | Starting out with open source? See the guide [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/) and have a look at [our issues labelled *good first issue*](https://github.com/orchardbirds/bokbokbok/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
20 |
21 | ## Setup
22 |
23 | Development install:
24 |
25 | ```shell
26 | pip install -e '.[all]'
27 | ```
28 |
29 | Run unit tests with
30 |
31 | ```shell
32 | pytest
33 | ```
34 |
35 | ## Standards
36 |
37 | - Python 3.6+
38 | - Follow [PEP8](http://pep8.org/) as closely as possible (except line length)
39 | - [google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/)
40 | - Git: Include a short description of *what* and *why* was done, *how* can be seen in the code. Use present tense, imperative mood
41 | - Git: limit the length of the first line to 72 chars. You can use multiple messages to specify a second (longer) line: `git commit -m "Patch load function" -m "This is a much longer explanation of what was done"`
42 |
43 |
44 | ### Derivations
45 |
46 | We use [Code cogs](https://www.codecogs.com/latex/eqneditor.php) to generate equations that are compatible with Git and markdown.
47 | To use an equation, choose svg format and HTML embedding and copy the link at the bottom of the page.
48 |
49 | ## Releases and versioning
50 |
51 | We use [semver](https://semver.org/) for versioning. When we are ready for a release, the maintainer runs:
52 |
53 | ```shell
54 | git tag -a v0.1 -m "bokbokbok v0.1" && git push origin v0.1
55 | ```
56 |
57 | When we create a new github release a [github action](https://github.com/ing-bank/skorecard/blob/main/.github/workflows/publish_pypi.yml) is triggered that:
58 |
59 | - a new version will be deployed to pypi
60 | - the docs will be re-built and deployed
61 |
62 |
63 | ### Documentation
64 |
65 | Documentation is a very crucial part of the project, because it ensures usability of the package. We develop the docs in the following way:
66 |
67 | * We use [mkdocs](https://www.mkdocs.org/) with [mkdocs-material](https://squidfunk.github.io/mkdocs-material/) theme. The `docs/` folder contains all the relevant documentation.
68 | * We use `mkdocs serve` to view the documentation locally. Use it to test the documentation every time you make any changes.
69 | * Maintainers can deploy the docs using `mkdocs gh-deploy`. The documentation is deployed to `https://orchardbirds.github.io/`.
70 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Dan Timbrell
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](#)
4 | [](#)
5 | [](#)
6 | 
7 |
8 |
9 | # bokbokbok
10 |
11 | ## Overview
12 |
13 | **bokbokbok** is a python package that enables you to use Custom Loss Functions and Evaluation Metrics for XGBoost and LightGBM.
14 | Main features:
15 |
16 | - [Weighted Cross Entropy](https://orchardbirds.github.io/bokbokbok/tutorials/weighted_cross_entropy.html)
17 | - [Weighted Focal Loss](https://orchardbirds.github.io/bokbokbok/tutorials/focal_loss.html)
18 | - [Log Cosh Loss](https://orchardbirds.github.io/bokbokbok/tutorials/log_cosh_loss.html)
19 | - [Root Mean Squared Percentage Error](https://orchardbirds.github.io/bokbokbok/tutorials/RMSPE.html)
20 | - [F1 score](https://orchardbirds.github.io/bokbokbok/tutorials/F1_score.html)
21 | - [Quadratic Weighted Kappa](https://orchardbirds.github.io/bokbokbok/tutorials/quadratic_weighted_kappa.html)
22 |
23 | ## Installation
24 |
25 | ```bash
26 | pip install bokbokbok
27 | ```
28 |
29 | ## Documentation
30 |
31 | The documentation can [be found here.](https://orchardbirds.github.io/bokbokbok/)
32 |
33 | ## Contributing
34 |
35 | To learn more about making a contribution to bokbokbok, please see [CONTRIBUTING.md.](https://github.com/orchardbirds/bokbokbok/blob/main/CONTRIBUTING.md)
36 |
--------------------------------------------------------------------------------
/bokbokbok/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/__init__.py
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/eval_metrics/__init__.py
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """Import required metrics."""
2 |
3 |
4 | from .binary_eval_metrics import (
5 | WeightedCrossEntropyMetric,
6 | WeightedFocalMetric,
7 | F1_Score_Binary,
8 | )
9 |
10 | from .multiclass_eval_metrics import (
11 | QuadraticWeightedKappaMetric,
12 | )
13 |
14 | __all__ = [
15 | "WeightedCrossEntropyMetric",
16 | "WeightedFocalMetric",
17 | "F1_Score_Binary",
18 | "QuadraticWeightedKappaMetric",
19 | ]
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/classification/binary_eval_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import f1_score
3 | from bokbokbok.utils import clip_sigmoid
4 |
5 |
6 | def WeightedCrossEntropyMetric(alpha=0.5, XGBoost=False):
7 | """
8 | Calculates the Weighted Cross Entropy Metric by applying a weighting factor alpha, allowing one to
9 | trade off recall and precision by up- or down-weighting the cost of a positive error relative to a
10 | negative error.
11 |
12 | A value alpha > 1 decreases the false negative count, hence increasing the recall.
13 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision.
14 |
15 | Args:
16 | alpha (float): The scale to be applied.
17 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
18 | Note that you should also set `maximize=False` in the XGBoost train function
19 |
20 | """
21 |
22 |
23 | def weighted_cross_entropy_metric(yhat, dtrain, alpha=alpha, XGBoost=XGBoost):
24 | """
25 | Weighted Cross Entropy Metric.
26 |
27 | Args:
28 | yhat: Predictions
29 | dtrain: The XGBoost / LightGBM dataset
30 | alpha (float): Scale applied
31 | XGBoost (Bool): If XGBoost is to be implemented
32 |
33 | Returns:
34 | Name of the eval metric, Eval score, Bool to minimise function
35 |
36 | """
37 | y = dtrain.get_label()
38 | yhat = clip_sigmoid(yhat)
39 | elements = - alpha * y * np.log(yhat) - (1 - y) * np.log(1 - yhat)
40 | if XGBoost:
41 | return f'WCE_alpha{alpha}', (np.sum(elements) / len(y))
42 | else:
43 | return f'WCE_alpha{alpha}', (np.sum(elements) / len(y)), False
44 |
45 | return weighted_cross_entropy_metric
46 |
47 |
48 | def WeightedFocalMetric(alpha=1.0, gamma=2.0, XGBoost=False):
49 | """
50 | Implements [alpha-weighted Focal Loss](https://arxiv.org/pdf/1708.02002.pdf)
51 |
52 | The more gamma is increased, the more the model is focussed on the hard, misclassified examples.
53 |
54 | A value alpha > 1 decreases the false negative count, hence increasing the recall.
55 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision.
56 |
57 | Args:
58 | alpha (float): The scale to be applied.
59 | gamma (float): The focusing parameter to be applied
60 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
61 | Note that you should also set `maximize=False` in the XGBoost train function
62 | """
63 |
64 | def focal_metric(yhat, dtrain, alpha=alpha, gamma=gamma, XGBoost=XGBoost):
65 | """
66 | Weighted Focal Loss Metric.
67 |
68 | Args:
69 | yhat: Predictions
70 | dtrain: The XGBoost / LightGBM dataset
71 | alpha (float): Scale applied
72 | gamma (float): Focusing parameter
73 | XGBoost (Bool): If XGBoost is to be implemented
74 |
75 | Returns:
76 | Name of the eval metric, Eval score, Bool to minimise function
77 |
78 | """
79 | y = dtrain.get_label()
80 | yhat = clip_sigmoid(yhat)
81 |
82 | elements = (- alpha * y * np.log(yhat) * np.power(1 - yhat, gamma) -
83 | (1 - y) * np.log(1 - yhat) * np.power(yhat, gamma))
84 |
85 | if XGBoost:
86 | return f'Focal_alpha{alpha}_gamma{gamma}', (np.sum(elements) / len(y))
87 | else:
88 | return f'Focal_alpha{alpha}_gamma{gamma}', (np.sum(elements) / len(y)), False
89 |
90 | return focal_metric
91 |
92 |
93 | def F1_Score_Binary(XGBoost=False, *args, **kwargs):
94 | """
95 | Implements the f1_score metric
96 | [from scikit learn](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn-metrics-f1-score)
97 |
98 | Args:
99 | *args: The arguments to be fed into the scikit learn metric.
100 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
101 | Note that you should also set `maximize=True` in the XGBoost train function
102 |
103 | """
104 | def binary_f1_score(yhat, data, XGBoost=XGBoost):
105 | """
106 | F1 Score.
107 |
108 | Args:
109 | yhat: Predictions
110 | dtrain: The XGBoost / LightGBM dataset
111 | XGBoost (Bool): If XGBoost is to be implemented
112 |
113 | Returns:
114 | Name of the eval metric, Eval score, Bool to maximise function
115 | """
116 | y_true = data.get_label()
117 | yhat = np.round(yhat)
118 | if XGBoost:
119 | return 'F1', f1_score(y_true, yhat, *args, **kwargs)
120 | else:
121 | return 'F1', f1_score(y_true, yhat, *args, **kwargs), True
122 |
123 | return binary_f1_score
124 |
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/classification/multiclass_eval_metrics.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import cohen_kappa_score
2 | import numpy as np
3 |
4 |
5 | def QuadraticWeightedKappaMetric(XGBoost=False):
6 | """
7 | Calculates the Weighted Cross Entropy Metric by applying a weighting factor alpha, allowing one to
8 | trade off recall and precision by up- or down-weighting the cost of a positive error relative to a
9 | negative error.
10 |
11 | A value alpha > 1 decreases the false negative count, hence increasing the recall.
12 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision.
13 |
14 | Args:
15 | alpha (float): The scale to be applied.
16 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
17 | Note that you should also set `maximize=False` in the XGBoost train function
18 |
19 | """
20 |
21 |
22 | def quadratic_weighted_kappa_metric(yhat, dtrain, XGBoost=XGBoost):
23 | """
24 | Weighted Cross Entropy Metric.
25 |
26 | Args:
27 | yhat: Predictions
28 | dtrain: The XGBoost / LightGBM dataset
29 | XGBoost (Bool): If XGBoost is to be implemented
30 |
31 | Returns:
32 | Name of the eval metric, Eval score, Bool to maximise function
33 |
34 | """
35 | y = dtrain.get_label()
36 | num_class = len(np.unique(dtrain.get_label()))
37 |
38 | if XGBoost == False:
39 | # LightGBM needs extra reshaping
40 | yhat = yhat.reshape(num_class, len(y)).T
41 | yhat = yhat.argmax(axis=1)
42 |
43 | qwk = cohen_kappa_score(y, yhat, weights="quadratic")
44 |
45 | if XGBoost:
46 | return 'QWK', qwk
47 | else:
48 | return 'QWK', qwk, True
49 |
50 | return quadratic_weighted_kappa_metric
51 |
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/regression/__init__.py:
--------------------------------------------------------------------------------
1 | """Import required metrics."""
2 |
3 |
4 | from .regression_eval_metrics import (
5 | LogCoshMetric,
6 | RMSPEMetric,
7 | )
8 |
9 | __all__ = [
10 | "LogCoshMetric",
11 | "RMSPEMetric",
12 | ]
13 |
--------------------------------------------------------------------------------
/bokbokbok/eval_metrics/regression/regression_eval_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def LogCoshMetric(XGBoost=False):
5 | """
6 | Calculates the [Log Cosh Error](https://openreview.net/pdf?id=rkglvsC9Ym) as an alternative to
7 | Mean Absolute Error.
8 | Args:
9 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
10 | Note that you should also set `maximize=False` in the XGBoost train function
11 |
12 | """
13 | def log_cosh_error(yhat, dtrain, XGBoost=XGBoost):
14 | """
15 | Root Mean Squared Log Error.
16 | All input labels are required to be greater than -1.
17 |
18 | yhat: Predictions
19 | dtrain: The XGBoost / LightGBM dataset
20 | XGBoost (Bool): If XGBoost is to be implemented
21 | """
22 |
23 | y = dtrain.get_label()
24 | elements = np.log(np.cosh(yhat - y))
25 | if XGBoost:
26 | return 'LogCosh', float(np.sum(elements) / len(y))
27 | else:
28 | return 'LogCosh', float(np.sum(elements) / len(y)), False
29 |
30 | return log_cosh_error
31 |
32 |
33 | def RMSPEMetric(XGBoost=False):
34 | """
35 | Calculates the Root Mean Squared Percentage Error:
36 | https://www.kaggle.com/c/optiver-realized-volatility-prediction/overview/evaluation
37 |
38 | The corresponding Loss function is Squared Percentage Error.
39 | Args:
40 | XGBoost (Bool): Set to True if using XGBoost. We assume LightGBM as default use.
41 | Note that you should also set `maximize=False` in the XGBoost train function
42 |
43 | """
44 | def RMSPE(yhat, dtrain, XGBoost=XGBoost):
45 | """
46 | Root Mean Squared Log Error.
47 | All input labels are required to be greater than -1.
48 |
49 | yhat: Predictions
50 | dtrain: The XGBoost / LightGBM dataset
51 | XGBoost (Bool): If XGBoost is to be implemented
52 | """
53 |
54 | y = dtrain.get_label()
55 | elements = ((y - yhat) / y) ** 2
56 | if XGBoost:
57 | return 'RMSPE', float(np.sqrt(np.sum(elements) / len(y)))
58 | else:
59 | return 'RMSPE', float(np.sqrt(np.sum(elements) / len(y))), False
60 |
61 | return RMSPE
--------------------------------------------------------------------------------
/bokbokbok/loss_functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/bokbokbok/loss_functions/__init__.py
--------------------------------------------------------------------------------
/bokbokbok/loss_functions/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """Import required losses."""
2 |
3 |
4 | from .classification_loss_functions import (
5 | WeightedCrossEntropyLoss,
6 | WeightedFocalLoss,
7 | )
8 |
9 | __all__ = [
10 | "WeightedCrossEntropyLoss",
11 | "WeightedFocalLoss"
12 | ]
--------------------------------------------------------------------------------
/bokbokbok/loss_functions/classification/classification_loss_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from bokbokbok.utils import clip_sigmoid
3 |
4 |
5 | def WeightedCrossEntropyLoss(alpha=0.5):
6 | """
7 | Calculates the Weighted Cross-Entropy Loss, which applies a factor alpha, allowing one to
8 | trade off recall and precision by up- or down-weighting the cost of a positive error relative
9 | to a negative error.
10 |
11 | A value alpha > 1 decreases the false negative count, hence increasing the recall.
12 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision.
13 | """
14 |
15 | def _gradient(yhat, dtrain, alpha):
16 | """Compute the weighted cross-entropy gradient.
17 |
18 | Args:
19 | yhat (np.array): Margin predictions
20 | dtrain: The XGBoost / LightGBM dataset
21 | alpha (float): Scale applied
22 |
23 | Returns:
24 | grad: Weighted cross-entropy gradient
25 | """
26 | y = dtrain.get_label()
27 |
28 | yhat = clip_sigmoid(yhat)
29 |
30 | grad = (y * yhat * (alpha - 1)) + yhat - (alpha * y)
31 |
32 | return grad
33 |
34 | def _hessian(yhat, dtrain, alpha):
35 | """Compute the weighted cross-entropy hessian.
36 |
37 | Args:
38 | yhat (np.array): Margin predictions
39 | dtrain: The XGBoost / LightGBM dataset
40 | alpha (float): Scale applied
41 |
42 | Returns:
43 | hess: Weighted cross-entropy Hessian
44 | """
45 | y = dtrain.get_label()
46 | yhat = clip_sigmoid(yhat)
47 |
48 | hess = (y * (alpha - 1) + 1) * yhat * (1 - yhat)
49 |
50 | return hess
51 |
52 | def weighted_cross_entropy(
53 | yhat,
54 | dtrain,
55 | alpha=alpha
56 | ):
57 | """
58 | Calculate gradient and hessian for weight cross-entropy,
59 |
60 | Args:
61 | yhat (np.array): Predictions
62 | dtrain: The XGBoost / LightGBM dataset
63 | alpha (float): Scale applied
64 |
65 | Returns:
66 | grad: Weighted cross-entropy gradient
67 | hess: Weighted cross-entropy Hessian
68 | """
69 | grad = _gradient(yhat, dtrain, alpha=alpha)
70 |
71 | hess = _hessian(yhat, dtrain, alpha=alpha)
72 |
73 | return grad, hess
74 |
75 | return weighted_cross_entropy
76 |
77 |
78 | def WeightedFocalLoss(alpha=1.0, gamma=2.0):
79 | """
80 | Calculates the [Weighted Focal Loss.](https://arxiv.org/pdf/1708.02002.pdf)
81 |
82 | Note that if using alpha = 1 and gamma = 0,
83 | this is the same as using regular Cross Entropy.
84 |
85 | The more gamma is increased, the more the model is focussed on the hard, misclassified examples.
86 |
87 | A value alpha > 1 decreases the false negative count, hence increasing the recall.
88 | Conversely, setting alpha < 1 decreases the false positive count and increases the precision.
89 |
90 | """
91 |
92 | def _gradient(yhat, dtrain, alpha, gamma):
93 | """Compute the weighted focal gradient.
94 |
95 | Args:
96 | yhat (np.array): Margin predictions
97 | dtrain: The XGBoost / LightGBM dataset
98 | alpha (float): Scale applied
99 | gamma (float): Focusing parameter
100 |
101 | Returns:
102 | grad: Weighted Focal Loss gradient
103 | """
104 | y = dtrain.get_label()
105 |
106 | yhat = clip_sigmoid(yhat)
107 |
108 | grad = (
109 | alpha * y * np.power(1 - yhat, gamma) * (gamma * yhat * np.log(yhat) + yhat - 1) +
110 | (1 - y) * np.power(yhat, gamma) * (yhat - gamma * np.log(1 - yhat) * (1 - yhat))
111 | )
112 |
113 | return grad
114 |
115 | def _hessian(yhat, dtrain, alpha, gamma):
116 | """Compute the weighted focal hessian.
117 |
118 | Args:
119 | yhat (np.array): Margin predictions
120 | dtrain: The XGBoost / LightGBM dataset
121 | alpha (float): Scale applied
122 | gamma (float): Focusing parameter
123 |
124 | Returns:
125 | hess: Weighted Focal Loss Hessian
126 | """
127 | y = dtrain.get_label()
128 |
129 | yhat = clip_sigmoid(yhat)
130 |
131 | hess = (
132 | alpha * y * yhat * np.power(1 - y,
133 | gamma) * (gamma * (1 - yhat) * np.log(yhat) + 2 * gamma * (1 - yhat) -
134 | np.power(gamma, 2) * yhat * np.log(yhat) + 1 - yhat) +
135 | (1 - y) * np.power(yhat, gamma + 1) * (1 - yhat) * (2 * gamma + gamma * (np.log(1 - yhat)) + 1)
136 | )
137 |
138 | return hess
139 |
140 | def focal_loss(
141 | yhat,
142 | dtrain,
143 | alpha=alpha,
144 | gamma=gamma):
145 | """
146 | Calculate gradient and hessian for Focal Loss,
147 |
148 | Args:
149 | yhat (np.array): Margin predictions
150 | dtrain: The XGBoost / LightGBM dataset
151 | alpha (float): Scale applied
152 | gamma (float): Focusing parameter
153 |
154 | Returns:
155 | grad: Focal Loss gradient
156 | hess: Focal Loss Hessian
157 | """
158 |
159 | grad = _gradient(yhat, dtrain, alpha=alpha, gamma=gamma)
160 |
161 | hess = _hessian(yhat, dtrain, alpha=alpha, gamma=gamma)
162 |
163 | return grad, hess
164 |
165 | return focal_loss
166 |
--------------------------------------------------------------------------------
/bokbokbok/loss_functions/regression/__init__.py:
--------------------------------------------------------------------------------
1 | """Import required losses."""
2 |
3 |
4 | from .regression_loss_functions import (
5 | LogCoshLoss,
6 | SPELoss,
7 | )
8 |
9 | __all__ = [
10 | "LogCoshLoss",
11 | "SPELoss",
12 | ]
--------------------------------------------------------------------------------
/bokbokbok/loss_functions/regression/regression_loss_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def LogCoshLoss():
5 | """
6 | [Log Cosh Loss](https://openreview.net/pdf?id=rkglvsC9Ym) is an alternative to Mean Absolute Error.
7 | """
8 |
9 | def _gradient(yhat, dtrain):
10 | """Compute the log cosh gradient.
11 |
12 | Args:
13 | yhat (np.array): Predictions
14 | dtrain: The XGBoost / LightGBM dataset
15 |
16 | Returns:
17 | log cosh gradient
18 | """
19 |
20 | y = dtrain.get_label()
21 | return -np.tanh(y - yhat)
22 |
23 | def _hessian(yhat, dtrain):
24 | """Compute the log cosh hessian.
25 |
26 | Args:
27 | yhat (np.array): Predictions
28 | dtrain: The XGBoost / LightGBM dataset
29 |
30 | Returns:
31 | log cosh Hessian
32 | """
33 |
34 | y = dtrain.get_label()
35 | return 1. / np.power(np.cosh(y - yhat), 2)
36 |
37 | def log_cosh_loss(
38 | yhat,
39 | dtrain
40 | ):
41 | """
42 | Calculate gradient and hessian for log cosh loss.
43 |
44 | Args:
45 | yhat (np.array): Predictions
46 | dtrain: The XGBoost / LightGBM dataset
47 |
48 | Returns:
49 | grad: log cosh loss gradient
50 | hess: log cosh loss Hessian
51 | """
52 | grad = _gradient(yhat, dtrain)
53 |
54 | hess = _hessian(yhat, dtrain)
55 |
56 | return grad, hess
57 |
58 | return log_cosh_loss
59 |
60 |
61 | def SPELoss():
62 | """
63 | Squared Percentage Error loss
64 | """
65 |
66 | def _gradient(yhat, dtrain):
67 | """
68 | Compute the gradient squared percentage error.
69 | Args:
70 | yhat (np.array): Predictions
71 | dtrain: The XGBoost / LightGBM dataset
72 |
73 | Returns:
74 | SPE Gradient
75 | """
76 | y = dtrain.get_label()
77 | return -2*(y-yhat)/(y**2)
78 |
79 | def _hessian(yhat, dtrain):
80 | """
81 | Compute the hessian for squared percentage error.
82 | Args:
83 | yhat (np.array): Predictions
84 | dtrain: The XGBoost / LightGBM dataset
85 |
86 | Returns:
87 | SPE Hessian
88 | """
89 | y = dtrain.get_label()
90 | return 2/(y**2)
91 |
92 | def squared_percentage(yhat, dtrain):
93 | """
94 | Calculate gradient and hessian for squared percentage error.
95 |
96 | Args:
97 | yhat (np.array): Predictions
98 | dtrain: The XGBoost / LightGBM dataset
99 |
100 | Returns:
101 | grad: SPE loss gradient
102 | hess: SPE loss Hessian
103 | """
104 | #yhat[yhat < -1] = -1 + 1e-6
105 | grad = _gradient(yhat, dtrain)
106 |
107 | hess = _hessian(yhat, dtrain)
108 |
109 | return grad, hess
110 |
111 | return squared_percentage
--------------------------------------------------------------------------------
/bokbokbok/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Import required functions."""
2 |
3 |
4 | from .functions import (
5 | clip_sigmoid
6 | )
7 |
8 | __all__ = [
9 | "clip_sigmoid",
10 | ]
11 |
--------------------------------------------------------------------------------
/bokbokbok/utils/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def clip_sigmoid(yhat):
5 | """
6 | Applies the sigmoid function and ensures that the values lie in the range
7 | 1e-15 < yhat < 1 - 1e-15.
8 | We clip to avoid dividing by zero in the loss functions.
9 |
10 | Args:
11 | yhat: The margin probabilities yet to be put into a sigmoid function
12 |
13 | Returns:
14 | yhat: The clipped probabilities
15 | """
16 | yhat = 1. / (1. + np.exp(-yhat))
17 | yhat[yhat >= 1] = 1 - 1e-15
18 | yhat[yhat <= 0] = 1e-15
19 | return yhat
20 |
--------------------------------------------------------------------------------
/docs/derivations/focal.md:
--------------------------------------------------------------------------------
1 | ## Weighted Focal Loss
2 |
3 | Weighted Focal Loss applies a scaling parameter *alpha* and a focusing parameter *gamma* to [Binary Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression)
4 |
5 | We take the definition of the Focal Loss from [this paper](https://arxiv.org/pdf/1708.02002.pdf):
6 |
7 |
8 |
9 | where:
10 |
11 |
12 |
13 | This is equivalent to writing:
14 |
15 |
16 |
17 |
18 | We calculate the Gradient:
19 |
20 |
21 |
22 |
23 | We also need to calculate the Hessian:
24 |
25 |
26 |
27 | By setting *alpha* = 1 and *gamma* = 0 we obtain the Gradient and Hessian for Binary Cross Entropy Loss, as expected.
--------------------------------------------------------------------------------
/docs/derivations/log_cosh.md:
--------------------------------------------------------------------------------
1 | ## Log Cosh Error
2 |
3 | The equation for Log Cosh Error is:
4 |
5 |
6 |
7 | We calculate the Gradient:
8 |
9 |
10 |
11 | We also need to calculate the Hessian:
12 |
13 |
--------------------------------------------------------------------------------
/docs/derivations/note.md:
--------------------------------------------------------------------------------
1 | For the gradient boosting packages we [have to calculate the gradient of the Loss function with respect to the marginal probabilites](https://github.com/Microsoft/LightGBM/blob/master/examples/python-guide/advanced_example.py).
2 |
3 | In this case, we must calculate
4 |
5 |
6 |
7 |
8 |
9 | The Hessian is similarly calculated:
10 |
11 |
12 |
13 |
14 | **Where y-hat is the sigmoid function, unless stated otherwise**:
15 |
16 |
17 |
18 | We will make use of the following property for the calculations of the Gradients and Hessians:
19 |
20 |
21 |
22 | Note to avoid divide-by-zero errors, we clip the values of the sigmoid such that the output of the sigmoid is bound by 10-15 from below and 1 - 10-15 from above.
--------------------------------------------------------------------------------
/docs/derivations/wce.md:
--------------------------------------------------------------------------------
1 | ## Weighted Cross Entropy Loss
2 |
3 | Weighted Cross Entropy applies a scaling parameter *alpha* to [Binary Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression),
4 | allowing us to penalise false positives or false negatives more harshly. If you want false
5 | positives to be penalised more than false negatives, *alpha* must be greater than 1. Otherwise,
6 | it must be less than 1.
7 |
8 | The equations for Binary and Weighted Cross Entropy Loss are the following:
9 |
10 |
11 |
12 |
13 |
14 | We calculate the Gradient:
15 |
16 |
17 |
18 |
19 | We also need to calculate the Hessian:
20 |
21 |
22 |
23 |
24 | By setting *alpha* = 1 we obtain the Gradient and Hessian for Binary Cross Entropy Loss, as expected.
--------------------------------------------------------------------------------
/docs/getting_started/install.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | In order to install bokbokbok you need to use Python 3.7 or higher.
4 |
5 | Install `bokbokbok` via pip with:
6 |
7 | ```bash
8 | pip install bokbokbok
9 | ```
10 |
11 | Alternatively you can fork/clone and run:
12 |
13 | ```bash
14 | git clone https://gitlab.com/orchardbirds/bokbokbok.git
15 | cd bokbokbok
16 | pip install .
17 | ```
--------------------------------------------------------------------------------
/docs/img/bokbokbok.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/orchardbirds/bokbokbok/3a8c52b7e2f6f80b803ff2e7073f3cfbf8f76b6c/docs/img/bokbokbok.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Welcome to the bokbokbok doks!
2 |
3 |
4 |
5 | **bokbokbok** is a Python library that lets us easily implement custom loss functions and eval metrics in LightGBM and XGBoost.
6 |
7 | ## Example Usage - Weighted Cross Entropy
8 |
9 | ```python
10 | clf = lgb.train(params=params,
11 | train_set=train,
12 | valid_sets=[train, valid],
13 | valid_names=['train','valid'],
14 | fobj=WeightedCrossEntropyLoss(alpha=alpha),
15 | feval=WeightedCrossEntropyMetric(alpha=alpha),
16 | early_stopping_rounds=100)
17 | ```
18 | ## Licence
19 | bokbokbok is created under the MIT License, see more in the LICENSE file
20 |
21 |
--------------------------------------------------------------------------------
/docs/reference/eval_metrics_binary.md:
--------------------------------------------------------------------------------
1 | ::: bokbokbok.eval_metrics.classification.binary_eval_metrics
--------------------------------------------------------------------------------
/docs/reference/eval_metrics_multiclass.md:
--------------------------------------------------------------------------------
1 | ::: bokbokbok.eval_metrics.classification.multiclass_eval_metrics
--------------------------------------------------------------------------------
/docs/reference/eval_metrics_regression.md:
--------------------------------------------------------------------------------
1 | ::: bokbokbok.eval_metrics.regression.regression_eval_metrics
--------------------------------------------------------------------------------
/docs/reference/loss_functions_classification.md:
--------------------------------------------------------------------------------
1 | ::: bokbokbok.loss_functions.classification.classification_loss_functions
--------------------------------------------------------------------------------
/docs/reference/loss_functions_regression.md:
--------------------------------------------------------------------------------
1 | ::: bokbokbok.loss_functions.regression.regression_loss_functions
--------------------------------------------------------------------------------
/docs/tutorials/F1_score.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### When to use F-Score\n",
8 | "\n",
9 | "The F-score or F-measure is a measure of a test's accuracy. It is calculated from the precision and recall of the test, where the precision is the number of true positive results divided by the number of all positive results, including those not identified correctly, and the recall is the number of true positive results divided by the number of all samples that should have been identified as positive. Precision is also known as positive predictive value, and recall is also known as sensitivity in diagnostic binary classification.\n",
10 | "\n",
11 | "The highest possible value of an F-score is 1.0, indicating perfect precision and recall, and the lowest possible value is 0, if either the precision or the recall is zero. "
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from sklearn.datasets import make_classification\n",
21 | "from sklearn.model_selection import train_test_split\n",
22 | "from sklearn.metrics import roc_auc_score\n",
23 | "from bokbokbok.eval_metrics.classification import F1_Score_Binary\n",
24 | "from bokbokbok.utils import clip_sigmoid\n",
25 | "\n",
26 | "X, y = make_classification(n_samples=1000, \n",
27 | " n_features=10, \n",
28 | " random_state=41114)\n",
29 | "\n",
30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
31 | " y, \n",
32 | " test_size=0.25, \n",
33 | " random_state=41114)"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "### Usage in LightGBM"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import lightgbm as lgb\n",
50 | "\n",
51 | "train = lgb.Dataset(X_train, y_train)\n",
52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
53 | "params = {\n",
54 | " 'n_estimators': 300,\n",
55 | " 'objective': 'binary',\n",
56 | " 'seed': 41114,\n",
57 | " 'n_jobs': 8,\n",
58 | " 'learning_rate': 0.1,\n",
59 | " }\n",
60 | "\n",
61 | "clf = lgb.train(params=params,\n",
62 | " train_set=train,\n",
63 | " valid_sets=[train, valid],\n",
64 | " valid_names=['train','valid'],\n",
65 | " feval=F1_Score_Binary(average='micro'),\n",
66 | " early_stopping_rounds=100)\n",
67 | "\n",
68 | "roc_auc_score(y_valid, clf.predict(X_valid))"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "### Usage in XGBoost"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "import xgboost as xgb\n",
85 | "\n",
86 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
87 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
88 | "\n",
89 | "params = {\n",
90 | " 'seed': 41114,\n",
91 | " 'objective':'binary:logistic',\n",
92 | " 'learning_rate': 0.1,\n",
93 | " 'disable_default_eval_metric': 1\n",
94 | " }\n",
95 | "\n",
96 | "bst = xgb.train(params,\n",
97 | " dtrain=dtrain,\n",
98 | " num_boost_round=300,\n",
99 | " early_stopping_rounds=10,\n",
100 | " verbose_eval=10,\n",
101 | " maximize=True,\n",
102 | " feval=F1_Score_Binary(average='micro', XGBoost=True),\n",
103 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n",
104 | "\n",
105 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))"
106 | ]
107 | }
108 | ],
109 | "metadata": {
110 | "kernelspec": {
111 | "display_name": "Python [conda env:skorecard_py37]",
112 | "language": "python",
113 | "name": "conda-env-skorecard_py37-py"
114 | },
115 | "language_info": {
116 | "codemirror_mode": {
117 | "name": "ipython",
118 | "version": 3
119 | },
120 | "file_extension": ".py",
121 | "mimetype": "text/x-python",
122 | "name": "python",
123 | "nbconvert_exporter": "python",
124 | "pygments_lexer": "ipython3",
125 | "version": "3.7.7"
126 | }
127 | },
128 | "nbformat": 4,
129 | "nbformat_minor": 2
130 | }
131 |
--------------------------------------------------------------------------------
/docs/tutorials/RMSPE.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### When to use (Root Mean) Squared Percentage Error?\n",
8 | "\n",
9 | "This function is defined according to [this Kaggle competition](https://www.kaggle.com/c/optiver-realized-volatility-prediction/overview/evaluation) for volatility calculation. \n",
10 | "\n",
11 | "RMSPE cannot be used as a Loss function - the gradient is constant and hence the Hessian is 0. Nevertheless, it can still be used as an evaluation metric as the model trains. To use the loss function, we simply remove the square for a non-zero Hessian."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from sklearn.datasets import make_regression\n",
21 | "from sklearn.model_selection import train_test_split\n",
22 | "from sklearn.metrics import mean_absolute_error\n",
23 | "from bokbokbok.eval_metrics.regression import RMSPEMetric\n",
24 | "from bokbokbok.loss_functions.regression import SPELoss\n",
25 | "\n",
26 | "X, y = make_regression(n_samples=10000, \n",
27 | " n_features=10, \n",
28 | " random_state=41114)\n",
29 | "\n",
30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
31 | " y, \n",
32 | " test_size=0.25, \n",
33 | " random_state=41114)\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "### Usage in LightGBM"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import lightgbm as lgb\n",
50 | "\n",
51 | "train = lgb.Dataset(X_train, y_train)\n",
52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
53 | "params = {\n",
54 | " 'n_estimators': 3000,\n",
55 | " 'seed': 41114,\n",
56 | " 'n_jobs': 8,\n",
57 | " 'max_leaves':10,\n",
58 | " }\n",
59 | "\n",
60 | "clf = lgb.train(params=params,\n",
61 | " train_set=train,\n",
62 | " valid_sets=[train, valid],\n",
63 | " valid_names=['train','valid'],\n",
64 | " fobj=SPELoss(),\n",
65 | " feval=RMSPEMetric(),\n",
66 | " early_stopping_rounds=100)\n",
67 | "\n",
68 | "mean_absolute_error(y_valid, clf.predict(X_valid))"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "### Usage in XGBoost"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "import xgboost as xgb\n",
85 | "\n",
86 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
87 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
88 | "\n",
89 | "params = {\n",
90 | " 'seed': 41114,\n",
91 | " 'learning_rate': 0.1,\n",
92 | " 'disable_default_eval_metric': 1\n",
93 | " }\n",
94 | "\n",
95 | "bst = xgb.train(params,\n",
96 | " dtrain=dtrain,\n",
97 | " num_boost_round=3000,\n",
98 | " early_stopping_rounds=100,\n",
99 | " verbose_eval=100,\n",
100 | " obj=SPELoss(),\n",
101 | " maximize=False,\n",
102 | " feval=RMSPEMetric(XGBoost=True),\n",
103 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n",
104 | "\n",
105 | "mean_absolute_error(y_valid, bst.predict(dvalid))"
106 | ]
107 | }
108 | ],
109 | "metadata": {
110 | "kernelspec": {
111 | "display_name": "Python [conda env:skorecard_py37]",
112 | "language": "python",
113 | "name": "conda-env-skorecard_py37-py"
114 | },
115 | "language_info": {
116 | "codemirror_mode": {
117 | "name": "ipython",
118 | "version": 3
119 | },
120 | "file_extension": ".py",
121 | "mimetype": "text/x-python",
122 | "name": "python",
123 | "nbconvert_exporter": "python",
124 | "pygments_lexer": "ipython3",
125 | "version": "3.7.7"
126 | }
127 | },
128 | "nbformat": 4,
129 | "nbformat_minor": 4
130 | }
131 |
--------------------------------------------------------------------------------
/docs/tutorials/focal_loss.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### When to use Focal Loss?\n",
8 | "\n",
9 | "Focal Loss addresses class imbalance in tasks such as object detection. Focal loss applies a modulating term to the Cross Entropy loss in order to focus learning on hard negative examples. It is a dynamically scaled Cross Entropy loss, where the scaling factor decays to zero as confidence in the correct class increases. Intuitively, this scaling factor can automatically down-weight the contribution of easy examples during training and rapidly focus the model on hard examples. This scaling factor is *gamma*. The more *gamma* is increased, the more the model is focussed on the hard, misclassified examples.\n",
10 | "\n",
11 | "We employ Weighted Focal Loss, which further allows us to reduce false positives or false negatives depending on our value of *alpha*:\n",
12 | "\n",
13 | "A value *alpha* > 1 decreases the false negative count, hence increasing the recall. Conversely, setting *alpha* < 1 decreases the false positive count and increases the precision. "
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "from sklearn.datasets import make_classification\n",
23 | "from sklearn.model_selection import train_test_split\n",
24 | "from sklearn.metrics import roc_auc_score\n",
25 | "from bokbokbok.loss_functions.classification import WeightedFocalLoss\n",
26 | "from bokbokbok.eval_metrics.classification import WeightedFocalMetric\n",
27 | "from bokbokbok.utils import clip_sigmoid\n",
28 | "\n",
29 | "X, y = make_classification(n_samples=1000, \n",
30 | " n_features=10, \n",
31 | " random_state=41114)\n",
32 | "\n",
33 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
34 | " y, \n",
35 | " test_size=0.25, \n",
36 | " random_state=41114)\n",
37 | "\n",
38 | "alpha = 0.7 # Reduce False Positives\n",
39 | "gamma = 2 # Focus on misclassified examples more strictly"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "### Usage in LightGBM"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "import lightgbm as lgb\n",
56 | "\n",
57 | "train = lgb.Dataset(X_train, y_train)\n",
58 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
59 | "params = {\n",
60 | " 'n_estimators': 300,\n",
61 | " 'seed': 41114,\n",
62 | " 'n_jobs': 8,\n",
63 | " 'learning_rate': 0.1,\n",
64 | " }\n",
65 | "\n",
66 | "clf = lgb.train(params=params,\n",
67 | " train_set=train,\n",
68 | " valid_sets=[train, valid],\n",
69 | " valid_names=['train','valid'],\n",
70 | " fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma),\n",
71 | " feval=WeightedFocalMetric(alpha=alpha, gamma=gamma),\n",
72 | " early_stopping_rounds=100)\n",
73 | "\n",
74 | "roc_auc_score(y_valid, clip_sigmoid(clf.predict(X_valid)))"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "### Usage in XGBoost"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "import xgboost as xgb\n",
91 | "\n",
92 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
93 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
94 | "\n",
95 | "params = {\n",
96 | " 'seed': 41114,\n",
97 | " 'learning_rate': 0.1,\n",
98 | " 'disable_default_eval_metric': 1\n",
99 | " }\n",
100 | "\n",
101 | "bst = xgb.train(params,\n",
102 | " dtrain=dtrain,\n",
103 | " num_boost_round=300,\n",
104 | " early_stopping_rounds=10,\n",
105 | " verbose_eval=10,\n",
106 | " obj=WeightedFocalLoss(alpha=alpha, gamma=gamma),\n",
107 | " maximize=False,\n",
108 | " feval=WeightedFocalMetric(alpha=alpha, gamma=gamma, XGBoost=True),\n",
109 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n",
110 | "\n",
111 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))"
112 | ]
113 | }
114 | ],
115 | "metadata": {
116 | "kernelspec": {
117 | "display_name": "Python [conda env:skorecard_py37]",
118 | "language": "python",
119 | "name": "conda-env-skorecard_py37-py"
120 | },
121 | "language_info": {
122 | "codemirror_mode": {
123 | "name": "ipython",
124 | "version": 3
125 | },
126 | "file_extension": ".py",
127 | "mimetype": "text/x-python",
128 | "name": "python",
129 | "nbconvert_exporter": "python",
130 | "pygments_lexer": "ipython3",
131 | "version": "3.7.7"
132 | }
133 | },
134 | "nbformat": 4,
135 | "nbformat_minor": 2
136 | }
137 |
--------------------------------------------------------------------------------
/docs/tutorials/log_cosh_loss.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### When to use Log Cosh Loss?\n",
8 | "\n",
9 | "Log Cosh Loss addresses the small number of problems that can arise from using Mean Absolute Error due to its sharpness. Log(cosh(x)) is a way to very closely approximate Mean Absolute Error while retaining a 'smooth' function.\n",
10 | "\n",
11 | "Do note that large y-values can cause issues here, which is why the y-values are scaled below"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from sklearn.datasets import make_regression\n",
21 | "from sklearn.model_selection import train_test_split\n",
22 | "from sklearn.metrics import mean_absolute_error\n",
23 | "from bokbokbok.eval_metrics.regression import LogCoshMetric\n",
24 | "from bokbokbok.loss_functions.regression import LogCoshLoss\n",
25 | "\n",
26 | "X, y = make_regression(n_samples=1000, \n",
27 | " n_features=10, \n",
28 | " random_state=41114)\n",
29 | "\n",
30 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
31 | " y/100, \n",
32 | " test_size=0.25, \n",
33 | " random_state=41114)\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "### Usage in LightGBM"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import lightgbm as lgb\n",
50 | "\n",
51 | "train = lgb.Dataset(X_train, y_train)\n",
52 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
53 | "params = {\n",
54 | " 'n_estimators': 3000,\n",
55 | " 'seed': 41114,\n",
56 | " 'n_jobs': 8,\n",
57 | " 'learning_rate': 0.1,\n",
58 | " 'verbose': 100,\n",
59 | " }\n",
60 | "\n",
61 | "clf = lgb.train(params=params,\n",
62 | " train_set=train,\n",
63 | " valid_sets=[train, valid],\n",
64 | " valid_names=['train','valid'],\n",
65 | " fobj=LogCoshLoss(),\n",
66 | " feval=LogCoshMetric(),\n",
67 | " early_stopping_rounds=100,\n",
68 | " verbose_eval=100)\n",
69 | "\n",
70 | "mean_absolute_error(y_valid, clf.predict(X_valid))"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "### Usage in XGBoost"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "import xgboost as xgb\n",
87 | "\n",
88 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
89 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
90 | "\n",
91 | "params = {\n",
92 | " 'seed': 41114,\n",
93 | " 'learning_rate': 0.1,\n",
94 | " 'disable_default_eval_metric': 1\n",
95 | " }\n",
96 | "\n",
97 | "bst = xgb.train(params,\n",
98 | " dtrain=dtrain,\n",
99 | " num_boost_round=3000,\n",
100 | " early_stopping_rounds=10,\n",
101 | " verbose_eval=100,\n",
102 | " obj=LogCoshLoss(),\n",
103 | " maximize=False,\n",
104 | " feval=LogCoshMetric(XGBoost=True),\n",
105 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n",
106 | "\n",
107 | "mean_absolute_error(y_valid, bst.predict(dvalid))"
108 | ]
109 | }
110 | ],
111 | "metadata": {
112 | "kernelspec": {
113 | "display_name": "Python [conda env:skorecard_py37]",
114 | "language": "python",
115 | "name": "conda-env-skorecard_py37-py"
116 | },
117 | "language_info": {
118 | "codemirror_mode": {
119 | "name": "ipython",
120 | "version": 3
121 | },
122 | "file_extension": ".py",
123 | "mimetype": "text/x-python",
124 | "name": "python",
125 | "nbconvert_exporter": "python",
126 | "pygments_lexer": "ipython3",
127 | "version": "3.7.7"
128 | }
129 | },
130 | "nbformat": 4,
131 | "nbformat_minor": 2
132 | }
133 |
--------------------------------------------------------------------------------
/docs/tutorials/quadratic_weighted_kappa.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from sklearn.datasets import make_multilabel_classification\n",
10 | "from sklearn.model_selection import train_test_split\n",
11 | "from sklearn.metrics import cohen_kappa_score\n",
12 | "from bokbokbok.eval_metrics.classification import QuadraticWeightedKappaMetric\n",
13 | "from bokbokbok.utils import clip_sigmoid\n",
14 | "\n",
15 | "X, y = make_multilabel_classification(n_samples=1000, \n",
16 | " n_features=10,\n",
17 | " n_classes=2,\n",
18 | " n_labels=1,\n",
19 | " random_state=41114)\n",
20 | "y = y.sum(axis=1)\n",
21 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
22 | " y, \n",
23 | " test_size=0.25, \n",
24 | " random_state=41114)\n"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "### Usage in LightGBM"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import lightgbm as lgb\n",
41 | "\n",
42 | "train = lgb.Dataset(X_train, y_train)\n",
43 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
44 | "params = {\n",
45 | " 'n_estimators': 300,\n",
46 | " 'seed': 41114,\n",
47 | " 'n_jobs': 8,\n",
48 | " 'learning_rate': 0.1,\n",
49 | " 'objective':'multiclass',\n",
50 | " 'num_class': 3\n",
51 | " }\n",
52 | "\n",
53 | "clf = lgb.train(params=params,\n",
54 | " train_set=train,\n",
55 | " valid_sets=[train, valid],\n",
56 | " valid_names=['train','valid'],\n",
57 | " feval=QuadraticWeightedKappaMetric(),\n",
58 | " early_stopping_rounds=100)\n"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "### Usage in XGBoost"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "import xgboost as xgb\n",
75 | "\n",
76 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
77 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
78 | "\n",
79 | "params = {\n",
80 | " 'seed': 41114,\n",
81 | " 'learning_rate': 0.1,\n",
82 | " 'disable_default_eval_metric': 1,\n",
83 | " 'objective': 'multi:softprob',\n",
84 | " 'num_class': 3\n",
85 | " }\n",
86 | "\n",
87 | "bst = xgb.train(params,\n",
88 | " dtrain=dtrain,\n",
89 | " num_boost_round=300,\n",
90 | " early_stopping_rounds=100,\n",
91 | " verbose_eval=10,\n",
92 | " feval=QuadraticWeightedKappaMetric(XGBoost=True),\n",
93 | " maximize=True,\n",
94 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n"
95 | ]
96 | }
97 | ],
98 | "metadata": {
99 | "kernelspec": {
100 | "display_name": "Python [conda env:skorecard_py37]",
101 | "language": "python",
102 | "name": "conda-env-skorecard_py37-py"
103 | },
104 | "language_info": {
105 | "codemirror_mode": {
106 | "name": "ipython",
107 | "version": 3
108 | },
109 | "file_extension": ".py",
110 | "mimetype": "text/x-python",
111 | "name": "python",
112 | "nbconvert_exporter": "python",
113 | "pygments_lexer": "ipython3",
114 | "version": "3.7.7"
115 | }
116 | },
117 | "nbformat": 4,
118 | "nbformat_minor": 5
119 | }
120 |
--------------------------------------------------------------------------------
/docs/tutorials/weighted_cross_entropy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### When to use Weighted Cross Entropy?\n",
8 | "\n",
9 | "A factor *alpha* is added in to Cross Entropy, allowing one to trade off recall and precision by up- or down-weighting the cost of a positive error relative to a negative error.\n",
10 | "\n",
11 | "A value *alpha* > 1 decreases the false negative count, hence increasing the recall. Conversely, setting *alpha* < 1 decreases the false positive count and increases the precision. "
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from sklearn.datasets import make_classification\n",
21 | "from sklearn.model_selection import train_test_split\n",
22 | "from sklearn.metrics import roc_auc_score\n",
23 | "from bokbokbok.loss_functions.classification import WeightedCrossEntropyLoss\n",
24 | "from bokbokbok.eval_metrics.classification import WeightedCrossEntropyMetric\n",
25 | "from bokbokbok.utils import clip_sigmoid\n",
26 | "\n",
27 | "X, y = make_classification(n_samples=1000, \n",
28 | " n_features=10, \n",
29 | " random_state=41114)\n",
30 | "\n",
31 | "X_train, X_valid, y_train, y_valid = train_test_split(X, \n",
32 | " y, \n",
33 | " test_size=0.25, \n",
34 | " random_state=41114)\n",
35 | "\n",
36 | "alpha = 0.7 # Reduce False Positives"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {},
42 | "source": [
43 | "### Usage in LightGBM"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "import lightgbm as lgb\n",
53 | "\n",
54 | "train = lgb.Dataset(X_train, y_train)\n",
55 | "valid = lgb.Dataset(X_valid, y_valid, reference=train)\n",
56 | "params = {\n",
57 | " 'n_estimators': 300,\n",
58 | " 'seed': 41114,\n",
59 | " 'n_jobs': 8,\n",
60 | " 'learning_rate': 0.1,\n",
61 | " }\n",
62 | "\n",
63 | "clf = lgb.train(params=params,\n",
64 | " train_set=train,\n",
65 | " valid_sets=[train, valid],\n",
66 | " valid_names=['train','valid'],\n",
67 | " fobj=WeightedCrossEntropyLoss(alpha=alpha),\n",
68 | " feval=WeightedCrossEntropyMetric(alpha=alpha),\n",
69 | " early_stopping_rounds=100)\n",
70 | "\n",
71 | "roc_auc_score(y_valid, clip_sigmoid(clf.predict(X_valid)))"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "### Usage in XGBoost"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "import xgboost as xgb\n",
88 | "\n",
89 | "dtrain = xgb.DMatrix(X_train, y_train)\n",
90 | "dvalid = xgb.DMatrix(X_valid, y_valid)\n",
91 | "\n",
92 | "params = {\n",
93 | " 'seed': 41114,\n",
94 | " 'learning_rate': 0.1,\n",
95 | " 'disable_default_eval_metric': 1\n",
96 | " }\n",
97 | "\n",
98 | "bst = xgb.train(params,\n",
99 | " dtrain=dtrain,\n",
100 | " num_boost_round=300,\n",
101 | " early_stopping_rounds=10,\n",
102 | " verbose_eval=10,\n",
103 | " obj=WeightedCrossEntropyLoss(alpha=alpha),\n",
104 | " maximize=False,\n",
105 | " feval=WeightedCrossEntropyMetric(alpha=alpha, XGBoost=True),\n",
106 | " evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])\n",
107 | "\n",
108 | "roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))"
109 | ]
110 | }
111 | ],
112 | "metadata": {
113 | "kernelspec": {
114 | "display_name": "Python [conda env:skorecard_py37]",
115 | "language": "python",
116 | "name": "conda-env-skorecard_py37-py"
117 | },
118 | "language_info": {
119 | "codemirror_mode": {
120 | "name": "ipython",
121 | "version": 3
122 | },
123 | "file_extension": ".py",
124 | "mimetype": "text/x-python",
125 | "name": "python",
126 | "nbconvert_exporter": "python",
127 | "pygments_lexer": "ipython3",
128 | "version": "3.7.7"
129 | }
130 | },
131 | "nbformat": 4,
132 | "nbformat_minor": 2
133 | }
134 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: bokbokbok doks
2 |
3 | repo_url: https://github.com/orchardbirds/bokbokbok/
4 | site_url: https://orchardbirds.github.io/bokbokbok/
5 | site_description: Implementing Custom Loss Functions and Eval Metrics in LightGBM and XGBoost
6 | site_author: Daniel Timbrell
7 |
8 | use_directory_urls: false
9 |
10 | nav:
11 | - Home: index.md
12 | - Getting started:
13 | - getting_started/install.md
14 | - How To:
15 | - Use Weighted Cross Entropy: tutorials/weighted_cross_entropy.ipynb
16 | - Use Weighted Focal Loss: tutorials/focal_loss.ipynb
17 | - Use F1 Score: tutorials/F1_score.ipynb
18 | - Use Log Cosh Score: tutorials/log_cosh_loss.ipynb
19 | - Use Root Mean Squared Percentage Error: tutorials/RMSPE.ipynb
20 | - Use Quadratic Weighted Kappa: tutorials/quadratic_weighted_kappa.ipynb
21 | - Derivations:
22 | - A Note About Gradients in Classification Problems: derivations/note.md
23 | - Weighted Cross Entropy: derivations/wce.md
24 | - Focal Loss: derivations/focal.md
25 | - Log Cosh Error: derivations/log_cosh.md
26 | - Reference:
27 | - Evaluation Metrics:
28 | - bokbokbok.eval_metrics.binary_classification: reference/eval_metrics_binary.md
29 | - bokbokbok.eval_metrics.multiclass_classification: reference/eval_metrics_multiclass.md
30 | - bokbokbok.eval_metrics.regression: reference/eval_metrics_regression.md
31 | - Loss Functions:
32 | - bokbokbok.loss_functions.classification: reference/loss_functions_classification.md
33 | - bokbokbok.loss_functions.regression: reference/loss_functions_regression.md
34 |
35 | plugins:
36 | - mkdocstrings:
37 | handlers:
38 | python:
39 | selection:
40 | inherited_members: true
41 | filters:
42 | - "!^Base"
43 | - "!^_" # exlude all members starting with _
44 | - "^__init__$" # but always include __init__ modules and methods
45 | rendering:
46 | show_root_toc_entry: false
47 | watch:
48 | - bokbokbok
49 | - search
50 | - mknotebooks:
51 | enable_default_jupyter_cell_styling: true
52 | enable_default_pandas_dataframe_styling: true
53 |
54 | copyright: Copyright © 2020
55 |
56 | theme:
57 | name: material
58 | logo: img/bokbokbok.png
59 | favicon: img/bokbokbok.png
60 | font:
61 | text: Ubuntu
62 | code: Ubuntu Mono
63 | features:
64 | - navigation.tabs
65 | palette:
66 | scheme: default
67 | primary: teal
68 | accent: yellow
69 |
70 |
71 | markdown_extensions:
72 | - codehilite
73 | - pymdownx.highlight
74 | - pymdownx.inlinehilite
75 | - pymdownx.superfences
76 | - pymdownx.details
77 | - pymdownx.tabbed
78 | - pymdownx.snippets
79 | - pymdownx.highlight:
80 | use_pygments: true
81 | - toc:
82 | permalink: true
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r", encoding="UTF-8") as fh:
4 | long_description = fh.read()
5 |
6 | base_packages = [
7 | "numpy>=1.19.2",
8 | "scikit-learn>=0.23.2",
9 | ]
10 |
11 | dev_dep = [
12 | "flake8>=3.8.3",
13 | "black>=19.10b0",
14 | "pre-commit>=2.5.0",
15 | "mypy>=0.770",
16 | "flake8-docstrings>=1.4.0" "pytest>=6.0.0",
17 | "pytest-cov>=2.10.0",
18 | "lightgbm>=3.0.0",
19 | "xgboost>=1.3.3",
20 | ]
21 |
22 | docs_dep = [
23 | "mkdocs-material>=6.1.0",
24 | "mkdocs-git-revision-date-localized-plugin>=0.7.2",
25 | "mkdocs-git-authors-plugin>=0.3.2",
26 | "mkdocs-table-reader-plugin>=0.4.1",
27 | "mkdocs-enumerate-headings-plugin>=0.4.3",
28 | "mkdocs-awesome-pages-plugin>=2.4.0",
29 | "mkdocs-minify-plugin>=0.3.0",
30 | "mknotebooks>=0.6.2",
31 | "mkdocstrings>=0.13.6",
32 | "mkdocs-print-site-plugin>=0.8.2",
33 | "mkdocs-markdownextradata-plugin>=0.1.9",
34 | ]
35 |
36 | setup(
37 | name="bokbokbok",
38 | version="0.6.1",
39 | description="Custom Losses and Metrics for XGBoost, LightGBM, CatBoost",
40 | long_description=long_description,
41 | long_description_content_type="text/markdown",
42 | author="Daniel Timbrell",
43 | author_email="dantimbrell@gmail.com",
44 | license="Open Source",
45 | python_requires=">=3.6",
46 | classifiers=[
47 | "License :: OSI Approved :: MIT License",
48 | "Operating System :: OS Independent",
49 | "Programming Language :: Python",
50 | "Programming Language :: Python :: 3",
51 | "Programming Language :: Python :: 3.6",
52 | "Programming Language :: Python :: 3.7",
53 | "Programming Language :: Python :: 3.8",
54 | "Programming Language :: Python :: 3 :: Only",
55 | ],
56 | include_package_data=True,
57 | install_requires=base_packages,
58 | extras_require={
59 | "base": base_packages,
60 | "all": base_packages + dev_dep + docs_dep
61 | },
62 | url="https://github.com/orchardbirds/bokbokbok",
63 | packages=find_packages(".", exclude=["tests", "notebooks", "docs"]),
64 | )
65 |
--------------------------------------------------------------------------------
/tests/test_focal.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import make_classification
2 | from sklearn.model_selection import train_test_split
3 | from sklearn.metrics import mean_absolute_error
4 | from bokbokbok.loss_functions.classification import WeightedFocalLoss, WeightedCrossEntropyLoss
5 | from bokbokbok.eval_metrics.classification import WeightedFocalMetric, WeightedCrossEntropyMetric
6 | from bokbokbok.utils import clip_sigmoid
7 | import lightgbm as lgb
8 |
9 |
10 | def test_focal_lgb_implementation():
11 | """
12 | Assert that there is no difference between running Focal with alpha=1 and gamma=0
13 | and LightGBM's internal CE loss.
14 | """
15 | X, y = make_classification(n_samples=1000,
16 | n_features=10,
17 | random_state=41114)
18 |
19 | X_train, X_valid, y_train, y_valid = train_test_split(X,
20 | y,
21 | test_size=0.25,
22 | random_state=41114)
23 |
24 | alpha = 1.0
25 | gamma = 0
26 |
27 | train = lgb.Dataset(X_train, y_train)
28 | valid = lgb.Dataset(X_valid, y_valid, reference=train)
29 |
30 | params_wfl = {
31 | 'n_estimators': 300,
32 | 'seed': 41114,
33 | 'n_jobs': 8,
34 | 'learning_rate': 0.1,
35 | }
36 |
37 | wfl_clf = lgb.train(params=params_wfl,
38 | train_set=train,
39 | valid_sets=[train, valid],
40 | valid_names=['train','valid'],
41 | fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma),
42 | feval=WeightedFocalMetric(alpha=alpha, gamma=gamma),
43 | early_stopping_rounds=100)
44 |
45 |
46 | params = {
47 | 'n_estimators': 300,
48 | 'objective': 'cross_entropy',
49 | 'seed': 41114,
50 | 'n_jobs': 8,
51 | 'metric': 'cross_entropy',
52 | 'learning_rate': 0.1,
53 | 'boost_from_average': False
54 | }
55 |
56 | clf = lgb.train(params=params,
57 | train_set=train,
58 | valid_sets=[train, valid],
59 | valid_names=['train','valid'],
60 | early_stopping_rounds=100)
61 |
62 | wfl_preds = clip_sigmoid(wfl_clf.predict(X_valid))
63 | preds = clf.predict(X_valid)
64 | assert mean_absolute_error(wfl_preds, preds) == 0.0
65 |
66 |
67 | def test_focal_wce_comparison():
68 | """
69 | Assert that there is no difference between running Focal with alpha=3 and gamma=0
70 | and running WCE with alpha=3.
71 | """
72 | X, y = make_classification(n_samples=1000,
73 | n_features=10,
74 | random_state=41114)
75 |
76 | X_train, X_valid, y_train, y_valid = train_test_split(X,
77 | y,
78 | test_size=0.25,
79 | random_state=41114)
80 |
81 | alpha = 3.0
82 | gamma = 0
83 |
84 | train = lgb.Dataset(X_train, y_train)
85 | valid = lgb.Dataset(X_valid, y_valid, reference=train)
86 |
87 | params_wfl = {
88 | 'n_estimators': 300,
89 | 'seed': 41114,
90 | 'n_jobs': 8,
91 | 'learning_rate': 0.1,
92 | }
93 |
94 | wfl_clf = lgb.train(params=params_wfl,
95 | train_set=train,
96 | valid_sets=[train, valid],
97 | valid_names=['train','valid'],
98 | fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma),
99 | feval=WeightedFocalMetric(alpha=alpha, gamma=gamma),
100 | early_stopping_rounds=100)
101 |
102 |
103 | params_wce = {
104 | 'n_estimators': 300,
105 | 'seed': 41114,
106 | 'n_jobs': 8,
107 | 'learning_rate': 0.1,
108 | }
109 |
110 | wce_clf = lgb.train(params=params_wce,
111 | train_set=train,
112 | valid_sets=[train, valid],
113 | valid_names=['train','valid'],
114 | fobj=WeightedCrossEntropyLoss(alpha=alpha),
115 | feval=WeightedCrossEntropyMetric(alpha=alpha),
116 | early_stopping_rounds=100)
117 |
118 | wfl_preds = clip_sigmoid(wfl_clf.predict(X_valid))
119 | wce_preds = clip_sigmoid(wce_clf.predict(X_valid))
120 | assert mean_absolute_error(wfl_preds, wce_preds) == 0.0
121 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from bokbokbok.utils import clip_sigmoid
4 |
5 | def test_clip_sigmoid():
6 | assert np.allclose(a=clip_sigmoid(np.array([100, 0, -100])),
7 | b=[1 - 1e-15, 0.5, 1e-15])
--------------------------------------------------------------------------------
/tests/test_wce.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import make_classification
2 | from sklearn.model_selection import train_test_split
3 | from sklearn.metrics import mean_absolute_error
4 | from bokbokbok.loss_functions.classification import WeightedCrossEntropyLoss
5 | from bokbokbok.eval_metrics.classification import WeightedCrossEntropyMetric
6 | from bokbokbok.utils import clip_sigmoid
7 | import lightgbm as lgb
8 |
9 |
10 | def test_wce_lgb_implementation():
11 | """
12 | Assert that there is no difference between running WCE with alpha=1
13 | and LightGBM's internal CE loss.
14 | """
15 | X, y = make_classification(n_samples=1000,
16 | n_features=10,
17 | random_state=41114)
18 |
19 | X_train, X_valid, y_train, y_valid = train_test_split(X,
20 | y,
21 | test_size=0.25,
22 | random_state=41114)
23 |
24 | alpha = 1.0
25 |
26 | train = lgb.Dataset(X_train, y_train)
27 | valid = lgb.Dataset(X_valid, y_valid, reference=train)
28 |
29 | params_wce = {
30 | 'n_estimators': 300,
31 | 'seed': 41114,
32 | 'n_jobs': 8,
33 | 'learning_rate': 0.1,
34 | }
35 |
36 | wce_clf = lgb.train(params=params_wce,
37 | train_set=train,
38 | valid_sets=[train, valid],
39 | valid_names=['train','valid'],
40 | fobj=WeightedCrossEntropyLoss(alpha=1.0),
41 | feval=WeightedCrossEntropyMetric(alpha=1.0),
42 | early_stopping_rounds=100)
43 |
44 |
45 | params = {
46 | 'n_estimators': 300,
47 | 'objective': 'cross_entropy',
48 | 'seed': 41114,
49 | 'n_jobs': 8,
50 | 'metric': 'cross_entropy',
51 | 'learning_rate': 0.1,
52 | 'boost_from_average': False
53 | }
54 |
55 | clf = lgb.train(params=params,
56 | train_set=train,
57 | valid_sets=[train, valid],
58 | valid_names=['train','valid'],
59 | early_stopping_rounds=100)
60 |
61 | wce_preds = clip_sigmoid(wce_clf.predict(X_valid))
62 | preds = clf.predict(X_valid)
63 | assert mean_absolute_error(wce_preds, preds) == 0.0
--------------------------------------------------------------------------------