├── .github └── workflows │ └── deploy.yml ├── .gitignore ├── LICENSE ├── README.md ├── animation.ipynb ├── bayeshist ├── __init__.py ├── bayeshist.py └── plotting.py ├── demo.ipynb ├── doc ├── bayesian-histogram-comp.png └── samples.png └── setup.py /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Upload release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.x' 18 | 19 | - name: Auto-bump version 20 | run: | 21 | # from refs/tags/v1.2.3 get 1.2.3 22 | VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') 23 | PLACEHOLDER="__version__\s*=\s*[\"'](.+)[\"']" 24 | VERSION_FILE="bayeshist/__init__.py" 25 | # ensure the placeholder is there. If grep doesn't find the placeholder 26 | # it exits with exit code 1 and github actions aborts the build. 27 | VERSION_LINE=$(grep -E "$PLACEHOLDER" "$VERSION_FILE") 28 | sed -i "s/$VERSION_LINE/__version__ = \"${VERSION}\"/g" "$VERSION_FILE" 29 | shell: bash 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install setuptools wheel twine 35 | 36 | - name: Build sdist 37 | run: | 38 | python setup.py sdist 39 | 40 | - name: Publish to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | with: 43 | user: __token__ 44 | password: ${{ secrets.PYPI_API_TOKEN }} 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dion Häfner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian histograms 2 | 3 | **Bayesian histograms** are a nifty tool for data mining if: 4 | 5 | - you want to know how the *event rate* (probability) of a binary **rare event** depends on a parameter; 6 | - you have millions or even **billions of data points**, but few positive samples; 7 | - you suspect the event rate depends **highly non-linearly** on the parameter; 8 | - you don't know whether you have *enough data*, so you need **uncertainty information**. 9 | 10 | Thanks to an adaptive bin pruning algorithm, you don't even have to choose the number of bins, and you should get good results out of the box. 11 | 12 | This is how they look in practice ([see full example below](#usage-example)): 13 | 14 |

15 | 16 |

17 | 18 | ## Installation 19 | 20 | ```bash 21 | $ pip install bayeshist 22 | ``` 23 | 24 | ## Usage example 25 | 26 | Assume you have binary samples of a rare event like this: 27 | 28 |

29 | 30 |

31 | 32 | Compute and plot a Bayesian histogram: 33 | 34 | ```python 35 | >>> from bayeshist import bayesian_histogram, plot_bayesian_histogram 36 | 37 | # compute Bayesian histogram from samples 38 | >>> bin_edges, beta_dist = bayesian_histogram(X, y, bins=100, pruning_method="bayes") 39 | 40 | # beta_dist is a `scipy.stats.Beta` object, so we can get the 41 | # predicted mean event rate for each histogram bin like this: 42 | >>> bin_mean_pred = best_dist.mean() 43 | 44 | # plot it up 45 | >>> plot_bayesian_histogram(bin_edges, beta_dist) 46 | ``` 47 | 48 | The result is something like this: 49 | 50 |

51 | 52 |

53 | 54 | See also [demo.ipynb](demo.ipynb) for a full walkthrough of this example. 55 | 56 | ## But how do they work? 57 | 58 | [Here's the blog post.](https://dionhaefner.github.io/2021/09/bayesian-histograms-for-rare-event-classification/) 59 | 60 | ## API reference 61 | 62 | ### `bayesian_histogram` 63 | 64 | ```python 65 | 66 | def bayesian_histogram( 67 | x: np.ndarray, 68 | y: np.ndarray, 69 | bins: Union[int, Iterable] = 100, 70 | x_range: Optional[Tuple[float, float]] = None, 71 | prior_params: Optional[Tuple[float, float]] = None, 72 | pruning_method: Optional[Literal["bayes", "fisher"]] = "bayes", 73 | pruning_threshold: Optional[float] = None, 74 | max_bin_size: Optional[float] = None, 75 | ) -> Tuple[np.ndarray, FrozenDistType]: 76 | """Compute Bayesian histogram for data x, binary target y. 77 | 78 | The output is a Beta distribution over the event rate for each bin. 79 | 80 | Parameters: 81 | 82 | x: 83 | 1-dim array of data. 84 | 85 | y: 86 | 1-dim array of binary labels (0 or 1). 87 | 88 | bins: 89 | int giving the number of equally spaced intial bins, 90 | or array giving initial bin edges. (default: 100) 91 | 92 | x_range: 93 | Range spanned by binning. Not used if `bins` is an array. 94 | (default: [min(x), max(x)]) 95 | 96 | prior_params: 97 | Parameters to use in Beta prior. First value relates to positive, 98 | second value to negative samples. [0.5, 0.5] represents Jeffrey's prior, [1, 1] a flat 99 | prior. The default is a weakly informative prior based on the global event rate. 100 | (default: `[1, num_neg / num_pos]`) 101 | 102 | pruning_method: 103 | Method to use to decide whether neighboring bins should be merged or not. 104 | Valid values are "bayes" (Bayes factor), "fisher" (exact Fisher test), or None 105 | (no pruning). (default: "bayes") 106 | 107 | pruning_threshold: 108 | Threshold to use in significance test specified by `pruning_method`. 109 | (default: 2 for "bayes", 0.2 for "fisher") 110 | 111 | max_bin_size: 112 | Maximum size (in units of x) above which bins will not be merged 113 | (except empty bins). (default: unlimited size) 114 | 115 | Returns: 116 | 117 | bin_edges: Coordinates of bin edges 118 | beta_dist: n-dimensional Beta distribution (n = number of bins) 119 | 120 | Example: 121 | 122 | >>> x = np.random.randn(1000) 123 | >>> p = 10 ** (-2 + x) 124 | >>> y = np.random.rand() < p 125 | >>> bins, beta_dist = bayesian_histogram(x, y) 126 | >>> plt.plot(0.5 * (bins[1:] + bins[:-1]), beta_dist.mean()) 127 | 128 | """ 129 | ``` 130 | 131 | ### `plot_bayesian_histogram` 132 | 133 | ```python 134 | def plot_bayesian_histogram( 135 | bin_edges: np.ndarray, 136 | data_dist: FrozenDistType, 137 | color: Union[str, Iterable[float], None] = None, 138 | label: Optional[str] = None, 139 | ax: Any = None, 140 | ci: Optional[Tuple[float, float]] = (0.01, 0.99) 141 | ) -> None: 142 | """Plot a Bayesian histogram as horizontal lines with credible intervals. 143 | 144 | Parameters: 145 | 146 | bin_edges: 147 | Coordinates of bin edges 148 | 149 | data_dist: 150 | n-dimensional Beta distribution (n = number of bins) 151 | 152 | color: 153 | Color to use (default: use next in current color cycle) 154 | 155 | label: 156 | Legend label (default: no label) 157 | 158 | ax: 159 | Matplotlib axis to use (default: current axis) 160 | 161 | ci: 162 | Credible interval used for shading, use `None` to disable shading. 163 | 164 | Example: 165 | 166 | >>> x = np.random.randn(1000) 167 | >>> p = 10 ** (-2 + x) 168 | >>> y = np.random.rand() < p 169 | >>> bins, beta_dist = bayesian_histogram(x, y) 170 | >>> plot_bayesian_histogram(bins, beta_dist) 171 | 172 | """ 173 | ``` 174 | 175 | ## Questions? 176 | 177 | [Feel free to open an issue.](https://github.com/dionhaefner/bayesian-histograms/issues) 178 | -------------------------------------------------------------------------------- /animation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "presidential-syndrome", 6 | "metadata": {}, 7 | "source": [ 8 | "# Create an animation of histogram pruning\n", 9 | "\n", 10 | "(messy code warning)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "editorial-samoa", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from functools import partial\n", 21 | "from tqdm.notebook import tqdm\n", 22 | "from matplotlib.animation import FuncAnimation\n", 23 | "from matplotlib.collections import PatchCollection\n", 24 | "from matplotlib.patches import Rectangle\n", 25 | "from IPython.display import display\n", 26 | "\n", 27 | "from bayeshist.bayeshist import _prune_histogram, _bayes_factor_test\n", 28 | "\n", 29 | "\n", 30 | "plt.rcParams[\"animation.html\"] = \"html5\"\n", 31 | "\n", 32 | "bin_edges = np.linspace(-4, 4, 40)\n", 33 | "neg_samples, _ = np.histogram(test_x[test_y == 0], bins=bin_edges)\n", 34 | "pos_samples, _ = np.histogram(test_x[test_y == 1], bins=bin_edges)\n", 35 | "\n", 36 | "pruning_threshold = 2\n", 37 | "prior_params = (1, 1000)\n", 38 | "test = partial(_bayes_factor_test, threshold=pruning_threshold)\n", 39 | "pruner = _prune_histogram(bin_edges, pos_samples, neg_samples, test, prior_params, yield_steps=True)\n", 40 | "\n", 41 | "states = [state for state in pruner if not isinstance(state, tuple)]\n", 42 | "\n", 43 | "fig = plt.figure(figsize=(9, 6))\n", 44 | "ylim = 1e2 * max(pos_samples.max(), neg_samples.max())\n", 45 | "\n", 46 | "pbar = tqdm()\n", 47 | "speedup_after = 3\n", 48 | "num_steps = 5\n", 49 | "frame_cutoff = 10 * speedup_after * num_steps\n", 50 | "\n", 51 | "\n", 52 | "def animate(frameno):\n", 53 | " pbar.update(1)\n", 54 | " \n", 55 | " if frameno < frame_cutoff:\n", 56 | " frameno = frameno // 10\n", 57 | " else:\n", 58 | " frameno = frameno - (frame_cutoff - frame_cutoff // 10)\n", 59 | " \n", 60 | " state_idx, step = frameno // num_steps, frameno % num_steps\n", 61 | " \n", 62 | " fig.clear()\n", 63 | " ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])\n", 64 | " \n", 65 | " axprob = fig.add_axes([0.1, 0.1, 0.8, 0.8], frameon=False)\n", 66 | " axprob.grid(False)\n", 67 | " axprob.yaxis.tick_right()\n", 68 | " axprob.yaxis.set_label_position(\"right\")\n", 69 | " axprob.set(\n", 70 | " xlim=(bin_edges[0], bin_edges[-1]),\n", 71 | " xticks=[],\n", 72 | " ylim=(1e-4, 1e1),\n", 73 | " yscale=\"log\"\n", 74 | " )\n", 75 | " axprob.set_ylabel(\"Event rate\", y=0.25)\n", 76 | " axprob.set_yticks([1e-4, 1e-3, 1e-2])\n", 77 | " axprob.tick_params(which=\"both\", right=False)\n", 78 | " \n", 79 | " is_final_state = state_idx >= len(states)\n", 80 | " \n", 81 | " if is_final_state:\n", 82 | " state = states[-1]\n", 83 | " step = 0\n", 84 | " else:\n", 85 | " state = states[state_idx]\n", 86 | " \n", 87 | " i = state[\"i\"]\n", 88 | " bins = state[\"bins\"]\n", 89 | " bin_centers = 0.5 * (bins[1:] + bins[:-1])\n", 90 | "\n", 91 | " if is_final_state:\n", 92 | " ax.hist(test_x[test_y == 0], log=True, alpha=0.6, bins=bins, facecolor=\"C0\", label=\"y = 0\")\n", 93 | " ax.hist(test_x[test_y == 1], log=True, alpha=0.6, bins=bins, facecolor=\"C1\", label=\"y = 1\")\n", 94 | " fig.text(0.5, 0.85, \"Final Bayesian histogram\", ha=\"center\", weight=\"bold\")\n", 95 | " else:\n", 96 | " ax.hist(test_x[test_y == 0], log=True, alpha=0.4, bins=bins, facecolor=\"C0\")\n", 97 | " ax.hist(test_x[test_y == 1], log=True, alpha=0.4, bins=bins, facecolor=\"C1\")\n", 98 | "\n", 99 | " ax.hist(test_x[test_y == 0], log=True, alpha=0.8, bins=bins[i:i+3], facecolor=\"C0\", label=\"y = 0\")\n", 100 | " ax.hist(test_x[test_y == 1], log=True, alpha=0.8, bins=bins[i:i+3], facecolor=\"C1\", label=\"y = 1\")\n", 101 | " \n", 102 | " event_dist = scipy.stats.beta(state[\"pos_samples\"] + prior_params[0], state[\"neg_samples\"] + prior_params[1])\n", 103 | " ci_low, ci_high = event_dist.ppf(0.01), event_dist.ppf(0.99)\n", 104 | "\n", 105 | " # background boxes\n", 106 | " errorboxes = [\n", 107 | " Rectangle((x1, y1), x2 - x1, y2 - y1)\n", 108 | " for x1, x2, y1, y2\n", 109 | " in zip(bins[:-1], bins[1:], ci_low, ci_high)\n", 110 | " ]\n", 111 | "\n", 112 | " pc = PatchCollection(errorboxes, facecolor=\"0.2\", alpha=0.2)\n", 113 | " axprob.add_collection(pc)\n", 114 | "\n", 115 | " # median indicator\n", 116 | " axprob.hlines(event_dist.median(), bins[:-1], bins[1:], colors=\"0.2\", label=\"p(y = 1)\")\n", 117 | "\n", 118 | " # box edges\n", 119 | " ax.hlines(ci_low, bins[:-1], bins[1:], colors=\"0.2\", alpha=0.8, linewidth=1)\n", 120 | " ax.hlines(ci_high, bins[:-1], bins[1:], colors=\"0.2\", alpha=0.8, linewidth=1)\n", 121 | " \n", 122 | " fig.legend(loc=\"upper center\", ncol=3, frameon=False)\n", 123 | "\n", 124 | " if step > 0:\n", 125 | " axdist1 = fig.add_axes([0.16, 0.72, 0.2, 0.1])\n", 126 | " axdist1.axis(\"off\") \n", 127 | " dist_x = np.logspace(-5, 0, 100)\n", 128 | "\n", 129 | " with np.errstate(divide='ignore'):\n", 130 | " alpha_1, beta_1 = state[\"samples_1\"]\n", 131 | " axdist1.plot(dist_x, scipy.stats.beta(alpha_1 + prior_params[0], beta_1 + prior_params[1]).pdf(dist_x), c=\"0.2\", label=\"original\")\n", 132 | "\n", 133 | " alpha_2, beta_2 = state[\"samples_2\"]\n", 134 | " axdist1.plot(dist_x, scipy.stats.beta(alpha_2 + prior_params[0], beta_2 + prior_params[1]).pdf(dist_x), c=\"0.2\")\n", 135 | "\n", 136 | " alpha_comb, beta_comb = alpha_1 + alpha_2, beta_1 + beta_2\n", 137 | " axdist1.plot(dist_x, scipy.stats.beta(alpha_comb + prior_params[0], beta_comb + prior_params[1]).pdf(dist_x), c=\"coral\", label=\"merged\")\n", 138 | "\n", 139 | " axdist1.text(0.5, -0.05, \"p(y = 1)\", transform=axdist1.transAxes, va=\"top\", ha=\"center\", color=\"0.2\")\n", 140 | " axdist1.set_xscale(\"log\")\n", 141 | " axdist1.set_title(\"Event rate distributions\", weight=\"bold\")\n", 142 | " axdist1.legend(loc=\"upper right\", frameon=False, labelcolor=\"linecolor\", handlelength=0)\n", 143 | "\n", 144 | " if step > 1:\n", 145 | " p_1 = scipy.stats.betabinom(alpha_1 + beta_1, alpha_1 + prior_params[0], beta_1 + prior_params[1]).logpmf(alpha_1)\n", 146 | " ax.text(bin_centers[i], beta_1, f\"{p_1:.1f}\", ha=\"center\", va=\"bottom\", fontsize=9)\n", 147 | "\n", 148 | " p_2 = scipy.stats.betabinom(alpha_2 + beta_2, alpha_2 + prior_params[0], beta_2 + prior_params[1]).logpmf(alpha_2)\n", 149 | " ax.text(bin_centers[i+1], beta_2, f\"{p_2:.1f}\", ha=\"center\", fontsize=9)\n", 150 | "\n", 151 | " p_c1 = scipy.stats.betabinom(alpha_1 + beta_1, alpha_comb + prior_params[0], beta_comb + prior_params[1]).logpmf(alpha_1)\n", 152 | " ax.text(bin_centers[i], beta_1 * 2, f\"{p_c1:.1f}\", color=\"coral\", ha=\"center\", fontsize=9)\n", 153 | "\n", 154 | " p_c2 = scipy.stats.betabinom(alpha_2 + beta_2, alpha_comb + prior_params[0], beta_comb + prior_params[1]).logpmf(alpha_2)\n", 155 | " ax.text(bin_centers[i+1], beta_2 * 2, f\"{p_c2:.1f}\", color=\"coral\", ha=\"center\", fontsize=9)\n", 156 | "\n", 157 | " ax.text(bins[i+1], max(beta_1, beta_2) * 4, \"Data log likelihood\", ha=\"center\", weight=\"bold\")\n", 158 | "\n", 159 | " if step > 2:\n", 160 | " compsign = \"$>$\" if state[\"test_value\"] > pruning_threshold else \"$\\\\ngtr$\" \n", 161 | " fig.text(0.8, 0.85, \"Log likelihood $\\\\Delta$\", ha=\"center\", va=\"top\", weight=\"bold\")\n", 162 | " fig.text(0.8, 0.82, f\"{np.log(state['test_value']):.2f} {compsign} log({pruning_threshold})\", ha=\"center\", va=\"top\")\n", 163 | "\n", 164 | " if step > 3:\n", 165 | " merge_text = \"merge\" if state[\"reverse_split\"] else \"don't merge\"\n", 166 | " ax.annotate(\"\", xy=(0.8, 0.78), xytext=(0.8, 0.7), arrowprops=dict(arrowstyle=\"<-\", color=\"black\"), xycoords=\"figure fraction\", textcoords=\"figure fraction\")\n", 167 | " fig.text(0.8, 0.7, f\"{merge_text}\", ha=\"center\", va=\"top\", weight=\"bold\")\n", 168 | "\n", 169 | " ax.set(\n", 170 | " xlabel=\"x\",\n", 171 | " ylabel=\"Count\",\n", 172 | " xlim=(bin_edges[0], bin_edges[-1]),\n", 173 | " ylim=(0.5, ylim),\n", 174 | " )\n", 175 | " \n", 176 | "animate(100)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "loved-steps", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "num_frames = frame_cutoff - frame_cutoff // 10 + len(states) * num_steps + 20\n", 187 | "\n", 188 | "with tqdm(total=num_frames) as pbar:\n", 189 | " anim = FuncAnimation(fig, animate, frames=num_frames, interval=100)\n", 190 | " display(anim)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "brown-stephen", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "anim.save(\"bayes-pruning.mp4\")" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.9.6" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 5 225 | } 226 | -------------------------------------------------------------------------------- /bayeshist/__init__.py: -------------------------------------------------------------------------------- 1 | from .bayeshist import bayesian_histogram 2 | from .plotting import plot_bayesian_histogram 3 | 4 | __version__ = "v0.0+dev" 5 | __all__ = ["bayesian_histogram", "plot_bayesian_histogram"] 6 | -------------------------------------------------------------------------------- /bayeshist/bayeshist.py: -------------------------------------------------------------------------------- 1 | """bayeshist.py 2 | 3 | Bayesian histograms for binary targets. 4 | """ 5 | 6 | from typing import Iterable, Literal, Optional, Tuple, Union 7 | from functools import partial 8 | 9 | import numpy as np 10 | 11 | from scipy.stats import beta as beta_dist, fisher_exact 12 | from scipy.special import betaln 13 | 14 | FrozenDistType = type(beta_dist(0, 0)) 15 | 16 | 17 | def _bayes_factor_test(ps1, ns1, ps2, ns2, prior_p, prior_n, threshold=2): 18 | """Tests whether two binomial datasets come from the same distribution. 19 | 20 | Computes the Bayes Factor of hypotheses: 21 | 22 | H1: Samples are drawn with p_i ~ Beta(alpha_i, beta_i), i={1,2} 23 | H0: Both samples are drawn with p ~ Beta(alpha_1 + alpha_2, beta_1 + beta_2) 24 | 25 | The Bayes Factor gives the relative increase in data likelihood 26 | after the split (higher values -> splitting is more favorable). 27 | """ 28 | # alpha and beta coefficients for distribution of p(y=1) 29 | alpha_1 = ps1 + prior_p 30 | beta_1 = ns1 + prior_n 31 | 32 | alpha_2 = ps2 + prior_p 33 | beta_2 = ns2 + prior_n 34 | 35 | alpha_tot = ps1 + ps2 + prior_p 36 | beta_tot = ns1 + ns2 + prior_n 37 | 38 | # we could use scipy.state.betabinom here, but betaln is faster 39 | def betabinom_logp(ps, ns, alpha, beta): 40 | # this omits choose(n, k), drops out in Bayes factor 41 | return betaln(ps + alpha, ns + beta) - betaln(alpha, beta) 42 | 43 | bayes_factor = np.exp( 44 | -betabinom_logp(ps1, ns1, alpha_tot, beta_tot) 45 | - betabinom_logp(ps2, ns2, alpha_tot, beta_tot) 46 | + betabinom_logp(ps1, ns1, alpha_1, beta_1) 47 | + betabinom_logp(ps2, ns2, alpha_2, beta_2) 48 | ) 49 | 50 | return bayes_factor > threshold, bayes_factor 51 | 52 | 53 | def _fisher_test(ps1, ns1, ps2, ns2, *args, threshold=0.05): 54 | """Tests whether two binomial datasets come from the same distribution. 55 | 56 | Uses an exact Fisher test. Prior parameters are unused. 57 | """ 58 | _, pvalue = fisher_exact([[ps1, ps2], [ns1, ns2]]) 59 | return pvalue < threshold, pvalue 60 | 61 | 62 | def _prune_histogram(bin_edges, pos_samples, neg_samples, test, prior_params, max_bin_size=None, yield_steps=False): 63 | """Perform histogram pruning. 64 | 65 | This iteratively merges neighboring bins until all neighbor pairs pass 66 | the given statistical test. 67 | """ 68 | if max_bin_size is None: 69 | max_bin_size = float("inf") 70 | 71 | while True: 72 | new_bins = [] 73 | new_pos_samples = [] 74 | new_neg_samples = [] 75 | 76 | num_bins = len(bin_edges) - 1 77 | splits_reversed = 0 78 | 79 | i = 0 80 | 81 | while True: 82 | if i == num_bins: 83 | break 84 | 85 | elif i == num_bins - 1: 86 | # only 1 bin left, nothing to compare to 87 | new_bins.append(bin_edges[i]) 88 | new_pos_samples.append(pos_samples[i]) 89 | new_neg_samples.append(neg_samples[i]) 90 | break 91 | 92 | is_significant, test_value = test( 93 | pos_samples[i], 94 | neg_samples[i], 95 | pos_samples[i + 1], 96 | neg_samples[i + 1], 97 | *prior_params 98 | ) 99 | 100 | reverse_split = ( 101 | not is_significant 102 | # ensure that we stay below max_bin_size 103 | and (bin_edges[i + 1] - bin_edges[i] < max_bin_size) 104 | # but always merge empty bins 105 | or (neg_samples[i] == pos_samples[i] == 0) 106 | or (neg_samples[i + 1] == pos_samples[i + 1] == 0) 107 | ) 108 | 109 | if yield_steps: 110 | yield dict( 111 | i=i - splits_reversed, 112 | samples_1=(pos_samples[i], neg_samples[i]), 113 | samples_2=(pos_samples[i+1], neg_samples[i+1]), 114 | test_value=test_value, 115 | is_significant=is_significant, 116 | reverse_split=reverse_split, 117 | bins=np.concatenate((new_bins, bin_edges[i:])), 118 | pos_samples=np.concatenate((new_pos_samples, pos_samples[i:])), 119 | neg_samples=np.concatenate((new_neg_samples, neg_samples[i:])), 120 | ) 121 | 122 | if reverse_split: 123 | splits_reversed += 1 124 | new_bins.append(bin_edges[i]) 125 | new_pos_samples.append(pos_samples[i] + pos_samples[i + 1]) 126 | new_neg_samples.append(neg_samples[i] + neg_samples[i + 1]) 127 | i += 2 128 | else: 129 | # keep everything and proceed with next pair 130 | new_bins.append(bin_edges[i]) 131 | new_pos_samples.append(pos_samples[i]) 132 | new_neg_samples.append(neg_samples[i]) 133 | i += 1 134 | 135 | new_bins.append(bin_edges[-1]) 136 | 137 | assert len(new_bins) == len(bin_edges) - splits_reversed 138 | 139 | bin_edges = new_bins 140 | pos_samples = new_pos_samples 141 | neg_samples = new_neg_samples 142 | 143 | if not splits_reversed: 144 | # no changes made -> we are done 145 | break 146 | 147 | bin_edges = np.array(bin_edges) 148 | pos_samples = np.array(pos_samples) 149 | neg_samples = np.array(neg_samples) 150 | 151 | yield bin_edges, pos_samples, neg_samples 152 | 153 | 154 | def bayesian_histogram( 155 | x: np.ndarray, 156 | y: np.ndarray, 157 | bins: Union[int, Iterable] = 100, 158 | x_range: Optional[Tuple[float, float]] = None, 159 | prior_params: Optional[Tuple[float, float]] = None, 160 | pruning_method: Optional[Literal["bayes", "fisher"]] = "bayes", 161 | pruning_threshold: Optional[float] = None, 162 | max_bin_size: Optional[float] = None, 163 | ) -> Tuple[np.ndarray, FrozenDistType]: 164 | """Compute Bayesian histogram for data x, binary target y. 165 | 166 | The output is a Beta distribution over the event rate for each bin. 167 | 168 | Parameters: 169 | 170 | x: 171 | 1-dim array of data. 172 | 173 | y: 174 | 1-dim array of binary labels (0 or 1). 175 | 176 | bins: 177 | int giving the number of equally spaced intial bins, 178 | or array giving initial bin edges. (default: 100) 179 | 180 | x_range: 181 | Range spanned by binning. Not used if `bins` is an array. 182 | (default: [min(x), max(x)]) 183 | 184 | prior_params: 185 | Parameters to use in Beta prior. First value relates to positive, 186 | second value to negative samples. [0.5, 0.5] represents Jeffrey's prior, [1, 1] a flat 187 | prior. The default is a weakly informative prior based on the global event rate. 188 | (default: `[1, num_neg / num_pos]`) 189 | 190 | pruning_method: 191 | Method to use to decide whether neighboring bins should be merged or not. 192 | Valid values are "bayes" (Bayes factor), "fisher" (exact Fisher test), or None 193 | (no pruning). (default: "bayes") 194 | 195 | pruning_threshold: 196 | Threshold to use in significance test specified by `pruning_method`. 197 | (default: 2 for "bayes", 0.2 for "fisher") 198 | 199 | max_bin_size: 200 | Maximum size (in units of x) above which bins will not be merged 201 | (except empty bins). (default: unlimited size) 202 | 203 | Returns: 204 | 205 | bin_edges: Coordinates of bin edges 206 | beta_dist: n-dimensional Beta distribution (n = number of bins) 207 | 208 | Example: 209 | 210 | >>> x = np.random.randn(1000) 211 | >>> p = 10 ** (-2 + x) 212 | >>> y = np.random.rand() < p 213 | >>> bins, beta_dist = bayesian_histogram(x, y) 214 | >>> plt.plot(0.5 * (bins[1:] + bins[:-1]), beta_dist.mean()) 215 | 216 | """ 217 | x = np.asarray(x) 218 | y = np.asarray(y) 219 | 220 | if not np.all(np.isin(np.unique(y), [0, 1])): 221 | raise ValueError("Binary targets y can only have values 0 and 1") 222 | 223 | if x_range is None: 224 | x_range = (np.min(x), np.max(x)) 225 | 226 | if pruning_method == "bayes": 227 | if pruning_threshold is None: 228 | # default bayes factor threshold 229 | pruning_threshold = 2 230 | 231 | test = partial(_bayes_factor_test, threshold=pruning_threshold) 232 | 233 | elif pruning_method == "fisher": 234 | if pruning_threshold is None: 235 | # default p-value threshold 236 | pruning_threshold = 0.2 237 | 238 | test = partial(_fisher_test, threshold=pruning_threshold) 239 | 240 | elif pruning_method is not None: 241 | raise ValueError('pruning_method must be "bayes", "fisher", or None.') 242 | 243 | if np.isscalar(bins): 244 | bin_edges = np.linspace(*x_range, bins + 1) 245 | else: 246 | bin_edges = np.asarray(bins) 247 | 248 | neg_samples, _ = np.histogram(x[y == 0], bins=bin_edges) 249 | pos_samples, _ = np.histogram(x[y == 1], bins=bin_edges) 250 | 251 | if prior_params is None: 252 | # default prior is weakly informative, using global event rate 253 | num_pos_samples = np.sum(pos_samples) 254 | num_neg_samples = np.sum(neg_samples) 255 | 256 | if num_pos_samples > num_neg_samples: 257 | prior_params = (num_pos_samples / num_neg_samples, 1) 258 | else: 259 | prior_params = (1, num_neg_samples / num_pos_samples) 260 | 261 | if pruning_method is not None: 262 | pruner = _prune_histogram( 263 | bin_edges, pos_samples, neg_samples, test, prior_params, max_bin_size=max_bin_size 264 | ) 265 | bin_edges, pos_samples, neg_samples = next(iter(pruner)) 266 | 267 | return bin_edges, beta_dist( 268 | pos_samples + prior_params[0], neg_samples + prior_params[1] 269 | ) 270 | -------------------------------------------------------------------------------- /bayeshist/plotting.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable, Optional, Any, Tuple 2 | import numpy as np 3 | 4 | from .bayeshist import FrozenDistType 5 | 6 | 7 | def plot_bayesian_histogram( 8 | bin_edges: np.ndarray, 9 | data_dist: FrozenDistType, 10 | color: Union[str, Iterable[float], None] = None, 11 | label: Optional[str] = None, 12 | ax: Any = None, 13 | ci: Optional[Tuple[float, float]] = (0.01, 0.99) 14 | ) -> None: 15 | """Plot a Bayesian histogram as horizontal lines with credible intervals. 16 | 17 | Parameters: 18 | 19 | bin_edges: 20 | Coordinates of bin edges 21 | 22 | data_dist: 23 | n-dimensional Beta distribution (n = number of bins) 24 | 25 | color: 26 | Color to use (default: use next in current color cycle) 27 | 28 | label: 29 | Legend label (default: no label) 30 | 31 | ax: 32 | Matplotlib axis to use (default: current axis) 33 | 34 | ci: 35 | Credible interval used for shading, use `None` to disable shading. 36 | 37 | Example: 38 | 39 | >>> x = np.random.randn(1000) 40 | >>> p = 10 ** (-2 + x) 41 | >>> y = np.random.rand() < p 42 | >>> bins, beta_dist = bayesian_histogram(x, y) 43 | >>> plot_bayesian_histogram(bins, beta_dist) 44 | 45 | """ 46 | import matplotlib.pyplot as plt 47 | from matplotlib.collections import PatchCollection 48 | from matplotlib.patches import Rectangle 49 | 50 | if ax is None: 51 | ax = plt.gca() 52 | 53 | if color is None: 54 | # advance color cycle 55 | dummy, = ax.plot([], []) 56 | color = dummy.get_color() 57 | 58 | if ci is not None: 59 | ci_low, ci_high = data_dist.ppf(ci[0]), data_dist.ppf(ci[1]) 60 | 61 | # background boxes 62 | errorboxes = [ 63 | Rectangle((x1, y1), x2 - x1, y2 - y1) 64 | for x1, x2, y1, y2 65 | in zip(bin_edges[:-1], bin_edges[1:], ci_low, ci_high) 66 | ] 67 | 68 | pc = PatchCollection(errorboxes, facecolor=color, alpha=0.2) 69 | ax.add_collection(pc) 70 | 71 | # box edges 72 | ax.hlines(ci_low, bin_edges[:-1], bin_edges[1:], colors=color, alpha=0.8, linewidth=1) 73 | ax.hlines(ci_high, bin_edges[:-1], bin_edges[1:], colors=color, alpha=0.8, linewidth=1) 74 | 75 | # median indicator 76 | ax.hlines(data_dist.median(), bin_edges[:-1], bin_edges[1:], colors=color, label=label) 77 | -------------------------------------------------------------------------------- /doc/bayesian-histogram-comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/bayesian-histograms/f0caff585d2837ae759e798d1d486a4411ff2fdf/doc/bayesian-histogram-comp.png -------------------------------------------------------------------------------- /doc/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dionhaefner/bayesian-histograms/f0caff585d2837ae759e798d1d486a4411ff2fdf/doc/samples.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | from codecs import open 6 | import os 7 | import re 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | 15 | # read version from __init__.py 16 | version_file = os.path.join(here, "bayeshist", "__init__.py") 17 | version_pattern = re.compile(r"__version__\s*=\s*[\"'](.+)[\"']") 18 | 19 | with open(version_file, encoding="utf-8") as f: 20 | for line in f: 21 | match = version_pattern.match(line) 22 | if match: 23 | version = match.group(1) 24 | break 25 | else: 26 | raise RuntimeError("Could not determine version from __init__.py") 27 | 28 | 29 | setup( 30 | name="bayeshist", 31 | license="MIT", 32 | version=version, 33 | description="Bayesian histograms for estimation of binary event rates", 34 | long_description=long_description, 35 | long_description_content_type="text/markdown", 36 | author="Dion Häfner", 37 | author_email="mail@dionhaefner.de", 38 | url="https://github.com/dionhaefner/bayeshist", 39 | packages=["bayeshist"], 40 | install_requires=[ 41 | "numpy", 42 | "scipy", 43 | ], 44 | python_requires=">=3.6", 45 | zip_safe=False, 46 | ) 47 | --------------------------------------------------------------------------------