├── .bumpversion.cfg ├── .coveragerc ├── .github └── workflows │ ├── release.yaml │ └── tox.yaml ├── .gitignore ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── badges ├── .gitignore └── coverage.svg ├── build_scripts ├── release-version.sh ├── run_pylint.py └── update_docs.py ├── config.json ├── config.py ├── docs ├── .gitignore ├── conf.py ├── getting-started.rst ├── index.rst ├── kyle │ ├── datasets.rst │ ├── evaluation.rst │ ├── evaluation │ │ ├── continuous.rst │ │ └── discrete.rst │ ├── index.rst │ ├── integrals.rst │ ├── metrics.rst │ ├── metrics │ │ └── calibration_metrics.rst │ ├── models.rst │ ├── models │ │ ├── calibratable_model.rst │ │ └── resnet.rst │ ├── sampling.rst │ ├── sampling │ │ └── fake_clf.rst │ ├── transformations.rst │ └── util.rst └── requirements.txt ├── notebooks ├── calibration_demo.ipynb ├── evaluating_cal_methods.ipynb ├── fake_classifiers.ipynb ├── metric_convergence_analysis.ipynb ├── test_notebooks.py └── trained_models │ └── lenet5.ckpt ├── notebooks_needing_refactoring └── fitting_fake_classifiers.ipynb ├── public ├── .nojekyll ├── coverage │ └── .gitignore ├── docs │ └── .gitignore └── index.html ├── pyproject.toml ├── pytest.ini ├── requirements-dev.txt ├── requirements-torch.txt ├── requirements.txt ├── setup.py ├── src ├── __init__.py └── kyle │ ├── __init__.py │ ├── calibration │ ├── __init__.py │ ├── calibration_methods.py │ └── model_calibrator.py │ ├── datasets.py │ ├── evaluation │ ├── __init__.py │ ├── continuous.py │ ├── discrete.py │ └── reliabilities.py │ ├── integrals.py │ ├── metrics │ ├── __init__.py │ └── calibration_metrics.py │ ├── models │ ├── __init__.py │ ├── calibratable_model.py │ └── resnet.py │ ├── sampling │ ├── __init__.py │ └── fake_clf.py │ ├── transformations.py │ └── util.py ├── tests ├── conftest.py └── kyle │ ├── calibration │ ├── calibration_methods │ │ └── test_calibration_methods.py │ └── test_model_calibrator.py │ ├── metrics │ └── test_metrics.py │ ├── sampling │ └── test_fake_clf.py │ └── test_util.py └── tox.ini /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.8-dev0 3 | commit = False 4 | tag = False 5 | allow_dirty = False 6 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? 7 | serialize = 8 | {major}.{minor}.{patch}-{release}{build} 9 | {major}.{minor}.{patch} 10 | 11 | [bumpversion:part:release] 12 | optional_value = prod 13 | first_value = dev 14 | values = 15 | dev 16 | prod 17 | 18 | [bumpversion:part:build] 19 | 20 | [bumpversion:file:setup.py] 21 | 22 | [bumpversion:file:src/kyle/__init__.py] 23 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = src 3 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | # Sequence of patterns matched against refs/tags 4 | tags: 5 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 6 | 7 | name: Create Release 8 | 9 | jobs: 10 | build: 11 | name: Create GitHub Release 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v2 16 | - name: Create Release 17 | id: create_release 18 | uses: actions/create-release@v1 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token 21 | with: 22 | tag_name: ${{ github.ref }} 23 | release_name: Release ${{ github.ref }} 24 | body: | 25 | Changes in this Release 26 | - First Change 27 | - Second Change 28 | draft: false 29 | prerelease: false 30 | deploy: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v2 34 | - name: Set up Python for PyPI Release 35 | uses: actions/setup-python@v1 36 | with: 37 | python-version: '3.8' 38 | - name: Install dependencies for PyPI Release 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install setuptools wheel twine 42 | - name: Build and publish to PyPI 43 | env: 44 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 45 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 46 | run: | 47 | python setup.py sdist bdist_wheel 48 | twine upload dist/* 49 | -------------------------------------------------------------------------------- /.github/workflows/tox.yaml: -------------------------------------------------------------------------------- 1 | name: Merge develop, run tests and build documentation 2 | 3 | on: 4 | pull_request: 5 | branches: [develop] 6 | push: 7 | branches: [develop, master] 8 | workflow_dispatch: 9 | inputs: 10 | reason: 11 | description: Why did you trigger the pipeline? 12 | required: False 13 | default: Check if it runs again due to external changes 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | # pandoc needed for docu, see https://nbsphinx.readthedocs.io/en/0.7.1/installation.html?highlight=pandoc#pandoc 21 | - name: Install Non-Python Packages 22 | run: sudo apt-get update -yq && sudo apt-get -yq install pandoc 23 | - uses: actions/checkout@v2.3.1 24 | with: 25 | fetch-depth: 0 26 | lfs: true 27 | persist-credentials: false 28 | # lfs=true is not enough, see https://stackoverflow.com/questions/61463578/github-actions-actions-checkoutv2-lfs-true-flag-not-converting-pointers-to-act 29 | - name: Checkout LFS Objects 30 | run: git lfs pull 31 | - name: Merge develop into current branch 32 | if: github.ref != 'refs/heads/develop' 33 | run: | 34 | git fetch origin develop:develop --update-head-ok 35 | git merge develop 36 | - name: Setup Python 3.8 37 | uses: actions/setup-python@v1 38 | with: 39 | python-version: "3.8" 40 | - name: Install Tox and Python Packages 41 | run: pip install tox 42 | - name: Run Tox 43 | run: tox 44 | - name: Prepare Pages 45 | if: github.ref == 'refs/heads/develop' 46 | run: | 47 | mv docs/_build/html/* public/docs 48 | mv htmlcov/* public/coverage 49 | - name: Deploy Pages 50 | uses: JamesIves/github-pages-deploy-action@3.7.1 51 | if: github.ref == 'refs/heads/develop' 52 | with: 53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 54 | BRANCH: gh-pages 55 | FOLDER: public 56 | TARGET_FOLDER: . 57 | CLEAN: true 58 | SINGLE_COMMIT: true 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | .idea 3 | config_local.json 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # reports 136 | pylint.html 137 | 138 | data -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: "python:3.8-buster" 2 | 3 | stages: 4 | - tox 5 | - documentation 6 | - build 7 | - publish 8 | - update-tox-cache 9 | 10 | variables: 11 | PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" 12 | 13 | cache: &global_cache 14 | paths: 15 | - .cache/pip 16 | - .venv/ 17 | - .tox 18 | - apt-cache/ 19 | key: ${CI_COMMIT_REF_SLUG} 20 | 21 | # Pip's cache doesn't store the python packages 22 | # https://pip.pypa.io/en/stable/reference/pip_install/#caching 23 | before_script: 24 | - mkdir -p apt-cache 25 | # pandoc needed for docu, see https://nbsphinx.readthedocs.io/en/0.7.1/installation.html?highlight=pandoc#pandoc 26 | - apt-get update -yq && apt-get -o dir::cache::archives="$(pwd)/apt-cache" -yq install pandoc 27 | - if [ -e $LOCAL_CONFIG ]; then mv $CONFIG_LOCAL ./config_local.json && echo "retrieved local config"; fi 28 | - pip install virtualenv 29 | - virtualenv .venv 30 | - source .venv/bin/activate 31 | 32 | .tox_job: &tox_job 33 | stage: tox 34 | script: 35 | - pip install tox 36 | - tox 37 | artifacts: 38 | paths: 39 | - badges 40 | - docs/_build 41 | - htmlcov 42 | - pylint.html 43 | 44 | tox_recreate: 45 | only: 46 | changes: 47 | - requirements.txt 48 | cache: 49 | # push cache if dependencies have changed 50 | <<: *global_cache 51 | policy: push 52 | <<: *tox_job 53 | 54 | tox_use_cache: 55 | except: 56 | changes: 57 | - requirements.txt 58 | cache: 59 | # use cache if dependencies haven't changed 60 | <<: *global_cache 61 | policy: pull 62 | <<: *tox_job 63 | 64 | pages: 65 | cache: {} 66 | stage: documentation 67 | script: 68 | - mv docs/_build/html/* public/docs 69 | - mv pylint.html public/pylint/index.html 70 | - mv htmlcov/* public/coverage 71 | artifacts: 72 | paths: 73 | - public 74 | only: 75 | - develop 76 | 77 | package: 78 | cache: 79 | paths: 80 | - .cache/pip 81 | - .venv/ 82 | key: "$CI_JOB_NAME-$CI_COMMIT_REF_SLUG" 83 | stage: build 84 | script: 85 | - | 86 | # Bump version number of develop branch 87 | if [ "$CI_COMMIT_BRANCH" = "develop" ]; then 88 | # Git config 89 | git config user.name "Gitlab CI" 90 | git config user.email "gitlab@example.org" 91 | chmod 0600 $GITLAB_DEPLOY_KEY 92 | 93 | # HTTPS clone URL -> git+ssh URL for pushing 94 | export GIT_REPO_URL_SSH=$(echo -n $CI_REPOSITORY_URL | sed -r 's%https?://.*@([^/]+)/%git@\1:%' -) 95 | git remote set-url origin $GIT_REPO_URL_SSH 96 | export GIT_SSH_COMMAND='ssh -i $GITLAB_DEPLOY_KEY -o IdentitiesOnly=yes -o StrictHostKeyChecking=no' 97 | 98 | pip install bump2version 99 | apt-get update && apt-get -o dir::cache::archives="$(pwd)/apt-cache" -yq install git-lfs 100 | 101 | bump2version build --commit 102 | git push -o ci.skip origin HEAD:develop 103 | fi 104 | - pip install setuptools wheel 105 | - python setup.py sdist bdist_wheel 106 | artifacts: 107 | paths: 108 | - dist/*.tar.gz 109 | - dist/*.whl 110 | 111 | publish_package: 112 | cache: {} 113 | only: 114 | - tags 115 | - develop 116 | stage: publish 117 | needs: [package] 118 | script: 119 | - pip install twine 120 | - export TWINE_REPOSITORY_URL=$PYPI_REPO_URL 121 | - export TWINE_USERNAME=$PYPI_REPO_USER 122 | - export TWINE_PASSWORD=$PYPI_REPO_PASS 123 | - twine upload dist/* 124 | 125 | update_tox_cache: 126 | needs: [] 127 | except: 128 | changes: 129 | - requirements.txt 130 | when: manual 131 | allow_failure: true 132 | cache: 133 | <<: *global_cache 134 | policy: push 135 | stage: update-tox-cache 136 | script: 137 | - pip install tox 138 | - tox -r 139 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black-jupyter 6 | language_version: python3 7 | 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.10.1 10 | hooks: 11 | - id: isort 12 | 13 | - repo: https://github.com/kynan/nbstripout 14 | rev: 0.6.1 15 | hooks: 16 | - id: nbstripout 17 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGE CONTROL] 2 | disable = 3 | I0011 # reasoning 4 | 5 | [MASTER] 6 | load-plugins=pylint_json2html 7 | 8 | [REPORTS] 9 | output-format=jsonextended 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | kyle - a python library for classifier calibration 2 | 3 | Copyright 2021-2021 by appliedAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kyle - a Calibration Toolkit 2 | 3 | ## Note: 4 | This library is currently in the alpha stage and breaking changes can happen at any time. Some 5 | central features are currently missing and will be added soon. 6 | 7 | ## Overview 8 | This library contains utils for measuring and visualizing calibration of probabilistic classifiers as well as for 9 | recalibrating them. Currently, only methods for recalibration through post-processing are supported, although we plan 10 | to include calibration specific training algorithms as well in the future. 11 | 12 | Kyle is model agnostic, any probabilistic classifier can be wrapped with a thin wrapper called `CalibratableModel` which 13 | supports multiple calibration algorithms. For a quick intro overview of the API have a look at the calibration demo 14 | notebook (the notebook with executed cells can be found in the docu). 15 | 16 | Apart from tools for analysing models, kyle also offers support for developing and testing custom calibration metrics 17 | and algorithms. In order not to have to rely on evaluation data sets and trained models for delivering labels and confidence 18 | vectors, with kyle custom samplers based on fake classifiers can be constructed. A note explaining the 19 | theory behind fake classifiers will be published soon. 20 | These samplers can 21 | also be fit on some data set in case you want to mimic it. Using the fake classifiers, an arbitrary number of ground 22 | truth labels and miscalibrated confidence vectors can be generated to help you analyse your algorithms (common use cases 23 | will be analysis of variance and bias of calibration metrics and benchmarking of recalibration algorithms). 24 | 25 | 26 | Currently, several algorithms in kyle use the [calibration framework library](https://github.com/fabiankueppers/calibration-framework) under the hood although this is subject 27 | to change. 28 | 29 | ## Installation 30 | Kyle can be installed from pypi, e.g. with 31 | ``` 32 | pip install kyle-calibration 33 | ``` -------------------------------------------------------------------------------- /badges/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /badges/coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 28% 19 | 28% 20 | 21 | 22 | -------------------------------------------------------------------------------- /build_scripts/release-version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | 5 | ## TTY colors and attributes 6 | #normal=$(tput sgr0) # normal text 7 | normal=$'\e[0m' # (works better sometimes) 8 | bold=$(tput bold) # make colors bold/bright 9 | red="$bold$(tput setaf 1)" # bright red text 10 | green=$(tput setaf 2) # dim green text 11 | fawn=$(tput setaf 3); beige="$fawn" # dark yellow text 12 | yellow="$bold$fawn" # bright yellow text 13 | darkblue=$(tput setaf 4) # dim blue text 14 | blue="$bold$darkblue" # bright blue text 15 | purple=$(tput setaf 5); magenta="$purple" # magenta text 16 | pink="$bold$purple" # bright magenta text 17 | darkcyan=$(tput setaf 6) # dim cyan text 18 | cyan="$bold$darkcyan" # bright cyan text 19 | gray=$(tput setaf 7) # dim white text 20 | darkgray="$bold"$(tput setaf 0) # bold black = dark gray text 21 | white="$bold$gray" # bright white text 22 | 23 | 24 | function fail() { 25 | echo "${red}$1${normal}" 26 | exit 1 27 | } 28 | 29 | function usage() { 30 | cat > /dev/stderr <" 55 | :param overwrite: whether to overwrite existing rst files. This should be used with caution as it will delete 56 | all manual changes to documentation files 57 | :return: 58 | """ 59 | library_basedir = basedir.split(os.path.sep, 1)[1] # splitting off the "src" part 60 | for file in os.listdir(basedir): 61 | if file.startswith("_"): 62 | continue 63 | 64 | library_file_path = os.path.join(library_basedir, file) 65 | full_path = os.path.join(basedir, file) 66 | file_name, ext = os.path.splitext(file) 67 | docs_file_path = os.path.join("docs", library_basedir, f"{file_name}.rst") 68 | if os.path.exists(docs_file_path) and not overwrite: 69 | log.debug(f"{docs_file_path} already exists, skipping it") 70 | if os.path.isdir(full_path): 71 | make_docu(basedir=full_path, overwrite=overwrite) 72 | continue 73 | os.makedirs(os.path.dirname(docs_file_path), exist_ok=True) 74 | 75 | if ext == ".py": 76 | log.info(f"writing module docu to {docs_file_path}") 77 | write_to_file(module_template(library_file_path), docs_file_path) 78 | elif os.path.isdir(full_path): 79 | log.info(f"writing package docu to {docs_file_path}") 80 | write_to_file(package_template(library_file_path), docs_file_path) 81 | make_docu(basedir=full_path, overwrite=overwrite) 82 | 83 | 84 | if __name__ == "__main__": 85 | logging.basicConfig(level=logging.INFO) 86 | make_docu() 87 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "sample_key": "sample_value" 3 | } 4 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging.handlers 3 | import os 4 | from typing import Dict, List, Union 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | __config_instance = None 9 | 10 | source_path = os.path.dirname(__file__) 11 | 12 | 13 | class __Configuration: 14 | """ 15 | Holds essential configuration entries 16 | """ 17 | 18 | log = log.getChild(__qualname__) 19 | 20 | def __init__(self, config_files: List[str] = None): 21 | """ 22 | :param config_files: list of JSON configuration files (relative to root) from which to read. 23 | If None, reads from './config.json' and './config_local.json' (latter files have precedence) 24 | """ 25 | if config_files is None: 26 | config_files = ["config.json", "config_local.json"] 27 | self.config = {} 28 | for filename in config_files: 29 | file_path = os.path.join(source_path, filename) 30 | if os.path.exists(file_path): 31 | self.log.info("Reading configuration from %s" % file_path) 32 | with open(file_path, "r") as f: 33 | self.config.update(json.load(f)) 34 | if not self.config: 35 | raise Exception( 36 | "No configuration entries could be read from %s" % config_files 37 | ) 38 | 39 | def _get_non_empty_entry( 40 | self, key: Union[str, List[str]] 41 | ) -> Union[float, str, List, Dict]: 42 | """ 43 | Retrieves an entry from the configuration 44 | 45 | :param key: key or list of keys to go through hierarchically 46 | :return: the queried json object 47 | """ 48 | if isinstance(key, str): 49 | key = [key] 50 | value = self.config 51 | for k in key: 52 | value = value.get(k) 53 | if value is None: 54 | raise Exception(f"Value for key '{key}' not set in configuration") 55 | return value 56 | 57 | def _get_path(self, key: Union[str, List[str]], create=False) -> str: 58 | """ 59 | Retrieves an existing local path from the configuration 60 | 61 | :param key: key or list of keys to go through hierarchically 62 | :param create: if True, a directory with the given path will be created on the fly. 63 | :return: the queried path 64 | """ 65 | path_string = self._get_non_empty_entry(key) 66 | path = os.path.abspath(path_string) 67 | if not os.path.exists(path): 68 | if isinstance(key, list): 69 | key = ".".join(key) # purely for logging 70 | if create: 71 | log.info( 72 | f"Configured directory {key}='{path}' not found; will create it" 73 | ) 74 | os.makedirs(path) 75 | else: 76 | raise FileNotFoundError( 77 | f"Configured directory {key}='{path}' does not exist." 78 | ) 79 | return path.replace("/", os.sep) 80 | 81 | @property 82 | def sample_key(self): 83 | return self._get_non_empty_entry("sample_key") 84 | 85 | 86 | def get_config(reload=False) -> __Configuration: 87 | """ 88 | :param reload: if True, the configuration will be reloaded from the json files 89 | :return: the configuration instance 90 | """ 91 | global __config_instance 92 | if __config_instance is None or reload: 93 | __config_instance = __Configuration() 94 | return __config_instance 95 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # kyle documentation build configuration file 4 | # 5 | # This file is execfile()d with the current directory set to its containing dir. 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import ast 11 | import logging 12 | import os 13 | import sys 14 | 15 | import pkg_resources 16 | 17 | log = logging.getLogger("docs") 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | sys.path.insert(0, os.path.abspath("../src")) 23 | print(sys.path) 24 | 25 | # -- General configuration ----------------------------------------------------- 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # needs_sphinx = '1.0' 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be extensions 31 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 32 | extensions = [ 33 | "sphinx.ext.napoleon", 34 | "sphinx.ext.autodoc", 35 | "sphinx.ext.doctest", 36 | "sphinx.ext.linkcode", 37 | "sphinx_rtd_theme", 38 | "nbsphinx", 39 | # see https://github.com/spatialaudio/nbsphinx/issues/24 for an explanation why this extension is necessary 40 | "IPython.sphinxext.ipython_console_highlighting", 41 | ] 42 | 43 | 44 | # adding links to source files (this works for gitlab and github like hosts and might need to be adjusted for others) 45 | # see https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html#module-sphinx.ext.linkcode 46 | def linkcode_resolve(domain, info): 47 | link_prefix = "https://gitlab.aai.lab/tl/calibration/kyle/blob/develop" 48 | if domain != "py": 49 | return None 50 | if not info["module"]: 51 | return None 52 | 53 | path, link_extension = get_path_and_link_extension(info["module"]) 54 | object_name = info["fullname"] 55 | if ( 56 | "." in object_name 57 | ): # don't add source link to methods within classes (you might want to change that) 58 | return None 59 | lineno = lineno_from_object_name(path, object_name) 60 | return f"{link_prefix}/{link_extension}#L{lineno}" 61 | 62 | 63 | def get_path_and_link_extension(module: str): 64 | """ 65 | :return: tuple of the form (path, link_extension) where 66 | the first entry is the local path to a given module or to __init__.py of the package 67 | and the second entry is the corresponding path from the top level directory 68 | """ 69 | filename = module.replace(".", "/") 70 | docs_dir = os.path.dirname(os.path.realpath(__file__)) 71 | source_path_prefix = os.path.join(docs_dir, f"../src/{filename}") 72 | 73 | if os.path.exists(source_path_prefix + ".py"): 74 | link_extension = f"src/{filename}.py" 75 | return source_path_prefix + ".py", link_extension 76 | elif os.path.exists(os.path.join(source_path_prefix, "__init__.py")): 77 | link_extension = f"src/{filename}/__init__.py" 78 | return os.path.join(source_path_prefix, "__init__.py"), link_extension 79 | else: 80 | raise Exception( 81 | f"{source_path_prefix} is neither a module nor a package with init - " 82 | f"did you forget to add an __init__.py?" 83 | ) 84 | 85 | 86 | def lineno_from_object_name(source_file, object_name): 87 | desired_node_name = object_name.split(".")[0] 88 | with open(source_file, "r") as f: 89 | source_node = ast.parse(f.read()) 90 | desired_node = next( 91 | ( 92 | node 93 | for node in source_node.body 94 | if getattr(node, "name", "") == desired_node_name 95 | ), 96 | None, 97 | ) 98 | if desired_node is None: 99 | log.warning(f"Could not find object {desired_node_name} in {source_file}") 100 | return 0 101 | else: 102 | return desired_node.lineno 103 | 104 | 105 | # this is useful for keeping the docs build environment small. Add heavy requirements here 106 | # and all other requirements to docs/requirements.txt 107 | autodoc_mock_imports = [ 108 | "netcal", 109 | "torch", 110 | "kornia", 111 | "torchvision", 112 | "pytorch-lightning", 113 | "matplotlib", 114 | ] 115 | 116 | autodoc_default_options = { 117 | "exclude-members": "log", 118 | "member-order": "bysource", 119 | "show-inheritance": True, 120 | } 121 | 122 | # Add any paths that contain templates here, relative to this directory. 123 | templates_path = ["_templates"] 124 | 125 | # The suffix of source filenames. 126 | source_suffix = ".rst" 127 | 128 | # The encoding of source files. 129 | # source_encoding = 'utf-8-sig' 130 | 131 | # The master toctree document. 132 | master_doc = "index" 133 | 134 | # General information about the project. 135 | package_name = "kyle-calibration" 136 | 137 | # The version info for the project you're documenting, acts as replacement for 138 | # |version| and |release|, also used in various other places throughout the 139 | # built documents. 140 | # 141 | # The full version, including alpha/beta/rc tags. 142 | version = pkg_resources.get_distribution(package_name).version 143 | release = version 144 | # The short X.Y version. 145 | major_v, minor_v = version.split(".")[:2] 146 | version = f"{major_v}.{minor_v}" 147 | 148 | # The language for content autogenerated by Sphinx. Refer to documentation 149 | # for a list of supported languages. 150 | # language = None 151 | 152 | # There are two options for replacing |today|: either, you set today to some 153 | # non-false value, then it is used: 154 | # today = '' 155 | # Else, today_fmt is used as the format for a strftime call. 156 | # today_fmt = '%B %d, %Y' 157 | 158 | # List of patterns, relative to source directory, that match files and 159 | # directories to ignore when looking for source files. 160 | exclude_patterns = ["_build"] 161 | 162 | # The reST default role (used for this markup: `text`) to use for all documents. 163 | # default_role = None 164 | 165 | # If true, '()' will be appended to :func: etc. cross-reference text. 166 | # add_function_parentheses = True 167 | 168 | # If true, the current module name will be prepended to all description 169 | # unit titles (such as .. function::). 170 | add_module_names = False 171 | 172 | # If true, sectionauthor and moduleauthor directives will be shown in the 173 | # output. They are ignored by default. 174 | # show_authors = False 175 | 176 | # The name of the Pygments (syntax highlighting) style to use. 177 | pygments_style = "sphinx" 178 | 179 | # A list of ignored prefixes for module index sorting. 180 | # modindex_common_prefix = [] 181 | 182 | 183 | # -- Options for HTML output --------------------------------------------------- 184 | 185 | # The theme to use for HTML and HTML Help pages. See the documentation for 186 | # a list of builtin themes. 187 | html_theme = "sphinx_rtd_theme" 188 | 189 | # Theme options are theme-specific and customize the look and feel of a theme 190 | # further. For a list of options available for each theme, see the 191 | # documentation. 192 | # html_theme_options = {} 193 | 194 | # Add any paths that contain custom themes here, relative to this directory. 195 | # html_theme_path = [] 196 | 197 | # The name for this set of Sphinx documents. If None, it defaults to 198 | # " v documentation". 199 | # html_title = None 200 | 201 | # A shorter title for the navigation bar. Default is the same as html_title. 202 | # html_short_title = None 203 | 204 | # The name of an image file (relative to this directory) to place at the top 205 | # of the sidebar. 206 | # html_logo = None 207 | 208 | # The name of an image file (within the static path) to use as favicon of the 209 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 210 | # pixels large. 211 | # html_favicon = None 212 | 213 | # Add any paths that contain custom static files (such as style sheets) here, 214 | # relative to this directory. They are copied after the builtin static files, 215 | # so a file named "default.css" will overwrite the builtin "default.css". 216 | html_static_path = [] 217 | 218 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 219 | # using the given strftime format. 220 | # html_last_updated_fmt = '%b %d, %Y' 221 | 222 | # If true, SmartyPants will be used to convert quotes and dashes to 223 | # typographically correct entities. 224 | # html_use_smartypants = True 225 | 226 | # Custom sidebar templates, maps document names to template names. 227 | # html_sidebars = {} 228 | 229 | # Additional templates that should be rendered to pages, maps page names to 230 | # template names. 231 | # html_additional_pages = {} 232 | 233 | # If false, no module index is generated. 234 | # html_domain_indices = True 235 | 236 | # If false, no index is generated. 237 | # html_use_index = True 238 | 239 | # If true, the index is split into individual pages for each letter. 240 | # html_split_index = False 241 | 242 | # If true, links to the reST sources are added to the pages. 243 | # html_show_sourcelink = True 244 | 245 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 246 | # html_show_sphinx = True 247 | 248 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 249 | # html_show_copyright = True 250 | 251 | # If true, an OpenSearch description file will be output, and all pages will 252 | # contain a tag referring to it. The value of this option must be the 253 | # base URL from which the finished HTML is served. 254 | # html_use_opensearch = '' 255 | 256 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 257 | # html_file_suffix = None 258 | 259 | # Output file base name for HTML help builder. 260 | htmlhelp_basename = "kyle_doc" 261 | 262 | 263 | # -- Options for LaTeX output -------------------------------------------------- 264 | 265 | latex_elements = { 266 | # The paper size ('letterpaper' or 'a4paper'). 267 | # 'papersize': 'letterpaper', 268 | # The font size ('10pt', '11pt' or '12pt'). 269 | # 'pointsize': '10pt', 270 | # Additional stuff for the LaTeX preamble. 271 | # 'preamble': '', 272 | } 273 | 274 | # Grouping the document tree into LaTeX files. List of tuples 275 | # (source start file, target name, title, author, documentclass [howto/manual]). 276 | # latex_documents = [] 277 | 278 | # The name of an image file (relative to this directory) to place at the top of 279 | # the title page. 280 | # latex_logo = None 281 | 282 | # For "manual" documents, if this is true, then toplevel headings are parts, 283 | # not chapters. 284 | # latex_use_parts = False 285 | 286 | # If true, show page references after internal links. 287 | # latex_show_pagerefs = False 288 | 289 | # If true, show URL addresses after external links. 290 | # latex_show_urls = False 291 | 292 | # Documents to append as an appendix to all manuals. 293 | # latex_appendices = [] 294 | 295 | # If false, no module index is generated. 296 | # latex_domain_indices = True 297 | 298 | 299 | # -- Options for manual page output -------------------------------------------- 300 | 301 | # One entry per manual page. List of tuples 302 | # (source start file, name, description, authors, manual section). 303 | man_pages = [("index", "kyle", "", ["appliedAI"], 1)] 304 | 305 | # If true, show URL addresses after external links. 306 | # man_show_urls = False 307 | 308 | 309 | # -- Options for Texinfo output ------------------------------------------------ 310 | 311 | # Grouping the document tree into Texinfo files. List of tuples 312 | # (source start file, target name, title, author, 313 | # dir menu entry, description, category) 314 | # texinfo_documents = [] 315 | 316 | # Documents to append as an appendix to all manuals. 317 | # texinfo_appendices = [] 318 | 319 | # If false, no module index is generated. 320 | # texinfo_domain_indices = True 321 | 322 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 323 | # texinfo_show_urls = 'footnote' 324 | -------------------------------------------------------------------------------- /docs/getting-started.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | This library works with python>=3.8. Install it by executing \n 5 | ``python setup.py install`` \n 6 | from the root directory. 7 | 8 | For developing the usage of tox is encouraged. Run ``tox`` from the root directory in order to build the package, 9 | these docs and perform several tests. You should not merge to master without tox having executed successfully! 10 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | kyle library and game 2 | ===================== 3 | 4 | .. toctree:: 5 | :caption: Guides and Tutorials 6 | :glob: 7 | 8 | * 9 | 10 | 11 | .. toctree:: 12 | :caption: Modules 13 | 14 | kyle/index 15 | 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/kyle/datasets.rst: -------------------------------------------------------------------------------- 1 | datasets 2 | ======== 3 | 4 | .. automodule:: kyle.datasets 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/evaluation.rst: -------------------------------------------------------------------------------- 1 | evaluation 2 | ========== 3 | 4 | .. automodule:: kyle.evaluation 5 | :members: 6 | :undoc-members: 7 | 8 | .. toctree:: 9 | :glob: 10 | 11 | evaluation/* 12 | -------------------------------------------------------------------------------- /docs/kyle/evaluation/continuous.rst: -------------------------------------------------------------------------------- 1 | continuous 2 | ========== 3 | 4 | .. automodule:: kyle.evaluation.continuous 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/evaluation/discrete.rst: -------------------------------------------------------------------------------- 1 | discrete 2 | ======== 3 | 4 | .. automodule:: kyle.evaluation.discrete 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/index.rst: -------------------------------------------------------------------------------- 1 | Library Modules 2 | =============== 3 | 4 | .. automodule:: kyle 5 | :members: 6 | :undoc-members: 7 | 8 | .. toctree:: 9 | :glob: 10 | 11 | * 12 | -------------------------------------------------------------------------------- /docs/kyle/integrals.rst: -------------------------------------------------------------------------------- 1 | integrals 2 | ========= 3 | 4 | .. automodule:: kyle.integrals 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/metrics.rst: -------------------------------------------------------------------------------- 1 | metrics 2 | ======= 3 | 4 | .. automodule:: kyle.metrics 5 | :members: 6 | :undoc-members: 7 | 8 | .. toctree:: 9 | :glob: 10 | 11 | metrics/* 12 | -------------------------------------------------------------------------------- /docs/kyle/metrics/calibration_metrics.rst: -------------------------------------------------------------------------------- 1 | calibration\_metrics 2 | ==================== 3 | 4 | .. automodule:: kyle.metrics.calibration_metrics 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/models.rst: -------------------------------------------------------------------------------- 1 | models 2 | ====== 3 | 4 | .. automodule:: kyle.models 5 | :members: 6 | :undoc-members: 7 | 8 | .. toctree:: 9 | :glob: 10 | 11 | models/* 12 | -------------------------------------------------------------------------------- /docs/kyle/models/calibratable_model.rst: -------------------------------------------------------------------------------- 1 | calibratable\_model 2 | =================== 3 | 4 | .. automodule:: kyle.models.calibratable_model 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/models/resnet.rst: -------------------------------------------------------------------------------- 1 | resnet 2 | ====== 3 | 4 | .. automodule:: kyle.models.resnet 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/sampling.rst: -------------------------------------------------------------------------------- 1 | sampling 2 | ======== 3 | 4 | .. automodule:: kyle.sampling 5 | :members: 6 | :undoc-members: 7 | 8 | .. toctree:: 9 | :glob: 10 | 11 | sampling/* 12 | -------------------------------------------------------------------------------- /docs/kyle/sampling/fake_clf.rst: -------------------------------------------------------------------------------- 1 | fake\_clf 2 | ========= 3 | 4 | .. automodule:: kyle.sampling.fake_clf 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/transformations.rst: -------------------------------------------------------------------------------- 1 | transformations 2 | =============== 3 | 4 | .. automodule:: kyle.transformations 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/kyle/util.rst: -------------------------------------------------------------------------------- 1 | util 2 | ==== 3 | 4 | .. automodule:: kyle.util 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/docs/requirements.txt -------------------------------------------------------------------------------- /notebooks/calibration_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "\n", 21 | "from sklearn import datasets\n", 22 | "from sklearn.model_selection import train_test_split\n", 23 | "from sklearn.neural_network import MLPClassifier\n", 24 | "from sklearn.metrics import accuracy_score\n", 25 | "\n", 26 | "from kyle.calibration import ModelCalibrator\n", 27 | "from kyle.models import CalibratableModel\n", 28 | "from kyle.metrics import ECE\n", 29 | "from kyle.calibration.calibration_methods import TemperatureScaling\n", 30 | "from kyle.sampling.fake_clf import DirichletFC\n", 31 | "from kyle.transformations import MaxComponentSimplexAut\n", 32 | "from kyle.evaluation import EvalStats" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "# What is calibration?" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "When we talk about how good a machine learning model is, what we (generally) mean to ask is: How accurate is the model?\n", 47 | "While this is a good enough metric in many cases, we are, in fact, leaving out important information about the model.\n", 48 | "One such piece of information is concerned with whether the confidence of the model is in line with its accuracy.\n", 49 | "If it is, we say the model is calibrated.\n", 50 | "\n", 51 | "To explain this concept in detail, let's begin with an example. Suppose we want to predict whether a patient has cancer.\n", 52 | "We can simulate data with two classes i.e. $y \\in \\{0, 1\\}$ where $y=0$ denotes a healthy patient and $y=1$ denotes a\n", 53 | "patient who has cancer." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "n_samples = 2000\n", 63 | "n_classes = 3" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "X, y = datasets.make_classification(\n", 73 | " n_samples=n_samples,\n", 74 | " n_features=20,\n", 75 | " n_informative=7,\n", 76 | " n_redundant=10,\n", 77 | " n_classes=n_classes,\n", 78 | " random_state=42,\n", 79 | ")\n", 80 | "X_train, X_test, y_train, y_test = train_test_split(\n", 81 | " X, y, test_size=0.5, random_state=42\n", 82 | ")" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "We can then train a neural network on our data:" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "model = MLPClassifier(hidden_layer_sizes=(20, 20, 10))\n", 99 | "model.fit(X_train, y_train)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "and make predictions on new samples. Let's see how our model performs on unseen examples:" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "y_pred = model.predict(X_test)\n", 116 | "model_accuracy = accuracy_score(y_test, y_pred)\n", 117 | "\n", 118 | "f\"Model accuracy: {model_accuracy*100}%\"" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "That seems pretty good! One might think our job here is done: After all, the model predicts whether a person has cancer\n", 126 | "or not with decent accuracy.\n", 127 | "Unfortunately, accuracy of a model does not tell us the full story. This is so due to the fact that at inference time, \n", 128 | "for a given sample a model outputs confidence scores for each class. We then take the class with the highest confidence\n", 129 | "and interpret that as the prediction of the model.\n", 130 | "\n", 131 | "This conversion of continuous (probability) to discrete (label) values can hide certain properties of the model.\n", 132 | "To illustrate this, let's take two models -- $A$ and $B$ -- trained on the same data. Let's further assume they have\n", 133 | "similar accuracy. Suppose we test both models with 10 healthy samples. $A$ assigns probabilities $(0.49, 0.51)$ to all\n", 134 | "samples, whereas $B$ assigns $(0.1, 0.9)$. While $A$ & $B$ will be wrong 100% of the time, notice $A$ being much closer\n", 135 | "to classifying the samples as belonging to the correct class compared to $B$." 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "Continuing with our previous example: Imagine that on all examples where the model was $95$% confident that the subject \n", 143 | "has cancer, it was correct $70$% of the time. Intuitively, it seems there's something not quite right with the model:\n", 144 | "the model is over-confident in its predictions. This notion is formalized by the concept of calibration.\n", 145 | "We say a model is (strongly) calibrated when, for any confidence value $p \\in [0, 1]$,\n", 146 | "prediction of a class with confidence $p$ is correct with probability $p$:\n", 147 | "\n", 148 | "\\begin{equation}\n", 149 | "P(\\widehat{y}=y|\\widehat{p}=p) = p \\quad \\forall p \\in [0, 1]\n", 150 | "\\end{equation}" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "So, is our model calibrated? As we can see in the equation above, $\\widehat{p}$ is continuous, which means we cannot\n", 158 | "compute the equation with finite data. We can, however, develop empirical measures that approximate the true measure\n", 159 | "of (mis)calibration.\n", 160 | "\n", 161 | "One simple way to get an empirical estimate of the model's accuracy and confidence is to discretize the probability\n", 162 | "space. This is done by slicing $p$ into $K$ equal-sized bins. We can then calculate the accuracy and confidence for each\n", 163 | "bin:\n", 164 | "\n", 165 | "\\begin{equation}\n", 166 | "accuracy_{B_k} = \\frac{1}{|B_k|} \\sum_{m=1}^{|B_k|}1(\\widehat{p}_m=p_m)\n", 167 | "\\end{equation}\n", 168 | "\n", 169 | "\\begin{equation}\n", 170 | "confidence_{B_k} = \\frac{1}{|B_k|} \\sum_{m=1}^{|B_k|}\\widehat{p}_m\n", 171 | "\\end{equation}" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "We can now simply calculate the weighted average difference between the accuracy and confidence of the model over all bins:\n", 179 | "\n", 180 | "\\begin{equation}\n", 181 | "\\sum_{k=1}^{K} \\frac{|B_k|}{n} \\Big|\\:accuracy_{B_k} - confidence_{B_k} \\Big|\n", 182 | "\\end{equation}\n", 183 | "\n", 184 | "This is known as the **Expected Calibration Error** $(ECE).$ As can be seen, $ECE=0$ if a model is perfectly calibrated.\n", 185 | "Let's calculate the $ECE$ for our model with $10$ bins:" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "ece = ECE(bins=12)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# Evaluate uncalibrated predictions\n", 204 | "y_pred = model.predict_proba(X_test)\n", 205 | "\n", 206 | "pre_calibration_ece = ece.compute(y_pred, y_test)\n", 207 | "\n", 208 | "f\"ECE before calibration: {pre_calibration_ece}\"" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "We can also visualize the extent of miscalibration by plotting the model's confidence *(x-axis)* vs. the ground truth\n", 216 | "probability *(y-axis)*. For a perfectly calibrated model, the plot should be $y=x$. Let's see how our model fares:" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "eval_stats = EvalStats(y_test, y_pred)\n", 226 | "class_labels = range(n_classes)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "fig = eval_stats.plot_reliability_curves(\n", 236 | " [\"top_class\", 0], display_weights=True, strategy=\"uniform\", n_bins=8\n", 237 | ")" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "The density of predictions is distributed highly inhomogeneously on the unit interval, some bins have\n", 245 | "few members and the estimate of the reliability has high variance. This can be helped by employing\n", 246 | "the \"quantile\" binning strategy, also called adaptive binning" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "fig = eval_stats.plot_reliability_curves(\n", 256 | " [0, \"top_class\"], display_weights=True, n_bins=8, strategy=\"quantile\"\n", 257 | ")" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "Now all bins have the same weight but different width. The pointwise reliability estimates\n", 265 | "have lower variance but there are wide gaps, thus requiring more interpolation.\n", 266 | "Both binning strategies have their advantages and disadvantages." 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "Okay, so our model is not calibrated as $ECE>0$. Can we do anything to remedy the situation?" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "# Model calibration" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "Indeed, we can improve the calibration of our model using various techniques. What's more, we don't need to train our\n", 288 | "model again; many calibration techniques are post-processing methods i.e. operating on the trained model's output\n", 289 | "confidence scores. The output scores for calibration are typically obtained on a validation set.\n", 290 | "\n", 291 | "In `kyle`, we have provided a `CalibratableModel` class which takes a model and, as the name suggests, makes it possible\n", 292 | "to calibrate that model. By default, we use a technique called [*Temperature scaling*](https://arxiv.org/abs/1706.04599)\n", 293 | "for calibration." 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "# Create calibratable model\n", 303 | "calibration_method = TemperatureScaling()\n", 304 | "calibratable_model = CalibratableModel(model, calibration_method)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "We also provide a `ModelCalibrator` class which holds the data to calibrate models:" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# Create model calibrator and calibrate model\n", 321 | "calibrator = ModelCalibrator(\n", 322 | " X_calibrate=X_test, y_calibrate=y_test, X_fit=X_train, y_fit=y_train\n", 323 | ")" 324 | ] 325 | }, 326 | { 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "We now have everything ready to calibrate our model:" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "calibrator.calibrate(calibratable_model)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "Let's see if calibrating the model improved the $ECE$ score" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# Passing X_test instead of X_calibrate in predict_proba() to make comparison with pre-calib model clear,\n", 356 | "# same reasong for y_test in ece.compute()\n", 357 | "calibrated_confidences = calibratable_model.predict_proba(X_test)\n", 358 | "\n", 359 | "post_calibration_ece = ece.compute(calibrated_confidences, y_test)\n", 360 | "\n", 361 | "f\"ECE before calibration: {pre_calibration_ece}, ECE after calibration: {post_calibration_ece}\"" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "metadata": {}, 367 | "source": [ 368 | "Great! $ECE$ has improved. Let's also plot a reliability curve to visually confirm the improvement in calibration." 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "eval_stats = EvalStats(y_test, calibrated_confidences)\n", 378 | "\n", 379 | "eval_stats.plot_reliability_curves(class_labels)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "Wonderful! We have successfully improved our model's calibration." 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "# Model-agnostic calibration" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "You may have noticed that to evaluate (mis)calibration of a model, we don't require the model itself.\n", 401 | "Rather, it is sufficient to have the confidence scores predicted by the model.\n", 402 | "This means we can abstract away the model and generate both the ground truth and confidence scores via sampling processes.\n", 403 | "\n", 404 | "In `kyle` we have provided samplers that simulate different kinds of calibration properties.\n", 405 | "One such sampler is the `DirichletFC` class which provides calibrated ground truth and confidences by default." 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "sampler = DirichletFC(num_classes=2)\n", 415 | "\n", 416 | "# Get 1000 calibrated fake confidence scores\n", 417 | "calibrated_samples = sampler.get_sample_arrays(1000)\n", 418 | "ground_truth, confidences = calibrated_samples" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "Let's evaluate the $ECE$ for these samples:" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "ece.compute(confidences, ground_truth)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "Wait, the $ECE>0$, how can we say that the samples are calibrated?\n", 442 | "\n", 443 | "As mentioned earlier, we only have finite samples so true miscalibration can only be measured asymptotically.\n", 444 | "This means that the more samples we have, the more accurate would $ECE$'s estimate become.\n", 445 | "We can test this by generating *5x* as many samples as before and evaluating $ECE$ again:" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "calibrated_samples = sampler.get_sample_arrays(5000)\n", 455 | "ground_truth, confidences = calibrated_samples\n", 456 | "\n", 457 | "ece.compute(confidences, ground_truth)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "As expected, $ECE$ goes down with more samples.\n", 465 | "\n", 466 | "We can also systematically generate uncalibrated samples. For instance, the `ShiftingSimplexAutomorphism` shifts the\n", 467 | "confidence scores by adding a fixed vector with positive entries to the input and normalizing the result." 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "def overestimating_max(x: np.ndarray):\n", 477 | " x = x.copy()\n", 478 | " mask = x > 1 / 2\n", 479 | " x[mask] = x[mask] - (1 / 4 - (1 - x[mask]) ** 2)\n", 480 | " return x\n", 481 | "\n", 482 | "\n", 483 | "automorphism = MaxComponentSimplexAut(overestimating_max)\n", 484 | "shifted_sampler = DirichletFC(num_classes=2, simplex_automorphism=automorphism)\n", 485 | "\n", 486 | "# Get 1000 uncalibrated fake confidence scores\n", 487 | "uncalibrated_samples = shifted_sampler.get_sample_arrays(10000)\n", 488 | "ground_truth, confidences = uncalibrated_samples" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "Let's see if the uncalibrated nature of the samples is validated by $ECE$:" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "ece.compute(confidences, ground_truth)" 505 | ] 506 | } 507 | ], 508 | "metadata": { 509 | "kernelspec": { 510 | "display_name": "Python 3 (ipykernel)", 511 | "language": "python", 512 | "name": "python3" 513 | }, 514 | "language_info": { 515 | "codemirror_mode": { 516 | "name": "ipython", 517 | "version": 3 518 | }, 519 | "file_extension": ".py", 520 | "mimetype": "text/x-python", 521 | "name": "python", 522 | "nbconvert_exporter": "python", 523 | "pygments_lexer": "ipython3", 524 | "version": "3.8.13" 525 | } 526 | }, 527 | "nbformat": 4, 528 | "nbformat_minor": 1 529 | } 530 | -------------------------------------------------------------------------------- /notebooks/evaluating_cal_methods.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Note - this cell should be executed only once per session\n", 10 | "import sys, os\n", 11 | "\n", 12 | "# in order to get top level modules and to have paths relative to repo root\n", 13 | "\n", 14 | "if os.path.basename(os.getcwd()) != \"notebooks\":\n", 15 | " raise Exception(f\"Wrong directory. Did you execute this cell twice?\")\n", 16 | "os.chdir(\"..\")\n", 17 | "sys.path.append(os.path.abspath(\".\"))\n", 18 | "\n", 19 | "%load_ext autoreload\n", 20 | "%autoreload 2" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Class-wise and Reduced Calibration Methods\n", 28 | "\n", 29 | "In this notebook we demonstrate two new strategies for calibrating probabilistic classifiers. These strategies act\n", 30 | "as wrappers around any calibration algorithm and therefore are implemented as wrappers. We test the improvements\n", 31 | "in different calibration errors due to these wrappers where the non-wrapped calibration methods serve as baselines.\n", 32 | "\n", 33 | "The tests are performed on random forests trained on two synthetic data sets (balanced and imbalanced) as well as\n", 34 | "on resnet20 trained on the CIFAR10 data set." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from collections import defaultdict\n", 44 | "\n", 45 | "from sklearn.datasets import make_classification\n", 46 | "from sklearn.metrics import accuracy_score\n", 47 | "from sklearn.ensemble import RandomForestClassifier\n", 48 | "\n", 49 | "import os\n", 50 | "import requests\n", 51 | "import logging\n", 52 | "\n", 53 | "from kyle.calibration.calibration_methods import *\n", 54 | "from kyle.evaluation import EvalStats\n", 55 | "\n", 56 | "from scipy.special import softmax\n", 57 | "\n", 58 | "from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score\n", 59 | "\n", 60 | "import numpy as np\n", 61 | "import matplotlib.pyplot as plt" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Helper functions for evaluation" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "DEFAULT_WRAPPERS = {\n", 78 | " \"Baseline\": lambda method_factory: method_factory(),\n", 79 | " \"Class-wise\": lambda method_factory: ClassWiseCalibration(method_factory),\n", 80 | " \"Reduced\": lambda method_factory: ConfidenceReducedCalibration(method_factory()),\n", 81 | " \"Class-wise reduced\": lambda method_factory: ClassWiseCalibration(\n", 82 | " lambda: ConfidenceReducedCalibration(method_factory())\n", 83 | " ),\n", 84 | "}\n", 85 | "\n", 86 | "DEFAULT_CV = 6\n", 87 | "DEFAULT_BINS = 25\n", 88 | "\n", 89 | "ALL_CALIBRATION_METHOD_FACTORIES = (\n", 90 | " # TemperatureScaling,\n", 91 | " BetaCalibration,\n", 92 | " # LogisticCalibration,\n", 93 | " IsotonicRegression,\n", 94 | " HistogramBinning,\n", 95 | ")\n", 96 | "ALL_METRICS = (\n", 97 | " \"ECE\",\n", 98 | " \"cwECE\",\n", 99 | ")\n", 100 | "\n", 101 | "\n", 102 | "def compute_score(scaler, confs: np.ndarray, labels: np.ndarray, bins, metric=\"ECE\"):\n", 103 | " calibrated_confs = scaler.get_calibrated_confidences(confs)\n", 104 | " eval_stats = EvalStats(labels, calibrated_confs)\n", 105 | " if metric == \"ECE\":\n", 106 | " return eval_stats.expected_calibration_error(n_bins=bins)\n", 107 | " elif metric == \"cwECE\":\n", 108 | " return eval_stats.class_wise_expected_calibration_error(n_bins=bins)\n", 109 | " elif isinstance(metric, int):\n", 110 | " return eval_stats.expected_calibration_error(class_label=metric, n_bins=bins)\n", 111 | " else:\n", 112 | " raise ValueError(f\"Unknown metric {metric}\")\n", 113 | "\n", 114 | "\n", 115 | "def get_scores(scaler, metric, cv, bins, confs, labels):\n", 116 | " scoring = lambda *args: compute_score(*args, bins=bins, metric=metric)\n", 117 | " return cross_val_score(scaler, confs, labels, scoring=scoring, cv=cv)\n", 118 | "\n", 119 | "\n", 120 | "def plot_scores(wrapper_scores_dict: dict, title=\"\", ax=None, y_lim=None):\n", 121 | " labels = wrapper_scores_dict.keys()\n", 122 | " scores_collection = wrapper_scores_dict.values()\n", 123 | "\n", 124 | " if ax is None:\n", 125 | " plt.figure(figsize=(14, 7))\n", 126 | " ax = plt.gca()\n", 127 | " ax.set_title(title)\n", 128 | " ax.boxplot(scores_collection, labels=labels)\n", 129 | " if y_lim is not None:\n", 130 | " ax.set_ylim(y_lim)\n", 131 | "\n", 132 | "\n", 133 | "def evaluate_calibration_wrappers(\n", 134 | " method_factory,\n", 135 | " confidences,\n", 136 | " gt_labels,\n", 137 | " wrappers_dict=None,\n", 138 | " metric=\"ECE\",\n", 139 | " cv=DEFAULT_CV,\n", 140 | " method_name=None,\n", 141 | " bins=DEFAULT_BINS,\n", 142 | " short_description=False,\n", 143 | "):\n", 144 | " if method_name is None:\n", 145 | " method_name = method_factory.__name__\n", 146 | " if short_description:\n", 147 | " description = f\"{method_name}\"\n", 148 | " else:\n", 149 | " description = (\n", 150 | " f\"Evaluating wrappers of {method_name} on metric {metric} with {bins} bins\\n \"\n", 151 | " f\"CV with {cv} folds on {len(confidences)} data points.\"\n", 152 | " )\n", 153 | " if wrappers_dict is None:\n", 154 | " wrappers_dict = DEFAULT_WRAPPERS\n", 155 | "\n", 156 | " wrapper_scores_dict = {}\n", 157 | " for wrapper_name, wrapper in wrappers_dict.items():\n", 158 | " method = wrapper(method_factory)\n", 159 | " scores = get_scores(\n", 160 | " method, metric, cv=cv, bins=bins, confs=confidences, labels=gt_labels\n", 161 | " )\n", 162 | " wrapper_scores_dict[wrapper_name] = scores\n", 163 | " return wrapper_scores_dict, description\n", 164 | "\n", 165 | "\n", 166 | "# taken such that minimum and maximum are visible in all plots\n", 167 | "DEFAULT_Y_LIMS_DICT = {\n", 168 | " \"ECE\": (0.004, 0.032),\n", 169 | " \"cwECE\": (0.005, 0.018),\n", 170 | "}\n", 171 | "\n", 172 | "\n", 173 | "def perform_default_evaluation(\n", 174 | " confidences,\n", 175 | " gt_labels,\n", 176 | " method_factories=ALL_CALIBRATION_METHOD_FACTORIES,\n", 177 | " metrics=ALL_METRICS,\n", 178 | "):\n", 179 | " evaluation_results = defaultdict(list)\n", 180 | " for metric in metrics:\n", 181 | " print(f\"Creating evaluation for {metric}\")\n", 182 | " for method_factory in method_factories:\n", 183 | " print(f\"Computing scores for {method_factory.__name__}\", end=\"\\r\")\n", 184 | " result = evaluate_calibration_wrappers(\n", 185 | " method_factory,\n", 186 | " confidences=confidences,\n", 187 | " gt_labels=gt_labels,\n", 188 | " metric=metric,\n", 189 | " short_description=True,\n", 190 | " )\n", 191 | " evaluation_results[metric].append(result)\n", 192 | " return evaluation_results\n", 193 | "\n", 194 | "\n", 195 | "def plot_default_evaluation_results(\n", 196 | " evaluation_results: dict, figsize=(25, 7), y_lims_dict=None, title_addon=None\n", 197 | "):\n", 198 | " if y_lims_dict is None:\n", 199 | " y_lims_dict = DEFAULT_Y_LIMS_DICT\n", 200 | " ncols = len(list(evaluation_results.values())[0])\n", 201 | " for metric, results in evaluation_results.items():\n", 202 | " fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=figsize)\n", 203 | " y_lim = y_lims_dict[metric]\n", 204 | " if ncols == 1: # axes fails to be a list if ncols=1\n", 205 | " axes = [axes]\n", 206 | " for col, result in zip(axes, results):\n", 207 | " wrapper_scores_dict, description = result\n", 208 | " plot_scores(wrapper_scores_dict, title=description, ax=col, y_lim=y_lim)\n", 209 | "\n", 210 | " title = f\"Evaluation with {metric} ({DEFAULT_CV} folds; {DEFAULT_BINS} bins)\"\n", 211 | " if title_addon is not None:\n", 212 | " title += f\"\\n{title_addon}\"\n", 213 | " fig.suptitle(title)\n", 214 | " plt.show()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "## Part 1: Random Forest\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "## Load Data" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "def get_calibration_dataset(\n", 238 | " n_classes=5,\n", 239 | " weights=None,\n", 240 | " n_samples=30000,\n", 241 | " n_informative=15,\n", 242 | " model=RandomForestClassifier(),\n", 243 | "):\n", 244 | " n_dataset_samples = 2 * n_samples\n", 245 | " test_size = 0.5\n", 246 | " X, y = make_classification(\n", 247 | " n_samples=n_dataset_samples,\n", 248 | " n_classes=n_classes,\n", 249 | " n_informative=n_informative,\n", 250 | " weights=weights,\n", 251 | " )\n", 252 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)\n", 253 | "\n", 254 | " train_index, test_index = list(sss.split(X, y))[0]\n", 255 | " X_train, y_train = X[train_index], y[train_index]\n", 256 | " X_test, y_test = X[test_index], y[test_index]\n", 257 | " model.fit(X_train, y_train)\n", 258 | " confidences = model.predict_proba(X_test)\n", 259 | " y_pred = confidences.argmax(1)\n", 260 | " accuracy = accuracy_score(y_pred, y_test)\n", 261 | " print(f\"Model accuracy: {accuracy}\")\n", 262 | " return confidences, y_test" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "# this takes a while\n", 272 | "print(f\"Creating balanced dataset\")\n", 273 | "balanced_confs, balanced_gt = get_calibration_dataset()\n", 274 | "print(f\"Creating unbalanced dataset\")\n", 275 | "unbalanced_confs, unbalanced_gt = get_calibration_dataset(\n", 276 | " weights=(0.3, 0.1, 0.25, 0.15)\n", 277 | ")" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "## Evaluating wrappers on a single calibration method" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "balanced_scores_ECE, description = evaluate_calibration_wrappers(\n", 294 | " HistogramBinning,\n", 295 | " confidences=balanced_confs,\n", 296 | " gt_labels=balanced_gt,\n", 297 | " metric=\"ECE\",\n", 298 | " cv=4,\n", 299 | ")\n", 300 | "\n", 301 | "plot_scores(balanced_scores_ECE, title=description)\n", 302 | "plt.show()" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "unbalanced_scores_ECE, description = evaluate_calibration_wrappers(\n", 312 | " TemperatureScaling,\n", 313 | " confidences=unbalanced_confs,\n", 314 | " gt_labels=unbalanced_gt,\n", 315 | " metric=\"ECE\",\n", 316 | ")\n", 317 | "\n", 318 | "plot_scores(unbalanced_scores_ECE, title=description)\n", 319 | "plt.show()" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "## Evaluating wrappers on multiple metrics and plotting next to each other" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "eval_results = perform_default_evaluation(\n", 336 | " confidences=balanced_confs, gt_labels=balanced_gt\n", 337 | ")" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "plot_default_evaluation_results(eval_results, title_addon=\"Balanced\")" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "unbalanced_eval_results = perform_default_evaluation(\n", 356 | " confidences=unbalanced_confs, gt_labels=unbalanced_gt\n", 357 | ")" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "plot_default_evaluation_results(unbalanced_eval_results, title_addon=\"Unbalanced\")" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "# Part 2: Resnet\n", 374 | "\n", 375 | "Here we will repeat the evaluation of calibration methods on a neural network, specifically\n", 376 | "on resnet20 trained on the CIFAR10 data set.\n", 377 | "\n", 378 | "Important: in order to run the resnet part you will need the packages from `requirements-torch.txt`" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "from kyle.models.resnet import load_weights, resnet20, resnet56\n", 388 | "from kyle.datasets import get_cifar10_dataset" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "selected_resnet = \"resnet20\"\n", 398 | "\n", 399 | "weights_file_names = {\n", 400 | " \"resnet20\": \"resnet20-12fca82f.th\",\n", 401 | " \"resnet56\": \"resnet56-4bfd9763.th\",\n", 402 | "}\n", 403 | "\n", 404 | "models_dict = {\n", 405 | " \"resnet20\": resnet20(),\n", 406 | " \"resnet56\": resnet56(),\n", 407 | "}\n", 408 | "\n", 409 | "\n", 410 | "resnet_path = os.path.join(\"data\", \"artifacts\", weights_file_names[selected_resnet])\n", 411 | "cifar_10_data_path = os.path.join(\"data\", \"raw\", \"cifar10\")\n", 412 | "logits_save_path = os.path.join(\n", 413 | " \"data\", \"processed\", \"cifar10\", f\"logits_{selected_resnet}.npy\"\n", 414 | ")\n", 415 | "\n", 416 | "if not os.path.isfile(resnet_path):\n", 417 | " print(\n", 418 | " f\"Downloading weights for {selected_resnet} to {os.path.abspath(resnet_path)}\"\n", 419 | " )\n", 420 | " os.makedirs(os.path.dirname(resnet_path), exist_ok=True)\n", 421 | " url = f\"https://github.com/akamaster/pytorch_resnet_cifar10/raw/master/pretrained_models/{weights_file_names[selected_resnet]}\"\n", 422 | " r = requests.get(url)\n", 423 | " with open(resnet_path, \"wb\") as file:\n", 424 | " file.write(r.content)\n", 425 | "\n", 426 | "resnet = models_dict[selected_resnet]\n", 427 | "load_weights(resnet_path, resnet)\n", 428 | "resnet.eval()\n", 429 | "\n", 430 | "\n", 431 | "def get_cifar10_confidences():\n", 432 | " cifar_10_X, cifar_10_Y = get_cifar10_dataset(cifar_10_data_path)\n", 433 | "\n", 434 | " if os.path.isfile(logits_save_path):\n", 435 | " logits = np.load(logits_save_path)\n", 436 | " else:\n", 437 | " # processing all at once may not fit into ram\n", 438 | " batch_boundaries = range(0, len(cifar_10_X) + 1, 1000)\n", 439 | "\n", 440 | " logits = []\n", 441 | " for i in range(len(batch_boundaries) - 1):\n", 442 | " print(f\"Processing batch {i+1}/{len(batch_boundaries)-1}\", end=\"\\r\")\n", 443 | " lower, upper = batch_boundaries[i], batch_boundaries[i + 1]\n", 444 | " logits.append(resnet(cifar_10_X[lower:upper]).detach().numpy())\n", 445 | "\n", 446 | " logits = np.vstack(logits)\n", 447 | " os.makedirs(os.path.dirname(logits_save_path), exist_ok=True)\n", 448 | " np.save(logits_save_path, logits, allow_pickle=False)\n", 449 | "\n", 450 | " confidences = softmax(logits, axis=1)\n", 451 | " gt_labels = cifar_10_Y.numpy()\n", 452 | " return confidences, gt_labels" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "cifar_confs, cifar_gt = get_cifar10_confidences()\n", 462 | "\n", 463 | "## Evaluating wrappers on a single calibration method" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [ 472 | "resnet_scores_ECE, description = evaluate_calibration_wrappers(\n", 473 | " HistogramBinning, confidences=cifar_confs, gt_labels=cifar_gt, metric=\"ECE\", cv=4\n", 474 | ")\n", 475 | "\n", 476 | "plot_scores(resnet_scores_ECE, title=description)\n", 477 | "plt.show()" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "resnet_scores_ECE, description = evaluate_calibration_wrappers(\n", 487 | " TemperatureScaling, confidences=cifar_confs, gt_labels=cifar_gt, metric=\"ECE\", cv=4\n", 488 | ")\n", 489 | "\n", 490 | "plot_scores(resnet_scores_ECE, title=description)\n", 491 | "plt.show()" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "## Evaluating wrappers on multiple metrics and plotting next to each other" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "eval_results = perform_default_evaluation(\n", 508 | " confidences=balanced_confs, gt_labels=balanced_gt\n", 509 | ")" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "plot_default_evaluation_results(\n", 519 | " eval_results, title_addon=f\"{selected_resnet} on CIFAR10\"\n", 520 | ")" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [] 529 | } 530 | ], 531 | "metadata": { 532 | "kernelspec": { 533 | "display_name": "Python 3 (ipykernel)", 534 | "language": "python", 535 | "name": "python3" 536 | }, 537 | "language_info": { 538 | "codemirror_mode": { 539 | "name": "ipython", 540 | "version": 3 541 | }, 542 | "file_extension": ".py", 543 | "mimetype": "text/x-python", 544 | "name": "python", 545 | "nbconvert_exporter": "python", 546 | "pygments_lexer": "ipython3", 547 | "version": "3.8.13" 548 | } 549 | }, 550 | "nbformat": 4, 551 | "nbformat_minor": 1 552 | } 553 | -------------------------------------------------------------------------------- /notebooks/fake_classifiers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import logging\n", 22 | "\n", 23 | "from kyle.evaluation import EvalStats\n", 24 | "from kyle.sampling.fake_clf import DirichletFC\n", 25 | "from kyle.transformations import *\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "logging.basicConfig(level=logging.INFO)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "n_samples = 100000" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Dirichlet fake classifiers\n", 45 | "\n", 46 | "Add explanation about the model and integrals\n", 47 | "\n", 48 | "## Computing properties with integrals\n", 49 | "\n", 50 | "The asymptotic values for ECE and accuracy can be computed through (numerical or analytical)\n", 51 | "integration." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "n_classes = 3\n", 61 | "alpha = [0.2, 0.3, 0.4]\n", 62 | "\n", 63 | "dirichlet_fc = DirichletFC(n_classes, alpha=alpha)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "print(\n", 73 | " \"mostly overestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", 74 | ")\n", 75 | "transform = PowerLawSimplexAut(np.array([2, 2, 2]))\n", 76 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 77 | "\n", 78 | "\n", 79 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 80 | "\n", 81 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 82 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", 83 | "ece_approx = eval_stats.expected_confidence() - eval_stats.accuracy()\n", 84 | "print(f\"{ece_approx=}\")\n", 85 | "eval_stats.plot_reliability_curves(\n", 86 | " [0, 1, \"top_class\"], display_weights=True, n_bins=200\n", 87 | ")\n", 88 | "plt.show()\n", 89 | "\n", 90 | "\n", 91 | "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", 92 | "# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", 93 | "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "print(\n", 103 | " \"mostly underestimating all classes (starting at 1/n_classes) with PowerLawSimplexAut\"\n", 104 | ")\n", 105 | "print(\"Note the variance and the resulting sensitivity to binning\")\n", 106 | "\n", 107 | "transform = PowerLawSimplexAut(np.array([0.3, 0.1, 0.2]))\n", 108 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 109 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 110 | "\n", 111 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 112 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", 113 | "ece_approx = -eval_stats.expected_confidence() + eval_stats.accuracy()\n", 114 | "print(f\"{ece_approx=}\")\n", 115 | "eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n", 116 | "plt.show()\n", 117 | "\n", 118 | "\n", 119 | "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", 120 | "# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", 121 | "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "print(\"Underestimating predictions with MaxComponent\")\n", 131 | "\n", 132 | "\n", 133 | "def overestimating_max(x: np.ndarray):\n", 134 | " x = x.copy()\n", 135 | " mask = x > 1 / 2\n", 136 | " x[mask] = x[mask] - (1 / 4 - (1 - x[mask]) ** 2)\n", 137 | " return x\n", 138 | "\n", 139 | "\n", 140 | "transform = MaxComponentSimplexAut(overestimating_max)\n", 141 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 142 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 143 | "\n", 144 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 145 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", 146 | "eval_stats.plot_reliability_curves([0, 1, \"top_class\"], display_weights=True)\n", 147 | "plt.show()\n", 148 | "\n", 149 | "# Integrals converge pretty slowly, this takes time\n", 150 | "# theoretical_acc = compute_accuracy(dirichlet_fc, opts={\"limit\": 75})[0]\n", 151 | "# theoretical_ece = compute_ECE(dirichlet_fc, opts={\"limit\": 75})[0]\n", 152 | "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "# Analytical results\n", 160 | "\n", 161 | "For top-class overconfident classifiers we have\n", 162 | "\n", 163 | "$ECE_i = \\int_{A_i} \\ (c_i - h_i(\\vec c)) \\cdot p(\\vec c)$\n", 164 | "\n", 165 | "$acc_i = \\int_{A_i} \\ h_i(\\vec c) \\cdot p(\\vec c)$\n", 166 | "\n", 167 | "In many relevant regimes, the DirichletFC can be approximately regarded as sufficiently confident.\n", 168 | "This means we can approximate ECE and accuracy as:\n", 169 | "\n", 170 | "$ECE_i \\ \\lessapprox \\ \\int_{\\tilde A_i} \\ (c_i - h_i(\\vec c)) \\cdot p(\\vec c)$\n", 171 | "\n", 172 | "$acc_i \\ \\lessapprox \\ \\int_{\\tilde A_i} \\ h_i(\\vec c) \\cdot p(\\vec c)$\n", 173 | "\n", 174 | "We can explicitly calculate the first part of the ECE:\n", 175 | "\n", 176 | "$ \\int_{\\tilde A_i} \\ c_i \\cdot p(\\vec c) = \\frac{\\alpha_i}{\\alpha_0}\n", 177 | "\\left(1 - (\\alpha_0-\\alpha_i) \\ \\beta(1/2;\\ \\alpha_i + 1, \\alpha_0-\\alpha_i) \\ \\binom{\\alpha_0}{\\alpha_i} \\right)$\n", 178 | "\n", 179 | "As expected, when $\\alpha_i \\rightarrow \\alpha_0$, this expression goes to one" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "The second part depends on the simplex automorphism $h$.\n", 187 | "We can sort of compute it for the RestrictedPowerAut and for some MaxComponentSimplexAut.\n", 188 | "However, both transforms seem to be rather on the pathological side of things..." 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "print(\"mostly overestimating first two classes with RestrictedPowerSimplexAut\")\n", 198 | "\n", 199 | "transform = RestrictedPowerSimplexAut(np.array([2, 4]))\n", 200 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 201 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 202 | "\n", 203 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 204 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", 205 | "print(\"Theoretical approximation of ECE\")\n", 206 | "print(eval_stats.expected_confidence() - eval_stats.accuracy())\n", 207 | "eval_stats.plot_reliability_curves([0, 1, 2, \"top_class\"], display_weights=True)\n", 208 | "plt.show()\n", 209 | "\n", 210 | "\n", 211 | "# theoretical_acc = compute_accuracy(dirichlet_fc)[0]\n", 212 | "# theoretical_ece = compute_ECE(dirichlet_fc)[0]\n", 213 | "# print(f\"{theoretical_acc=} , {theoretical_ece=}\")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## The Calibration Game\n", 221 | "\n", 222 | "Below are potential 5-classes classifiers that we will use in the calibration game.\n", 223 | "They all have roughly the same accuracy but very differing ECEs, corresponding to\n", 224 | "different difficulty settings for the game." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "n_classes = 5\n", 234 | "n_samples = 500000" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "print(\"hardest setting: accuracy 80, ECE 18\")\n", 244 | "\n", 245 | "exponents = np.array([0.05, 0.4, 0.1, 0.2, 0.1]) * 2 / 3\n", 246 | "alpha = np.ones(5) * 1 / 150\n", 247 | "\n", 248 | "# exponents = np.ones(5) * 1/5\n", 249 | "# alpha = np.ones(5) * 1/45\n", 250 | "\n", 251 | "dirichlet_fc = DirichletFC(n_classes, alpha=alpha)\n", 252 | "transform = PowerLawSimplexAut(exponents)\n", 253 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 254 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 255 | "\n", 256 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 257 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", 258 | "eval_stats.plot_reliability_curves([0, \"top_class\"], display_weights=True)\n", 259 | "plt.show()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "print(\"medium setting: accuracy 80, ECE 10\")\n", 269 | "\n", 270 | "exponents = np.array([0.5, 1, 1, 1, 0.5]) * 1 / 1.8\n", 271 | "alpha = np.array([0.5, 2, 3, 4, 5]) * 1 / 65\n", 272 | "\n", 273 | "n_samples = 300000\n", 274 | "n_classes = 5\n", 275 | "\n", 276 | "\n", 277 | "dirichlet_fc = DirichletFC(n_classes, alpha=alpha)\n", 278 | "transform = PowerLawSimplexAut(exponents)\n", 279 | "dirichlet_fc.set_simplex_automorphism(transform)\n", 280 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 281 | "\n", 282 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 283 | "print(f\"ECE is {eval_stats.expected_calibration_error(n_bins=200)}\")\n", 284 | "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n", 285 | "plt.show()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "print(\"mostly underestimating all classes (starting at 1/n_classes)\")\n", 295 | "\n", 296 | "\n", 297 | "# accuracy 80, ECE 0\n", 298 | "alpha = np.array([1, 2, 3, 2, 3]) * 1 / 19\n", 299 | "\n", 300 | "n_samples = 300000\n", 301 | "n_classes = 5\n", 302 | "\n", 303 | "dirichlet_fc = DirichletFC(n_classes, alpha=alpha)\n", 304 | "eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))\n", 305 | "\n", 306 | "print(f\"Accuracy is {eval_stats.accuracy()}\")\n", 307 | "print(f\"ECE is {eval_stats.expected_calibration_error()}\")\n", 308 | "eval_stats.plot_reliability_curves([4, \"top_class\"], display_weights=True)\n", 309 | "plt.show()" 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 2 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython2", 329 | "version": "2.7.6" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 0 334 | } 335 | -------------------------------------------------------------------------------- /notebooks/metric_convergence_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "!pip install tinydb" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from dataclasses import dataclass\n", 22 | "from itertools import product\n", 23 | "from typing import List\n", 24 | "from typing import Optional\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from matplotlib.axis import Axis\n", 28 | "import numpy as np\n", 29 | "from sklearn import datasets\n", 30 | "from sklearn.ensemble import RandomForestClassifier\n", 31 | "from sklearn.metrics import accuracy_score\n", 32 | "from sklearn.model_selection import train_test_split\n", 33 | "from sklearn.neural_network import MLPClassifier\n", 34 | "from tinydb import TinyDB, Query\n", 35 | "from tinydb.storages import MemoryStorage\n", 36 | "from tinydb.table import Table\n", 37 | "\n", 38 | "from kyle.evaluation.reliabilities import (\n", 39 | " expected_calibration_error,\n", 40 | " class_wise_expected_calibration_error,\n", 41 | ")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## Case 1: RF and MLP on Synthetic Data\n", 49 | "\n", 50 | "Here the sample size is increased by simply including more synthetic data" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "n_samples = 15000\n", 60 | "n_classes = 4" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "n_features = 20\n", 70 | "n_informative = 7\n", 71 | "n_redundant = 10\n", 72 | "\n", 73 | "X, y = datasets.make_classification(\n", 74 | " n_samples=n_samples,\n", 75 | " n_features=n_features,\n", 76 | " n_informative=n_informative,\n", 77 | " n_redundant=n_redundant,\n", 78 | " n_classes=n_classes,\n", 79 | " random_state=42,\n", 80 | ")\n", 81 | "\n", 82 | "X_train, X_test, y_train, y_test = train_test_split(\n", 83 | " X, y, test_size=0.8, random_state=42\n", 84 | ")\n", 85 | "\n", 86 | "print(f\"Training set size: {len(X_train)}, calibration set size: {len(X_test)}\")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "MODELS = {\n", 96 | " \"mlp\": MLPClassifier(hidden_layer_sizes=(30, 30, 20), max_iter=500),\n", 97 | " \"rf\": RandomForestClassifier(),\n", 98 | "}\n", 99 | "\n", 100 | "predicted_confs = {}\n", 101 | "\n", 102 | "for model_name, model in MODELS.items():\n", 103 | " print(f\"Fitting {model_name} on {len(X_train)} samples.\")\n", 104 | " model.fit(X_train, y_train)\n", 105 | "\n", 106 | " confs = model.predict_proba(X_test)\n", 107 | " predicted_confs[model_name] = confs\n", 108 | " y_pred = confs.argmax(axis=1)\n", 109 | " model_accuracy = accuracy_score(y_test, y_pred)\n", 110 | " print(f\"Test accuracy of {model_name}: {model_accuracy}\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# Consistency resampling to get calibrated classifiers\n", 120 | "\n", 121 | "calibrated_y_true = {\n", 122 | " \"mlp\": np.zeros(len(X_test)),\n", 123 | " \"rf\": np.zeros(len(X_test)),\n", 124 | "}\n", 125 | "for model, confs in predicted_confs.items():\n", 126 | " for i, conf in enumerate(confs):\n", 127 | " calibrated_y_true[model][i] = np.random.choice(n_classes, p=conf)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "METRICS = {\n", 137 | " \"ECE\": expected_calibration_error,\n", 138 | " \"cwECE\": class_wise_expected_calibration_error,\n", 139 | "}\n", 140 | "\n", 141 | "\n", 142 | "def get_scores(\n", 143 | " evaluation_set_size: int,\n", 144 | " num_samples: int,\n", 145 | " model: str,\n", 146 | " metric: str,\n", 147 | " consistency_resampling=False,\n", 148 | "):\n", 149 | " results = []\n", 150 | " for _ in range(num_samples):\n", 151 | " sample_indices = np.random.choice(\n", 152 | " len(X_test), evaluation_set_size, replace=False\n", 153 | " )\n", 154 | " confs = predicted_confs[model][sample_indices]\n", 155 | " if not consistency_resampling:\n", 156 | " y_true = y_test[sample_indices]\n", 157 | " else:\n", 158 | " y_true = calibrated_y_true[model][sample_indices]\n", 159 | " score = METRICS[metric](y_true, confs)\n", 160 | " results.append(score)\n", 161 | " return np.array(results)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "@dataclass\n", 171 | "class MetricEvaluation:\n", 172 | " model: str\n", 173 | " metric: str\n", 174 | " n_bins: int\n", 175 | " strategy: int\n", 176 | " scores: Optional[np.ndarray] = None\n", 177 | " set_size: Optional[int] = None\n", 178 | " num_samples: Optional[int] = None\n", 179 | " consistency_resampling: bool = False\n", 180 | "\n", 181 | " def perform_evaluation(\n", 182 | " self, set_size: int, num_samples: int, consistency_resampling=False\n", 183 | " ):\n", 184 | " self.set_size = set_size\n", 185 | " self.num_samples = num_samples\n", 186 | " self.consistency_resampling = consistency_resampling\n", 187 | " self.scores = get_scores(\n", 188 | " set_size,\n", 189 | " num_samples,\n", 190 | " self.model,\n", 191 | " self.metric,\n", 192 | " consistency_resampling=consistency_resampling,\n", 193 | " )\n", 194 | "\n", 195 | " def mean(self):\n", 196 | " self._assert_nonempty()\n", 197 | " return self.scores.mean()\n", 198 | "\n", 199 | " def std(self):\n", 200 | " self._assert_nonempty()\n", 201 | " return self.scores.std()\n", 202 | "\n", 203 | " def _assert_nonempty(self):\n", 204 | " if self.scores is None:\n", 205 | " raise RuntimeError(\n", 206 | " f\"You must run `perform_evaluation` before computing statistics: {self}\"\n", 207 | " )" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "# collect all evaluations to an in-memory database\n", 217 | "n_bins_options = [5, 40]\n", 218 | "binning_strategy_options = [\"uniform\", \"quantile\"]\n", 219 | "\n", 220 | "\n", 221 | "def save_evaluations_to_db(\n", 222 | " set_sizes: List[int], num_samples: int, db: TinyDB, consistency_resampling=False\n", 223 | "):\n", 224 | " for set_size, model, metric, n_bins, strategy in product(\n", 225 | " set_sizes, MODELS, METRICS, n_bins_options, binning_strategy_options\n", 226 | " ):\n", 227 | " metric_evaluation = MetricEvaluation(model, metric, n_bins, strategy)\n", 228 | " metric_evaluation.perform_evaluation(\n", 229 | " set_size, num_samples, consistency_resampling=consistency_resampling\n", 230 | " )\n", 231 | " db.insert(metric_evaluation.__dict__)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "# customization of tinydb\n", 241 | "class EvaluationsTable(Table):\n", 242 | " def search(self, cond: Query) -> List[MetricEvaluation]:\n", 243 | " results = super().search(cond)\n", 244 | " return [MetricEvaluation(**eval_dict) for eval_dict in results]\n", 245 | "\n", 246 | "\n", 247 | "TinyDB.table_class = EvaluationsTable\n", 248 | "TinyDB.default_storage_class = MemoryStorage" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "set_sizes = range(500, 8000, 500)\n", 258 | "num_samples = 10\n", 259 | "db = TinyDB()\n", 260 | "evalQ = Query()" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "save_evaluations_to_db(\n", 270 | " set_sizes=set_sizes,\n", 271 | " num_samples=num_samples,\n", 272 | " db=db,\n", 273 | " consistency_resampling=False,\n", 274 | ")\n", 275 | "\n", 276 | "save_evaluations_to_db(\n", 277 | " set_sizes=set_sizes,\n", 278 | " num_samples=num_samples,\n", 279 | " db=db,\n", 280 | " consistency_resampling=True,\n", 281 | ")" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "def get_query(\n", 291 | " model: str,\n", 292 | " n_bins: int = None,\n", 293 | " strategy: str = None,\n", 294 | " metric: str = None,\n", 295 | " consistency_resampling=False,\n", 296 | "):\n", 297 | " q = evalQ.model == model\n", 298 | " q = q & (evalQ.consistency_resampling == consistency_resampling)\n", 299 | " if n_bins:\n", 300 | " q = q & (evalQ.n_bins == n_bins)\n", 301 | " if strategy:\n", 302 | " q = q & (evalQ.strategy == strategy)\n", 303 | " if metric:\n", 304 | " q = q & (evalQ.metric == metric)\n", 305 | " return q\n", 306 | "\n", 307 | "\n", 308 | "def get_evaluations(\n", 309 | " model: str,\n", 310 | " n_bins: int = None,\n", 311 | " strategy: str = None,\n", 312 | " metric: str = None,\n", 313 | " consistency_resampling=False,\n", 314 | ") -> List[MetricEvaluation]:\n", 315 | " evaluations = db.search(\n", 316 | " get_query(\n", 317 | " model,\n", 318 | " n_bins,\n", 319 | " strategy,\n", 320 | " metric,\n", 321 | " consistency_resampling=consistency_resampling,\n", 322 | " )\n", 323 | " )\n", 324 | " return sorted(evaluations, key=lambda ev: ev.set_size)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "def plot_convergence(\n", 334 | " model: str,\n", 335 | " n_bins: int,\n", 336 | " strategy: str,\n", 337 | " metric: str,\n", 338 | " consistency_resampling=False,\n", 339 | " delta_x=0,\n", 340 | " color=0,\n", 341 | " ax: Axis = None,\n", 342 | "):\n", 343 | " selected_evaluations = get_evaluations(\n", 344 | " model, n_bins, strategy, metric, consistency_resampling\n", 345 | " )\n", 346 | "\n", 347 | " selected_set_sizes = np.zeros(len(selected_evaluations))\n", 348 | " means = np.zeros(len(selected_evaluations))\n", 349 | " stds = np.zeros(len(selected_evaluations))\n", 350 | " for i, ev in enumerate(selected_evaluations):\n", 351 | " selected_set_sizes[i] = ev.set_size\n", 352 | " means[i] = ev.mean()\n", 353 | " stds[i] = ev.std()\n", 354 | "\n", 355 | " if isinstance(color, int):\n", 356 | " color = f\"C{color}\"\n", 357 | " x_values = selected_set_sizes + delta_x\n", 358 | " ymin = means - stds\n", 359 | " ymax = means + stds\n", 360 | " if ax is None:\n", 361 | " ax = plt.gca()\n", 362 | " title = f\"{metric} for model: {model}\"\n", 363 | " if consistency_resampling:\n", 364 | " title += f\" (cons-res. labels)\"\n", 365 | " ax.set_title(title)\n", 366 | " ax.set_xlabel(\"sample size\")\n", 367 | " ax.plot(x_values, means, \".\", color=color, label=f\"{n_bins} bins, {strategy}\")\n", 368 | " ax.vlines(x_values, ymin=ymin, ymax=ymax, color=color, linewidth=2)\n", 369 | " return ax\n", 370 | "\n", 371 | "\n", 372 | "def plot_all_convergences(\n", 373 | " model: str, metric: str, ax: Axis = None, delta_x=60, consistency_resampling=False\n", 374 | "):\n", 375 | " for i, (n_bins, strategy) in enumerate(\n", 376 | " product(n_bins_options, binning_strategy_options)\n", 377 | " ):\n", 378 | " ax = plot_convergence(\n", 379 | " model,\n", 380 | " n_bins,\n", 381 | " strategy,\n", 382 | " metric,\n", 383 | " delta_x=i * delta_x,\n", 384 | " color=i,\n", 385 | " ax=ax,\n", 386 | " consistency_resampling=consistency_resampling,\n", 387 | " )\n", 388 | " ax.legend()\n", 389 | " return ax" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "def to_win_path(path: str):\n", 399 | " return path.replace(\"/c/\", \"C:\\\\\").replace(\"/\", \"\\\\\")\n", 400 | "\n", 401 | "\n", 402 | "fig, axs = plt.subplots(2, 4, figsize=(16, 9))\n", 403 | "\n", 404 | "plot_all_convergences(\"mlp\", \"ECE\", ax=axs[0, 0])\n", 405 | "plot_all_convergences(\"rf\", \"ECE\", ax=axs[0, 1])\n", 406 | "plot_all_convergences(\"mlp\", \"cwECE\", ax=axs[0, 2])\n", 407 | "plot_all_convergences(\"rf\", \"cwECE\", ax=axs[0, 3])\n", 408 | "plot_all_convergences(\"mlp\", \"ECE\", consistency_resampling=True, ax=axs[1, 0])\n", 409 | "plot_all_convergences(\"rf\", \"ECE\", consistency_resampling=True, ax=axs[1, 1])\n", 410 | "plot_all_convergences(\"mlp\", \"cwECE\", consistency_resampling=True, ax=axs[1, 2])\n", 411 | "plot_all_convergences(\"rf\", \"cwECE\", consistency_resampling=True, ax=axs[1, 3])\n", 412 | "\n", 413 | "fig.tight_layout(pad=2.0)\n", 414 | "plt.show()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Python 3", 428 | "language": "python", 429 | "name": "python3" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 2 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython2", 441 | "version": "2.7.6" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 0 446 | } 447 | -------------------------------------------------------------------------------- /notebooks/test_notebooks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import nbformat 5 | import pytest 6 | from nbconvert.preprocessors import ExecutePreprocessor 7 | 8 | NOTEBOOKS_DIR = "notebooks" 9 | DOCS_DIR = "docs" 10 | resources = {"metadata": {"path": NOTEBOOKS_DIR}} 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | notebooks_to_ignore = ["evaluating_cal_methods.ipynb"] 16 | 17 | notebooks_to_test = [ 18 | file 19 | for file in os.listdir(NOTEBOOKS_DIR) 20 | if file.endswith(".ipynb") and file not in notebooks_to_ignore 21 | ] 22 | 23 | 24 | @pytest.mark.parametrize("notebook", notebooks_to_test) 25 | def test_notebook(notebook): 26 | notebook_path = os.path.join(NOTEBOOKS_DIR, notebook) 27 | log.info(f"Reading jupyter notebook from {notebook_path}") 28 | with open(notebook_path) as f: 29 | nb = nbformat.read(f, as_version=4) 30 | ep = ExecutePreprocessor(timeout=600, resource=resources) 31 | # HACK: this is needed because some really smart person didn't test correctly the init of ExecutePreprocessor 32 | ep.nb = nb 33 | with ep.setup_kernel(): 34 | for i, cell in enumerate(nb["cells"]): 35 | log.info(f"processing cell {i} from {notebook}") 36 | ep.preprocess_cell(cell, resources=resources, index=i) 37 | 38 | # saving the executed notebook to docs 39 | output_path = os.path.join(DOCS_DIR, notebook) 40 | log.info(f"Saving executed notebook to {output_path} for documentation purposes") 41 | with open(output_path, "w", encoding="utf-8") as f: 42 | nbformat.write(nb, f) 43 | -------------------------------------------------------------------------------- /notebooks/trained_models/lenet5.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/notebooks/trained_models/lenet5.ckpt -------------------------------------------------------------------------------- /public/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/public/.nojekyll -------------------------------------------------------------------------------- /public/coverage/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/public/coverage/.gitignore -------------------------------------------------------------------------------- /public/docs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/public/docs/.gitignore -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | kyle project pages 5 | 6 | 7 |

Welcome to the kyle project pages!

8 |

This page hosts the documentation and reports from the develop branch of the project

9 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Must be kept in sync with `requirements.txt` 3 | requires = [ 4 | "setuptools >= 46.0.0", 5 | "setuptools_scm >= 2.0.0, <3" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | # Black-compatible settings for isort 9 | # See https://black.readthedocs.io/en/stable/compatible_configs.html 10 | [tool.isort] 11 | profile = "black" 12 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = 3 | tests 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | tox 2 | jupyter 3 | pytest 4 | pylint 5 | bump2version 6 | anybadge 7 | pandas 8 | tqdm 9 | -------------------------------------------------------------------------------- /requirements-torch.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | kornia~=0.5 4 | pytorch-lightning==1.2.8 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.0 2 | scikit-learn>=1.0 3 | matplotlib>=3.2.1 4 | netcal==1.2.0 5 | scipy>=1.7.1 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | test_requirements = ["pytest"] 4 | docs_requirements = [ 5 | "Sphinx==3.2.1", 6 | "sphinxcontrib-websupport==1.2.4", 7 | "sphinx_rtd_theme", 8 | ] 9 | 10 | setup( 11 | name="kyle-calibration", 12 | package_dir={"": "src"}, 13 | packages=find_packages(where="src"), 14 | python_requires=">=3.8", 15 | license="MIT", 16 | url="https://github.com/appliedAI-Initiative/kyle", 17 | include_package_data=True, 18 | version="0.1.8-dev0", 19 | description="appliedAI classifier calibration library", 20 | install_requires=open("requirements.txt").readlines(), 21 | setup_requires=["wheel"], 22 | tests_require=test_requirements, 23 | extras_require={"test": test_requirements, "docs": docs_requirements}, 24 | author="appliedAI", 25 | ) 26 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/src/__init__.py -------------------------------------------------------------------------------- /src/kyle/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.8-dev0" 2 | -------------------------------------------------------------------------------- /src/kyle/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_calibrator import ModelCalibrator 2 | -------------------------------------------------------------------------------- /src/kyle/calibration/calibration_methods.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, List, Optional, TypeVar 3 | 4 | import netcal.binning as bn 5 | import netcal.scaling as scl 6 | import numpy as np 7 | from netcal import AbstractCalibration 8 | from sklearn.base import BaseEstimator 9 | 10 | 11 | class BaseCalibrationMethod(ABC, BaseEstimator): 12 | @abstractmethod 13 | def fit(self, confidences: np.ndarray, ground_truth: np.ndarray): 14 | pass 15 | 16 | @abstractmethod 17 | def get_calibrated_confidences(self, confidences: np.ndarray): 18 | pass 19 | 20 | def __str__(self): 21 | return self.__class__.__name__ 22 | 23 | 24 | def _get_confidences_from_netcal_calibrator( 25 | confidences: np.ndarray, calibrator: AbstractCalibration 26 | ): 27 | calibrated_confs = calibrator.transform(confidences) 28 | 29 | # TODO: there is a whole bunch of hacks here. I want to get rid of netcal, don't like the code there 30 | # unfortunately, for 2-dim input netcal gives only the probabilities for the second class, 31 | # changing the dimension of the output array 32 | if calibrated_confs.ndim < 2: 33 | second_class_confs = calibrated_confs 34 | first_class_confs = 1 - second_class_confs 35 | calibrated_confs = np.stack([first_class_confs, second_class_confs], axis=1) 36 | 37 | if ( 38 | len(confidences) == 1 39 | ): # Netcal has a bug for single data points, this is a dirty fix 40 | calibrated_confs = calibrated_confs[None, 0] 41 | 42 | if calibrated_confs.shape != confidences.shape: 43 | raise RuntimeError( 44 | f"Shape mismatch for input {confidences}, output {calibrated_confs}. " 45 | f"Netcal output: {second_class_confs}" 46 | ) 47 | 48 | return calibrated_confs 49 | 50 | 51 | TNetcalModel = TypeVar("TNetcalModel", bound=AbstractCalibration) 52 | 53 | 54 | # TODO: this is definitely not the final class structure. For now its ok, I want to completely decouple from netcal soon 55 | class NetcalBasedCalibration(BaseCalibrationMethod, Generic[TNetcalModel]): 56 | def __init__(self, netcal_model: TNetcalModel): 57 | self.netcal_model = netcal_model 58 | 59 | def fit(self, confidences: np.ndarray, ground_truth: np.ndarray): 60 | self.netcal_model.fit(confidences, ground_truth) 61 | 62 | def get_calibrated_confidences(self, confidences: np.ndarray) -> np.ndarray: 63 | return _get_confidences_from_netcal_calibrator(confidences, self.netcal_model) 64 | 65 | 66 | class TemperatureScaling(NetcalBasedCalibration[scl.TemperatureScaling]): 67 | def __init__(self): 68 | super().__init__(scl.TemperatureScaling()) 69 | 70 | 71 | class BetaCalibration(NetcalBasedCalibration[scl.BetaCalibration]): 72 | def __init__(self): 73 | super().__init__(scl.BetaCalibration()) 74 | 75 | 76 | class LogisticCalibration(NetcalBasedCalibration[scl.LogisticCalibration]): 77 | def __init__(self): 78 | super().__init__(scl.LogisticCalibration()) 79 | 80 | 81 | class IsotonicRegression(NetcalBasedCalibration[bn.IsotonicRegression]): 82 | def __init__(self): 83 | super().__init__(bn.IsotonicRegression()) 84 | 85 | 86 | class HistogramBinning(NetcalBasedCalibration[bn.HistogramBinning]): 87 | def __init__(self, bins=20): 88 | super().__init__(bn.HistogramBinning(bins=bins)) 89 | self.bins = bins 90 | 91 | 92 | class ClassWiseCalibration(BaseCalibrationMethod): 93 | def __init__(self, calibration_method_factory=TemperatureScaling): 94 | self.calibration_method_factory = calibration_method_factory 95 | self.n_classes: Optional[int] = None 96 | self.calibration_methods: Optional[List[BaseCalibrationMethod]] = None 97 | 98 | # TODO: maybe parallelize this and predict 99 | def fit(self, confidences: np.ndarray, labels: np.ndarray): 100 | self.n_classes = confidences.shape[1] 101 | self.calibration_methods = [] 102 | for class_label in range(self.n_classes): 103 | calibration_method = self.calibration_method_factory() 104 | selected_confs, selected_labels = get_class_confs_labels( 105 | class_label, confidences, labels 106 | ) 107 | calibration_method.fit(selected_confs, selected_labels) 108 | self.calibration_methods.append(calibration_method) 109 | 110 | def get_calibrated_confidences(self, confs: np.ndarray): 111 | result = np.zeros(confs.shape) 112 | argmax = confs.argmax(1) 113 | for class_label in range(self.n_classes): 114 | scaler = self.calibration_methods[class_label] 115 | indices = argmax == class_label 116 | selected_confs = confs[indices] 117 | calibrated_confs = scaler.get_calibrated_confidences(selected_confs) 118 | assert calibrated_confs.shape == selected_confs.shape, ( 119 | f"Expected shape {selected_confs.shape} but got {calibrated_confs.shape}. Confs: " 120 | f"{selected_confs}, output: {calibrated_confs}" 121 | ) 122 | 123 | result[indices] = calibrated_confs 124 | return result 125 | 126 | 127 | class ConfidenceReducedCalibration(BaseCalibrationMethod, BaseEstimator): 128 | def __init__(self, calibration_method=TemperatureScaling()): 129 | self.calibration_method = calibration_method 130 | 131 | def fit(self, confidences: np.ndarray, ground_truth: np.ndarray): 132 | reduced_confs, reduced_gt = get_binary_classification_data( 133 | confidences, ground_truth 134 | ) 135 | self.calibration_method.fit(reduced_confs, reduced_gt) 136 | 137 | def get_calibrated_confidences(self, confidences: np.ndarray): 138 | reduced_confs = get_reduced_confidences(confidences) 139 | reduced_predictions = self.calibration_method.get_calibrated_confidences( 140 | reduced_confs 141 | ) 142 | reduced_predictions = reduced_predictions[:, 0] # take only 0-class prediction 143 | n_classes = confidences.shape[1] 144 | non_predicted_class_confidences = (1 - reduced_predictions) / (n_classes - 1) 145 | 146 | # using broadcasting here 147 | calibrated_confidences = ( 148 | non_predicted_class_confidences * np.ones(confidences.shape).T 149 | ) 150 | calibrated_confidences = calibrated_confidences.T 151 | 152 | argmax_indices = np.expand_dims(confidences.argmax(axis=1), axis=1) 153 | np.put_along_axis( 154 | calibrated_confidences, argmax_indices, reduced_predictions[:, None], axis=1 155 | ) 156 | assert np.all( 157 | np.isclose(calibrated_confidences.sum(1), np.ones(len(confidences))) 158 | ) 159 | assert calibrated_confidences.shape == confidences.shape 160 | return calibrated_confidences 161 | 162 | 163 | def get_class_confs_labels(c: int, confidences: np.ndarray, labels: np.ndarray): 164 | indices = confidences.argmax(1) == c 165 | return confidences[indices], labels[indices] 166 | 167 | 168 | def get_reduced_confidences(confidences: np.ndarray): 169 | top_class_predictions = confidences.max(axis=1) 170 | return np.stack([top_class_predictions, 1 - top_class_predictions], axis=1) 171 | 172 | 173 | def get_binary_classification_data(confidences: np.ndarray, labels: np.ndarray): 174 | new_confidences = get_reduced_confidences(confidences) 175 | pred_was_correct = labels == confidences.argmax(axis=1) 176 | # this is a hack - we predict class 0 if pred was correct, else class 1 177 | new_gt = (np.logical_not(pred_was_correct)).astype(int) 178 | return new_confidences, new_gt 179 | -------------------------------------------------------------------------------- /src/kyle/calibration/model_calibrator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from kyle.models import CalibratableModel 4 | 5 | 6 | class ModelCalibrator: 7 | def __init__( 8 | self, 9 | X_calibrate: np.ndarray, 10 | y_calibrate: np.ndarray, 11 | X_fit: np.ndarray = None, 12 | y_fit: np.ndarray = None, 13 | ): 14 | self.X_calibrate = X_calibrate 15 | self.y_calibrate = y_calibrate 16 | self.X_fit = X_fit 17 | self.y_fit = y_fit 18 | 19 | def calibrate(self, calibratable_model: CalibratableModel, fit: bool = False): 20 | if fit: 21 | if self.X_fit is None or self.y_fit is None: 22 | raise AttributeError("No dataset for fitting provided") 23 | calibratable_model.fit(self.X_fit, self.y_fit) 24 | 25 | calibratable_model.calibrate(self.X_calibrate, self.y_calibrate) 26 | 27 | def __str__(self): 28 | return self.__class__.__name__ 29 | -------------------------------------------------------------------------------- /src/kyle/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.enhance import denormalize 3 | from torch import Tensor, tensor 4 | from torch.utils.data import DataLoader 5 | from torchvision.datasets import CIFAR10 6 | from torchvision.transforms import transforms 7 | 8 | # see https://github.com/akamaster/pytorch_resnet_cifar10 9 | resnet_normalize_transform = transforms.Normalize( 10 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 11 | ) 12 | 13 | 14 | def resnet_denormalize_transform(data: Tensor): 15 | is_batch = len(data.shape) == 4 16 | if not is_batch: 17 | data = data[None, :] # transform only works on batches 18 | result = denormalize( 19 | data, 20 | tensor(resnet_normalize_transform.mean), 21 | tensor(resnet_normalize_transform.std), 22 | ) 23 | if not is_batch: 24 | result = result[0] 25 | return result 26 | 27 | 28 | def get_cifar10_dataloader(path: str, train=False): 29 | dataset = CIFAR10( 30 | path, 31 | train=train, 32 | download=True, 33 | transform=transforms.Compose( 34 | [transforms.ToTensor(), resnet_normalize_transform] 35 | ), 36 | ) 37 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1) 38 | return dataloader 39 | 40 | 41 | def get_cifar10_dataset(path: str, train=False): 42 | dataset = CIFAR10( 43 | path, 44 | train=train, 45 | download=True, 46 | transform=transforms.Compose( 47 | [transforms.ToTensor(), resnet_normalize_transform] 48 | ), 49 | ) 50 | images = [] 51 | targets = [] 52 | # Quick hack, can't find a nice way of doing that. Datasets cannot be sliced and we need the transform 53 | # Alternative is to retrieve the .data array and do the transformations and reshaping ourselves but this is brittle 54 | for image, target in dataset: 55 | images.append(image) 56 | targets.append(target) 57 | return torch.stack(images), torch.tensor(targets) 58 | -------------------------------------------------------------------------------- /src/kyle/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous import compute_accuracy, compute_ECE, compute_expected_max 2 | from .discrete import EvalStats 3 | -------------------------------------------------------------------------------- /src/kyle/evaluation/continuous.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.integrate import quad 3 | from scipy.stats import dirichlet 4 | 5 | from kyle.integrals import ( 6 | dirichlet_exp_value, 7 | simplex_integral_fixed_comp, 8 | simplex_integral_fixed_max, 9 | ) 10 | from kyle.sampling.fake_clf import DirichletFC 11 | from kyle.transformations import SimplexAut 12 | 13 | 14 | def _prob_correct_prediction(conf: np.ndarray, simplex_aut: SimplexAut): 15 | conf = conf.squeeze() 16 | gt_probabilities = simplex_aut.transform(conf, check_io=False) 17 | return gt_probabilities[np.argmax(conf)] 18 | 19 | 20 | def _prob_class(conf: np.ndarray, simplex_aut: SimplexAut, selected_class: int): 21 | conf = conf.squeeze() 22 | gt_probabilities = simplex_aut.transform(conf, check_io=False) 23 | return gt_probabilities[selected_class] 24 | 25 | 26 | def _probability_vector(*parametrization: float): 27 | return np.array(list(parametrization) + [1 - np.sum(parametrization)]) 28 | 29 | 30 | def compute_accuracy(dirichlet_fc: DirichletFC, **kwargs): 31 | def integrand(*parametrization): 32 | conf = _probability_vector(*parametrization) 33 | return _prob_correct_prediction(conf, dirichlet_fc.simplex_automorphism) 34 | 35 | return dirichlet_exp_value(integrand, dirichlet_fc.alpha, **kwargs) 36 | 37 | 38 | def compute_expected_max(dirichlet_fc: DirichletFC, **kwargs): 39 | def integrand(*parametrization): 40 | conf = _probability_vector(*parametrization) 41 | return np.max(conf) 42 | 43 | return dirichlet_exp_value(integrand, dirichlet_fc.alpha, **kwargs) 44 | 45 | 46 | def compute_ECE(dirichlet_fc: DirichletFC, conditioned="full", **kwargs): 47 | """ 48 | Computes theoretical ECE of dirichlet_fc Fake Classifier conditioned on full confidence vector, conditioned on the 49 | confidence in prediction or conditioned on each class confidence separately (see [1]_ for further details) 50 | 51 | :param dirichlet_fc: Dirichlet fake classifier to calculate ECE for 52 | :param conditioned: Quantity to condition ECE on 53 | :param kwargs: passed to integrator function 54 | :return: * If conditioned on full confidence vector returns: result, abserr, (further scipy.nquad output) 55 | * If conditioned on the confidence in prediction returns: result, abserr, (further scipy.quad output) 56 | * If conditioned on each class separately returns: List of num_classes+1 entries. First entry contains 57 | average of all "i-class ECEs". Subsequent entries contain results for each "i-class ECE" 58 | separately: result, abserr, (further scipy.quad output) 59 | 60 | References 61 | ---------- 62 | .. [1] Kull, M., Perello-Nieto, M., Kängsepp, M., Filho, T. S., Song, H., & Flach, P. (2019). Beyond temperature 63 | scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration. 64 | """ 65 | 66 | if conditioned == "full": 67 | return _compute_ECE_full(dirichlet_fc, **kwargs) 68 | elif conditioned == "confidence": 69 | return _compute_ECE_conf(dirichlet_fc, **kwargs) 70 | elif conditioned == "class": 71 | return _compute_ECE_class(dirichlet_fc, **kwargs) 72 | else: 73 | raise ValueError("ECE has to be one of fully, confidence or class conditioned") 74 | 75 | 76 | def _compute_ECE_full(dirichlet_fc: DirichletFC, **kwargs): 77 | def integrand(*parametrization): 78 | conf = _probability_vector(*parametrization) 79 | return np.abs( 80 | np.max(conf) 81 | - _prob_correct_prediction(conf, dirichlet_fc.simplex_automorphism) 82 | ) 83 | 84 | return dirichlet_exp_value(integrand, dirichlet_fc.alpha, **kwargs) 85 | 86 | 87 | def _compute_ECE_conf(dirichlet_fc: DirichletFC, **kwargs): 88 | # Need higher precision for accurate result due to nesting of two quad/nquad calls 89 | # Sets higher precision if precision not already set in **kwargs 90 | opts = {"epsabs": 1e-4} 91 | opts.update(kwargs.pop("opts", {})) 92 | kwargs.update({"opts": opts}) 93 | 94 | num_classes = len(dirichlet_fc.alpha) 95 | 96 | def p_c(*parametrization): 97 | return dirichlet.pdf(parametrization, dirichlet_fc.alpha) 98 | 99 | def p_y_c(*parametrization): 100 | conf = _probability_vector(*parametrization) 101 | return _prob_correct_prediction(conf, dirichlet_fc.simplex_automorphism) * p_c( 102 | *parametrization 103 | ) 104 | 105 | def integrand(max_conf): 106 | int_p_c = simplex_integral_fixed_max(p_c, num_classes, max_conf, **kwargs)[0] 107 | int_p_y_c = simplex_integral_fixed_max(p_y_c, num_classes, max_conf, **kwargs)[ 108 | 0 109 | ] 110 | return np.abs(int_p_y_c / int_p_c - max_conf) * int_p_c 111 | 112 | # At exactly 1/num_classes or 1 get 0/0 113 | boundary_offset = 1e-2 114 | 115 | return quad( 116 | integrand, 117 | 1 / num_classes + boundary_offset, 118 | 1 - boundary_offset, 119 | epsabs=opts["epsabs"], 120 | ) 121 | 122 | 123 | def _compute_ECE_class(dirichlet_fc: DirichletFC, **kwargs): 124 | # Need higher precision for accurate result due to nesting of two quad/nquad calls 125 | # Sets higher precision if precision not already set in **kwargs 126 | opts = {"epsabs": 1e-4} 127 | opts.update(kwargs.pop("opts", {})) 128 | kwargs.update({"opts": opts}) 129 | 130 | num_classes = len(dirichlet_fc.alpha) 131 | 132 | integral_results = [] 133 | 134 | for i in range(num_classes): 135 | 136 | def p_c(*parametrization): 137 | return dirichlet.pdf(parametrization, dirichlet_fc.alpha) 138 | 139 | def p_y_c(*parametrization): 140 | conf = _probability_vector(*parametrization) 141 | return _prob_class(conf, dirichlet_fc.simplex_automorphism, i) * p_c( 142 | *parametrization 143 | ) 144 | 145 | def integrand(comp_conf): 146 | int_p_c = simplex_integral_fixed_comp( 147 | p_c, num_classes, i, comp_conf, **kwargs 148 | )[0] 149 | int_p_y_c = simplex_integral_fixed_comp( 150 | p_y_c, num_classes, i, comp_conf, **kwargs 151 | )[0] 152 | return np.abs(int_p_y_c / int_p_c - comp_conf) * int_p_c 153 | 154 | # At exactly 0 or 1 get 0/0 155 | boundary_offset = 1e-2 156 | 157 | result = quad( 158 | integrand, 159 | 1 / num_classes + boundary_offset, 160 | 1 - boundary_offset, 161 | epsabs=opts["epsabs"], 162 | ) 163 | 164 | integral_results.append(result) 165 | 166 | integral_results.insert(0, sum(S[0] for S in integral_results) / num_classes) 167 | 168 | return integral_results 169 | -------------------------------------------------------------------------------- /src/kyle/evaluation/discrete.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Literal, Sequence, Union 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | from kyle.evaluation.reliabilities import ( 8 | _assert_1d, 9 | _binary_classifier_reliability, 10 | _plot_reliability_curves, 11 | _to2d, 12 | classifier_reliability, 13 | ) 14 | from kyle.util import safe_accuracy_score 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | class EvalStats: 20 | """ 21 | Class for computing evaluation statistics of classifiers, including calibration metrics 22 | """ 23 | 24 | def __init__( 25 | self, 26 | y_true: np.ndarray, 27 | confidences: np.ndarray, 28 | ): 29 | """ 30 | 31 | :param y_true: integer array of shape (n_samples,). Assumed to contain true labels in range [0, n_classes-1] 32 | :param confidences: array of shape (n_samples, n_classes) 33 | """ 34 | _assert_1d(y_true, "y_true") 35 | self.y_true = y_true 36 | self.confidences = _to2d(confidences) 37 | 38 | # Saving some fields, so they don't have to be recomputed 39 | self.num_samples = len(y_true) 40 | self.num_classes = confidences.shape[1] 41 | self.y_pred = confidences.argmax(axis=1) 42 | # noinspection PyArgumentList 43 | self._top_class_confidences = np.take_along_axis( 44 | confidences, self.y_pred[:, None], axis=1 45 | ).T[0] 46 | self._argmax_predicted_mask = self.y_true == self.y_pred 47 | # reshaping b/c the mask can be computed through numpy broadcasting 48 | possible_labels = np.arange(self.num_classes).reshape((self.num_classes, 1)) 49 | # from this field we can get a mask for whether the label i was predicted 50 | # by calling self._label_predicted_masks[i] 51 | # noinspection PyTypeChecker 52 | self._label_predicted_masks: np.ndarray = self.y_true == possible_labels 53 | 54 | @property 55 | def top_class_confidences(self): 56 | return self._top_class_confidences 57 | 58 | def expected_confidence( 59 | self, class_label: Union[int, Literal["top_class"]] = "top_class" 60 | ): 61 | """ 62 | Returns the expected confidence for the selected class or for the predictions (default) 63 | 64 | :param class_label: either the class label as int or "top_class" 65 | :return: 66 | """ 67 | if class_label == "top_class": 68 | confs = self._top_class_confidences 69 | else: 70 | confs = self.confidences[:, class_label] 71 | return float(np.mean(confs)) 72 | 73 | def accuracy(self): 74 | return safe_accuracy_score(self.y_true, self.y_pred) 75 | 76 | def expected_calibration_error( 77 | self, 78 | class_label: Union[int, Literal["top_class"]] = "top_class", 79 | n_bins=12, 80 | strategy: Literal["uniform", "quantile"] = "uniform", 81 | ): 82 | """ 83 | :param class_label: if "top_class", will be the usual (confidence) ECE. Otherwise, it will be the 84 | marginal class-wise ECE for the selected class. 85 | :param n_bins: 86 | :param strategy: 87 | :return: 88 | """ 89 | reliabilities = self.reliabilities( 90 | class_label=class_label, 91 | n_bins=n_bins, 92 | strategy=strategy, 93 | ) 94 | sum_members = np.sum(reliabilities.n_members) 95 | if sum_members == 0: 96 | return 0.0 97 | weights = reliabilities.n_members / sum_members 98 | abs_diff = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 99 | return np.dot(abs_diff, weights) 100 | 101 | def average_calibration_error( 102 | self, n_bins=12, strategy: Literal["uniform", "quantile"] = "uniform" 103 | ): 104 | reliabilities = self.reliabilities( 105 | class_label="top_class", n_bins=n_bins, strategy=strategy 106 | ) 107 | abs_distances = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 108 | return np.mean(abs_distances) 109 | 110 | def max_calibration_error( 111 | self, n_bins=12, strategy: Literal["uniform", "quantile"] = "uniform" 112 | ): 113 | reliabilities = self.reliabilities( 114 | class_label="top_class", n_bins=n_bins, strategy=strategy 115 | ) 116 | abs_distances = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 117 | return np.max(abs_distances) 118 | 119 | def class_wise_expected_calibration_error( 120 | self, n_bins=12, strategy: Literal["uniform", "quantile"] = "uniform" 121 | ): 122 | sum_marginal_errors = sum( 123 | self.expected_calibration_error(k, n_bins=n_bins, strategy=strategy) 124 | for k in range(self.num_classes) 125 | ) 126 | return sum_marginal_errors / self.num_classes 127 | 128 | # TODO or not TODO: could in principle work for any 1-dim. reduction but we might not need this generality 129 | def reliabilities( 130 | self, 131 | class_label: Union[int, Literal["top_class"]], 132 | n_bins=12, 133 | strategy: Literal["uniform", "quantile"] = "uniform", 134 | ): 135 | """ 136 | Computes arrays related to the reliabilities of the provided confidences. They can be used e.g. for computing 137 | calibration errors or for visualizing reliability curves. 138 | 139 | :param n_bins: 140 | :param class_label: either an integer label for the class-wise reliabilities, or "top_class" for the 141 | reliabilities in predictions. 142 | :param strategy: 143 | 144 | :return: named tuple containing arrays with confidences, accuracies, members, bin_edges 145 | """ 146 | # Reducing here to save time on re-computation 147 | if class_label == "top_class": 148 | reduced_y_pred = self.top_class_confidences 149 | reduced_y_true = self._argmax_predicted_mask 150 | else: 151 | reduced_y_pred = self.confidences[:, class_label] 152 | reduced_y_true = self._label_predicted_masks[class_label] 153 | return _binary_classifier_reliability( 154 | reduced_y_true, 155 | reduced_y_pred, 156 | n_bins=n_bins, 157 | strategy=strategy, 158 | ) 159 | 160 | def plot_reliability_curves( 161 | self, 162 | class_labels: Sequence[Union[int, Literal["top_class"]]], 163 | display_weights=False, 164 | n_bins=12, 165 | strategy: Literal["uniform", "quantile"] = "uniform", 166 | ): 167 | """ 168 | 169 | :param class_labels: 170 | :param display_weights: If True, for each reliability curve the weights of each bin will be 171 | plotted as histogram. The weights have been scaled for the sake of display, only relative differences 172 | between them have an interpretable meaning. 173 | The errors containing "expected" in the name take these weights into account. 174 | :param strategy: 175 | :param n_bins: 176 | :return: figure 177 | """ 178 | return _plot_reliability_curves( 179 | self.reliabilities, class_labels, display_weights, n_bins, strategy 180 | ) 181 | 182 | def plot_gt_distribution(self, label_names: Dict[int, str] = None): 183 | class_labels, counts = np.unique(self.y_true, return_counts=True) 184 | if label_names is not None: 185 | class_labels = [ 186 | label_names.get(label_id, label_id) for label_id in class_labels 187 | ] 188 | 189 | fig, ax = plt.subplots() 190 | ax.pie(counts, labels=class_labels, autopct="%1.1f%%", startangle=90) 191 | ax.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. 192 | ax.set_title("Ground Truth Distribution") 193 | return fig 194 | -------------------------------------------------------------------------------- /src/kyle/evaluation/reliabilities.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Literal, NamedTuple, Protocol, Sequence, Union 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from matplotlib.colors import ListedColormap 7 | from sklearn.utils import check_consistent_length, column_or_1d 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class ReliabilityResult(NamedTuple): 13 | prob_true: np.ndarray 14 | prob_pred: np.ndarray 15 | n_members: np.ndarray 16 | bin_edges: np.ndarray 17 | 18 | 19 | def _assert_1d(arr: np.ndarray, arr_name: str, error_type=ValueError): 20 | if arr.ndim != 1: 21 | raise error_type(f"{arr_name} should be a 1d array but got shape: {arr.shape}") 22 | 23 | 24 | def _to2d(confidences: np.ndarray): 25 | """ 26 | If a 1d array is passed, we assume that it corresponds to the confidence in True class, i.e. in label=1. 27 | Then to get the confs in the right order, we prepend 1-confidences for the label 0 28 | """ 29 | if confidences.ndim == 2: 30 | return confidences 31 | if confidences.ndim == 1: 32 | return np.stack([1 - confidences, confidences]).T 33 | raise ValueError( 34 | f"Cannot turn array of shape: {confidences.shape} to two-dim array." 35 | ) 36 | 37 | 38 | def _binary_classifier_reliability( 39 | y_true: np.ndarray, 40 | y_prob: np.ndarray, 41 | n_bins=12, 42 | strategy: Literal["uniform", "quantile"] = "uniform", 43 | ) -> ReliabilityResult: 44 | """ 45 | The implementation is essentially a copy of `sklearn.calibration.calibration_curve` but contains additional 46 | quantities in the return: `n_members` and `bin_edges`. Unfortunately, the sklearn implementation cannot be 47 | used as it doesn't return these quantities, and they would have to be recomputed. 48 | Compute arrays related to the reliabilities of a binary classifiers. They can be used e.g. for computing 49 | calibration errors or for visualizing reliability curves. 50 | 51 | :param n_bins: 52 | :param y_true: the confidences in the prediction of the class **True** 53 | (or class 1, if labels passed as integers). Should be a numpy array of shape (n_samples, ) 54 | :param y_prob: array of shape (n_samples, ) containing True and False or integers (1, 0) respectively 55 | :param strategy: Strategy used to define the widths of the bins. 56 | **uniform** The bins have identical widths. 57 | **quantile** The bins have the same number of samples and depend on y_prob. 58 | :return: named tuple containing arrays with confidences, accuracies, members, bin_edges 59 | """ 60 | y_true = column_or_1d(y_true) 61 | y_prob = column_or_1d(y_prob) 62 | check_consistent_length(y_true, y_prob) 63 | 64 | if y_prob.min() < 0 or y_prob.max() > 1: 65 | raise ValueError("y_prob has values outside [0, 1].") 66 | 67 | uniform_bins = np.linspace(0.0, 1.0, n_bins + 1) 68 | if strategy == "quantile": # Determine bin edges by distribution of data 69 | bins = np.quantile(y_prob, uniform_bins) 70 | elif strategy == "uniform": 71 | bins = uniform_bins 72 | else: 73 | raise ValueError( 74 | "Invalid entry to 'strategy' input. Strategy " 75 | "must be either 'quantile' or 'uniform'." 76 | ) 77 | 78 | binids = np.searchsorted(bins[1:-1], y_prob) 79 | 80 | bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins)) 81 | bin_true = np.bincount(binids, weights=y_true, minlength=len(bins)) 82 | bin_total = np.bincount(binids, minlength=len(bins)) 83 | 84 | is_nonempty = bin_total != 0 85 | n_empty_bins = n_bins - is_nonempty.sum() 86 | if n_empty_bins > 0: 87 | log.debug( 88 | f"{n_empty_bins} of {n_bins} bins were empty, the reliability curve cannot be estimated in them." 89 | f"This can be prevented by either: \n" 90 | f" 1) reducing the number of bins (current value is {n_bins}) or \n" 91 | f" 2) increasing the sample size (current value is {len(y_true)}) or \n" 92 | f" 3) using strategy='quantile'" 93 | ) 94 | 95 | last_nonempty_bin = np.where(is_nonempty)[0][-1] 96 | if last_nonempty_bin == n_bins: 97 | last_bin_edge = 1.0 98 | else: 99 | last_bin_edge = bins[last_nonempty_bin + 1] 100 | 101 | bin_edges = np.append(bins[is_nonempty], last_bin_edge) 102 | bin_members = bin_total[is_nonempty] 103 | prob_true = bin_true[is_nonempty] / bin_members 104 | prob_pred = bin_sums[is_nonempty] / bin_members 105 | 106 | return ReliabilityResult(prob_true, prob_pred, bin_members, bin_edges) 107 | 108 | 109 | def classifier_reliability( 110 | y_true: np.ndarray, 111 | confidences: np.ndarray, 112 | class_label: Union[int, Literal["top_class"]] = 0, 113 | n_bins=12, 114 | strategy: Literal["uniform", "quantile"] = "uniform", 115 | ): 116 | """ 117 | Computes arrays related to the reliabilities of the provided confidences. They can be used e.g. for computing 118 | calibration errors or for visualizing reliability curves. 119 | 120 | :param y_true: 121 | :param confidences: 122 | :param n_bins: 123 | :param class_label: either an integer label for the class-wise reliabilities, or "top_class" for the 124 | reliabilities in predictions. 125 | :param strategy: 126 | 127 | :return: named tuple containing arrays with confidences, accuracies, members, bin_edges 128 | """ 129 | confidences = _to2d(confidences) 130 | 131 | y_pred = confidences.argmax(axis=1) 132 | # noinspection PyArgumentList 133 | if class_label == "top_class": 134 | reduced_y_prob = np.take_along_axis(confidences, y_pred[:, None], axis=1).T[0] 135 | reduced_y_true = y_true == y_pred 136 | else: 137 | reduced_y_prob = confidences[:, class_label] 138 | reduced_y_true = y_true == class_label 139 | return _binary_classifier_reliability( 140 | reduced_y_true, 141 | reduced_y_prob, 142 | n_bins=n_bins, 143 | strategy=strategy, 144 | ) 145 | 146 | 147 | def expected_calibration_error( 148 | y_true: np.ndarray, 149 | confidences: np.ndarray, 150 | class_label: Union[int, Literal["top_class"]] = "top_class", 151 | n_bins=12, 152 | strategy: Literal["uniform", "quantile"] = "uniform", 153 | ): 154 | """ 155 | :param class_label: if "top_class", will be the usual (confidence) ECE. Otherwise, it will be the 156 | marginal class-wise ECE for the selected class. 157 | :param n_bins: 158 | :param strategy: 159 | :return: 160 | """ 161 | reliabilities = classifier_reliability( 162 | y_true, 163 | confidences, 164 | class_label=class_label, 165 | n_bins=n_bins, 166 | strategy=strategy, 167 | ) 168 | sum_members = np.sum(reliabilities.n_members) 169 | if sum_members == 0: 170 | return 0.0 171 | weights = reliabilities.n_members / sum_members 172 | abs_diff = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 173 | return np.dot(abs_diff, weights) 174 | 175 | 176 | def average_calibration_error( 177 | y_true: np.ndarray, 178 | confidences: np.ndarray, 179 | n_bins=12, 180 | strategy: Literal["uniform", "quantile"] = "uniform", 181 | ): 182 | reliabilities = classifier_reliability( 183 | y_true, confidences, class_label="top_class", n_bins=n_bins, strategy=strategy 184 | ) 185 | abs_distances = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 186 | return np.mean(abs_distances) 187 | 188 | 189 | def max_calibration_error( 190 | y_true: np.ndarray, 191 | confidences: np.ndarray, 192 | n_bins=12, 193 | strategy: Literal["uniform", "quantile"] = "uniform", 194 | ): 195 | reliabilities = classifier_reliability( 196 | y_true, confidences, class_label="top_class", n_bins=n_bins, strategy=strategy 197 | ) 198 | abs_distances = np.abs(reliabilities.prob_pred - reliabilities.prob_true) 199 | return np.max(abs_distances) 200 | 201 | 202 | def class_wise_expected_calibration_error( 203 | y_true: np.ndarray, 204 | confidences: np.ndarray, 205 | n_bins=12, 206 | strategy: Literal["uniform", "quantile"] = "uniform", 207 | ): 208 | confidences = _to2d(confidences) 209 | num_classes = confidences.shape[-1] 210 | sum_marginal_errors = sum( 211 | expected_calibration_error( 212 | y_true, confidences, k, n_bins=n_bins, strategy=strategy 213 | ) 214 | for k in range(num_classes) 215 | ) 216 | return sum_marginal_errors / num_classes 217 | 218 | 219 | class ReliabilitiesProviderProtocol(Protocol): 220 | def __call__(self, class_label: int, n_bins, strategy) -> ReliabilityResult: 221 | pass 222 | 223 | 224 | def _plot_reliability_curves( 225 | reliabilities_provider: ReliabilitiesProviderProtocol, 226 | class_labels: Sequence[Union[int, Literal["top_class"]]], 227 | display_weights, 228 | n_bins, 229 | strategy: Literal["uniform", "quantile"], 230 | ): 231 | """ 232 | Helper function to plot reliabilities. Within EvalStats part of the reliabilities 233 | is precomputed, and y_true and confidences are known - so we don't want to use the same 234 | provider there. 235 | 236 | :param reliabilities_provider: 237 | :param class_labels: 238 | :param display_weights: 239 | :param n_bins: 240 | :param strategy: 241 | :return: 242 | """ 243 | colors = ListedColormap(["y", "g", "r", "c", "m"]) 244 | 245 | fig = plt.figure() 246 | plt.title(f"Reliability curves ({n_bins} bins)") 247 | plt.xlabel("confidence") 248 | plt.ylabel("ground truth probability") 249 | plt.axis("equal") 250 | 251 | # plotting a diagonal for perfect calibration 252 | plt.plot([0, 1], [0, 1], label="perfect calibration", color="b") 253 | 254 | # for each class, plot curve and weights, cycle through colors 255 | for i, class_label in enumerate(class_labels): 256 | color = colors(i) 257 | if class_label == "top_class": 258 | plot_label = "prediction" 259 | else: 260 | plot_label = f"class {class_label}" 261 | 262 | prob_true, prob_pred, n_members, bin_edges = reliabilities_provider( 263 | class_label, 264 | n_bins=n_bins, 265 | strategy=strategy, 266 | ) 267 | plt.plot(prob_pred, prob_true, marker=".", label=plot_label, color=color) 268 | if display_weights: 269 | # rescale the weights for improved visibility 270 | weights = n_members / (3 * n_members.max()) 271 | width = np.diff(bin_edges) 272 | plt.bar( 273 | bin_edges[:-1], 274 | weights, 275 | align="edge", 276 | alpha=0.2, 277 | width=width, 278 | color=color, 279 | label=f"weights ({plot_label})", 280 | edgecolor="black", 281 | linewidth=0.5, 282 | linestyle="--", 283 | ) 284 | 285 | axes = plt.gca() 286 | axes.set_xlim([0, 1]) 287 | axes.set_ylim([0, 1]) 288 | plt.legend(loc="best") 289 | return fig 290 | 291 | 292 | def plot_reliability_curves( 293 | y_true: np.ndarray, 294 | confidences: np.ndarray, 295 | class_labels: Sequence[Union[int, Literal["top_class"]]], 296 | display_weights=False, 297 | n_bins=12, 298 | strategy: Literal["uniform", "quantile"] = "uniform", 299 | ): 300 | """ 301 | :param y_true: 302 | :param confidences: 303 | :param class_labels: 304 | :param display_weights: If True, for each reliability curve the weights of each bin will be 305 | plotted as histogram. The weights have been scaled for the sake of display, only relative differences 306 | between them have an interpretable meaning. 307 | The errors containing "expected" in the name take these weights into account. 308 | :param strategy: 309 | :param n_bins: 310 | :return: figure 311 | """ 312 | 313 | def reliabilities_provider( 314 | class_label, 315 | n_bins, 316 | strategy, 317 | ): 318 | return classifier_reliability( 319 | y_true, 320 | confidences, 321 | class_label, 322 | n_bins=n_bins, 323 | strategy=strategy, 324 | ) 325 | 326 | return _plot_reliability_curves( 327 | reliabilities_provider, class_labels, display_weights, n_bins, strategy 328 | ) 329 | -------------------------------------------------------------------------------- /src/kyle/integrals.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Protocol, Sequence 3 | 4 | from scipy.integrate import nquad 5 | from scipy.stats import dirichlet 6 | 7 | 8 | # this is the currently supported way to annotate callables with *args of a certain type, 9 | # see https://mypy.readthedocs.io/en/latest/protocols.html#callback-protocols 10 | # hopefully at some point the pycharm type checker will learn to recognize those. 11 | # I opened an issue for JetBrains: https://youtrack.jetbrains.com/issue/PY-45438 12 | class Integrand(Protocol): 13 | def __call__(self, *parameters: float) -> float: 14 | ... 15 | 16 | 17 | def simplex_integral( 18 | f: Callable, num_classes: int, boundary_offset=1e-10, coord_sum: float = 1, **kwargs 19 | ): 20 | """ 21 | Performs an integral over num_classes-1 dimensional simplex using scipy 22 | 23 | :param f: function to integrate over the simplex. Should accept num_classes-1 variables 24 | :param num_classes: equals dimension of the simplex + 1 25 | :param boundary_offset: can be used to prevent numerical errors due to singularities at the simplex' boundary 26 | :param coord_sum: sets sum of coordinates of simplex. For standard simplex sum(x1,x2,...) = 1. Mainly useful for 27 | simplex_integral_fixed_max 28 | :param kwargs: will be passed to scipy.integrate.nquad 29 | :return: 30 | """ 31 | if num_classes < 2: 32 | raise ValueError("need at least two classes") 33 | 34 | def nested_variable_boundary(*previous_variables: float): 35 | """ 36 | Any variable for the simplex integral goes from zero to coord_sum (usually 1) - sum(all previous variables). 37 | See docu of nquad for more details on boundaries 38 | """ 39 | return [ 40 | 0 + boundary_offset, 41 | coord_sum - sum(previous_variables) - boundary_offset, 42 | ] 43 | 44 | simplex_boundary = [nested_variable_boundary] * (num_classes - 1) 45 | # we typically don't need higher precision 46 | opts = {"epsabs": 1e-2} 47 | opts.update(kwargs.pop("opts", {})) 48 | return nquad(f, simplex_boundary, opts=opts, **kwargs) 49 | 50 | 51 | def simplex_integral_fixed_comp( 52 | f: Callable, num_classes: int, selected_class: int, x_comp: float, **kwargs 53 | ): 54 | """ 55 | Performs an integral over the subset of a num_classes-1 dimensional simplex defined by the selected_class component 56 | of the confidence vector having a fixed value of x_comp, i.e. marginalises out all other classes. 57 | 58 | Computing this involves integrating over a num_classes-2 dimensional non-unit simplex with coord_sum set to 1-x_comp 59 | and with the selected_class argument of f being set to x_comp 60 | 61 | :param f: function to integrate over the subset of the simplex. Should accept num_classes-1 variables 62 | :param num_classes: equals dimension of the simplex + 1 63 | :param selected_class: selected confidence vector component [0, num_classes-1] 64 | :param x_comp: fixed value of the selected vector component 65 | :param kwargs: passed to simplex_integral 66 | :return: 67 | """ 68 | 69 | if not (0 <= x_comp <= 1): 70 | raise ValueError("Confidences have to lie in range (0,1)") 71 | 72 | if selected_class == num_classes - 1: 73 | 74 | def constrained_integrand(*args: float): 75 | constrained_args = [1 - x_comp - sum(args[0:]), *args[0:]] 76 | return f(*constrained_args) 77 | 78 | else: 79 | 80 | def constrained_integrand(*args: float): 81 | constrained_args = [*args[0:selected_class], x_comp, *args[selected_class:]] 82 | return f(*constrained_args) 83 | 84 | return simplex_integral( 85 | constrained_integrand, num_classes - 1, coord_sum=1 - x_comp, **kwargs 86 | ) 87 | 88 | 89 | def simplex_integral_fixed_max(f: Callable, num_classes: int, x_max: float, **kwargs): 90 | """ 91 | Performs an integral over the subset of a num_classes-1 dimensional simplex defined by the largest 92 | coordinate/confidence having a fixed value of x_max, i.e. marginalises over all possible confidence vectors with 93 | maximum confidence of x_max. 94 | 95 | Computing this integral involves computing the sum of num_classes integrals each over a num_classes-2 dimensional 96 | simplex. For x_max > 0.5 the integrals are 'true' simplex integrals. For x_max < 0.5 the boundaries become complex 97 | and non-simplex like. The integrals can then be extended to full simplex integrals using an appropiate indicator 98 | function, ``get_argmax_region_char_function``. 99 | 100 | :param f: function to integrate over the subset of the simplex. Should accept num_classes-1 variables 101 | :param num_classes: equals dimension of the simplex + 1 102 | :param x_max: fixed value of largest coordinate value. defines subset of simplex 103 | :param kwargs: passed to simplex_integral_fixed_comp 104 | :return: 105 | """ 106 | 107 | if not (1 / num_classes < x_max < 1): 108 | return 0, 0 109 | 110 | # For small x_max higher precision is required for accurate results as over large ingtegration range integrand is 0 111 | # Sets higher precision if precision not already set in **kwargs 112 | if x_max < 1 / 2: 113 | opts = {"epsabs": 1e-4} 114 | opts.update(kwargs.pop("opts", {})) 115 | kwargs.update({"opts": opts}) 116 | 117 | integral_result = (0, 0) 118 | 119 | for i in range(num_classes): 120 | 121 | argmax_char_func = get_argmax_region_char_function(i) 122 | 123 | constrained_integral = simplex_integral_fixed_comp( 124 | lambda *args: argmax_char_func(*args) * f(*args), 125 | num_classes, 126 | i, 127 | x_max, 128 | **kwargs, 129 | ) 130 | integral_result = tuple( 131 | sum(p) for p in zip(integral_result, constrained_integral) 132 | ) 133 | 134 | return integral_result 135 | 136 | 137 | def dirichlet_exp_value(f: Callable, alpha: Sequence[float], **kwargs): 138 | """ 139 | Computes expectation value of f over num_classes-1 dimensional simplex using scipy. Note scipy.dirichlet.pdf for 140 | n classes accepts n-1 entries as sum(x_n) = 1. 141 | 142 | :param f: 143 | :param alpha: the parameters of the dirichlet distribution, one for each class 144 | :param kwargs: passed to simplex_integral 145 | :return: 146 | """ 147 | num_classes = len(alpha) 148 | return simplex_integral( 149 | lambda *args: f(*args) * dirichlet.pdf(args, alpha), num_classes, **kwargs 150 | ) 151 | 152 | 153 | def get_argmax_region_char_function(selected_class: int) -> Integrand: 154 | """ 155 | Returns the char. function for the area in which the selected class is the argmax of the input args. 156 | The returned function takes a variable number of floats as input. They represent the first N-1 independent 157 | entries of an element of a simplex in N-dimensional space (N classes). 158 | """ 159 | 160 | def char_function(*args: float): 161 | if len(args) < 1: 162 | raise ValueError("need at least two classes/one input") 163 | if not 0 <= selected_class <= len(args): 164 | raise IndexError( 165 | f"selected_class {selected_class} out of bound for input of length {len(args)}" 166 | ) 167 | probabilities = list(args) + [1 - sum(args)] 168 | class_confidence = probabilities[selected_class] 169 | return float(class_confidence == max(probabilities)) 170 | 171 | return char_function 172 | 173 | 174 | log = logging.getLogger(__name__) 175 | -------------------------------------------------------------------------------- /src/kyle/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .calibration_metrics import ACE, ECE, MCE, BaseCalibrationError 2 | -------------------------------------------------------------------------------- /src/kyle/metrics/calibration_metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import netcal.metrics 4 | import numpy as np 5 | 6 | from kyle.util import in_simplex 7 | 8 | 9 | # TODO: replace by our own implementations used in EvalStats 10 | class BaseCalibrationError(ABC): 11 | @abstractmethod 12 | def _compute( 13 | self, confidences: np.ndarray, ground_truth: np.ndarray, **kwargs 14 | ) -> float: 15 | pass 16 | 17 | def compute(self, confidences: np.ndarray, ground_truth: np.ndarray, **kwargs): 18 | if not in_simplex(confidences): 19 | raise ValueError("Invalid confidences array") 20 | return self._compute(confidences, ground_truth, **kwargs) 21 | 22 | def __str__(self): 23 | return self.__class__.__name__ 24 | 25 | 26 | class NetcalCalibrationError(BaseCalibrationError): 27 | def __init__(self, netcal_metric): 28 | """ 29 | Instance of a netcal metric class, e.g. netcal.metrics.ECE 30 | """ 31 | self.netcal_metric = netcal_metric 32 | 33 | def _compute( 34 | self, confidences: np.ndarray, ground_truth: np.ndarray, **kwargs 35 | ) -> float: 36 | return self.netcal_metric.measure(confidences, ground_truth, **kwargs) 37 | 38 | 39 | class ACE(NetcalCalibrationError): 40 | """Average Calibration Error. Wraps around netcal's implementation - for further reading refer to netcal's docs.""" 41 | 42 | def __init__(self, bins: int = 10): 43 | super(ACE, self).__init__(netcal.metrics.ACE(bins)) 44 | 45 | 46 | class ECE(NetcalCalibrationError): 47 | """Expected Calibration Error. Wraps around netcal's implementation - for further reading refer to netcal's docs.""" 48 | 49 | def __init__(self, bins: int = 10): 50 | super().__init__(netcal.metrics.ECE(bins)) 51 | 52 | 53 | class MCE(NetcalCalibrationError): 54 | """Maximum Calibration Error. Wraps around netcal's implementation - for further reading refer to netcal's docs.""" 55 | 56 | def __init__(self, bins: int = 10): 57 | super().__init__(netcal.metrics.MCE(bins)) 58 | 59 | 60 | # TODO: get rid of this 61 | -------------------------------------------------------------------------------- /src/kyle/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .calibratable_model import CalibratableModel 2 | -------------------------------------------------------------------------------- /src/kyle/models/calibratable_model.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | import numpy as np 4 | 5 | from kyle.calibration.calibration_methods import ( 6 | BaseCalibrationMethod, 7 | TemperatureScaling, 8 | ) 9 | 10 | 11 | class ClassifierProtocol(Protocol): 12 | def fit(self, X: np.ndarray, y: np.ndarray): 13 | ... 14 | 15 | def predict(self, X: np.ndarray) -> np.ndarray: 16 | ... 17 | 18 | def predict_proba(self, X: np.ndarray) -> np.ndarray: 19 | ... 20 | 21 | 22 | class CalibratableModel(ClassifierProtocol): 23 | def __init__( 24 | self, 25 | model: ClassifierProtocol, 26 | calibration_method: BaseCalibrationMethod = TemperatureScaling(), 27 | ): 28 | self.model = model 29 | self.calibration_method = calibration_method 30 | 31 | def calibrate(self, X: np.ndarray, y: np.ndarray): 32 | uncalibrated_confidences = self.model.predict_proba(X) 33 | self.calibration_method.fit(uncalibrated_confidences, y) 34 | 35 | def fit(self, X: np.ndarray, y: np.ndarray): 36 | self.model.fit(X, y) 37 | 38 | def predict(self, X: np.ndarray) -> np.ndarray: 39 | calibrated_proba = self.predict_proba(X) 40 | 41 | return np.argmax(calibrated_proba, axis=2) 42 | 43 | def predict_proba(self, X: np.ndarray) -> np.ndarray: 44 | uncalibrated_confidences = self.model.predict_proba(X) 45 | return self.calibration_method.get_calibrated_confidences( 46 | uncalibrated_confidences 47 | ) 48 | 49 | def __str__(self): 50 | return f"{self.__class__.__name__}, method: {self.calibration_method}" 51 | -------------------------------------------------------------------------------- /src/kyle/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code taken from: https://github.com/akamaster/pytorch_resnet_cifar10 3 | Proper implementation of ResNet20 for Cifar10. Pytorch only has ResNets for ImageNet which 4 | differ in number of parameters 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.init as init 10 | 11 | 12 | def _weights_init(m): 13 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight) 15 | 16 | 17 | class LambdaLayer(nn.Module): 18 | def __init__(self, lambd): 19 | super(LambdaLayer, self).__init__() 20 | self.lambd = lambd 21 | 22 | def forward(self, x): 23 | return self.lambd(x) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, in_planes, planes, stride=1, option="A"): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = nn.Conv2d( 32 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 33 | ) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.conv2 = nn.Conv2d( 36 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 37 | ) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != planes: 42 | if option == "A": 43 | """ 44 | For CIFAR10 ResNet paper uses option A. 45 | """ 46 | self.shortcut = LambdaLayer( 47 | lambda x: F.pad( 48 | x[:, :, ::2, ::2], 49 | (0, 0, 0, 0, planes // 4, planes // 4), 50 | "constant", 51 | 0, 52 | ) 53 | ) 54 | elif option == "B": 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d( 57 | in_planes, 58 | self.expansion * planes, 59 | kernel_size=1, 60 | stride=stride, 61 | bias=False, 62 | ), 63 | nn.BatchNorm2d(self.expansion * planes), 64 | ) 65 | 66 | def forward(self, x): 67 | out = F.relu(self.bn1(self.conv1(x))) 68 | out = self.bn2(self.conv2(out)) 69 | out += self.shortcut(x) 70 | out = F.relu(out) 71 | return out 72 | 73 | 74 | class ResNet(nn.Module): 75 | def __init__(self, block, num_blocks, num_classes=10): 76 | super(ResNet, self).__init__() 77 | self.in_planes = 16 78 | 79 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(16) 81 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 84 | self.linear = nn.Linear(64, num_classes) 85 | 86 | self.apply(_weights_init) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1] * (num_blocks - 1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = F.relu(self.bn1(self.conv1(x))) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = F.avg_pool2d(out, out.size()[3]) 103 | out = out.view(out.size(0), -1) 104 | out = self.linear(out) 105 | return out 106 | 107 | 108 | def resnet20(): 109 | return ResNet(BasicBlock, [3, 3, 3]) 110 | 111 | 112 | def resnet56(): 113 | return ResNet(BasicBlock, [9, 9, 9]) 114 | 115 | 116 | def load_weights(weights_path: str, model: ResNet): 117 | weights_dict = torch.load(weights_path, map_location=torch.device("cpu"))[ 118 | "state_dict" 119 | ] 120 | weights_dict = { 121 | key.replace("module.", ""): value for key, value in weights_dict.items() 122 | } 123 | model.load_state_dict(weights_dict) 124 | -------------------------------------------------------------------------------- /src/kyle/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/kyle/067e08d0cd908997159b00832907f50ce5791233/src/kyle/sampling/__init__.py -------------------------------------------------------------------------------- /src/kyle/sampling/fake_clf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Union 3 | 4 | import numpy as np 5 | import scipy.optimize 6 | import scipy.stats 7 | 8 | from kyle.transformations import IdentitySimplexAut, SimplexAut 9 | from kyle.util import sample_index 10 | 11 | 12 | class FakeClassifier(ABC): 13 | def __init__( 14 | self, 15 | num_classes: int, 16 | simplex_automorphism: SimplexAut = None, 17 | check_io=True, 18 | ): 19 | if num_classes < 1: 20 | raise ValueError(f"{self.__class__.__name__} requires at least two classes") 21 | self.num_classes = num_classes 22 | self._rng = np.random.default_rng() 23 | 24 | self._simplex_automorphism: SimplexAut = None 25 | self.set_simplex_automorphism(simplex_automorphism) 26 | self.check_io = check_io 27 | 28 | # TODO or not TODO: one could get rid of separate SimplexAut. class in favor of passing a function 29 | # pro: the function is less verbose to write, easier for user; contra: naming and state become more convoluted 30 | def set_simplex_automorphism(self, aut: Union[SimplexAut, None]) -> None: 31 | """ 32 | :param aut: if None, the identity automorphism will be set 33 | """ 34 | if aut is None: 35 | aut = IdentitySimplexAut(self.num_classes) 36 | if aut.num_classes is not None and aut.num_classes != self.num_classes: 37 | raise ValueError(f"{aut} has wrong number of classes: {aut.num_classes}") 38 | self._simplex_automorphism = aut 39 | 40 | @abstractmethod 41 | def sample_confidences(self, n_samples: int) -> np.ndarray: 42 | ... 43 | 44 | @property 45 | def simplex_automorphism(self): 46 | return self._simplex_automorphism 47 | 48 | def get_sample_arrays(self, n_samples: int): 49 | """ 50 | Get arrays with ground truth and predicted probabilities 51 | 52 | :param n_samples: 53 | :return: tuple of arrays of shapes (n_samples,), (n_samples, n_classes) 54 | """ 55 | calibrated_confidences = self.sample_confidences(n_samples) 56 | gt_labels = sample_index(calibrated_confidences) 57 | confidences = self.simplex_automorphism.transform(calibrated_confidences) 58 | return gt_labels, confidences 59 | 60 | def __str__(self): 61 | return f"{self.__class__.__name__}_{self.simplex_automorphism}" 62 | 63 | 64 | class DirichletFC(FakeClassifier): 65 | def __init__( 66 | self, 67 | num_classes: int, 68 | alpha: Sequence[float] = None, 69 | simplex_automorphism: SimplexAut = None, 70 | ): 71 | super().__init__(num_classes, simplex_automorphism=simplex_automorphism) 72 | 73 | self._alpha: np.ndarray = None 74 | self.set_alpha(alpha) 75 | 76 | def set_alpha(self, alpha: Union[np.ndarray, None]): 77 | """ 78 | :param alpha: if None, the default value of [1, ..., 1] will be set 79 | """ 80 | if alpha is None: 81 | alpha = np.ones(self.num_classes) 82 | else: 83 | alpha = np.array(alpha) 84 | if not alpha.shape == (self.num_classes,): 85 | raise ValueError(f"Wrong shape of alpha: {alpha.shape}") 86 | self._alpha = alpha 87 | 88 | @property 89 | def alpha(self): 90 | return self._alpha 91 | 92 | def sample_confidences(self, n_samples: int) -> np.ndarray: 93 | return self._rng.dirichlet(self.alpha, size=n_samples) 94 | 95 | def pdf(self, confidences, alpha=None): 96 | if alpha is None: 97 | alpha = self.alpha 98 | return scipy.stats.dirichlet.pdf(confidences.T, alpha) 99 | 100 | def fit(self, confidences, initial_alpha=None, alpha_bounds=None, **kwargs): 101 | """ 102 | Fits the dirichlet fake classifier to the provided confidence distribution using maximum likelihood estimation 103 | and sets the fake classifier parameters to the best fit parameters 104 | 105 | :param confidences: Numpy array of shape (num_samples, num_classes); 106 | confidence distribution to fit classifier to 107 | :param initial_alpha: Float; Initial guess for fitting alpha parameters 108 | :param alpha_bounds: Tuple, (lower_bound, upper_bound); Bounds for fitting alpha parameters. A lower/upper bound 109 | of None corresponds to unbounded parameter 110 | :param kwargs: passed to ``scipy.optimize.minimize`` 111 | :return: 112 | """ 113 | if initial_alpha is None: 114 | initial_alpha = self.alpha 115 | 116 | if alpha_bounds is None: 117 | alpha_bounds = (0.0001, None) 118 | 119 | # rescale confidences to avoid divergences on sides of simplex and renormalize 120 | confidences = ( 121 | confidences * (confidences.shape[0] - 1) + 1 / self.num_classes 122 | ) / confidences.shape[0] 123 | confidences = confidences / np.sum(confidences, axis=1)[:, None] 124 | 125 | alpha_bounds = [alpha_bounds] * self.num_classes 126 | 127 | nll = lambda parm: -np.sum(np.log(self.pdf(confidences, parm))) 128 | mle_fit = scipy.optimize.minimize( 129 | nll, initial_alpha, bounds=alpha_bounds, **kwargs 130 | ) 131 | self.set_alpha(mle_fit.x) 132 | 133 | return mle_fit 134 | 135 | 136 | class MultiDirichletFC(FakeClassifier): 137 | """ 138 | A fake classifier that first draws from a K categorical distribution and based on the result then draws from 139 | 1 of K Dirichlet Distributions of a restricted form. 140 | The K'th Dirichlet Distribution has parameters of the form: sigma * {1, 1, ..., alpha_k, 1, 1, ...}; alpha > 1 141 | where 'alpha_k' is at the k'th position. 142 | Effectively a distribution with a maximum of variable position and variable variance in each corner of the simplex 143 | 144 | :param num_classes: 145 | :param alpha: numpy array of shape (num_classes,). k'th entry corresponds to alpha_k for the k'th dirichlet 146 | :param sigma: numpy array of shape (num_classes,). k'th entry corresponds to sigma for the k'th dirichlet 147 | :param distribution_weights: numpy array of shape (num_classes,). Probabilities used for drawing from K Categorical 148 | :param simplex_automorphism: 149 | """ 150 | 151 | def __init__( 152 | self, 153 | num_classes: int, 154 | alpha: Sequence[float] = None, 155 | sigma: Sequence[float] = None, 156 | distribution_weights: Sequence[float] = None, 157 | simplex_automorphism: SimplexAut = None, 158 | ): 159 | super().__init__(num_classes, simplex_automorphism=simplex_automorphism) 160 | 161 | self._alpha: np.ndarray = None 162 | self._sigma: np.ndarray = None 163 | self._distribution_weights: np.ndarray = None 164 | 165 | self.set_alpha(alpha) 166 | self.set_sigma(sigma) 167 | self.set_distribution_weights(distribution_weights) 168 | 169 | @property 170 | def alpha(self): 171 | return self._alpha 172 | 173 | def set_alpha(self, alpha: Union[np.ndarray, None]): 174 | """ 175 | :param alpha: if None, the default value of [1, ..., 1] will be set. 176 | """ 177 | if alpha is None: 178 | alpha = np.ones(self.num_classes) 179 | else: 180 | alpha = np.array(alpha) 181 | if not alpha.shape == (self.num_classes,): 182 | raise ValueError(f"Wrong shape of alpha: {alpha.shape}") 183 | self._alpha = alpha 184 | 185 | @property 186 | def sigma(self): 187 | return self._sigma 188 | 189 | def set_sigma(self, sigma: Union[np.ndarray, None]): 190 | """ 191 | :param sigma: if None, the default value of [1, ..., 1] will be set 192 | """ 193 | if sigma is None: 194 | sigma = np.ones(self.num_classes) 195 | else: 196 | sigma = np.array(sigma) 197 | if not sigma.shape == (self.num_classes,): 198 | raise ValueError(f"Wrong shape of sigma: {sigma.shape}") 199 | self._sigma = sigma 200 | 201 | @property 202 | def distribution_weights(self): 203 | return self._distribution_weights 204 | 205 | def set_distribution_weights(self, distribution_weights: Union[np.ndarray, None]): 206 | """ 207 | :param distribution_weights: if None, the default value of [1/num_classes, ..., 1/num_classes] will be set 208 | """ 209 | if distribution_weights is None: 210 | distribution_weights = np.ones(self.num_classes) / self.num_classes 211 | else: 212 | distribution_weights = np.array(distribution_weights) 213 | if not distribution_weights.shape == (self.num_classes,): 214 | raise ValueError( 215 | f"Wrong shape of predicted_class_weights: {distribution_weights.shape}" 216 | ) 217 | self._distribution_weights = distribution_weights / np.sum(distribution_weights) 218 | 219 | def get_parameters(self): 220 | return self._alpha, self._sigma, self._distribution_weights 221 | 222 | def set_parameters(self, alpha, sigma, distribution_weights): 223 | self.set_alpha(alpha) 224 | self.set_sigma(sigma) 225 | self.set_distribution_weights(distribution_weights) 226 | 227 | def sample_confidences(self, n_samples: int) -> np.ndarray: 228 | 229 | weight_array = np.repeat(self.distribution_weights[None, :], n_samples, axis=0) 230 | chosen_distributions = sample_index(weight_array) 231 | 232 | confidences = np.zeros((n_samples, self.num_classes)) 233 | 234 | for i, chosen_distribution in enumerate(chosen_distributions): 235 | alpha_vector = np.ones(self.num_classes) 236 | alpha_vector[chosen_distribution] = self.alpha[chosen_distribution] 237 | alpha_vector *= self.sigma[chosen_distribution] 238 | 239 | confidences[i, :] = self._rng.dirichlet(alpha_vector) 240 | 241 | return confidences 242 | 243 | def pdf(self, confidences, alpha=None, sigma=None, distribution_weights=None): 244 | """ 245 | Computes pdf of MultiDirichletFC. Using K categorical distribution to sample from K dirichlet distributions 246 | is equivalent to sampling from a pdf that is a weighted sum of the K individual dirichlet pdf's 247 | :param confidences: numpy array of shape (num_classes,) or (num_samples, num_classes) 248 | :param distribution_weights: numpy array of shape (num_classes,) uses self.distribution_weights if not provided 249 | :param sigma: numpy array of shape (num_classes,) uses self.sigma if not provided 250 | :param alpha: numpy array of shape (num_classes,) uses self.alpha if not provided 251 | """ 252 | 253 | if alpha is None: 254 | alpha = self.alpha 255 | if sigma is None: 256 | sigma = self.sigma 257 | if distribution_weights is None: 258 | distribution_weights = self.distribution_weights 259 | 260 | confidences = confidences.T 261 | 262 | distributions = np.zeros(confidences.shape) 263 | 264 | for i, (a, s) in enumerate(zip(alpha, sigma)): 265 | alpha_vector = np.ones(self.num_classes) 266 | alpha_vector[i] = a 267 | alpha_vector *= s 268 | 269 | distributions[i] = scipy.stats.dirichlet.pdf(confidences, alpha_vector) 270 | 271 | return np.sum(distribution_weights[:, None] * distributions, axis=0) / np.sum( 272 | distribution_weights 273 | ) 274 | 275 | def fit( 276 | self, 277 | confidences, 278 | initial_parameters=None, 279 | parameter_bounds=None, 280 | simplified_fitting=True, 281 | **kwargs, 282 | ): 283 | """ 284 | Fits a Multi-Dirichlet fake classifier to the provided confidence distribution using maximum likeihood 285 | estimation and sets the fake classifier parameters to the best fit parameters. 286 | If simplified_fitting is set to False all parameters of the fake classifier are fit directly via MLE 287 | If simplified_fitting is set to True each dirichlet is fit separately. Alpha and Sigma of the k'th dirichlet 288 | are fit to the subset of the confidences that predict the k'th class, i.e. for which argmax(c) = k. The 289 | distribution weights are not fit, but estimated from the predicted class probabilities of the confidence 290 | distribution. 291 | 292 | :param confidences: Numpy array of shape (num_samples, num_classes); 293 | confidence distribution to fit classifier to 294 | :param initial_parameters: Numpy array of shape (3,) ((2,) for simplified_fitting=True) 295 | Corresponds to initial guesses for each parameter 'class' alpha, sigma and distribution_weights 296 | If None, uses [1, 1, 1/num_classes] 297 | :param parameter_bounds: Sequence of 3 (2 for simplified_fitting=True) tuples (lower_bound, upper_bound) 298 | Corresponds to the bounds on each parameter 'class', alpha, sigma and distribution_weights 299 | A lower/upper bound of None corresponds to unbounded parameters 300 | If None, uses intervals [(0, + infinity), (0, + infinity), (0,1)] 301 | :param simplified_fitting: If False directly fits Multi-Dirichlet FC to confidence distribution 302 | If True fits each dirichlet separately. Only fits alpha and sigma, not 303 | distribution_weights 304 | :param kwargs: passed to ``scipy.optimize.minimize`` 305 | :return: If simplfied_fitting=False: scipy OptimizeResult 306 | If simplified_fitting=True: List of num_classes OptimizeResults, one for each separate dirichlet fit 307 | """ 308 | 309 | # rescale confidences to avoid divergences on sides of simplex and renormalize 310 | confidences = ( 311 | confidences * (confidences.shape[0] - 1) + 1 / self.num_classes 312 | ) / confidences.shape[0] 313 | confidences = confidences / np.sum(confidences, axis=1)[:, None] 314 | 315 | if not simplified_fitting: 316 | if initial_parameters is None: 317 | initial_parameters = np.array([1, 1, 1 / self.num_classes]) 318 | if parameter_bounds is None: 319 | # dirichlet distribution undefined for alpha/sigma parameters exactly = 0 320 | parameter_bounds = [(0.0001, None)] * 2 + [(0, 1)] 321 | 322 | # scipy requires an initial guess and a bound (lower, upper) for each parameter 323 | # not just each parameter class 324 | initial_parameters = np.repeat(initial_parameters, self.num_classes) 325 | parameter_bounds = [ 326 | pair for pair in parameter_bounds for i in range(self.num_classes) 327 | ] 328 | 329 | nll = lambda parms: -np.sum( 330 | np.log(self.pdf(confidences, *np.split(parms, 3))) 331 | ) 332 | mle_fit = scipy.optimize.minimize( 333 | nll, initial_parameters, bounds=parameter_bounds 334 | ) 335 | self.set_parameters(*np.split(mle_fit.x, 3)) 336 | 337 | return mle_fit 338 | 339 | if simplified_fitting: 340 | if initial_parameters is None: 341 | initial_parameters = np.array([1, 1]) 342 | if parameter_bounds is None: 343 | # dirichlet distribution undefined for alpha/sigma parameters exactly = 0 344 | parameter_bounds = [(0.0001, None)] * 2 345 | 346 | predicted_class = np.argmax(confidences, axis=1) 347 | class_split_confidences = [ 348 | confidences[predicted_class == i, :] for i in range(self.num_classes) 349 | ] 350 | 351 | estimated_distribution_weights = [ 352 | k_class_conf.shape[0] for k_class_conf in class_split_confidences 353 | ] 354 | estimated_distribution_weights = estimated_distribution_weights / np.sum( 355 | estimated_distribution_weights 356 | ) 357 | 358 | mle_fits = [] 359 | 360 | for k, k_class_confidences in enumerate(class_split_confidences): 361 | 362 | def k_dir_nll(alpha_k, sigma_k): 363 | alpha = np.ones(self.num_classes) 364 | alpha[k] = alpha_k 365 | sigma = np.ones(self.num_classes) 366 | sigma[k] = sigma_k 367 | # 'isolate' the k'th dirichlet distribution 368 | distribution_weights = np.zeros(self.num_classes) 369 | distribution_weights[k] = 1 370 | return -np.sum( 371 | np.log( 372 | self.pdf( 373 | k_class_confidences, alpha, sigma, distribution_weights 374 | ) 375 | ) 376 | ) 377 | 378 | k_initial_parameters = initial_parameters 379 | k_parameter_bounds = parameter_bounds 380 | 381 | k_dir_mle_fit = scipy.optimize.minimize( 382 | lambda parms: k_dir_nll(*parms), 383 | k_initial_parameters, 384 | bounds=k_parameter_bounds, 385 | **kwargs, 386 | ) 387 | mle_fits.append(k_dir_mle_fit) 388 | 389 | self.set_alpha(np.array([k_mle_fit.x[0] for k_mle_fit in mle_fits])) 390 | self.set_sigma(np.array([k_mle_fit.x[1] for k_mle_fit in mle_fits])) 391 | self.set_distribution_weights(estimated_distribution_weights) 392 | 393 | return mle_fits 394 | -------------------------------------------------------------------------------- /src/kyle/transformations.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Sequence 3 | 4 | import numpy as np 5 | from scipy.special import softmax 6 | 7 | from kyle.util import in_simplex 8 | 9 | 10 | class SimplexAut(ABC): 11 | """ 12 | Base class for all simplex automorphisms 13 | 14 | :param num_classes: The dimension of the simplex vector, equals 1 + (dimension of the simplex as manifold). 15 | If provided, will use this for addition I/O checks. 16 | """ 17 | 18 | def __init__(self, num_classes: int = None): 19 | # Several transformations can be defined without referring to num_classes, which is why it is optional. 20 | self.num_classes = num_classes 21 | 22 | def __str__(self): 23 | return self.__class__.__name__ 24 | 25 | @abstractmethod 26 | def _transform(self, x: np.ndarray) -> np.ndarray: 27 | """ 28 | :param x: array of shape (n_samples, n_classes) 29 | :return: transformed array of shape (n_samples, n_classes) 30 | """ 31 | pass 32 | 33 | def transform(self, x: np.ndarray, check_io=True) -> np.ndarray: 34 | if len(x.shape) == 1: 35 | x = x[None, :] 36 | if check_io and not in_simplex(x, self.num_classes): 37 | raise ValueError(f"Input has to be from a simplex of suitable dimension") 38 | x = self._transform(x) 39 | if check_io and not in_simplex(x, self.num_classes): 40 | raise ValueError( 41 | f"Bad implementation: Output has to be from a simplex of suitable dimension" 42 | ) 43 | return x.squeeze() 44 | 45 | 46 | class LogitsBasedSimplexAut(SimplexAut, ABC): 47 | @abstractmethod 48 | def transform_logits(self, logits: np.ndarray) -> np.ndarray: 49 | pass 50 | 51 | def _transform(self, x: np.ndarray) -> np.ndarray: 52 | logits = np.log(x) 53 | return softmax(self.transform_logits(logits), axis=1) 54 | 55 | 56 | class TempScaling(LogitsBasedSimplexAut): 57 | def __init__(self, temperature: float): 58 | self.temperature = temperature 59 | super().__init__() 60 | 61 | def transform_logits(self, logits: np.ndarray) -> np.ndarray: 62 | return self.temperature * logits 63 | 64 | 65 | class IdentitySimplexAut(SimplexAut): 66 | def _transform(self, x: np.ndarray) -> np.ndarray: 67 | return x 68 | 69 | 70 | class SingleComponentSimplexAut(SimplexAut): 71 | """ 72 | A simplex automorphism resulting from the application of a map on the unit interval to a 73 | single component of x and normalizing the result. 74 | 75 | :param component: integer in range [0, num_classes - 1], corresponding to the component on which to apply the mapping 76 | :param mapping: map from the unit interval [0,1] to itself, should be applicable to arrays 77 | :param num_classes: The dimension of the simplex vector, equals 1 + (dimension of the simplex as manifold). 78 | If provided, will use this for addition I/O checks. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | component: int, 84 | mapping: Callable[[np.ndarray], np.ndarray], 85 | num_classes: int = None, 86 | ): 87 | assert ( 88 | 0 <= component < num_classes 89 | ), "Selected component should be in the range [0, num_classes - 1]" 90 | self.component = component 91 | self.mapping = mapping 92 | super().__init__(num_classes=num_classes) 93 | 94 | def _transform(self, x: np.ndarray) -> np.ndarray: 95 | x = x.copy() 96 | x[:, self.component] = self.mapping(x[:, self.component]) 97 | return x / x.sum(axis=1)[:, None] 98 | 99 | 100 | class MaxComponentSimplexAut(SimplexAut): 101 | """ 102 | A simplex automorphism resulting from the application of a map on the unit interval to a 103 | the argmax of x and normalizing the remaining components such that the output vector sums to 1. 104 | 105 | :param mapping: map from the unit interval [0,1] to itself, must be applicable to arrays 106 | :param num_classes: The dimension of the simplex vector, equals 1 + (dimension of the simplex as manifold). 107 | If provided, will use this for addition I/O checks. 108 | """ 109 | 110 | def __init__(self, mapping: Callable[[np.ndarray], np.ndarray], num_classes=None): 111 | self.mapping = mapping 112 | super().__init__(num_classes=num_classes) 113 | 114 | def _transform(self, x: np.ndarray) -> np.ndarray: 115 | # this transform has a singularity if one component exactly equals one, so we add a minor "noise" 116 | x = x + 1e-10 117 | x = x / x.sum(axis=1)[:, None] 118 | 119 | argmax = x.argmax(axis=1) 120 | old_values = np.choose(argmax, x.T) 121 | new_values = self.mapping(old_values) 122 | # the result must sum to 1, so we will rescale the remaining entries of the confidence vectors 123 | remaining_comps_normalization = (1 - new_values) / (1 - old_values) 124 | new_values_compensated_for_norm = new_values / remaining_comps_normalization 125 | np.put_along_axis( 126 | x, argmax[:, None], new_values_compensated_for_norm[:, None], axis=1 127 | ) 128 | return x * remaining_comps_normalization[:, None] 129 | 130 | 131 | class PowerLawSimplexAut(SimplexAut): 132 | """ 133 | An automorphism resulting from taking elementwise powers of the inputs with fixed exponents 134 | and normalizing the result. 135 | 136 | | 137 | | *Intuition*: 138 | 139 | If exponents[j] < exponents[i], then the output will be more shifted towards the j-th direction 140 | than the i-th. If all exponents are equal to some number s, then s>1 means a shift towards the boundary 141 | of the simplex whereas 0 np.ndarray: 151 | x = np.float_power(x, self.exponents) 152 | return x / x.sum(axis=1)[:, None] 153 | 154 | 155 | class RestrictedPowerSimplexAut(SimplexAut): 156 | """ 157 | Maybe a bad idea, feels unnatural 158 | """ 159 | 160 | def __init__(self, exponents: np.ndarray): 161 | """ 162 | 163 | :param exponents: numpy array of shape (num_classes - 1, ) 164 | """ 165 | if not np.all(exponents >= 1): 166 | raise ValueError("Only exponents >= 1 are permitted") 167 | self.exponents = exponents[None, :] 168 | super().__init__(len(exponents) + 1) 169 | 170 | def _transform(self, x: np.ndarray) -> np.ndarray: 171 | x = x.copy() 172 | x[:, :-1] = np.float_power(x[:, :-1], self.exponents) 173 | x[:, -1] = 1 - x[:, :-1].sum(axis=1) 174 | return x / x.sum(axis=1)[:, None] 175 | -------------------------------------------------------------------------------- /src/kyle/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | from sklearn.metrics import accuracy_score 5 | 6 | 7 | def safe_accuracy_score(y_true: np.ndarray, y_pred: np.ndarray, **kwargs) -> float: 8 | """ 9 | Wrapper around sklearn accuracy store that returns zero for empty sequences of labels 10 | 11 | :param y_true: Ground truth (correct) labels. 12 | :param y_pred: Predicted labels, as returned by a classifier. 13 | :param kwargs: 14 | :return: 15 | """ 16 | if len(y_true) == len(y_pred) == 0: 17 | return 0 18 | return accuracy_score(y_true, y_pred, **kwargs) 19 | 20 | 21 | def in_simplex(probabilities: np.ndarray, num_classes=None) -> bool: 22 | """ 23 | 24 | :param probabilities: single vector of probabilities of shape (n_classes,) or multiple 25 | vectors as array of shape (n_samples, n_classes) 26 | :param num_classes: if provided, will check whether probability vectors have the correct number of classes 27 | :return: 28 | """ 29 | if len(probabilities.shape) == 1: 30 | probabilities = probabilities[None, :] 31 | if num_classes is None: 32 | num_classes = probabilities.shape[1] 33 | 34 | return ( 35 | probabilities.shape[1] == num_classes 36 | and np.allclose(np.sum(probabilities, axis=1), 1.0, rtol=0.01) 37 | and (probabilities >= 0).all() 38 | and (probabilities <= 1).all() 39 | ) 40 | 41 | 42 | def sample_index(probabilities: np.ndarray) -> Union[int, np.ndarray]: 43 | """ 44 | Sample indices with the input probabilities. This is essentially a vectorized 45 | version of np.random.choice 46 | 47 | :param probabilities: single vector of probabilities of shape (n_indices-1,) or multiple 48 | vectors as array of shape (n_samples, n_indices-1) 49 | :return: index or array of indices 50 | """ 51 | rng = np.random.default_rng() 52 | if len(probabilities.shape) == 1: 53 | return rng.choice(len(probabilities), p=probabilities) 54 | elif len(probabilities.shape) == 2: 55 | # this is a vectorized implementation of np.random.choice with inverse transform sampling 56 | # see e.g. https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html 57 | # and https://stackoverflow.com/questions/47722005/vectorizing-numpy-random-choice-for-given-2d-array-of-probabilities-along-an-a 58 | random_uniform = rng.random(len(probabilities))[:, None] 59 | return (probabilities.cumsum(axis=1) > random_uniform).argmax(axis=1) 60 | else: 61 | raise ValueError( 62 | f"Unsupported input shape: {probabilities.shape}. Can only be 1 or 2 dimensional." 63 | ) 64 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from kyle.sampling.fake_clf import DirichletFC 5 | from kyle.transformations import PowerLawSimplexAut 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def uncalibrated_samples(): 10 | faker = DirichletFC(2, simplex_automorphism=PowerLawSimplexAut(np.array([30, 20]))) 11 | return faker.get_sample_arrays(1000) 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def calibrated_samples(): 16 | faker = DirichletFC(2) 17 | return faker.get_sample_arrays(1000) 18 | -------------------------------------------------------------------------------- /tests/kyle/calibration/calibration_methods/test_calibration_methods.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn import clone 3 | 4 | from kyle.calibration.calibration_methods import ( 5 | BetaCalibration, 6 | HistogramBinning, 7 | IsotonicRegression, 8 | LogisticCalibration, 9 | TemperatureScaling, 10 | ) 11 | from kyle.metrics import ECE 12 | 13 | 14 | @pytest.fixture(scope="module") 15 | def metric(): 16 | return ECE() 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def calibration_method(): 21 | return TemperatureScaling() 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "calibration_method", 26 | [ 27 | HistogramBinning(), 28 | TemperatureScaling(), 29 | IsotonicRegression(), 30 | BetaCalibration(), 31 | LogisticCalibration(), 32 | ], 33 | ) 34 | def test_calibration_methods_clonability(calibration_method): 35 | clone(calibration_method) 36 | 37 | 38 | def test_methods_calibrationErrorLessAfterCalibration( 39 | metric, uncalibrated_samples, calibration_method 40 | ): 41 | ground_truth, confidences = uncalibrated_samples 42 | error_pre_calibration = metric.compute(confidences, ground_truth) 43 | calibration_method.fit(confidences, ground_truth) 44 | calibrated_confidences = calibration_method.get_calibrated_confidences(confidences) 45 | error_post_calibration = metric.compute(calibrated_confidences, ground_truth) 46 | 47 | assert error_post_calibration <= error_pre_calibration 48 | -------------------------------------------------------------------------------- /tests/kyle/calibration/test_model_calibrator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn import datasets 3 | from sklearn.model_selection import train_test_split 4 | from sklearn.neural_network import MLPClassifier 5 | 6 | from kyle.calibration import ModelCalibrator 7 | from kyle.metrics import ECE 8 | from kyle.models import CalibratableModel 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def dataset(): 13 | X, y = datasets.make_classification( 14 | n_samples=2000, 15 | n_features=20, 16 | n_informative=7, 17 | n_redundant=10, 18 | n_classes=2, 19 | random_state=42, 20 | ) 21 | X_train, X_test, y_train, y_test = train_test_split( 22 | X, y, test_size=0.2, random_state=42 23 | ) 24 | 25 | return X_train, X_test, y_train, y_test 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def uncalibrated_model(): 30 | return MLPClassifier(hidden_layer_sizes=(50, 50, 50)) 31 | 32 | 33 | @pytest.fixture(scope="module") 34 | def calibratable_model(uncalibrated_model): 35 | return CalibratableModel(uncalibrated_model) 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def calibrator(dataset): 40 | X_train, X_val, y_train, y_val = dataset 41 | calibrator = ModelCalibrator(X_val, y_val, X_fit=X_train, y_fit=y_train) 42 | return calibrator 43 | 44 | 45 | def test_calibrator_integrationTest(calibrator, calibratable_model): 46 | calibrator.calibrate(calibratable_model, fit=True) 47 | metric = ECE() 48 | predicted_probas = calibratable_model.model.predict_proba(calibrator.X_calibrate) 49 | calibrated_predicted_probas = calibratable_model.predict_proba( 50 | calibrator.X_calibrate 51 | ) 52 | assert metric.compute( 53 | calibrated_predicted_probas, calibrator.y_calibrate 54 | ) < metric.compute(predicted_probas, calibrator.y_calibrate) 55 | -------------------------------------------------------------------------------- /tests/kyle/metrics/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kyle.metrics import ACE, ECE, MCE 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def metrics(): 8 | criteria = [ECE(), MCE(), ACE()] 9 | return criteria 10 | 11 | 12 | def test_metrics_calibratedConfidencesHaveZeroError(metrics, calibrated_samples): 13 | ground_truth, confidences = calibrated_samples 14 | for criterion in metrics: 15 | epsilon = 0.1 16 | assert criterion.compute(confidences, ground_truth) <= epsilon 17 | 18 | 19 | def test_metrics_uncalibratedConfidencesHaveNonZeroError(metrics, uncalibrated_samples): 20 | ground_truth, confidences = uncalibrated_samples 21 | for criterion in metrics: 22 | epsilon = 0.1 23 | assert criterion.compute(confidences, ground_truth) > epsilon 24 | -------------------------------------------------------------------------------- /tests/kyle/sampling/test_fake_clf.py: -------------------------------------------------------------------------------- 1 | from kyle.sampling.fake_clf import DirichletFC 2 | from kyle.util import in_simplex 3 | 4 | 5 | def test_DirichletFC_basics(): 6 | faker = DirichletFC(3) 7 | ground_truth, class_proba = faker.get_sample_arrays(10) 8 | assert ground_truth.shape == (10,) 9 | assert class_proba.shape == (10, 3) 10 | assert ground_truth[0] in [0, 1, 2] 11 | assert in_simplex(class_proba) 12 | -------------------------------------------------------------------------------- /tests/kyle/test_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from kyle.util import in_simplex 4 | 5 | 6 | def test_in_simplex_negativeEntriesForbidden(): 7 | assert not in_simplex(np.array([0.5, -0.5])) 8 | 9 | 10 | def test_in_simplex_larger1Forbidden(): 11 | assert not in_simplex(np.array([0, 2])) 12 | 13 | 14 | def test_in_simplex_sumNot1Forbidden(): 15 | assert not in_simplex(np.array([0.4, 0.7])) 16 | assert not in_simplex(np.array([0.1, 0.1])) 17 | assert not in_simplex(np.random.default_rng().random((5, 3))) 18 | 19 | 20 | def test_in_simplex_wrongSizeForbidden(): 21 | assert not in_simplex(np.array([1]), num_classes=2) 22 | assert not in_simplex(np.array([1, 0, 0]), num_classes=2) 23 | assert not in_simplex(np.random.default_rng().random((5, 3)), num_classes=2) 24 | 25 | 26 | def test_in_simplex_correctInputIsCorrect(): 27 | assert in_simplex(np.array([0.5, 0.5]), num_classes=2) 28 | x = np.random.default_rng().random(5) 29 | assert in_simplex(x / x.sum()) 30 | 31 | 32 | def test_in_simplex_correct2DInputIsCorrect(): 33 | x = np.random.default_rng().random((5, 3)) 34 | row_sums = x.sum(axis=1) 35 | x = x / row_sums[:, np.newaxis] 36 | assert in_simplex(x) 37 | assert in_simplex(x, num_classes=3) 38 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py, docs, report 3 | isolated_build = True 4 | 5 | [testenv] 6 | # pytest-cov has an issue when the tests are inside an sdist, as created by tox by default 7 | # despite tests being run, coverage discovers no data, leading to: Coverage.py warning: No data was collected 8 | # this issue is resolved by running pytest-cov within tox development mode, thus not creating an sdist 9 | usedevelop = true 10 | commands = 11 | coverage erase 12 | pytest --cov --cov-append --cov-report=term-missing tests 13 | pytest -n 4 notebooks 14 | deps = 15 | pytest 16 | pytest-cov 17 | pytest-xdist 18 | pytest-lazy-fixture 19 | jupyter==1.0.0 20 | nbconvert==6.4.5 21 | -rrequirements.txt 22 | 23 | 24 | [testenv:docs] 25 | ; NOTE: we don't use pytest for running the doctest, even though with pytest no imports have to be written in them 26 | ; The reason is that we want to be running doctest during the docs build (which might happen on a remote machine, 27 | ; like read_the_docs does) with possibly fewer external dependencies and use sphinx' ability to automock the missing ones. 28 | commands = 29 | python build_scripts/update_docs.py 30 | git add docs/* 31 | sphinx-build -W -b html -d "{envtmpdir}/doctrees" docs "docs/_build/html" 32 | sphinx-build -b doctest -d "{envtmpdir}/doctrees" docs "docs/_build/doctest" 33 | deps = 34 | Sphinx==3.2.1 35 | sphinxcontrib-websupport==1.2.4 36 | jinja2<3.1 37 | sphinx_rtd_theme 38 | nbsphinx 39 | ipython 40 | whitelist_externals = 41 | git 42 | 43 | [testenv:report] 44 | commands = 45 | coverage html 46 | coverage-badge -o badges/coverage.svg -f 47 | coverage erase 48 | deps = 49 | coverage 50 | coverage-badge 51 | skip_install = true 52 | --------------------------------------------------------------------------------