├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── conjugate_prior ├── __init__.py ├── beta.py ├── dirichlet.py ├── gamma.py ├── invgamma.py ├── normal.py └── prior.py ├── setup.cfg └── setup.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | on: [push] 3 | 4 | jobs: 5 | 6 | deploy: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: '3.x' 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | pip install build 19 | - name: Build package 20 | run: python -m build 21 | - name: Publish package 22 | uses: pypa/gh-action-pypi-publish@release/v1 23 | with: 24 | password: ${{ secrets.PYPI_API }} 25 | 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Uri Goren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conjugate Prior 2 | Python implementation of the conjugate prior table for Bayesian Statistics 3 | 4 | [![Downloads](http://pepy.tech/badge/conjugate-prior)](http://pepy.tech/count/conjugate-prior) 5 | 6 | See wikipedia page: 7 | 8 | https://en.wikipedia.org/wiki/Conjugate_prior#Table_of_conjugate_distributions 9 | 10 | ## Installation: 11 | `pip install conjugate-prior` 12 | 13 | ## Supported Models: 14 | 1. `BetaBinomial` - Useful for independent trials such as click-trough-rate (ctr), web visitor conversion. 15 | 1. `BetaBernoulli` - Same as above. 16 | 1. `GammaExponential` - Useful for churn-rate analysis, cost, dwell-time. 17 | 1. `GammaPoisson` - Useful for time passed until event, as above. 18 | 1. `NormalNormalKnownVar` - Useful for modeling a centralized distribution with constant noise. 19 | 1. `NormalLogNormalKnownVar` - Useful for modeling a Length of a support phone call. 20 | 1. `InvGammaNormalKnownMean` - Useful for modeling the effect of a noise. 21 | 1. `InvGammaWeibullKnownShape` - Useful for reasoning about particle sizes over time. 22 | 1. `DirichletMultinomial` - Extension of BetaBinomial to more than 2 types of events (Limited support). 23 | 24 | ## Basic API 25 | 1. `model = GammaExponential(a, b)` - A Bayesian model with an `Exponential` likelihood, and a `Gamma` prior. Where `a` and `b` are the prior parameters. 26 | 1. `model.pdf(x)` - Returns the probability-density-function of the prior function at `x`. 27 | 1. `model.cdf(x)` - Returns the cumulative-density-function of the prior function at `x`. 28 | 1. `model.mean()` - Returns the prior mean. 29 | 1. `model.plot(l, u)` - Plots the prior distribution between `l` and `u`. 30 | 1. `model.posterior(l, u)` - Returns the credible interval on `(l,u)` (equivalent to `cdf(u)-cdf(l)`). 31 | 1. `model.update(data)` - Returns a *new* model after observing `data`. 32 | 1. `model.predict(x)` - Predicts the likelihood of observing `x` (if a posterior predictive exists). 33 | 1. `model.sample()` - Draw a single sample from the posterior distribution. 34 | 35 | 36 | 37 | ## Coin flip example: 38 | 39 | from conjugate_prior import BetaBinomial 40 | heads = 95 41 | tails = 105 42 | prior_model = BetaBinomial() # Uninformative prior 43 | updated_model = prior_model.update(heads, tails) 44 | credible_interval = updated_model.posterior(0.45, 0.55) 45 | print ("There's {p:.2f}% chance that the coin is fair".format(p=credible_interval*100)) 46 | predictive = updated_model.predict(50, 50) 47 | print ("The chance of flipping 50 Heads and 50 Tails in 100 trials is {p:.2f}%".format(p=predictive*100)) 48 | 49 | ## Variant selection with Multi-armed-bandit 50 | 51 | Assume we have `10` creatives (variants) we can choose for our ad campaign, at first we start with the uninformative prior. 52 | 53 | After getting feedback (i.e. clicks) from displaying the ads, we update our model. 54 | 55 | Then we sample the `DirrechletMultinomial` model for the updated distribution. 56 | 57 | from conjugate_prior import DirichletMultinomial 58 | from collections import Counter 59 | # Assuming we have 10 creatives 60 | model = DirichletMultinomial(10) 61 | mle = lambda M:[int(r.argmax()) for r in M] 62 | selections = [v for k,v in sorted(Counter(mle(model.sample(100))).most_common())] 63 | print("Percentage before 1000 clicks: ",selections) 64 | # after a period of time, we got this array of clicks 65 | clicks = [400,200,100,50,20,20,10,0,0,200] 66 | model = model.update(clicks) 67 | selections = [v for k,v in sorted(Counter(mle(model.sample(100))).most_common())] 68 | print("Percentage after 1000 clicks: ",selections) 69 | 70 | ## Naive Recommendation System with UCB 71 | 72 | from conjugate_prior import BetaBinomialRanker 73 | ranker = BetaBinomialRanker(prior=0.1) # 10% click-through-rate 74 | ranker["cmpgn1"]+=(1,9) # 1 click, 9 skips 75 | ranker["cmpgn2"]+=(10,90) # 10 click, 90 skips 76 | ranker["cmpgn3"]+=(1,2) # 1 click, 3 skips 77 | # Balance exploration and exploitation w/UCB 78 | print(ranker.rank_by_ucb()) -------------------------------------------------------------------------------- /conjugate_prior/__init__.py: -------------------------------------------------------------------------------- 1 | from .dirichlet import * 2 | from .gamma import * 3 | from .beta import * 4 | from .normal import * 5 | from .invgamma import * 6 | from .prior import ConjugatePrior, BetaBinomialRanker, GammaExponentialRanker 7 | __version__ = '0.85' 8 | -------------------------------------------------------------------------------- /conjugate_prior/beta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from scipy import special as fn 4 | 5 | try: 6 | from matplotlib import pyplot as plt 7 | except ModuleNotFoundError: 8 | import sys 9 | 10 | sys.stderr.write("matplotlib was not found, plotting would raise an exception.\n") 11 | plt = None 12 | 13 | 14 | class BetaBinomial: 15 | __slots__ = ["positives", "negatives"] 16 | 17 | def __init__(self, *args): 18 | if not any(args) or args[0] is None: 19 | # uninformative prior 20 | self.positives = self.negatives = 1 21 | elif len(args) == 1: 22 | # assuming rate 23 | self.positives = args[0] * 100.0 24 | self.negatives = (1 - args[0]) * 100.0 25 | elif len(args) == 2: 26 | self.positives = args[0] 27 | self.negatives = args[1] 28 | else: 29 | raise SyntaxError("Illegal number of arguments") 30 | 31 | def __iadd__(self, other): 32 | if isinstance(other, BetaBinomial): 33 | self.positives += other.positives 34 | self.negatives += other.negatives 35 | elif isinstance(other, tuple): 36 | self.positives += other[0] 37 | self.negatives += other[1] 38 | else: 39 | raise TypeError("Unsupported type") 40 | return self 41 | def update(self, *args): 42 | if len(args) == 1: 43 | n = p = 0 44 | for x in args[0]: 45 | if x: 46 | p += 1 47 | else: 48 | n += 1 49 | return BetaBinomial(self.positives + p, self.negatives + n) 50 | elif len(args) == 2: 51 | return BetaBinomial(self.positives + args[0], self.negatives + args[1]) 52 | else: 53 | raise SyntaxError("Illegal number of arguments") 54 | 55 | def pdf(self, x): 56 | return stats.beta.pdf(x, self.positives, self.negatives) 57 | 58 | def cdf(self, x): 59 | return stats.beta.cdf(x, self.positives, self.negatives) 60 | 61 | def posterior(self, l, u): 62 | if l > u: 63 | return 0.0 64 | return self.cdf(u) - self.cdf(l) 65 | 66 | def mean(self, n=1): 67 | return self.positives * n / (self.positives + self.negatives) 68 | 69 | def plot(self, l=0.0, u=1.0): 70 | x = np.linspace(u, l, 1001) 71 | y = stats.beta.pdf(x, self.positives, self.negatives) 72 | y = y / y.sum() 73 | plt.plot(x, y) 74 | plt.xlim((l, u)) 75 | 76 | def predict(self, t, f, log=False): 77 | a = self.positives 78 | b = self.negatives 79 | log_pmf = (fn.gammaln(t + f + 1) + fn.gammaln(t + a) + fn.gammaln(f + b) + fn.gammaln(a + b)) - \ 80 | (fn.gammaln(t + 1) + fn.gammaln(f + 1) + fn.gammaln(a) + fn.gammaln(b) + fn.gammaln(t + f + a + b)) 81 | if log: 82 | return log_pmf 83 | return np.exp(log_pmf) 84 | 85 | def sample(self, n=1): 86 | p = np.random.beta(self.positives, self.negatives,n) 87 | return p 88 | 89 | def percentile(self, p): 90 | return stats.beta.ppf(p, self.positives, self.negatives) 91 | 92 | 93 | class BetaBernoulli(BetaBinomial): 94 | def update(self, *args): 95 | if len(args) == 1: 96 | n = p = 0 97 | for x in args[0]: 98 | if x: 99 | p += 1 100 | else: 101 | n += 1 102 | return BetaBernoulli(self.positives + p, self.negatives + n) 103 | elif len(args) == 2: 104 | return BetaBernoulli(self.positives + args[0], self.negatives + args[1]) 105 | else: 106 | raise SyntaxError("Illegal number of arguments") 107 | 108 | def sample(self, n=1,output_parameter=False): 109 | p = np.random.beta(self.positives, self.negatives,n) 110 | if output_parameter: 111 | return p 112 | return int(np.random.random() < p) 113 | -------------------------------------------------------------------------------- /conjugate_prior/dirichlet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | import collections 4 | 5 | 6 | class DirichletMultinomial: 7 | __slots__ = ["alpha", "k"] 8 | 9 | def __init__(self, alpha=None): 10 | if type(alpha) == int: 11 | self.k = alpha 12 | self.alpha = np.ones(alpha) 13 | elif len(alpha) > 1: 14 | self.k = len(alpha) 15 | self.alpha = np.array(alpha) 16 | else: 17 | raise SyntaxError("Argument should be a vector or an int") 18 | 19 | def update(self, counts): 20 | if isinstance(counts, list): 21 | counts = collections.Counter(counts) 22 | if not isinstance(counts, dict): 23 | raise SyntaxError("Argument should be a dict or a list") 24 | counts_vec = [counts.get(i, 0) for i in range(self.k)] 25 | return DirichletMultinomial(np.add(self.alpha, counts_vec)) 26 | 27 | def pdf(self, x): 28 | diri = stats.dirichlet(self.alpha) 29 | return diri.pdf(x) 30 | 31 | def mean(self, n=1): 32 | return self.alpha * n / (self.alpha.sum()) 33 | 34 | def cdf(self, weights, x): 35 | Omega = lambda row: np.dot(weights, row) 36 | # Sample from Dirichlet posterior 37 | samples = np.random.dirichlet(self.alpha, 100000) 38 | # apply sum to sample draws 39 | W_samples = np.apply_along_axis(Omega, 1, samples) 40 | # Compute P(W > x) 41 | return (W_samples > x).mean() 42 | 43 | def posterior(self, weights, l, u): 44 | if l > u: 45 | return 0.0 46 | return self.cdf(weights, l) - self.cdf(weights, u) 47 | 48 | def sample(self, n=1): 49 | return np.random.dirichlet(self.alpha, n) 50 | 51 | def percentile(self, p): 52 | return stats.dirichlet.ppf(p, self.alpha) 53 | -------------------------------------------------------------------------------- /conjugate_prior/gamma.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | try: 5 | from matplotlib import pyplot as plt 6 | except ModuleNotFoundError: 7 | import sys 8 | 9 | sys.stderr.write("matplotlib was not found, plotting would raise an exception.\n") 10 | plt = None 11 | 12 | 13 | class GammaExponential: 14 | __slots__ = ["alpha", "beta"] 15 | 16 | def __init__(self, alpha, beta=None): 17 | if beta is None: 18 | print("Assuming first parameter is the Expectancy") 19 | lamda = 1.0 / alpha 20 | beta = 0.5 21 | alpha = lamda * beta 22 | self.alpha = alpha 23 | self.beta = beta 24 | def __iadd__(self, other): 25 | if isinstance(other, GammaExponential): 26 | self.alpha += other.alpha 27 | self.beta += other.beta 28 | elif isinstance(other, float): 29 | self.alpha += 1 30 | self.beta += other 31 | else: 32 | raise TypeError("Unsupported type") 33 | return self 34 | def update(self, *args): 35 | if len(args) == 1: 36 | return GammaExponential(self.alpha + len(args[0]), self.beta + sum(args[0])) 37 | elif len(args) == 2: 38 | return GammaExponential(self.alpha + args[0], self.beta + args[1]) 39 | else: 40 | raise SyntaxError("Illegal number of arguments") 41 | 42 | def pdf(self, x): 43 | return stats.gamma.pdf(1.0 / x, self.alpha, scale=1.0 / self.beta) 44 | 45 | def cdf(self, x): 46 | return 1 - stats.gamma.cdf(1.0 / x, self.alpha, scale=1.0 / self.beta) 47 | 48 | def posterior(self, l, u): 49 | if l > u: 50 | return 0.0 51 | return self.cdf(u) - self.cdf(l) 52 | 53 | def mean(self): 54 | return self.alpha / self.beta 55 | 56 | def plot(self, l=0, u=10): 57 | x = np.linspace(u, l, 1001) 58 | y = stats.gamma.pdf(x, self.alpha, scale=1.0 / self.beta) 59 | plt.plot(x, y) 60 | plt.xlim((l, u)) 61 | 62 | def plot_inverse_lambda(self, l=0.0001, u=0.999): 63 | x = np.linspace(1.0 / u, 1.0 / l, 1001) 64 | y = stats.gamma.pdf(x, self.alpha, scale=1.0 / self.beta) 65 | x = 1 / x 66 | y = list(reversed(y)) 67 | plt.plot(x, y) 68 | plt.xlim((l, u)) 69 | 70 | def predict(self, x): 71 | return stats.lomax.cdf(1.0 / x, self.alpha, scale=1.0 / self.beta) 72 | 73 | def sample(self,n=1): 74 | lamda = np.random.gamma(self.alpha, 1/self.beta) 75 | return np.random.exponential(1/lamda,n) 76 | 77 | def percentile(self, p): 78 | return stats.gamma.ppf(p, self.alpha, scale=1.0 / self.beta) 79 | 80 | 81 | class GammaPoisson(GammaExponential): 82 | def update(self, *args): 83 | if len(args) == 1: 84 | return GammaPoisson(self.alpha + sum(args[0]), self.beta + len(args[0])) 85 | elif len(args) == 2: 86 | return GammaPoisson(self.alpha + args[0], self.beta + args[1]) 87 | else: 88 | raise SyntaxError("Illegal number of arguments") 89 | 90 | def predict(self, x): 91 | return stats.nbinom.pmf(x, self.alpha, scale=1.0 / (1 + self.beta)) 92 | 93 | def sample(self,n=1): 94 | lamda = np.random.gamma(self.alpha, 1/self.beta) 95 | return np.random.poisson(lamda,n) 96 | -------------------------------------------------------------------------------- /conjugate_prior/invgamma.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | try: 5 | from matplotlib import pyplot as plt 6 | except ModuleNotFoundError: 7 | import sys 8 | 9 | sys.stderr.write("matplotlib was not found, plotting would raise an exception.\n") 10 | plt = None 11 | 12 | 13 | class InvGammaNormalKnownMean: 14 | __slots__ = ["alpha", "beta", "shape"] 15 | 16 | def __init__(self, *args): 17 | self.beta = 1 18 | self.shape = 1 19 | if len(args) == 1: 20 | self.alpha = args[0] 21 | elif len(args) == 2: 22 | self.alpha = args[0] 23 | self.beta = args[1] 24 | elif len(args) == 3: 25 | self.alpha = args[0] 26 | self.beta = args[1] 27 | self.shape = args[2] 28 | else: 29 | raise SyntaxError("Illegal number of arguments") 30 | 31 | def update(self, data): 32 | var = np.var(data) 33 | mean = np.mean(data) 34 | n = len(data) 35 | beta_update = sum([(d - mean) ** 2 for d in data]) / 2.0 36 | return InvGammaNormalKnownMean(self.alpha + n / 2.0, self.beta + beta_update) 37 | 38 | def pdf(self, x): 39 | return stats.invgamma.pdf(x, a=self.alpha, scale=self.beta) 40 | 41 | def cdf(self, x): 42 | return stats.invgamma.cdf(x, a=self.alpha, scale=self.beta) 43 | 44 | def posterior(self, l, u): 45 | if l > u: 46 | return 0.0 47 | return self.cdf(u) - self.cdf(l) 48 | 49 | def plot(self, l=0.0, u=3.0): 50 | x = np.linspace(u, l, 1001) 51 | y = stats.invgamma.pdf(x, a=self.alpha, scale=self.beta) 52 | y = y / y.sum() 53 | plt.plot(x, y) 54 | plt.xlim((l, u)) 55 | 56 | def sample(self, n): 57 | mean = stats.invgamma.rvs(a=self.alpha, scale=self.beta, size=n) 58 | return np.random.normal(mean, self.shape, n) 59 | 60 | def predict(self, x): 61 | return stats.invgamma.cdf(x, a=self.alpha, scale=self.beta) 62 | 63 | def percentile(self, p): 64 | return stats.invgamma.ppf(p, a=self.alpha, scale=self.beta) 65 | 66 | 67 | 68 | class InvGammaWeibullKnownShape(InvGammaNormalKnownMean): 69 | def update(self, data): 70 | return InvGammaWeibullKnownShape(self.alpha + len(data), self.beta + sum([d ** self.shape for d in data])) 71 | 72 | def sample(self, n): 73 | l = stats.invgamma.rvs(a=self.alpha, scale=self.beta, size=n) ** (1 / self.shape) 74 | return np.random.weibull(l, n) 75 | 76 | def predict(self, x): 77 | raise NotImplemented("No posterior predictive") -------------------------------------------------------------------------------- /conjugate_prior/normal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | try: 5 | from matplotlib import pyplot as plt 6 | except ModuleNotFoundError: 7 | import sys 8 | 9 | sys.stderr.write("matplotlib was not found, plotting would raise an exception.\n") 10 | plt = None 11 | 12 | 13 | class NormalNormalKnownVar: 14 | __slots__ = ["mean", "var", "known_var"] 15 | 16 | def __init__(self, known_var, prior_mean=0, prior_var=1): 17 | self.mean = prior_mean 18 | self.var = prior_var 19 | self.known_var = known_var 20 | 21 | def update(self, data, var=None, n=None): 22 | if var: 23 | mean=data 24 | else: 25 | var = np.var(data) 26 | mean = np.mean(data) 27 | n = len(data) 28 | denom = (1.0 / self.var + n / self.known_var) 29 | return NormalNormalKnownVar(self.known_var, (self.mean / self.var + sum(data) / self.known_var) / denom, 30 | 1.0 / denom) 31 | 32 | def pdf(self, x): 33 | return stats.norm.pdf(x, self.mean, np.sqrt(self.var)) 34 | 35 | def cdf(self, x): 36 | return stats.norm.cdf(x, self.mean, np.sqrt(self.var)) 37 | 38 | def posterior(self, l, u): 39 | if l > u: 40 | return 0.0 41 | return self.cdf(u) - self.cdf(l) 42 | 43 | def plot(self, l=0.0, u=1.0): 44 | x = np.linspace(u, l, 1001) 45 | y = stats.norm.pdf(x, self.mean, np.sqrt(self.var)) 46 | y = y / y.sum() 47 | plt.plot(x, y) 48 | plt.xlim((l, u)) 49 | 50 | def predict(self, x): 51 | return stats.norm.cdf(x, self.mean, np.sqrt(self.var + self.known_var)) 52 | 53 | def sample(self,n=1): 54 | return np.random.normal(self.mean, np.sqrt(self.var + self.known_var),size=n) 55 | 56 | def percentile(self, p): 57 | return stats.norm.ppf(p, self.mean, np.sqrt(self.var)) 58 | 59 | 60 | class NormalLogNormalKnownVar(NormalNormalKnownVar): 61 | def update(self, data): 62 | data = np.log(data) 63 | var = np.var(data) 64 | mean = np.mean(data) 65 | n = len(data) 66 | denom = (1.0 / self.var + n / self.known_var) 67 | return NormalLogNormalKnownVar(self.known_var, (self.mean / self.var + sum(data) / self.known_var) / denom, 68 | 1.0 / denom) 69 | 70 | def predict(self, x): 71 | raise NotImplemented("No posterior predictive") 72 | 73 | def sample(self,n=1): 74 | return np.log(np.random.normal(self.mean, np.sqrt(self.var + self.known_var)),size=n) 75 | -------------------------------------------------------------------------------- /conjugate_prior/prior.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | import scipy.stats as stats 4 | import pandas as pd 5 | from .gamma import GammaExponential 6 | from .normal import NormalNormalKnownVar 7 | from .beta import BetaBinomial 8 | from .invgamma import InvGammaWeibullKnownShape 9 | try: 10 | from matplotlib import pyplot as plt 11 | except ModuleNotFoundError: 12 | import sys 13 | 14 | sys.stderr.write("matplotlib was not found, plotting would raise an exception.\n") 15 | plt = None 16 | 17 | def aic_bic(data, dist_name, params): 18 | """Calculate AIC and BIC for a given distribution""" 19 | dist = getattr(stats, dist_name) 20 | log_likelihood = np.sum(dist.logpdf(data, *params)) 21 | k = len(params) 22 | n = len(data) 23 | aic = 2*k - 2*log_likelihood 24 | bic = np.log(n)*k - 2*log_likelihood 25 | return aic, bic 26 | 27 | class ConjugatePrior: 28 | def __init__(self, criterion='aic') -> None: 29 | self.criterion = criterion 30 | self.best_fit = None 31 | self.params = {} 32 | 33 | def fit(self, X): 34 | norm_params = stats.norm.fit(X) 35 | exp_params = stats.expon.fit(X) 36 | weibull_params = stats.weibull_min.fit(X) 37 | 38 | # Calculate AIC and BIC for each distribution 39 | aic_norm, bic_norm = aic_bic(X, 'norm', norm_params) 40 | aic_exp, bic_exp = aic_bic(X, 'expon', exp_params) 41 | aic_weibull, bic_weibull = aic_bic(X, 'weibull_min', weibull_params) 42 | 43 | # Collect results in a DataFrame for comparison 44 | results = pd.DataFrame({ 45 | 'Distribution': ['norm', 'expon', 'weibull_min'], 46 | 'AIC': [aic_norm, aic_exp, aic_weibull], 47 | 'BIC': [bic_norm, bic_exp, bic_weibull] 48 | }) 49 | 50 | if self.criterion.lower() == 'aic': 51 | self.best_fit = results.loc[results['AIC'].idxmin()]["Distribution"] 52 | elif self.criterion.lower() == 'bic': 53 | self.best_fit = results.loc[results['BIC'].idxmin()]["Distribution"] 54 | else: 55 | raise ValueError("Criterion must be either 'aic' or 'bic'") 56 | if self.best_fit == 'norm': 57 | self.params = norm_params 58 | elif self.best_fit == 'expon': 59 | self.params = exp_params 60 | elif self.best_fit == 'weibull_min': 61 | self.params = weibull_params 62 | 63 | def predict(self, X): 64 | if self.best_fit is None: 65 | raise ValueError("You must call the fit method first") 66 | dist = getattr(stats, self.best_fit) 67 | return dist.pdf(X, *self.params) 68 | 69 | def as_prior(self): 70 | # TODO: Check 71 | if self.best_fit == 'norm': 72 | return NormalNormalKnownVar(*self.params) 73 | elif self.best_fit == 'expon': 74 | return GammaExponential(*self.params) 75 | elif self.best_fit == 'weibull_min': 76 | return InvGammaWeibullKnownShape(*self.params) 77 | 78 | class BetaBinomialRanker: 79 | def __init__(self, n=0, prior=None,ucb_percentile = 0.95, discount_coefficient=1, names=None) -> None: 80 | self.cmpgns = [BetaBinomial(prior) for _ in range(n)] 81 | self.n = n 82 | self.prior = prior 83 | self.ucb_percentile = ucb_percentile 84 | self.discount_coefficient = discount_coefficient 85 | if names is None: 86 | self.names = [str(i) for i in range(n)] 87 | else: 88 | self.names = names 89 | def __getitem__(self, name): 90 | i = self.names.index(name) 91 | return self.cmpgns[i] 92 | def __setitem__(self, name, value): 93 | try: 94 | i = self.names.index(name) 95 | except ValueError: 96 | self.names.append(name) 97 | self.cmpgns.append(BetaBinomial(self.prior)) 98 | self.n += 1 99 | i = self.n - 1 100 | p,n = value 101 | self.cmpgns[i].positives = p 102 | self.cmpgns[i].negatives = n 103 | def __delitem__(self, name): 104 | i = self.names.index(name) 105 | del self.cmpgns[i] 106 | del self.names[i] 107 | self.n -= 1 108 | def __str__(self): 109 | return str(self.cmpgns) 110 | def reset(self): 111 | self.cmpgns = [BetaBinomial(self.prior) for _ in range(self.n)] 112 | def update(self, name, p, n): 113 | i = self.names.index(name) 114 | self.cmpgns[i] = self.cmpgns[i].update(p,n) 115 | def update_all(self, data: List[Tuple[int, int]]): 116 | assert len(data) == self.n, "Data must have the same number of campaigns as the model" 117 | for i, d in enumerate(data): 118 | p,n = d 119 | self.cmpgns[i] = self.cmpgns[i].update(p,n) 120 | def rank_by_mle(self): 121 | lst = sorted([(c.mean(), i) for i,c in enumerate(self.cmpgns)]) 122 | return [self.names[i] for _,i in lst] 123 | def rank_by_ucb(self): 124 | lst = sorted([(c.percentile(self.ucb_percentile), i) for i,c in enumerate(self.cmpgns)]) 125 | return [self.names[i] for _,i in lst] 126 | def discount(self): 127 | for i in range(self.n): 128 | self.cmpgns[i].positives *= self.discount_coefficient 129 | self.cmpgns[i].negatives *= self.discount_coefficient 130 | return self 131 | 132 | class GammaExponentialRanker: 133 | def __init__(self, n=0, prior=None,ucb_percentile = 0.95, discount_coefficient=1, names=None) -> None: 134 | self.cmpgns = [GammaExponential(prior) for _ in range(n)] 135 | self.n = n 136 | self.prior = prior 137 | self.ucb_percentile = ucb_percentile 138 | self.discount_coefficient = discount_coefficient 139 | if names is None: 140 | self.names = [str(i) for i in range(n)] 141 | else: 142 | self.names = names 143 | def __getitem__(self, name): 144 | i = self.names.index(name) 145 | return self.cmpgns[i] 146 | def __setitem__(self, name, value): 147 | try: 148 | i = self.names.index(name) 149 | except ValueError: 150 | self.names.append(name) 151 | self.cmpgns.append(GammaExponential(self.prior)) 152 | self.n += 1 153 | i = self.n - 1 154 | a,b = value 155 | self.cmpgns[i].alpha = a 156 | self.cmpgns[i].beta = b 157 | def __delitem__(self, name): 158 | i = self.names.index(name) 159 | del self.cmpgns[i] 160 | del self.names[i] 161 | self.n -= 1 162 | def __str__(self): 163 | return str(self.cmpgns) 164 | def reset(self): 165 | self.cmpgns = [GammaExponential(self.prior) for _ in range(self.n)] 166 | def update(self, name, data): 167 | i = self.names.index(name) 168 | self.cmpgns[i] = self.cmpgns[i].update(data) 169 | def update_all(self, data: List[List[int]]): 170 | assert len(data) == self.n, "Data must have the same number of campaigns as the model" 171 | for i, d in enumerate(data): 172 | self.cmpgns[i] = self.cmpgns[i].update(d) 173 | def rank_by_mle(self): 174 | lst = sorted([(c.mean(), i) for i,c in enumerate(self.cmpgns)]) 175 | return [self.names[i] for _,i in lst] 176 | def rank_by_ucb(self): 177 | lst = sorted([(c.percentile(self.ucb_percentile), i) for i,c in enumerate(self.cmpgns)]) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | with open("README.md", 'r') as f: 3 | long_description = f.read() 4 | with open("conjugate_prior/__init__.py", 'r') as f: 5 | for l in f: 6 | if l.startswith("__version__"): 7 | _,version=l.split("=", 1) 8 | version = version.strip('\'" \r\n\t') 9 | setup( 10 | name="conjugate_prior", 11 | packages=["conjugate_prior"], 12 | install_requires=[ 13 | 'setuptools', 14 | 'scipy', 15 | 'numpy', 16 | 'matplotlib', 17 | ], 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | version=version, 21 | description='Bayesian Statistics conjugate prior distributions', 22 | author='Uri Goren', 23 | author_email='conjugate@argmaxml.com', 24 | url='https://github.com/argmaxml/conjugate_prior', 25 | keywords=['conjugate', 'bayesian', 'stats', 'statistics', 'bayes', 'distribution', 'probability', 'hypothesis', 26 | 'modelling', 'thompson sampling'], 27 | classifiers=[], 28 | ) 29 | --------------------------------------------------------------------------------