├── _config.yml ├── environment.yml ├── sat_scale.csv ├── dice.py ├── dice_soln.py ├── train.py ├── install_test.py ├── train_soln.py ├── euro.py ├── euro_soln.py ├── billiards.py ├── train2.py ├── sat_ranks.csv ├── euro2.py ├── euro2_soln.py ├── lincoln.py ├── volunteer.py ├── tutorial.md ├── sat.py ├── sat_soln.py ├── README.md ├── debug.ipynb ├── 02_dice.ipynb ├── 01_cookie.ipynb ├── 03_euro.ipynb ├── 05_world_cup.ipynb ├── 04_bandit.ipynb ├── distribution.py ├── thinkplot.py └── empyrical_dist.py /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: BayesMadeSimple 2 | 3 | dependencies: 4 | - python=3.7 5 | - jupyter 6 | - numpy 7 | - matplotlib 8 | - seaborn 9 | - pandas 10 | - scipy 11 | - pip 12 | - pip: 13 | - empiricaldist 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /sat_scale.csv: -------------------------------------------------------------------------------- 1 | "http://professionals.collegeboard.com/profdownload/sat-raw-score-to-scaled-score-ranges-2010.pdf",,,,, 2 | "Raw score","Critical Reading","Raw score","Mathematics","Raw score","Writing Skills" 3 | 67,800,,,, 4 | 65,"790-800",,,, 5 | 60,"710-730",,,, 6 | 55,"660-670",54,800,, 7 | 50,"620-630",50,"710-730",49,800 8 | 45,"580-590",45,"650-670",45,"690-730" 9 | 40,"550-560",40,"610-630",40,"620-660" 10 | 35,"520-530",35,"570-580",35,"570-660" 11 | 30,"490-510",30,"530-540",30,"520-550" 12 | 25,"460-480",25,"490-500",25,"480-510" 13 | 20,"420-450",20,"450-460",20,"440-460" 14 | 15,"390-420",15,"410-420",15,"400-420" 15 | 10,"350-380",10,"360-380",10,"350-370" 16 | 5,"300-330",5,"300-330",5,"300-330" 17 | 0,"220-270",0,"220-260",0,"210-260" 18 | -5,200,-5,200,-5,200 19 | -------------------------------------------------------------------------------- /dice.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | from thinkbayes import Suite 11 | 12 | 13 | class Dice(Suite): 14 | """Represents hypotheses about which die was rolled.""" 15 | 16 | def Likelihood(self, data, hypo): 17 | """Computes the likelihood of the data under the hypothesis. 18 | 19 | hypo: integer number of sides on the die 20 | data: integer die roll 21 | """ 22 | # write this method 23 | return 1 24 | 25 | 26 | def main(): 27 | suite = Dice([4, 6, 8, 12, 20]) 28 | 29 | suite.Update(6) 30 | print('After one 6') 31 | suite.Print() 32 | 33 | for roll in [8, 7, 7, 5, 4]: 34 | suite.Update(roll) 35 | 36 | print('After more rolls') 37 | suite.Print() 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /dice_soln.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | from thinkbayes import Suite 11 | 12 | 13 | class Dice(Suite): 14 | """Represents hypotheses about which die was rolled.""" 15 | 16 | def Likelihood(self, data, hypo): 17 | """Computes the likelihood of the data under the hypothesis. 18 | 19 | hypo: integer number of sides on the die 20 | data: integer die roll 21 | """ 22 | if hypo < data: 23 | return 0 24 | else: 25 | return 1.0/hypo 26 | 27 | 28 | def main(): 29 | suite = Dice([4, 6, 8, 12, 20]) 30 | 31 | suite.Update(6) 32 | print('After one 6') 33 | suite.Print() 34 | 35 | for roll in [8, 7, 7, 5, 4]: 36 | suite.Update(roll) 37 | 38 | print('After more rolls') 39 | suite.Print() 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | class Train(thinkbayes.Suite): 15 | """Represents hypotheses about how many trains the company has. 16 | 17 | The likelihood function for the train problem is the same as 18 | for the Dice problem. 19 | """ 20 | def Likelihood(self, data, hypo): 21 | """Computes the likelihood of the data under the hypothesis. 22 | 23 | hypo: number of trains the carrier operates 24 | data: the number of the observed train 25 | """ 26 | # fill this in! 27 | return 1 28 | 29 | 30 | 31 | def main(): 32 | hypos = range(100, 1001) 33 | suite = Train(hypos) 34 | 35 | suite.Update(321) 36 | 37 | thinkplot.PrePlot(1) 38 | thinkplot.Pmf(suite) 39 | thinkplot.Show(xlabel='Number of trains', 40 | ylabel='Probability', 41 | legend=False) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /install_test.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import math 11 | import numpy 12 | 13 | from matplotlib import pyplot 14 | 15 | import thinkplot 16 | import thinkbayes2 17 | 18 | 19 | def RenderPdf(mu, sigma, n=101): 20 | """Makes xs and ys for a normal PDF with (mu, sigma). 21 | 22 | n: number of places to evaluate the PDF 23 | """ 24 | xs = numpy.linspace(mu-4*sigma, mu+4*sigma, n) 25 | ys = [thinkbayes2.EvalNormalPdf(x, mu, sigma) for x in xs] 26 | return xs, ys 27 | 28 | 29 | def main(): 30 | xs, ys = RenderPdf(100, 15) 31 | 32 | n = 34 33 | pyplot.fill_between(xs[-n:], ys[-n:], y2=0.0001, color='blue', alpha=0.2) 34 | s = 'Congratulations!\nIf you got this far,\nyou must be here.' 35 | d = dict(shrink=0.05) 36 | pyplot.annotate(s, [127, 0.002], xytext=[80, 0.005], arrowprops=d) 37 | 38 | thinkplot.Plot(xs, ys) 39 | thinkplot.Show(title='Distribution of IQ', 40 | xlabel='IQ', 41 | ylabel='PDF', 42 | legend=False) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /train_soln.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | class Train(thinkbayes.Suite): 15 | """Represents hypotheses about how many trains the company has. 16 | 17 | The likelihood function for the train problem is the same as 18 | for the Dice problem. 19 | """ 20 | def Likelihood(self, data, hypo): 21 | """Computes the likelihood of the data under the hypothesis. 22 | 23 | hypo: number of trains the carrier operates 24 | data: the number of the observed train 25 | """ 26 | if hypo < data: 27 | return 0 28 | else: 29 | return 1.0/hypo 30 | 31 | 32 | 33 | def main(): 34 | hypos = range(100, 1001) 35 | suite = Train(hypos) 36 | 37 | suite.Update(321) 38 | print('Posterior mean', suite.Mean()) 39 | print('Posterior MLE', suite.MaximumLikelihood()) 40 | print('Posterior CI 90', suite.CredibleInterval(90)) 41 | 42 | thinkplot.PrePlot(1) 43 | thinkplot.Pmf(suite) 44 | thinkplot.Show(xlabel='Number of trains', 45 | ylabel='Probability', 46 | legend=False) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /euro.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | """This file contains a partial solution to a problem from 15 | MacKay, "Information Theory, Inference, and Learning Algorithms." 16 | 17 | Exercise 3.15 (page 50): A statistical statement appeared in 18 | "The Guardian" on Friday January 4, 2002: 19 | 20 | When spun on edge 250 times, a Belgian one-euro coin came 21 | up heads 140 times and tails 110. 'It looks very suspicious 22 | to me,' said Barry Blight, a statistics lecturer at the London 23 | School of Economics. 'If the coin weere unbiased, the chance of 24 | getting a result as extreme as that would be less than 7%.' 25 | 26 | MacKay asks, "But do these data give evidence that the coin is biased 27 | rather than fair?" 28 | 29 | """ 30 | 31 | class Euro(thinkbayes.Suite): 32 | 33 | def Likelihood(self, data, hypo): 34 | """Computes the likelihood of the data under the hypothesis. 35 | 36 | hypo: integer value of x, the probability of heads (0-100) 37 | data: string 'H' or 'T' 38 | """ 39 | # fill this in! 40 | return 1 41 | 42 | 43 | def main(): 44 | suite = Euro(range(0, 101)) 45 | 46 | suite.Update('H') 47 | 48 | thinkplot.Pdf(suite) 49 | thinkplot.Show(xlabel='x', 50 | ylabel='Probability', 51 | legend=False) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /euro_soln.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | """This file contains a partial solution to a problem from 15 | MacKay, "Information Theory, Inference, and Learning Algorithms." 16 | 17 | Exercise 3.15 (page 50): A statistical statement appeared in 18 | "The Guardian" on Friday January 4, 2002: 19 | 20 | When spun on edge 250 times, a Belgian one-euro coin came 21 | up heads 140 times and tails 110. 'It looks very suspicious 22 | to me,' said Barry Blight, a statistics lecturer at the London 23 | School of Economics. 'If the coin weere unbiased, the chance of 24 | getting a result as extreme as that would be less than 7%.' 25 | 26 | MacKay asks, "But do these data give evidence that the coin is biased 27 | rather than fair?" 28 | 29 | """ 30 | 31 | class Euro(thinkbayes.Suite): 32 | 33 | def Likelihood(self, data, hypo): 34 | """Computes the likelihood of the data under the hypothesis. 35 | 36 | hypo: integer value of x, the probability of heads (0-100) 37 | data: string 'H' or 'T' 38 | """ 39 | x = hypo / 100.0 40 | if data == 'H': 41 | return x 42 | else: 43 | return 1-x 44 | 45 | 46 | def main(): 47 | suite = Euro(range(0, 101)) 48 | 49 | suite.Update('H') 50 | 51 | thinkplot.Pdf(suite) 52 | thinkplot.Show(xlabel='x', 53 | ylabel='Probability', 54 | legend=False) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /billiards.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2015 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import numpy 11 | import thinkbayes 12 | import thinkplot 13 | 14 | 15 | """ 16 | This problem presents a solution to the "Bayesian Billiards Problem", 17 | presented in this video: 18 | 19 | https://www.youtube.com/watch?v=KhAUfqhLakw 20 | 21 | Based on the formulation in this paper: 22 | 23 | http://www.nature.com/nbt/journal/v22/n9/full/nbt0904-1177.html 24 | 25 | Of a problem originally posed by Bayes himself. 26 | """ 27 | 28 | class Billiards(thinkbayes.Suite): 29 | 30 | def Likelihood(self, data, hypo): 31 | """Computes the likelihood of the data under the hypothesis. 32 | 33 | data: tuple (#wins, #losses) 34 | hypo: float probability of win 35 | """ 36 | p = hypo 37 | win, lose = data 38 | like = p**win * (1-p)**lose 39 | return like 40 | 41 | 42 | def ProbWinMatch(pmf): 43 | total = 0 44 | for p, prob in pmf.Items(): 45 | total += prob * (1-p)**3 46 | return total 47 | 48 | 49 | def main(): 50 | ps = numpy.linspace(0, 1, 101) 51 | bill = Billiards(ps) 52 | bill.Update((5, 3)) 53 | thinkplot.Pdf(bill) 54 | thinkplot.Save(root='billiards1', 55 | xlabel='probability of win', 56 | ylabel='PDF', 57 | formats=['png']) 58 | 59 | bayes_result = ProbWinMatch(bill) 60 | print(thinkbayes.Odds(1-bayes_result)) 61 | 62 | mle = 5 / 8 63 | freq_result = (1-mle)**3 64 | print(thinkbayes.Odds(1-freq_result)) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /train2.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | class Train(thinkbayes.Suite): 15 | """Represents hypotheses about how many trains the company has. 16 | 17 | The likelihood function for the train problem is the same as 18 | for the Dice problem. 19 | """ 20 | def Likelihood(self, data, hypo): 21 | """Computes the likelihood of the data under the hypothesis. 22 | 23 | hypo: number of trains the carrier operates 24 | data: the number of the observed train 25 | """ 26 | if hypo < data: 27 | return 0 28 | else: 29 | return 1.0/hypo 30 | 31 | 32 | 33 | def main(): 34 | hypos = range(1, 101) 35 | suite = Train(hypos) 36 | 37 | suite.Update(25) 38 | print('Posterior mean', suite.Mean()) 39 | print('Posterior MLE', suite.MaximumLikelihood()) 40 | print('Posterior CI 90', suite.CredibleInterval(90)) 41 | 42 | thinkplot.PrePlot(1) 43 | thinkplot.Pmf(suite, linewidth=5) 44 | thinkplot.Save(root='train2', 45 | xlabel='Number of trains', 46 | ylabel='Probability', 47 | formats=['png']) 48 | 49 | thinkplot.Pmf(suite, linewidth=5, color='0.8') 50 | suite.Update(42) 51 | print('Posterior mean', suite.Mean()) 52 | print('Posterior MLE', suite.MaximumLikelihood()) 53 | print('Posterior CI 90', suite.CredibleInterval(90)) 54 | 55 | thinkplot.PrePlot(1) 56 | thinkplot.Pmf(suite, linewidth=5) 57 | thinkplot.Save(root='train3', 58 | xlabel='Number of trains', 59 | ylabel='Probability', 60 | formats=['png']) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /sat_ranks.csv: -------------------------------------------------------------------------------- 1 | "http://professionals.collegeboard.com/profdownload/sat-mathematics-percentile-ranks-2010.pdf",,,,,, 2 | ,"Total",,"Male",,"Female", 3 | "Score","Number","Percentile","Number","Percentile","Number","Percentile" 4 | 800,11959,99,8072,99,3887,"99+" 5 | 790,3588,99,2327,99,1261,99 6 | 780,3770,99,2377,98,1393,99 7 | 770,9663,98,6340,97,3323,99 8 | 760,7016,98,4413,97,2603,98 9 | 750,10313,97,6721,96,3592,98 10 | 740,7542,97,4612,95,2930,98 11 | 730,9721,96,6004,94,3717,97 12 | 720,7674,95,4664,94,3010,97 13 | 710,15694,94,9508,92,6186,96 14 | 700,17394,93,10568,91,6826,95 15 | 690,21305,92,12858,89,8447,94 16 | 680,22953,90,13552,87,9401,93 17 | 670,24377,89,14205,85,10172,92 18 | 660,27356,87,15564,83,11792,91 19 | 650,26303,85,14685,81,11618,89 20 | 640,26844,84,14906,79,11938,88 21 | 630,28425,82,15675,77,12750,86 22 | 620,31515,80,17018,74,14497,84 23 | 610,44082,77,23042,71,21040,82 24 | 600,33358,75,17325,69,16033,80 25 | 590,38788,72,20170,66,18618,78 26 | 580,40879,70,20699,63,20180,75 27 | 570,51731,66,25790,60,25941,72 28 | 560,38232,64,18782,57,19450,70 29 | 550,38473,61,18659,54,19814,67 30 | 540,54552,58,26165,51,28387,64 31 | 530,55065,54,25373,47,29692,60 32 | 520,43599,51,20022,44,23577,57 33 | 510,45452,48,20673,42,24779,54 34 | 500,61084,45,26823,38,34261,50 35 | 490,57090,41,25113,34,31977,46 36 | 480,40871,38,17594,32,23277,44 37 | 470,47521,35,20181,29,27340,40 38 | 460,58137,31,23544,26,34593,36 39 | 450,50172,28,20777,23,29395,33 40 | 440,39180,26,15576,21,23604,30 41 | 430,46320,23,18412,18,27908,26 42 | 420,47752,20,18154,16,29598,23 43 | 410,41915,17,16052,14,25863,20 44 | 400,30052,15,11210,12,18842,17 45 | 390,32466,13,12140,10,20326,15 46 | 380,27135,11,9984,9,17151,13 47 | 370,25852,9,9484,8,16368,11 48 | 360,23847,8,8809,6,15038,9 49 | 350,20800,6,7755,5,13045,7 50 | 340,17639,5,6554,4,11085,6 51 | 330,14623,4,5371,4,9252,5 52 | 320,11400,4,4295,3,7105,4 53 | 310,11776,3,4411,2,7365,3 54 | 300,8430,2,3300,2,5130,3 55 | 290,7358,2,2898,2,4460,2 56 | 280,7358,1,2872,1,4486,2 57 | 270,3194,1,1301,1,1893,1 58 | 260,3495,1,1375,1,2120,1 59 | 250,2163,1,821,1,1342,1 60 | 240,3278,1,1303,1,1975,1 61 | 230,1454,1,555,"1-",899,1 62 | 220,2519,"1-",991,"1-",1528,"1-" 63 | 210,1221,"1-",501,"1-",720,"1-" 64 | 200,4265,"–",1868,"–",2397,"–" 65 | ,1547990,,720793,,827197, 66 | ,,,65606,,38728, 67 | ,,,0.0910191969,,0.0468183516, 68 | -------------------------------------------------------------------------------- /euro2.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | 12 | 13 | """This file contains a partial solution to a problem from 14 | MacKay, "Information Theory, Inference, and Learning Algorithms." 15 | 16 | Exercise 3.15 (page 50): A statistical statement appeared in 17 | "The Guardian" on Friday January 4, 2002: 18 | 19 | When spun on edge 250 times, a Belgian one-euro coin came 20 | up heads 140 times and tails 110. 'It looks very suspicious 21 | to me,' said Barry Blight, a statistics lecturer at the London 22 | School of Economics. 'If the coin weere unbiased, the chance of 23 | getting a result as extreme as that would be less than 7%.' 24 | 25 | MacKay asks, "But do these data give evidence that the coin is biased 26 | rather than fair?" 27 | 28 | """ 29 | 30 | 31 | class Euro(thinkbayes.Suite): 32 | 33 | def Likelihood(self, data, hypo): 34 | """Computes the likelihood of the data under the hypothesis. 35 | 36 | data: tuple (#heads, #tails) 37 | hypo: integer value of x, the probability of heads (0-100) 38 | """ 39 | x = hypo / 100.0 40 | heads, tails = data 41 | like = x**heads * (1-x)**tails 42 | return like 43 | 44 | 45 | def AverageLikelihood(suite, data): 46 | """Computes the average likelihood over all hypothesis in suite. 47 | 48 | Args: 49 | suite: Suite of hypotheses 50 | data: some representation of the observed data 51 | 52 | Returns: 53 | float 54 | """ 55 | total = 0 56 | 57 | for hypo, prob in suite.Items(): 58 | like = suite.Likelihood(data, hypo) 59 | total += prob * like 60 | 61 | return total 62 | 63 | 64 | def main(): 65 | fair = Euro() 66 | fair.Set(50, 1) 67 | 68 | bias = Euro() 69 | for x in range(0, 101): 70 | if x != 50: 71 | bias.Set(x, 1) 72 | bias.Normalize() 73 | 74 | # notice that we've changed the representation of the data 75 | data = 140, 110 76 | 77 | like_bias = AverageLikelihood(bias, data) 78 | print('like_bias', like_bias) 79 | 80 | like_fair = AverageLikelihood(fair, data) 81 | print('like_fair', like_fair) 82 | 83 | ratio = like_bias / like_fair 84 | print('Bayes factor', ratio) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /euro2_soln.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | 14 | """This file contains a partial solution to a problem from 15 | MacKay, "Information Theory, Inference, and Learning Algorithms." 16 | 17 | Exercise 3.15 (page 50): A statistical statement appeared in 18 | "The Guardian" on Friday January 4, 2002: 19 | 20 | When spun on edge 250 times, a Belgian one-euro coin came 21 | up heads 140 times and tails 110. 'It looks very suspicious 22 | to me,' said Barry Blight, a statistics lecturer at the London 23 | School of Economics. 'If the coin weere unbiased, the chance of 24 | getting a result as extreme as that would be less than 7%.' 25 | 26 | MacKay asks, "But do these data give evidence that the coin is biased 27 | rather than fair?" 28 | 29 | """ 30 | 31 | class Euro(thinkbayes.Suite): 32 | 33 | def Likelihood(self, data, hypo): 34 | """Computes the likelihood of the data under the hypothesis. 35 | 36 | data: tuple (#heads, #tails) 37 | hypo: integer value of x, the probability of heads (0-100) 38 | """ 39 | x = hypo / 100.0 40 | heads, tails = data 41 | like = x**heads * (1-x)**tails 42 | return like 43 | 44 | 45 | def AverageLikelihood(suite, data): 46 | """Computes the average likelihood over all hypothesis in suite. 47 | 48 | Args: 49 | suite: Suite of hypotheses 50 | data: some representation of the observed data 51 | 52 | Returns: 53 | float 54 | """ 55 | total = 0 56 | 57 | for hypo, prob in suite.Items(): 58 | like = suite.Likelihood(data, hypo) 59 | total += prob * like 60 | 61 | return total 62 | 63 | 64 | def main(): 65 | fair = Euro() 66 | fair.Set(50, 1) 67 | 68 | bias = Euro() 69 | for x in range(0, 51): 70 | bias.Set(x, x) 71 | for x in range(51, 101): 72 | bias.Set(x, 100-x) 73 | bias.Normalize() 74 | 75 | thinkplot.Pdf(bias) 76 | thinkplot.Show() 77 | 78 | # notice that we've changed the representation of the data 79 | data = 140, 110 80 | 81 | like_bias = AverageLikelihood(bias, data) 82 | print('like_bias', like_bias) 83 | 84 | like_fair = AverageLikelihood(fair, data) 85 | print('like_fair', like_fair) 86 | 87 | ratio = like_bias / like_fair 88 | print('Bayes factor', ratio) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /lincoln.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2014 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | import numpy 14 | 15 | """ 16 | Bayesian solution to the Lincoln index, described in a blog 17 | article at Probably Overthinking It. 18 | 19 | Last year my occasional correspondent John D. Cook wrote an excellent 20 | blog post about the Lincoln index, which is a way to estimate the 21 | number of errors in a document (or program) by comparing results from 22 | two independent testers. Here's his presentation of the problem: 23 | 24 | "Suppose you have a tester who finds 20 bugs in your program. You 25 | want to estimate how many bugs are really in the program. You know 26 | there are at least 20 bugs, and if you have supreme confidence in your 27 | tester, you may suppose there are around 20 bugs. But maybe your 28 | tester isn't very good. Maybe there are hundreds of bugs. How can you 29 | have any idea how many bugs there are? There's no way to know with one 30 | tester. But if you have two testers, you can get a good idea, even if 31 | you don't know how skilled the testers are." 32 | 33 | Then he presents the Lincoln index, an estimator "described by 34 | Frederick Charles Lincoln in 1930," where Wikpedia's use of 35 | "described" is a hint that the index is another example of Stigler's 36 | law of eponymy. 37 | 38 | "Suppose two testers independently search for bugs. Let k1 be the 39 | number of errors the first tester finds and k2 the number of errors 40 | the second tester finds. Let c be the number of errors both testers 41 | find. The Lincoln Index estimates the total number of errors as k1 k2 42 | / c [I changed his notation to be consistent with mine]." 43 | 44 | So if the first tester finds 20 bugs, the second finds 15, and they 45 | find 3 in common, we estimate that there are about 100 bugs. 46 | 47 | Of course, whenever I see something like this, the idea that pops into 48 | my head is that there must be a (better) Bayesian solution! And there 49 | is. 50 | 51 | """ 52 | 53 | def choose(n, k, d={}): 54 | """The binomial coefficient "n choose k". 55 | 56 | Args: 57 | n: number of trials 58 | k: number of successes 59 | d: map from (n,k) tuples to cached results 60 | 61 | Returns: 62 | int 63 | """ 64 | if k == 0: 65 | return 1 66 | if n == 0: 67 | return 0 68 | 69 | try: 70 | return d[n, k] 71 | except KeyError: 72 | res = choose(n-1, k) + choose(n-1, k-1) 73 | d[n, k] = res 74 | return res 75 | 76 | def binom(k, n, p): 77 | """Computes the rest of the binomial PMF. 78 | 79 | k: number of hits 80 | n: number of attempts 81 | p: probability of a hit 82 | """ 83 | return p**k * (1-p)**(n-k) 84 | 85 | 86 | class Lincoln(thinkbayes.Suite, thinkbayes.Joint): 87 | """Represents hypotheses about the number of errors.""" 88 | 89 | def Likelihood(self, data, hypo): 90 | """Computes the likelihood of the data under the hypothesis. 91 | 92 | hypo: n, p1, p2 93 | data: k1, k2, c 94 | """ 95 | n, p1, p2 = hypo 96 | k1, k2, c = data 97 | 98 | part1 = choose(n, k1) * binom(k1, n, p1) 99 | part2 = choose(k1, c) * choose(n-k1, k2-c) * binom(k2, n, p2) 100 | return part1 * part2 101 | 102 | 103 | def main(): 104 | data = 20, 15, 3 105 | probs = numpy.linspace(0, 1, 101) 106 | hypos = [] 107 | for n in range(32, 350): 108 | for p1 in probs: 109 | for p2 in probs: 110 | hypos.append((n, p1, p2)) 111 | 112 | suite = Lincoln(hypos) 113 | suite.Update(data) 114 | 115 | n_marginal = suite.Marginal(0) 116 | 117 | thinkplot.Pmf(n_marginal, label='n') 118 | thinkplot.Save(root='lincoln1', 119 | xlabel='number of bugs', 120 | ylabel='PMF', 121 | formats=['pdf', 'png']) 122 | 123 | print(n_marginal.Mean()) 124 | print(n_marginal.MaximumLikelihood()) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /volunteer.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2013 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import thinkbayes 11 | import thinkplot 12 | 13 | import numpy 14 | 15 | """ 16 | Problem: students sign up to participate in a community service 17 | project. Some fraction, q, of the students who sign up actually 18 | participate, and of those some fraction, r, report back. 19 | 20 | Given a sample of students who sign up and the number who report 21 | back, we can estimate the product q*r, but don't learn much about 22 | q and r separately. 23 | 24 | If we can get a smaller sample of students where we know who 25 | participated and who reported, we can use that to improve the 26 | estimates of q and r. 27 | 28 | And we can use that to compute the posterior distribution of the 29 | number of students who participated. 30 | 31 | """ 32 | 33 | class Volunteer(thinkbayes.Suite): 34 | 35 | def Likelihood(self, data, hypo): 36 | """Computes the likelihood of the data under the hypothesis. 37 | 38 | hypo: pair of (q, r) 39 | data: one of two possible formats 40 | """ 41 | if len(data) == 2: 42 | return self.Likelihood1(data, hypo) 43 | elif len(data) == 3: 44 | return self.Likelihood2(data, hypo) 45 | else: 46 | raise ValueError() 47 | 48 | def Likelihood1(self, data, hypo): 49 | """Computes the likelihood of the data under the hypothesis. 50 | 51 | hypo: pair of (q, r) 52 | data: tuple (signed up, reported) 53 | """ 54 | q, r = hypo 55 | p = q * r 56 | signed_up, reported = data 57 | yes = reported 58 | no = signed_up - reported 59 | 60 | like = p**yes * (1-p)**no 61 | return like 62 | 63 | def Likelihood2(self, data, hypo): 64 | """Computes the likelihood of the data under the hypothesis. 65 | 66 | hypo: pair of (q, r) 67 | data: tuple (signed up, participated, reported) 68 | """ 69 | q, r = hypo 70 | 71 | signed_up, participated, reported = data 72 | 73 | yes = participated 74 | no = signed_up - participated 75 | like1 = q**yes * (1-q)**no 76 | 77 | yes = reported 78 | no = participated - reported 79 | like2 = r**yes * (1-r)**no 80 | 81 | return like1 * like2 82 | 83 | 84 | def MarginalDistribution(suite, index): 85 | """Extracts the marginal distribution of one parameter. 86 | 87 | suite: Suite 88 | index: which parameter 89 | 90 | returns: Pmf 91 | """ 92 | pmf = thinkbayes.Pmf() 93 | for t, prob in suite.Items(): 94 | pmf.Incr(t[index], prob) 95 | return pmf 96 | 97 | 98 | def MarginalProduct(suite): 99 | """Extracts the distribution of the product of the parameters. 100 | 101 | suite: Suite 102 | 103 | returns: Pmf 104 | """ 105 | pmf = thinkbayes.Pmf() 106 | for (q, r), prob in suite.Items(): 107 | pmf.Incr(q*r, prob) 108 | return pmf 109 | 110 | 111 | def main(): 112 | probs = numpy.linspace(0, 1, 101) 113 | 114 | hypos = [] 115 | for q in probs: 116 | for r in probs: 117 | hypos.append((q, r)) 118 | 119 | suite = Volunteer(hypos) 120 | 121 | # update the Suite with the larger sample of students who 122 | # signed up and reported 123 | data = 140, 50 124 | suite.Update(data) 125 | 126 | # update again with the smaller sample of students who signed 127 | # up, participated, and reported 128 | data = 5, 3, 1 129 | suite.Update(data) 130 | 131 | #p_marginal = MarginalProduct(suite) 132 | q_marginal = MarginalDistribution(suite, 0) 133 | r_marginal = MarginalDistribution(suite, 1) 134 | 135 | thinkplot.Pmf(q_marginal, label='q') 136 | thinkplot.Pmf(r_marginal, label='r') 137 | #thinkplot.Pmf(p_marginal) 138 | 139 | thinkplot.Save(root='volunteer1', 140 | xlabel='fraction participating/reporting', 141 | ylabel='PMF', 142 | formats=['png'] 143 | ) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /tutorial.md: -------------------------------------------------------------------------------- 1 | ## Tutorial: Bayes Made Simple 2 | 3 | Allen Downey 4 | 5 | The tutorial material is based on my book, [*Think Bayes*](http://greenteapress.com/wp/think-bayes/), 6 | a class I teach at Olin College, and my blog, [“Probably Overthinking It.”](http://allendowney.com/blog) 7 | 8 | 9 | ### Installation instructions 10 | 11 | Note: Please try to install everything you need for this tutorial before you leave home! 12 | 13 | To prepare for this tutorial, you have two options: 14 | 15 | 1. Install Jupyter on your laptop and download my code from GitHub. 16 | 17 | 2. Run the Jupyter notebooks on a virtual machine on Binder. 18 | 19 | I'll provide instructions for both, but here's the catch: if everyone chooses Option 2, 20 | the wireless network will fail and no one will be able to do the hands-on part of the workshop. 21 | 22 | So, I strongly encourage you to try Option 1 and only resort to Option 2 if you can't get Option 1 working. 23 | 24 | 25 | 26 | #### Option 1A: If you already have Jupyter installed. 27 | 28 | To do the exercises, you need Python 2 or 3 with NumPy, SciPy, and matplotlib. 29 | If you are not sure whether you have those modules already, the easiest way to check is to run my code and see if it works. 30 | 31 | Code for this workshop is in a Git repository on Github. 32 | If you have a Git client installed, you should be able to download it by running: 33 | 34 | git clone https://github.com/AllenDowney/BayesMadeSimple 35 | 36 | It should create a directory named `BayesMadeSimple`. 37 | Otherwise you can download the repository in [this zip file](https://github.com/AllenDowney/BayesMadeSimple/archive/master.zip). 38 | 39 | To start Jupyter, run: 40 | 41 | cd BayesMadeSimple 42 | jupyter notebook 43 | 44 | Jupyter should launch your default browser or open a tab in an existing browser window. 45 | If not, the Jupyter server should print a URL you can use. For example, when I launch Jupyter, I get 46 | 47 | ``` 48 | ~/ThinkComplexity2$ jupyter notebook 49 | [I 10:03:20.115 NotebookApp] Serving notebooks from local directory: /home/downey/BayesMadeSimple 50 | [I 10:03:20.115 NotebookApp] 0 active kernels 51 | [I 10:03:20.115 NotebookApp] The Jupyter Notebook is running at: http://localhost:8888/ 52 | [I 10:03:20.115 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation). 53 | ``` 54 | 55 | In this case, the URL is [http://localhost:8888](http://localhost:8888). 56 | When you start your server, you might get a different URL. 57 | Whatever it is, if you paste it into a browser, you should should see a home page with a list of the 58 | notebooks in the repository. 59 | 60 | Click on `workshop01.ipynb`. It should open the first notebook for the tutorial. 61 | 62 | Select the cell with the import statements and press "Shift-Enter" to run the code in the cell. 63 | If it works and you get no error messages, **you are all set**. 64 | 65 | If you get error messages about missing packages, you can install the packages you need using your package manager, 66 | or try Option 1B and install Anaconda. 67 | 68 | 69 | #### Option 1B: If you don't already have Jupyter. 70 | 71 | I highly recommend installing Anaconda, which is a Python distribution that contains everything 72 | you need for this tutorial. It is easy to install on Windows, Mac, and Linux, and because it does a 73 | user-level install, it will not interfere with other Python installations. 74 | 75 | [Information about installing Anaconda is here](http://docs.continuum.io/anaconda/install.html). 76 | 77 | When you install Anaconda, you should get Jupyter by default, but if not, run 78 | 79 | conda install jupyter 80 | 81 | Then go to Option 1A to make sure you can run my code. 82 | 83 | If you don't want to install Anaconda, 84 | [you can see some other options here](http://jupyter.readthedocs.io/en/latest/install.html). 85 | 86 | 87 | #### Option 2: only if Option 1 failed. 88 | 89 | You can run my notebook in a virtual machine on Binder. To launch the VM, press this button: 90 | 91 | [![Binder](http://mybinder.org/badge.svg)](http://mybinder.org:/repo/allendowney/BayesMadeSimple) 92 | 93 | You should see a home page with a list of the files in the repository. 94 | 95 | If you want to try the exercises, open `workshop01.ipynb`. 96 | You should be able to run the notebooks in your browser and try out the examples. 97 | 98 | However, be aware that the virtual machine you are running is temporary. 99 | If you leave it idle for more than an hour or so, it will disappear along with any work you have done. 100 | 101 | Special thanks to the generous people who run Binder, which makes it easy to share and reproduce computation. 102 | -------------------------------------------------------------------------------- /sat.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import csv 11 | 12 | import thinkbayes 13 | import thinkplot 14 | 15 | 16 | def ReadScale(filename='sat_scale.csv', col=2): 17 | """Reads a CSV file of SAT scales (maps from raw score to standard score). 18 | 19 | Args: 20 | filename: string filename 21 | col: which column to start with (0=Reading, 2=Math, 4=Writing) 22 | 23 | Returns: thinkbayes.Interpolator object 24 | """ 25 | def ParseRange(s): 26 | t = [int(x) for x in s.split('-')] 27 | return 1.0 * sum(t) / len(t) 28 | 29 | fp = open(filename) 30 | reader = csv.reader(fp) 31 | raws = [] 32 | scores = [] 33 | 34 | for t in reader: 35 | try: 36 | raw = int(t[col]) 37 | raws.append(raw) 38 | score = ParseRange(t[col+1]) 39 | scores.append(score) 40 | except: 41 | pass 42 | 43 | raws.sort() 44 | scores.sort() 45 | return thinkbayes.Interpolator(raws, scores) 46 | 47 | 48 | def ReadRanks(filename='sat_ranks.csv'): 49 | """Reads a CSV file of SAT scores. 50 | 51 | Args: 52 | filename: string filename 53 | 54 | Returns: 55 | list of (score, freq) pairs 56 | """ 57 | fp = open(filename) 58 | reader = csv.reader(fp) 59 | res = [] 60 | 61 | for t in reader: 62 | try: 63 | score = int(t[0]) 64 | freq = int(t[1]) 65 | res.append((score, freq)) 66 | except ValueError: 67 | pass 68 | 69 | return res 70 | 71 | 72 | def DivideValues(pmf, denom): 73 | """Divides the values in a Pmf by denom. 74 | 75 | Returns a new Pmf. 76 | """ 77 | new = thinkbayes.Pmf() 78 | denom = float(denom) 79 | for val, prob in pmf.Items(): 80 | x = val / denom 81 | new.Set(x, prob) 82 | return new 83 | 84 | 85 | class Exam(object): 86 | """Encapsulates information about an exam. 87 | 88 | Contains the distribution of scaled scores and an 89 | Interpolator that maps between scaled and raw scores. 90 | """ 91 | def __init__(self): 92 | self.scale = ReadScale() 93 | 94 | scores = ReadRanks() 95 | score_pmf = thinkbayes.MakePmfFromDict(dict(scores)) 96 | 97 | self.raw = self.ReverseScale(score_pmf) 98 | self.max_score = max(self.raw.Values()) 99 | self.prior = DivideValues(self.raw, denom=self.max_score) 100 | 101 | def Lookup(self, raw): 102 | """Looks up a raw score and returns a scaled score.""" 103 | return self.scale.Lookup(raw) 104 | 105 | def Reverse(self, score): 106 | """Looks up a scaled score and returns a raw score. 107 | 108 | Since we ignore the penalty, negative scores round up to zero. 109 | """ 110 | raw = self.scale.Reverse(score) 111 | return raw if raw > 0 else 0 112 | 113 | def ReverseScale(self, pmf): 114 | """Applies the reverse scale to the values of a PMF. 115 | 116 | Args: 117 | pmf: Pmf object 118 | scale: Interpolator object 119 | 120 | Returns: 121 | new Pmf 122 | """ 123 | new = thinkbayes.Pmf() 124 | for val, prob in pmf.Items(): 125 | raw = self.Reverse(val) 126 | new.Incr(raw, prob) 127 | return new 128 | 129 | 130 | class Sat(thinkbayes.Suite): 131 | """Represents the distribution of efficacy for a test-taker.""" 132 | 133 | def __init__(self, exam): 134 | thinkbayes.Suite.__init__(self) 135 | 136 | self.exam = exam 137 | 138 | # start with the prior distribution 139 | for x, prob in exam.prior.Items(): 140 | self.Set(x, prob) 141 | 142 | def Likelihood(self, data, hypo): 143 | """Computes the likelihood of a test score, given x.""" 144 | x = hypo 145 | score = data 146 | raw = self.exam.Reverse(score) 147 | 148 | yes, no = raw, self.exam.max_score - raw 149 | like = x**yes * (1-x)**no 150 | return like 151 | 152 | 153 | def PmfProbGreater(pmf1, pmf2): 154 | """Probability that a value from pmf1 is less than a value from pmf2. 155 | 156 | Args: 157 | pmf1: Pmf object 158 | pmf2: Pmf object 159 | 160 | Returns: 161 | float probability 162 | """ 163 | total = 0.0 164 | for x1, p1 in pmf1.Items(): 165 | for x2, p2 in pmf2.Items(): 166 | # Fill this in! 167 | pass 168 | 169 | return total 170 | 171 | 172 | def main(): 173 | exam = Exam() 174 | 175 | alice = Sat(exam) 176 | alice.label = 'alice' 177 | alice.Update(780) 178 | 179 | bob = Sat(exam) 180 | bob.label = 'bob' 181 | bob.Update(760) 182 | 183 | print('Prob Alice is "smarter":', PmfProbGreater(alice, bob)) 184 | print('Prob Bob is "smarter":', PmfProbGreater(bob, alice)) 185 | 186 | thinkplot.PrePlot(2) 187 | thinkplot.Pdfs([alice, bob]) 188 | thinkplot.Show(xlabel='x', 189 | ylabel='Probability', 190 | loc='upper left', 191 | xlim=[0.7, 1.02]) 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /sat_soln.py: -------------------------------------------------------------------------------- 1 | """This file contains code used in "Think Bayes", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2012 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function, division 9 | 10 | import csv 11 | 12 | import thinkbayes 13 | import thinkplot 14 | 15 | 16 | def ReadScale(filename='sat_scale.csv', col=2): 17 | """Reads a CSV file of SAT scales (maps from raw score to standard score). 18 | 19 | Args: 20 | filename: string filename 21 | col: which column to start with (0=Reading, 2=Math, 4=Writing) 22 | 23 | Returns: thinkbayes.Interpolator object 24 | """ 25 | def ParseRange(s): 26 | t = [int(x) for x in s.split('-')] 27 | return 1.0 * sum(t) / len(t) 28 | 29 | fp = open(filename) 30 | reader = csv.reader(fp) 31 | raws = [] 32 | scores = [] 33 | 34 | for t in reader: 35 | try: 36 | raw = int(t[col]) 37 | raws.append(raw) 38 | score = ParseRange(t[col+1]) 39 | scores.append(score) 40 | except: 41 | pass 42 | 43 | raws.sort() 44 | scores.sort() 45 | return thinkbayes.Interpolator(raws, scores) 46 | 47 | 48 | def ReadRanks(filename='sat_ranks.csv'): 49 | """Reads a CSV file of SAT scores. 50 | 51 | Args: 52 | filename: string filename 53 | 54 | Returns: 55 | list of (score, freq) pairs 56 | """ 57 | fp = open(filename) 58 | reader = csv.reader(fp) 59 | res = [] 60 | 61 | for t in reader: 62 | try: 63 | score = int(t[0]) 64 | freq = int(t[1]) 65 | res.append((score, freq)) 66 | except ValueError: 67 | pass 68 | 69 | return res 70 | 71 | 72 | def DivideValues(pmf, denom): 73 | """Divides the values in a Pmf by denom. 74 | 75 | Returns a new Pmf. 76 | """ 77 | new = thinkbayes.Pmf() 78 | denom = float(denom) 79 | for val, prob in pmf.Items(): 80 | x = val / denom 81 | new.Set(x, prob) 82 | return new 83 | 84 | 85 | class Exam(object): 86 | """Encapsulates information about an exam. 87 | 88 | Contains the distribution of scaled scores and an 89 | Interpolator that maps between scaled and raw scores. 90 | """ 91 | def __init__(self): 92 | self.scale = ReadScale() 93 | 94 | scores = ReadRanks() 95 | score_pmf = thinkbayes.MakePmfFromDict(dict(scores)) 96 | 97 | self.raw = self.ReverseScale(score_pmf) 98 | self.max_score = max(self.raw.Values()) 99 | self.prior = DivideValues(self.raw, denom=self.max_score) 100 | 101 | def Lookup(self, raw): 102 | """Looks up a raw score and returns a scaled score.""" 103 | return self.scale.Lookup(raw) 104 | 105 | def Reverse(self, score): 106 | """Looks up a scaled score and returns a raw score. 107 | 108 | Since we ignore the penalty, negative scores round up to zero. 109 | """ 110 | raw = self.scale.Reverse(score) 111 | return raw if raw > 0 else 0 112 | 113 | def ReverseScale(self, pmf): 114 | """Applies the reverse scale to the values of a PMF. 115 | 116 | Args: 117 | pmf: Pmf object 118 | scale: Interpolator object 119 | 120 | Returns: 121 | new Pmf 122 | """ 123 | new = thinkbayes.Pmf() 124 | for val, prob in pmf.Items(): 125 | raw = self.Reverse(val) 126 | new.Incr(raw, prob) 127 | return new 128 | 129 | 130 | class Sat(thinkbayes.Suite): 131 | """Represents the distribution of efficacy for a test-taker.""" 132 | 133 | def __init__(self, exam): 134 | thinkbayes.Suite.__init__(self) 135 | 136 | self.exam = exam 137 | 138 | # start with the prior distribution 139 | for x, prob in exam.prior.Items(): 140 | self.Set(x, prob) 141 | 142 | def Likelihood(self, data, hypo): 143 | """Computes the likelihood of a test score, given x.""" 144 | x = hypo 145 | score = data 146 | raw = self.exam.Reverse(score) 147 | 148 | yes, no = raw, self.exam.max_score - raw 149 | like = x**yes * (1-x)**no 150 | return like 151 | 152 | 153 | def PmfProbGreater(pmf1, pmf2): 154 | """Probability that a value from pmf1 is less than a value from pmf2. 155 | 156 | Args: 157 | pmf1: Pmf object 158 | pmf2: Pmf object 159 | 160 | Returns: 161 | float probability 162 | """ 163 | total = 0.0 164 | for x1, p1 in pmf1.Items(): 165 | for x2, p2 in pmf2.Items(): 166 | if x1 > x2: 167 | total += p1 * p2 168 | 169 | return total 170 | 171 | 172 | def main(): 173 | exam = Exam() 174 | 175 | alice = Sat(exam) 176 | alice.label = 'alice' 177 | alice.Update(780) 178 | 179 | bob = Sat(exam) 180 | bob.label = 'bob' 181 | bob.Update(760) 182 | 183 | print('Prob Alice is "smarter":', PmfProbGreater(alice, bob)) 184 | print('Prob Bob is "smarter":', PmfProbGreater(bob, alice)) 185 | 186 | thinkplot.PrePlot(2) 187 | thinkplot.Pdfs([alice, bob]) 188 | thinkplot.Show(xlabel='x', 189 | ylabel='Probability', 190 | loc='upper left', 191 | xlim=[0.7, 1.02]) 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Bayesian Statistics Made Simple 2 | 3 | [Allen Downey](https:allendowney.com) 4 | 5 | Bayesian statistical methods are becoming more common, but there are not many resources to help beginners get started. People who know Python can use their programming skills to get a head start. 6 | 7 | In this tutorial, I introduce Bayesian methods using grid algorithms, which help develop understanding and prepare for MCMC, which is a powerful algorithm for real-world problems. 8 | 9 | It is based on my book, [*Think Bayes*](http://greenteapress.com/wp/think-bayes/), 10 | a [class I teach at Olin College](https://sites.google.com/site/compbayes18/), and my blog, [“Probably Overthinking It.”](http://allendowney.com/blog) 11 | 12 | [Slides for this tutorial are here](https://docs.google.com/presentation/d/e/2PACX-1vTUIf7LJJpUd4NzInBGRyHnHqoZ4E736sqd6Iwq_ne3_aDXdJlNgO8O57_USzQzFfDx0gA44fniKe5R/pub). 13 | 14 | 15 | ### Installation instructions 16 | 17 | Note: Please try to install everything you need for this tutorial before you leave home! 18 | 19 | To prepare for this tutorial, you have two options: 20 | 21 | 1. Install Jupyter on your laptop and download my code from GitHub. 22 | 23 | 2. Run the Jupyter notebooks on a virtual machine on Binder. 24 | 25 | I'll provide instructions for both, but here's the catch: if everyone chooses Option 2, the wireless network might not be able to handle the load. So, I strongly encourage you to try Option 1 and only resort to Option 2 if you can't get Option 1 working. 26 | 27 | 28 | 29 | #### Option 1A: If you already have Jupyter installed. 30 | 31 | Code for this workshop is in a Git repository on Github. 32 | You can download it in [this zip file](https://github.com/AllenDowney/BayesMadeSimple/archive/master.zip). When you unzip it, you should get a directory named `BayesMadeSimple`. 33 | 34 | Or, if you have a Git client installed, you can clone the repo by running: 35 | 36 | ``` 37 | git clone https://github.com/AllenDowney/BayesMadeSimple 38 | ``` 39 | 40 | It should create a directory named `BayesMadeSimple`. 41 | 42 | To run the notebooks, you need Python 3 with Jupyter, NumPy, SciPy, matplotlib and Seaborn. 43 | If you are not sure whether you have those modules already, the easiest way to check is to run my code and see if it works. 44 | 45 | You will also need a small library I wrote, called `empyrical-dist`. You can [see it on PyPI](https://pypi.org/project/empyrical-dist/) and you can install it using pip: 46 | 47 | 48 | ``` 49 | pip install empyrical-dist 50 | ``` 51 | 52 | To start Jupyter, run: 53 | 54 | ``` 55 | cd BayesMadeSimple 56 | jupyter notebook 57 | ``` 58 | 59 | Jupyter should launch your default browser or open a tab in an existing browser window. 60 | If not, the Jupyter server should print a URL you can use. For example, when I launch Jupyter, I get 61 | 62 | ``` 63 | ~/BayesMadeSimple$ jupyter notebook 64 | [I 10:03:20.115 NotebookApp] Serving notebooks from local directory: /home/downey/BayesMadeSimple 65 | [I 10:03:20.115 NotebookApp] 0 active kernels 66 | [I 10:03:20.115 NotebookApp] The Jupyter Notebook is running at: http://localhost:8888/ 67 | [I 10:03:20.115 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation). 68 | ``` 69 | 70 | In this case, the URL is [http://localhost:8888](http://localhost:8888). 71 | When you start your server, you might get a different URL. 72 | Whatever it is, if you paste it into a browser, you should should see a home page with a list of the 73 | notebooks in the repository. 74 | 75 | Click on `01_cookie.ipynb`. It should open the first notebook for the tutorial. 76 | 77 | Select the cell with the import statements and press "Shift-Enter" to run the code in the cell. 78 | If it works and you get no error messages, **you are all set**. 79 | 80 | If you get error messages about missing packages, you can install the packages you need using your package manager, 81 | or try Option 1B and install Anaconda. 82 | 83 | 84 | #### Option 1B: If you don't already have Jupyter. 85 | 86 | I highly recommend installing Anaconda, which is a Python distribution that contains everything 87 | you need for this tutorial. It is easy to install on Windows, Mac, and Linux, and because it does a 88 | user-level install, it will not interfere with other Python installations. 89 | 90 | [Information about installing Anaconda is here](https://www.anaconda.com/distribution/#download-section). 91 | 92 | Choose the Python 3.7 distribution. 93 | 94 | After you install Anaconda, you can install the packages you need like this: 95 | 96 | ``` 97 | conda install jupyter numpy scipy matplotlib seaborn 98 | pip install empyrical-dist 99 | ``` 100 | 101 | Or you can create a Conda environment just for the workshop, like this: 102 | 103 | ``` 104 | cd BayesMadeSimple 105 | conda env create -f environment.yml 106 | conda activate BayesMadeSimple 107 | ``` 108 | 109 | Then go to Option 1A to make sure you can run my code. 110 | 111 | 112 | #### Option 2: if Option 1 failed. 113 | 114 | You can run my notebook in a virtual machine on Binder. To launch the VM, press this button: 115 | 116 | [![Binder](http://mybinder.org/badge.svg)](https://mybinder.org/v2/gh/AllenDowney/BayesMadeSimple/master) 117 | 118 | You should see a home page with a list of the files in the repository. 119 | 120 | If you want to try the exercises, open `01_cookie.ipynb`. 121 | You should be able to run the notebooks in your browser and try out the examples. 122 | 123 | However, be aware that the virtual machine you are running is temporary. 124 | If you leave it idle for more than an hour or so, it will disappear along with any work you have done. 125 | 126 | Special thanks to the people who run Binder, which makes it easy to share and reproduce computation. 127 | 128 | -------------------------------------------------------------------------------- /debug.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/downey/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 13 | " from ._conv import register_converters as _register_converters\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "from __future__ import print_function, division\n", 19 | "\n", 20 | "%matplotlib inline\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import pymc3 as pm\n", 24 | "import scipy\n", 25 | "import seaborn as sns\n", 26 | "\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "import thinkbayes2\n", 30 | "import thinkplot" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "I want to predict the number of goals scored in the next game, where\n", 38 | "\n", 39 | "`goals ~ Poisson(mu)`\n", 40 | "\n", 41 | "`mu ~ Gamma(alpha, beta)`\n", 42 | "\n", 43 | "Suppose my posterior distribution for `mu` has `alpha=10`, `beta=5`." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "alpha = 10\n", 53 | "beta = 5" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "I can draw a sample from the posterior, and it has the mean I expect, `alpha/beta`" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "2.0014370180768606" 72 | ] 73 | }, 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "iters = 100000\n", 81 | "sample_mu = np.random.gamma(shape=alpha, scale=1/beta, size=iters)\n", 82 | "np.mean(sample_mu)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "2.0" 94 | ] 95 | }, 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "mu = alpha / beta\n", 103 | "mu" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "I can sample from the predictive distribution by drawing one Poisson sample for each sampled value of `mu`, and it has the mean I expect." 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "2.00996" 122 | ] 123 | }, 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "output_type": "execute_result" 127 | } 128 | ], 129 | "source": [ 130 | "sample_pred = np.random.poisson(sample_mu)\n", 131 | "np.mean(sample_pred)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "Now I'll try to do the same thing with pymc3.\n", 139 | "\n", 140 | "Pretending that `mu` is a known constant, I can sample from `Poisson(mu)` and I get the mean I expect." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 6, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "2.00449" 152 | ] 153 | }, 154 | "execution_count": 6, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "model = pm.Model()\n", 161 | "\n", 162 | "with model:\n", 163 | " goals = pm.Poisson('goals', mu)\n", 164 | " sample_pred_wrong_pm = goals.random(size=iters)\n", 165 | "\n", 166 | "np.mean(sample_pred_wrong_pm)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "And sampling from the posterior disrtribution of `mu`, I get the mean I expect." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 7, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "1.9981993583520818" 185 | ] 186 | }, 187 | "execution_count": 7, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "model = pm.Model()\n", 194 | "\n", 195 | "with model:\n", 196 | " mu = pm.Gamma('mu', alpha, beta)\n", 197 | " sample_post_pm = mu.random(size=iters)\n", 198 | "\n", 199 | "np.mean(sample_post_pm)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "But if I try to sample from the posterior predictive distribution (at least in the way I expected it to work), I don't get the mean I expect." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 8, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "1.37646" 218 | ] 219 | }, 220 | "execution_count": 8, 221 | "metadata": {}, 222 | "output_type": "execute_result" 223 | } 224 | ], 225 | "source": [ 226 | "model = pm.Model()\n", 227 | "\n", 228 | "with model:\n", 229 | " mu = pm.Gamma('mu', alpha, beta)\n", 230 | " goals = pm.Poisson('goals', mu)\n", 231 | " sample_pred_pm = goals.random(size=iters)\n", 232 | "\n", 233 | "np.mean(sample_pred_pm)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "It looks like it might be taking one sample from the Gamma distribution and using it to generate the entire sample of goals.\n", 241 | "\n", 242 | "I suspect something is wrong with my mental model of how to specify the model in pymc3." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [] 251 | } 252 | ], 253 | "metadata": { 254 | "kernelspec": { 255 | "display_name": "Python 3", 256 | "language": "python", 257 | "name": "python3" 258 | }, 259 | "language_info": { 260 | "codemirror_mode": { 261 | "name": "ipython", 262 | "version": 3 263 | }, 264 | "file_extension": ".py", 265 | "mimetype": "text/x-python", 266 | "name": "python", 267 | "nbconvert_exporter": "python", 268 | "pygments_lexer": "ipython3", 269 | "version": "3.6.4" 270 | } 271 | }, 272 | "nbformat": 4, 273 | "nbformat_minor": 1 274 | } 275 | -------------------------------------------------------------------------------- /02_dice.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Bayesian Statistics Made Simple\n", 8 | "===\n", 9 | "\n", 10 | "Code and exercises from my workshop on Bayesian statistics in Python.\n", 11 | "\n", 12 | "Copyright 2016 Allen Downey\n", 13 | "\n", 14 | "MIT License: https://opensource.org/licenses/MIT" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('white')\n", 30 | "sns.set_context('talk')\n", 31 | "\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from empiricaldist import Pmf" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### The dice problem\n", 42 | "\n", 43 | "Create a Suite to represent dice with different numbers of sides." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "dice = Pmf.from_seq([4, 6, 8, 12])\n", 53 | "dice" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "**Exercise 1:** We'll solve this problem two ways. First we'll do it \"by hand\", as we did with the cookie problem; that is, we'll multiply each hypothesis by the likelihood of the data, and then renormalize.\n", 61 | "\n", 62 | "In the space below, update `dice` based on the likelihood of the data (rolling a 6), then normalize and display the results." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Solution goes here" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "**Exercise 2:** Now let's do the same calculation using `Pmf.update`, which encodes the structure of a Bayesian update." 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "Define a function called `likelihood_dice` that takes `data` and `hypo` and returns the probability of the data (the outcome of rolling the die) for a given hypothesis (number of sides on the die).\n", 86 | "\n", 87 | "Hint: What should you do if the outcome exceeds the hypothetical number of sides on the die?\n", 88 | "\n", 89 | "Here's an outline to get you started." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def likelihood_dice(data, hypo):\n", 99 | " \"\"\"Likelihood function for the dice problem.\n", 100 | " \n", 101 | " data: outcome of the die roll\n", 102 | " hypo: number of sides\n", 103 | " \n", 104 | " returns: float probability\n", 105 | " \"\"\"\n", 106 | " # TODO: fill this in!\n", 107 | " return 1" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Solution goes here" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "Now we can create a `Pmf` object and update it." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "dice = Pmf.from_seq([4, 6, 8, 12])\n", 133 | "dice.update(likelihood_dice, 6)\n", 134 | "dice" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "If we get more data, we can perform more updates." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 7, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "for roll in [8, 7, 7, 5, 4]:\n", 151 | " dice.update(likelihood_dice, roll)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "Here are the results." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "dice" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### The German tank problem\n", 175 | "\n", 176 | "The German tank problem is actually identical to the dice problem." 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 9, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "def likelihood_tank(data, hypo):\n", 186 | " \"\"\"Likelihood function for the tank problem.\n", 187 | " \n", 188 | " data: observed serial number\n", 189 | " hypo: number of tanks\n", 190 | " \n", 191 | " returns: float probability\n", 192 | " \"\"\"\n", 193 | " if data > hypo:\n", 194 | " return 0\n", 195 | " else:\n", 196 | " return 1 / hypo" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "Here is the update after seeing Tank #42." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 10, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "tank = Pmf.from_seq(range(100))\n", 213 | "tank.update(likelihood_tank, 42)\n", 214 | "tank.mean()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "And here's what the posterior distribution looks like." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 11, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "def decorate_tank(title):\n", 231 | " \"\"\"Labels the axes.\n", 232 | " \n", 233 | " title: string\n", 234 | " \"\"\"\n", 235 | " plt.xlabel('Number of tanks')\n", 236 | " plt.ylabel('PMF')\n", 237 | " plt.title(title)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 12, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "tank.plot()\n", 247 | "decorate_tank('Distribution after one tank')" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "**Exercise 3:** Suppose we see another tank with serial number 17. What effect does this have on the posterior probabilities?\n", 255 | "\n", 256 | "Update the `Pmf` with the new data and plot the results." 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 13, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "# Solution goes here" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.7.3" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 1 297 | } 298 | -------------------------------------------------------------------------------- /01_cookie.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Bayesian Statistics Made Simple\n", 8 | "===\n", 9 | "\n", 10 | "Code and exercises from my workshop on Bayesian statistics in Python.\n", 11 | "\n", 12 | "Copyright 2016 Allen Downey\n", 13 | "\n", 14 | "MIT License: https://opensource.org/licenses/MIT" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('white')\n", 30 | "sns.set_context('talk')\n", 31 | "\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from empiricaldist import Pmf" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### Working with Pmfs\n", 42 | "\n", 43 | "Create a Pmf object to represent a six-sided die." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "d6 = Pmf()" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "A Pmf is a map from possible outcomes to their probabilities." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "for x in [1,2,3,4,5,6]:\n", 69 | " d6[x] = 1" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "Initially the probabilities don't add up to 1." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "d6" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "`normalize` adds up the probabilities and divides through. The return value is the total probability before normalizing." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "d6.normalize()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "Now the Pmf is normalized." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 6, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "d6" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "And we can compute its mean (which only works if it's normalized)." 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 7, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "d6.mean()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "`choice` chooses a random values from the Pmf." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "d6.choice(size=10)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "`bar` plots the Pmf as a bar chart" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 9, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def decorate_dice(title):\n", 166 | " \"\"\"Labels the axes.\n", 167 | " \n", 168 | " title: string\n", 169 | " \"\"\"\n", 170 | " plt.xlabel('Outcome')\n", 171 | " plt.ylabel('PMF')\n", 172 | " plt.title(title)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 10, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "d6.bar()\n", 182 | "decorate_dice('One die')" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "`d6.add_dist(d6)` creates a new `Pmf` that represents the sum of two six-sided dice." 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 11, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "twice = d6.add_dist(d6)\n", 199 | "twice" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "**Exercise 1:** Plot `twice` and compute its mean." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 12, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "# Solution goes here" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "**Exercise 2:** Suppose I roll two dice and tell you the result is greater than 3.\n", 223 | "\n", 224 | "Plot the `Pmf` of the remaining possible outcomes and compute its mean." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 13, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# Solution goes here" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "**Bonus exercise:** In Dungeons and Dragons, the amount of damage a [goblin](https://www.dndbeyond.com/monsters/goblin) can withstand is the sum of two six-sided dice. The amount of damage you inflict with a [short sword](https://www.dndbeyond.com/equipment/shortsword) is determined by rolling one six-sided die.\n", 241 | "\n", 242 | "Suppose you are fighting a goblin and you have already inflicted 3 points of damage. What is your probability of defeating the goblin with your next successful attack?\n", 243 | "\n", 244 | "Hint: `Pmf` provides comparator functions like `gt_dist` and `le_dist`, which compare two distributions and return a probability." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 14, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "# Solution goes here" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "### The cookie problem\n", 261 | "\n", 262 | "`Pmf.from_seq` makes a `Pmf` object from a sequence of values.\n", 263 | "\n", 264 | "Here's how we can use it to create a `Pmf` with two equally likely hypotheses." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 15, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "cookie = Pmf.from_seq(['Bowl 1', 'Bowl 2'])\n", 274 | "cookie" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "Now we can update each hypothesis with the likelihood of the data (a vanilla cookie)." 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 16, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "cookie['Bowl 1'] *= 0.75\n", 291 | "cookie['Bowl 2'] *= 0.5\n", 292 | "cookie.normalize()" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "And display the posterior probabilities." 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 17, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "cookie" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "**Exercise 3:** Suppose we put the first cookie back, stir, choose again from the same bowl, and get a chocolate cookie. \n", 316 | "\n", 317 | "What are the posterior probabilities after the second cookie?\n", 318 | "\n", 319 | "Hint: The posterior (after the first cookie) becomes the prior (before the second cookie)." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 18, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "# Solution goes here" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "**Exercise 4:** Instead of doing two updates, what if we collapse the two pieces of data into one update?\n", 336 | "\n", 337 | "Re-initialize `Pmf` with two equally likely hypotheses and perform one update based on two pieces of data, a vanilla cookie and a chocolate cookie.\n", 338 | "\n", 339 | "The result should be the same regardless of how many updates you do (or the order of updates)." 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 19, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "# Solution goes here" 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.7.3" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 1 373 | } 374 | -------------------------------------------------------------------------------- /03_euro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Bayesian Statistics Made Simple\n", 8 | "===\n", 9 | "\n", 10 | "Code and exercises from my workshop on Bayesian statistics in Python.\n", 11 | "\n", 12 | "Copyright 2016 Allen Downey\n", 13 | "\n", 14 | "MIT License: https://opensource.org/licenses/MIT" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('white')\n", 30 | "sns.set_context('talk')\n", 31 | "\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from empiricaldist import Pmf" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### The Euro problem\n", 42 | "\n", 43 | "*\"When spun on edge 250 times, a Belgian one-euro coin came up heads 140 times and tails 110. 'It looks very suspicious to me,' said Barry Blight, a statistics lecturer at the London School of Economics. 'If the coin were unbiased, the chance of getting a result as extreme as that would be less than 7%.' \"*\n", 44 | "\n", 45 | "From “The Guardian” quoted by MacKay, *Information Theory, Inference, and Learning Algorithms*.\n", 46 | "\n", 47 | "\n", 48 | "**Exercise 1:** Write a function called `likelihood_euro` that defines the likelihood function for the Euro problem. Note that `hypo` is in the range 0 to 100.\n", 49 | "\n", 50 | "Here's an outline to get you started." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def likelihood_euro(data, hypo):\n", 60 | " \"\"\" Likelihood function for the Euro problem.\n", 61 | " \n", 62 | " data: string, either 'H' or 'T'\n", 63 | " hypo: prob of heads (0-100)\n", 64 | " \n", 65 | " returns: float probability\n", 66 | " \"\"\"\n", 67 | " # TODO: fill this in!\n", 68 | " return 1" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# Solution goes here" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "For the prior, we'll start with a uniform distribution from 0 to 100." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def decorate_euro(title):\n", 94 | " \"\"\"Labels the axes.\n", 95 | " \n", 96 | " title: string\n", 97 | " \"\"\"\n", 98 | " plt.xlabel('Probability of heads')\n", 99 | " plt.ylabel('PMF')\n", 100 | " plt.title(title)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "euro = Pmf.from_seq(range(101))\n", 110 | "euro.plot()\n", 111 | "decorate_euro('Prior distribution')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "Now we can update with a single heads:" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "euro.update(likelihood_euro, 'H')\n", 128 | "euro.plot()\n", 129 | "decorate_euro('Posterior distribution, one heads')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "Another heads:" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "euro.update(likelihood_euro, 'H')\n", 146 | "euro.plot()\n", 147 | "decorate_euro('Posterior distribution, two heads')" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "And a tails:" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 8, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "euro.update(likelihood_euro, 'T')\n", 164 | "euro.plot()\n", 165 | "decorate_euro('Posterior distribution, HHT')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "Starting over, here's what it looks like after 7 heads and 3 tails." 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 9, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "euro = Pmf.from_seq(range(101))\n", 182 | "\n", 183 | "for outcome in 'HHHHHHHTTT':\n", 184 | " euro.update(likelihood_euro, outcome)\n", 185 | "\n", 186 | "euro.plot()\n", 187 | "decorate_euro('Posterior distribution, 7 heads, 3 tails')" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "The maximum apostiori probability (MAP) is 70%, which is the observed proportion." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 10, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "euro.max_prob()" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "Here are the posterior probabilities after 140 heads and 110 tails." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 11, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "euro = Pmf.from_seq(range(101))\n", 220 | "\n", 221 | "evidence = 'H' * 140 + 'T' * 110\n", 222 | "for outcome in evidence:\n", 223 | " euro.update(likelihood_euro, outcome)\n", 224 | " \n", 225 | "euro.plot()\n", 226 | "\n", 227 | "decorate_euro('Posterior distribution, 140 heads, 110 tails')" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "The posterior mean is about 56%" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 12, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "euro.mean()" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "So is the MAP." 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 13, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "euro.max_prob()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "And the median (50th percentile)." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 14, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "euro.quantile(0.5)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "The posterior credible interval has a 90% chance of containing the true value (provided that the prior distribution truly represents our background knowledge)." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 15, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "euro.credible_interval(0.9)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "### Swamping the prior\n", 299 | "\n", 300 | "The following function makes a Euro object with a triangle prior." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 16, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "def TrianglePrior():\n", 310 | " \"\"\"Makes a Suite with a triangular prior.\n", 311 | " \"\"\"\n", 312 | " suite = Pmf(name='triangle')\n", 313 | " for x in range(0, 51):\n", 314 | " suite[x] = x\n", 315 | " for x in range(51, 101):\n", 316 | " suite[x] = 100-x \n", 317 | " suite.normalize()\n", 318 | " return suite" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "And here's what it looks like:" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 17, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "euro1 = Pmf.from_seq(range(101), name='uniform')\n", 335 | "euro1.plot()\n", 336 | "\n", 337 | "euro2 = TrianglePrior()\n", 338 | "euro2.plot()\n", 339 | "\n", 340 | "plt.legend()\n", 341 | "decorate_euro('Prior distributions')" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "**Exercise 9:** Update `euro1` and `euro2` with the same data we used before (140 heads and 110 tails) and plot the posteriors. How big is the difference in the means?" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 18, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "# Solution goes here" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "The posterior distributions are not identical, but with this data, they converge to the point where there is no practical difference, for most purposes." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.7.3" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 1 396 | } 397 | -------------------------------------------------------------------------------- /05_world_cup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Bayesian Statistics Made Simple\n", 8 | "\n", 9 | "Code and exercises from my workshop on Bayesian statistics in Python.\n", 10 | "\n", 11 | "Copyright 2019 Allen Downey\n", 12 | "\n", 13 | "MIT License: https://opensource.org/licenses/MIT" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%matplotlib inline\n", 23 | "\n", 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "\n", 27 | "import seaborn as sns\n", 28 | "sns.set_style('white')\n", 29 | "sns.set_context('talk')\n", 30 | "\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "\n", 33 | "from empiricaldist import Pmf" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### The World Cup problem\n", 41 | "\n", 42 | "In the 2018 FIFA World Cup final, France defeated Croatia 4 goals to 2. Based on this outcome, we can answer the following questions:\n", 43 | "\n", 44 | "1. How confident should we be that France is the better team?\n", 45 | "\n", 46 | "2. If the same teams played again, what is the chance Croatia would win?\n", 47 | "\n", 48 | "To answer these questions, we have to make some modeling assumptions:\n", 49 | "\n", 50 | "1. Goal scoring can be well modeled by a Poisson process, so the distribution of goals scored by each team against the other is Poisson($\\lambda$), where $\\lambda$ is a goal-scoring rate, measured in goals per game.\n", 51 | "\n", 52 | "2. For two random World Cup teams, the disrtribution of goal scoring rates is Gamma($\\alpha$), where $\\alpha$ is a parameter we can choose based on past results.\n", 53 | "\n", 54 | "To determine $\\alpha$, I used [data from previous World Cups](https://www.statista.com/statistics/269031/goals-scored-per-game-at-the-fifa-world-cup-since-1930/) to estimate that the average goal scoring rate is about 1.4 goals per game.\n", 55 | "\n", 56 | "We can use `scipy.stats.gamma` to compute the PDF of the Gamma distribution." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from scipy.stats import gamma\n", 66 | "\n", 67 | "α = 1.4\n", 68 | "qs = np.linspace(0, 6)\n", 69 | "ps = gamma(α).pdf(qs)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "Now we can use `qs` and `ps` to make a `Pmf` that represents the prior distribution" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "prior = Pmf(ps, index=qs)\n", 86 | "prior.normalize()\n", 87 | "prior.mean()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "And plot it." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def decorate_rate(title):\n", 104 | " \"\"\"Labels the axes.\n", 105 | " \n", 106 | " title: string\n", 107 | " \"\"\"\n", 108 | " plt.xlabel('Goal scoring rate')\n", 109 | " plt.ylabel('PMF')\n", 110 | " plt.title(title)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "prior.plot()\n", 120 | "decorate_rate('Prior distribution')" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "This prior implies:\n", 128 | "\n", 129 | "1. The most common goal-scoring rates are near 1.\n", 130 | "\n", 131 | "2. The goal-scoring rate is never 0; eventually, any team will score against any other.\n", 132 | "\n", 133 | "3. The goal-scoring rate is unlikely to be greater than 4, and never greater than 6." 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "### The likelihood function\n", 141 | "\n", 142 | "Suppose you are given the goal-scoring rate, $\\lambda$, and asked to compute the probability of scoring a number of goals, $k$. The answer is given by the Poisson PMF:\n", 143 | "\n", 144 | "$ \\mathrm{PMF}(k; \\lambda) = \\frac{\\lambda^k \\exp(-\\lambda)}{k!} $\n", 145 | "\n", 146 | "**Exercise 1:** Write a likelihood function that takes $k$ and $\\lambda$ as parameters `data` and `hypo`, and computes $\\mathrm{PMF}(k; \\lambda)$.\n", 147 | "\n", 148 | "You can use NumPy/SciPy functions or `scipy.stats.poisson`. " 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "def likelihood(data, hypo):\n", 158 | " \"\"\"Likelihood function for World Cup\n", 159 | " \n", 160 | " data: integer number of goals in a game\n", 161 | " hypo: goal scoring rate in goals per game\n", 162 | " \n", 163 | " returns: float probability\n", 164 | " \"\"\"\n", 165 | " # TODO: fill this in!\n", 166 | " return 1" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 7, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# Solution goes here" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "# Solution goes here" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "### Update\n", 192 | "\n", 193 | "First we'll compute the posterior distribution for France, having seen them score 4 goals." 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 9, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "france = Pmf(prior, copy=True)\n", 203 | "\n", 204 | "france.update(likelihood, 4)\n", 205 | "france.plot(label='France')\n", 206 | "decorate_rate('Posterior distribution, 4 goals')\n", 207 | "\n", 208 | "france.mean()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "**Exercise 2:** Do the same for Croatia." 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 10, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "## Solution\n", 225 | "\n", 226 | "croatia = Pmf(prior, copy=True)\n", 227 | "\n", 228 | "croatia.update(likelihood, 2)\n", 229 | "croatia.plot(label='Croatia', color='C3')\n", 230 | "decorate_rate('Posterior distribution, 2 goals')\n", 231 | "\n", 232 | "croatia.mean()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "### Probability of superiority\n", 240 | "\n", 241 | "Now that we have a posterior distribution for each team, we can answer the first question: How confident should we be that France is the better team?\n", 242 | "\n", 243 | "In the model, \"better\" means having a higher goal-scoring rate against the opponent. We can use the posterior distributions to compute the \"probability of superiority\", which is the probability that a random value drawn from France's disgrtibution exceeds a value drawn from Croatia's.\n", 244 | "\n", 245 | "Remember that `Pmf` provides `choice`, which returns a random sample as a NumPy array:" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 11, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "sample_france = france.choice(size=1000)\n", 255 | "sample_france.mean()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "**Exercise 3:** Generate a similar sample for Croatia; then compute the fraction of samples where the goal-scoring rate is higher for Croatia. \n", 263 | "\n", 264 | "Hint: use `np.mean`." 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 12, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# Solution goes here" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 13, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "# Solution goes here" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "On the basis of one game, we have only moderate confidence that France is actually the better team." 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "### Predicting the rematch\n", 297 | "\n", 298 | "Now we can take on the second question: If the same teams played again, what is the chance Croatia would win?\n", 299 | "\n", 300 | "To answer this question, we'll generate a sample from the \"posterior predictive distribution\", which is the number of goals we expect a team to score.\n", 301 | "\n", 302 | "If we knew the goal scoring rate, $\\lambda$, the distribution of goals would be $Poisson(\\lambda)$.\n", 303 | "\n", 304 | "Since we don't know $\\lambda$, we can use the sample we generated in the previous section to generate a sample of goals, like this:" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 14, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "goals_france = np.random.poisson(sample_france)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "Now we can plot the results:" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 15, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "def decorate_goals(title):\n", 330 | " \"\"\"Labels the axes.\n", 331 | " \n", 332 | " title: string\n", 333 | " \"\"\"\n", 334 | " plt.xlabel('Goals scored')\n", 335 | " plt.ylabel('PMF')\n", 336 | " plt.ylim([0, 0.32])\n", 337 | " plt.title(title)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 16, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "pmf_france = Pmf.from_seq(goals_france)\n", 347 | "pmf_france.bar(label='France')\n", 348 | "decorate_goals('Predictive distribution')\n", 349 | "plt.legend()\n", 350 | "\n", 351 | "goals_france.mean()" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "This distribution represents two sources of uncertainty: we don't know the actual value of $\\lambda$, and even if we did, we would not know the number of goals in the next game." 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "**Exercise 4:** Generate and plot the predictive distribution for Croatia." 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 17, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "# Solution goes here" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": {}, 380 | "source": [ 381 | "In a sense, these distributions represent the outcomes of 1000 simulated games." 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "**Exercise 5:** Compute the fraction of simulated rematches Croatia would win, how many France would win, and how many would end in a tie." 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 18, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "# Solution goes here" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 19, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "# Solution goes here" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 20, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "# Solution goes here" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "Assuming that Croatia wins half of the ties, their chance of winning the rematch is about 33%." 423 | ] 424 | } 425 | ], 426 | "metadata": { 427 | "kernelspec": { 428 | "display_name": "Python 3", 429 | "language": "python", 430 | "name": "python3" 431 | }, 432 | "language_info": { 433 | "codemirror_mode": { 434 | "name": "ipython", 435 | "version": 3 436 | }, 437 | "file_extension": ".py", 438 | "mimetype": "text/x-python", 439 | "name": "python", 440 | "nbconvert_exporter": "python", 441 | "pygments_lexer": "ipython3", 442 | "version": "3.7.3" 443 | } 444 | }, 445 | "nbformat": 4, 446 | "nbformat_minor": 1 447 | } 448 | -------------------------------------------------------------------------------- /04_bandit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Bayesian Statistics Made Simple\n", 8 | "===\n", 9 | "\n", 10 | "Code and exercises from my workshop on Bayesian statistics in Python.\n", 11 | "\n", 12 | "Copyright 2018 Allen Downey\n", 13 | "\n", 14 | "MIT License: https://opensource.org/licenses/MIT" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "%matplotlib inline\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "import seaborn as sns\n", 29 | "sns.set_style('white')\n", 30 | "sns.set_context('talk')\n", 31 | "\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from empiricaldist import Pmf" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### The Bayesian bandit problem\n", 42 | "\n", 43 | "Suppose you have several \"one-armed bandit\" slot machines, and reason to think that they have different probabilities of paying off.\n", 44 | "\n", 45 | "Each time you play a machine, you either win or lose, and you can use the outcome to update your belief about the probability of winning.\n", 46 | "\n", 47 | "Then, to decide which machine to play next, you can use the \"Bayesian bandit\" strategy, explained below.\n", 48 | "\n", 49 | "First, let's see how to do the update." 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "### The prior\n", 57 | "\n", 58 | "If we know nothing about the probability of wining, we can start with a uniform prior." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def decorate_bandit(title):\n", 68 | " \"\"\"Labels the axes.\n", 69 | " \n", 70 | " title: string\n", 71 | " \"\"\"\n", 72 | " plt.xlabel('Probability of winning')\n", 73 | " plt.ylabel('PMF')\n", 74 | " plt.title(title)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "bandit = Pmf.from_seq(range(101))\n", 84 | "bandit.plot()\n", 85 | "decorate_bandit('Prior distribution')" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### The likelihood function\n", 93 | "\n", 94 | "The likelihood function that computes the probability of an outcome (W or L) for a hypothetical value of x, the probability of winning (from 0 to 1)." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def likelihood_bandit(data, hypo):\n", 104 | " \"\"\"Likelihood function for Bayesian bandit\n", 105 | " \n", 106 | " data: string, either 'W' or 'L'\n", 107 | " hypo: probability of winning (0-100)\n", 108 | " \n", 109 | " returns: float probability\n", 110 | " \"\"\"\n", 111 | " x = hypo / 100\n", 112 | " if data == 'W':\n", 113 | " return x\n", 114 | " else:\n", 115 | " return 1-x" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "**Exercise 1:** Suppose you play a machine 10 times and win once. What is the posterior distribution of $x$?" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 5, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# Solution goes here" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## Multiple bandits" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "Now suppose we have several bandits and we want to decide which one to play." 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "For this example, we have 4 machines with these probabilities:" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "actual_probs = [0.10, 0.20, 0.30, 0.40]" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "The function `play` simulates playing one machine once and returns `W` or `L`." 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "from random import random\n", 178 | "from collections import Counter\n", 179 | "\n", 180 | "# count how many times we've played each machine\n", 181 | "counter = Counter()\n", 182 | "\n", 183 | "def flip(p):\n", 184 | " \"\"\"Return True with probability p.\"\"\"\n", 185 | " return random() < p\n", 186 | "\n", 187 | "def play(i):\n", 188 | " \"\"\"Play machine i.\n", 189 | " \n", 190 | " returns: string 'W' or 'L'\n", 191 | " \"\"\"\n", 192 | " counter[i] += 1\n", 193 | " p = actual_probs[i]\n", 194 | " if flip(p):\n", 195 | " return 'W'\n", 196 | " else:\n", 197 | " return 'L'" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "Here's a test, playing machine 3 twenty times:" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "for i in range(20):\n", 214 | " result = play(3)\n", 215 | " print(result, end=' ')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "Now I'll make 4 `Pmf` objects to represent our beliefs about the 4 machines." 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 9, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "prior = range(101)\n", 232 | "beliefs = [Pmf.from_seq(prior) for i in range(4)]" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "This function displays the four posterior distributions" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 10, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "options = dict(xticklabels='invisible', yticklabels='invisible')\n", 249 | "\n", 250 | "def plot(beliefs, **options):\n", 251 | " sns.set_context('paper')\n", 252 | " for i, b in enumerate(beliefs):\n", 253 | " plt.subplot(2, 2, i+1)\n", 254 | " b.plot(label='Machine %s' % i)\n", 255 | " plt.gca().set_yticklabels([])\n", 256 | " plt.legend()\n", 257 | " \n", 258 | " plt.tight_layout()\n", 259 | " sns.set_context('talk')" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 11, 265 | "metadata": { 266 | "scrolled": true 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "plot(beliefs)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "The following function updates our beliefs about one of the machines based on one outcome." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 12, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "def update(beliefs, i, outcome):\n", 287 | " \"\"\"Update beliefs about machine i, given outcome.\n", 288 | " \n", 289 | " beliefs: list of Pmf\n", 290 | " i: index into beliefs\n", 291 | " outcome: string 'W' or 'L'\n", 292 | " \"\"\"\n", 293 | " beliefs[i].update(likelihood_bandit, outcome)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "**Exercise 2:** Write a nested loop that plays each machine 10 times; then plot the posterior distributions. \n", 301 | "\n", 302 | "Hint: call `play` and then `update`." 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 13, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "# Solution goes here" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 14, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# Solution goes here" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "After playing each machine 10 times, we have some information about their probabilies:" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 15, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "[belief.mean() for belief in beliefs]" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": {}, 342 | "source": [ 343 | "## Bayesian Bandits\n", 344 | "\n", 345 | "To get more information, we could play each machine 100 times, but while we are gathering data, we are not making good use of it. The kernel of the Bayesian Bandits algorithm is that it collects and uses data at the same time. In other words, it balances exploration and exploitation.\n", 346 | "\n", 347 | "The following function chooses among the machines so that the probability of choosing each machine is proportional to its \"probability of superiority\".\n", 348 | "\n", 349 | "`choice` chooses a value from the posterior distribution.\n", 350 | "\n", 351 | "`argmax` returns the index of the machine that chose the highest value." 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 16, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "def choose(beliefs):\n", 361 | " \"\"\"Use the Bayesian bandit strategy to choose a machine.\n", 362 | " \n", 363 | " Draws a sample from each distributions.\n", 364 | " \n", 365 | " returns: index of the machine that yielded the highest value\n", 366 | " \"\"\"\n", 367 | " ps = [b.choice() for b in beliefs]\n", 368 | " return np.argmax(ps)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "Here's an example." 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 17, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "choose(beliefs)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "**Exercise 3:** Putting it all together, fill in the following function to choose a machine, play once, and update `beliefs`:" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 18, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "def choose_play_update(beliefs, verbose=False):\n", 401 | " \"\"\"Chose a machine, play it, and update beliefs.\n", 402 | " \n", 403 | " beliefs: list of Pmf objects\n", 404 | " verbose: Boolean, whether to print results\n", 405 | " \"\"\"\n", 406 | " # choose a machine\n", 407 | " machine = ____\n", 408 | " \n", 409 | " # play it\n", 410 | " outcome = ____\n", 411 | " \n", 412 | " # update beliefs\n", 413 | " update(____)\n", 414 | " \n", 415 | " if verbose:\n", 416 | " print(i, outcome, beliefs[machine].mean())" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 19, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# Solution goes here" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "Here's an example" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 20, 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "choose_play_update(beliefs, verbose=True)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "## Trying it out" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "Let's start again with a fresh set of machines (and an empty `Counter`)." 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 21, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "beliefs = [Pmf.from_seq(prior) for i in range(4)]\n", 465 | "counter = Counter()" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": {}, 471 | "source": [ 472 | "If we run the bandit algorithm 100 times, we can see how `beliefs` gets updated:" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": 22, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "num_plays = 100\n", 482 | "\n", 483 | "for i in range(num_plays):\n", 484 | " choose_play_update(beliefs)\n", 485 | " \n", 486 | "plot(beliefs)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "We can summarize `beliefs` by printing the posterior mean and credible interval:" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 23, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "for i, b in enumerate(beliefs):\n", 503 | " print(b.mean(), b.credible_interval(0.9))" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": {}, 509 | "source": [ 510 | "The credible intervals usually contain the true values (10, 20, 30, and 40).\n", 511 | "\n", 512 | "The estimates are still rough, especially for the lower-probability machines. But that's a feature, not a bug: the goal is to play the high-probability machines most often. Making the estimates more precise is a means to that end, but not an end itself.\n", 513 | "\n", 514 | "Let's see how many times each machine got played. If things go according to plan, the machines with higher probabilities should get played more often." 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 24, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "for machine, count in sorted(counter.items()):\n", 524 | " print(machine, count)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "metadata": { 530 | "collapsed": true 531 | }, 532 | "source": [ 533 | "**Exercise 4:** Go back and run this section again with a different value of `num_play` and see how it does." 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [] 542 | } 543 | ], 544 | "metadata": { 545 | "kernelspec": { 546 | "display_name": "Python 3", 547 | "language": "python", 548 | "name": "python3" 549 | }, 550 | "language_info": { 551 | "codemirror_mode": { 552 | "name": "ipython", 553 | "version": 3 554 | }, 555 | "file_extension": ".py", 556 | "mimetype": "text/x-python", 557 | "name": "python", 558 | "nbconvert_exporter": "python", 559 | "pygments_lexer": "ipython3", 560 | "version": "3.7.3" 561 | } 562 | }, 563 | "nbformat": 4, 564 | "nbformat_minor": 1 565 | } 566 | -------------------------------------------------------------------------------- /distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Pmf: Represents a Probability Mass Function (PMF). 4 | Cdf: Represents a Cumulative Distribution Function (CDF). 5 | 6 | Copyright 2019 Allen B. Downey 7 | 8 | MIT License: https://opensource.org/licenses/MIT 9 | """ 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | 16 | from scipy.interpolate import interp1d 17 | 18 | 19 | def underride(d, **options): 20 | """Add key-value pairs to d only if key is not in d. 21 | 22 | d: dictionary 23 | options: keyword args to add to d 24 | 25 | returns: modified d 26 | """ 27 | for key, val in options.items(): 28 | d.setdefault(key, val) 29 | 30 | return d 31 | 32 | 33 | class Pmf(pd.Series): 34 | """Represents a probability Mass Function (PMF).""" 35 | 36 | def __init__(self, *args, **kwargs): 37 | """Initialize a Pmf. 38 | 39 | Note: this cleans up a weird Series behavior, which is 40 | that Series() and Series([]) yield different results. 41 | See: https://github.com/pandas-dev/pandas/issues/16737 42 | """ 43 | if args: 44 | super().__init__(*args, **kwargs) 45 | else: 46 | underride(kwargs, dtype=np.float64) 47 | super().__init__([], **kwargs) 48 | 49 | def copy(self, **kwargs): 50 | """Make a copy. 51 | 52 | returns: new Pmf 53 | """ 54 | return Pmf(self, **kwargs) 55 | 56 | def __getitem__(self, qs): 57 | """Look up qs and return ps.""" 58 | try: 59 | return super().__getitem__(qs) 60 | except (KeyError, ValueError, IndexError): 61 | return 0 62 | 63 | @property 64 | def qs(self): 65 | """Get the quantities. 66 | 67 | returns: NumPy array 68 | """ 69 | return self.index.values 70 | 71 | @property 72 | def ps(self): 73 | """Get the probabilities. 74 | 75 | returns: NumPy array 76 | """ 77 | return self.values 78 | 79 | def _repr_html_(self): 80 | """Returns an HTML representation of the series. 81 | 82 | Mostly used for Jupyter notebooks. 83 | """ 84 | df = pd.DataFrame(dict(probs=self)) 85 | return df._repr_html_() 86 | 87 | def normalize(self): 88 | """Make the probabilities add up to 1 (modifies self). 89 | 90 | returns: normalizing constant 91 | """ 92 | total = self.sum() 93 | self /= total 94 | return total 95 | 96 | def mean(self): 97 | """Computes expected value. 98 | 99 | returns: float 100 | """ 101 | #TODO: error if not normalized 102 | return np.sum(self.ps * self.qs) 103 | 104 | def median(self): 105 | """Median (50th percentile). 106 | 107 | returns: float 108 | """ 109 | return self.quantile(0.5) 110 | 111 | def quantile(self, ps): 112 | """Quantiles. 113 | 114 | Computes the inverse CDF of ps, that is, 115 | the values that correspond to the given probabilities. 116 | 117 | returns: float 118 | """ 119 | return self.make_cdf().quantile(ps) 120 | 121 | def var(self): 122 | """Variance of a PMF. 123 | 124 | returns: float 125 | """ 126 | m = self.mean() 127 | d = self.qs - m 128 | return np.sum(d**2 * self.ps) 129 | 130 | def std(self): 131 | """Standard deviation of a PMF. 132 | 133 | returns: float 134 | """ 135 | return np.sqrt(self.var()) 136 | 137 | def sample(self, *args, **kwargs): 138 | """Makes a random sample. 139 | 140 | args: same as ps.Series.sample 141 | options: same as ps.Series.sample 142 | 143 | returns: Series 144 | """ 145 | # TODO: finish this 146 | underride(kwargs, weights=self.ps) 147 | return self.index.sample(*args, **kwargs) 148 | 149 | def choice(self, *args, **kwargs): 150 | """Makes a random sample. 151 | 152 | Uses the probabilities as weights unless `p` is provided. 153 | 154 | args: same as np.random.choice 155 | options: same as np.random.choice 156 | 157 | returns: NumPy array 158 | """ 159 | underride(kwargs, p=self.ps) 160 | return np.random.choice(self.qs, *args, **kwargs) 161 | 162 | def bar(self, **options): 163 | """Makes a bar plot. 164 | 165 | options: same as plt.bar 166 | """ 167 | underride(options, label=self.name) 168 | plt.bar(self.qs, self.ps, **options) 169 | 170 | def __add__(self, x): 171 | """Computes the Pmf of the sum of values drawn from self and x. 172 | 173 | x: another Pmf or a scalar 174 | 175 | returns: new Pmf 176 | """ 177 | if isinstance(x, Pmf): 178 | return pmf_add(self, x) 179 | else: 180 | return Pmf(self.ps, index=self.qs + x) 181 | 182 | __radd__ = __add__ 183 | 184 | def __sub__(self, x): 185 | """Computes the Pmf of the diff of values drawn from self and other. 186 | 187 | x: another Pmf 188 | 189 | returns: new Pmf 190 | """ 191 | if isinstance(x, Pmf): 192 | return pmf_sub(self, x) 193 | else: 194 | return Pmf(self.ps, index=self.qs - x) 195 | 196 | # TODO: implement rsub 197 | # __rsub__ = __sub__ 198 | 199 | # TODO: mul, div, truediv, divmod? 200 | 201 | def make_joint(self, other, **options): 202 | """Make joint distribution 203 | 204 | :param self: 205 | :param other: 206 | :param options: passed to Pmf constructor 207 | 208 | :return: new Pmf 209 | """ 210 | qs = pd.MultiIndex.from_product([self.qs, other.qs]) 211 | ps = np.multiply.outer(self.ps, other.ps).flatten() 212 | return Pmf(ps, index=qs, **options) 213 | 214 | def marginal(self, i, name=None): 215 | """Gets the marginal distribution of the indicated variable. 216 | 217 | i: index of the variable we want 218 | name: string 219 | 220 | Returns: Pmf 221 | """ 222 | # TODO: rewrite this using multiindex operations 223 | pmf = Pmf(name=name) 224 | for vs, p in self.items(): 225 | pmf[vs[i]] += p 226 | return pmf 227 | 228 | def conditional(self, i, j, val, name=None): 229 | """Gets the conditional distribution of the indicated variable. 230 | 231 | Distribution of vs[i], conditioned on vs[j] = val. 232 | 233 | i: index of the variable we want 234 | j: which variable is conditioned on 235 | val: the value the jth variable has to have 236 | name: string 237 | 238 | Returns: Pmf 239 | """ 240 | # TODO: rewrite this using multiindex operations 241 | pmf = Pmf(name=name) 242 | for vs, p in self.items(): 243 | if vs[j] == val: 244 | pmf[vs[i]] += p 245 | 246 | pmf.normalize() 247 | return pmf 248 | 249 | def update(self, likelihood, data): 250 | """Bayesian update. 251 | 252 | likelihood: function that takes (data, hypo) and returns 253 | likelihood of data under hypo 254 | data: whatever format like_func understands 255 | 256 | returns: normalizing constant 257 | """ 258 | for hypo in self.qs: 259 | self[hypo] *= likelihood(data, hypo) 260 | 261 | return self.normalize() 262 | 263 | def max_prob(self): 264 | """Value with the highest probability. 265 | 266 | returns: the value with the highest probability 267 | """ 268 | return self.idxmax() 269 | 270 | def make_cdf(self, normalize=True): 271 | """Make a Cdf from the Pmf. 272 | 273 | It can be good to normalize the cdf even if the Pmf was normalized, 274 | to guarantee that the last element of `ps` is 1. 275 | 276 | returns: Cdf 277 | """ 278 | cdf = Cdf(self.cumsum()) 279 | if normalize: 280 | cdf.normalize() 281 | return cdf 282 | 283 | def quantile(self, ps): 284 | """Quantities corresponding to given probabilities. 285 | 286 | ps: sequence of probabilities 287 | 288 | return: sequence of quantities 289 | """ 290 | cdf = self.sort_index().cumsum() 291 | interp = interp1d(cdf.values, cdf.index, 292 | kind='next', 293 | copy=False, 294 | assume_sorted=True, 295 | bounds_error=False, 296 | fill_value=(self.qs[0], np.nan)) 297 | return interp(ps) 298 | 299 | def credible_interval(self, p): 300 | """Credible interval containing the given probability. 301 | 302 | p: float 0-1 303 | 304 | returns: array of two quantities 305 | """ 306 | tail = (1-p) / 2 307 | ps = [tail, 1-tail] 308 | return self.quantile(ps) 309 | 310 | @staticmethod 311 | def from_seq(seq, normalize=True, sort=True, **options): 312 | """Make a PMF from a sequence of values. 313 | 314 | seq: any kind of sequence 315 | normalize: whether to normalize the Pmf, default True 316 | sort: whether to sort the Pmf by values, default True 317 | options: passed to the pd.Series constructor 318 | 319 | returns: Pmf object 320 | """ 321 | series = pd.Series(seq).value_counts(sort=False) 322 | 323 | options['copy'] = False 324 | pmf = Pmf(series, **options) 325 | 326 | if sort: 327 | pmf.sort_index(inplace=True) 328 | 329 | if normalize: 330 | pmf.normalize() 331 | 332 | return pmf 333 | 334 | # Comparison operators 335 | 336 | def gt(self, x): 337 | """Probability that a sample from this Pmf > x. 338 | 339 | x: number 340 | 341 | returns: float probability 342 | """ 343 | if isinstance(x, Pmf): 344 | return pmf_gt(self, x) 345 | else: 346 | return self[self.qs > x].sum() 347 | 348 | __gt__ = gt 349 | 350 | def lt(self, x): 351 | """Probability that a sample from this Pmf < x. 352 | 353 | x: number 354 | 355 | returns: float probability 356 | """ 357 | if isinstance(x, Pmf): 358 | return pmf_lt(self, x) 359 | else: 360 | return self[self.qs < x].sum() 361 | 362 | __lt__ = lt 363 | 364 | def ge(self, x): 365 | """Probability that a sample from this Pmf >= x. 366 | 367 | x: number 368 | 369 | returns: float probability 370 | """ 371 | if isinstance(x, Pmf): 372 | return pmf_ge(self, x) 373 | else: 374 | return self[self.qs >= x].sum() 375 | 376 | __ge__ = ge 377 | 378 | def le(self, x): 379 | """Probability that a sample from this Pmf <= x. 380 | 381 | x: number 382 | 383 | returns: float probability 384 | """ 385 | if isinstance(x, Pmf): 386 | return pmf_le(self, x) 387 | else: 388 | return self[self.qs <= x].sum() 389 | 390 | __le__ = le 391 | 392 | def eq(self, x): 393 | """Probability that a sample from this Pmf == x. 394 | 395 | x: number 396 | 397 | returns: float probability 398 | """ 399 | if isinstance(x, Pmf): 400 | return pmf_eq(self, x) 401 | else: 402 | return self[self.qs == x].sum() 403 | 404 | __eq__ = eq 405 | 406 | def ne(self, x): 407 | """Probability that a sample from this Pmf != x. 408 | 409 | x: number 410 | 411 | returns: float probability 412 | """ 413 | if isinstance(x, Pmf): 414 | return pmf_ne(self, x) 415 | else: 416 | return self[self.qs != x].sum() 417 | 418 | __ne__ = ne 419 | 420 | 421 | def pmf_conv(pmf1, pmf2, ufunc): 422 | """Convolve two PMFs. 423 | 424 | pmf1: 425 | pmf2: 426 | ufunc: elementwise function for arrays 427 | 428 | returns: new Pmf 429 | """ 430 | qs = ufunc(pmf1.qs, pmf2.qs).flatten() 431 | ps = np.multiply.outer(pmf1.ps, pmf2.ps).flatten() 432 | series = pd.Series(ps).groupby(qs).sum() 433 | return Pmf(series) 434 | 435 | 436 | def pmf_add(pmf1, pmf2): 437 | """Distribution of the sum. 438 | 439 | pmf1: 440 | pmf2: 441 | 442 | returns: new Pmf 443 | """ 444 | return pmf_conv(pmf1, pmf2, np.add.outer) 445 | 446 | 447 | def pmf_sub(pmf1, pmf2): 448 | """Distribution of the difference. 449 | 450 | pmf1: 451 | pmf2: 452 | 453 | returns: new Pmf 454 | """ 455 | return pmf_conv(pmf1, pmf2, np.subtract.outer) 456 | 457 | 458 | def pmf_outer(pmf1, pmf2, ufunc): 459 | """Computes the outer product of two PMFs. 460 | 461 | pmf1: 462 | pmf2: 463 | ufunc: function to apply to the qs 464 | 465 | returns: NumPy array 466 | """ 467 | qs = ufunc.outer(pmf1.qs, pmf2.qs) 468 | ps = np.multiply.outer(pmf1.ps, pmf2.ps) 469 | return qs * ps 470 | 471 | 472 | def pmf_gt(pmf1, pmf2): 473 | """Probability that a value from pmf1 is greater than a value from pmf2. 474 | 475 | pmf1: Pmf object 476 | pmf2: Pmf object 477 | 478 | returns: float probability 479 | """ 480 | outer = pmf_outer(pmf1, pmf2, np.greater) 481 | return outer.sum() 482 | 483 | 484 | def pmf_lt(pmf1, pmf2): 485 | """Probability that a value from pmf1 is less than a value from pmf2. 486 | 487 | pmf1: Pmf object 488 | pmf2: Pmf object 489 | 490 | returns: float probability 491 | """ 492 | outer = pmf_outer(pmf1, pmf2, np.less) 493 | return outer.sum() 494 | 495 | 496 | def pmf_ge(pmf1, pmf2): 497 | """Probability that a value from pmf1 is >= than a value from pmf2. 498 | 499 | pmf1: Pmf object 500 | pmf2: Pmf object 501 | 502 | returns: float probability 503 | """ 504 | outer = pmf_outer(pmf1, pmf2, np.greater_equal) 505 | return outer.sum() 506 | 507 | 508 | def pmf_le(pmf1, pmf2): 509 | """Probability that a value from pmf1 is <= than a value from pmf2. 510 | 511 | pmf1: Pmf object 512 | pmf2: Pmf object 513 | 514 | returns: float probability 515 | """ 516 | outer = pmf_outer(pmf1, pmf2, np.less_equal) 517 | return outer.sum() 518 | 519 | 520 | def pmf_eq(pmf1, pmf2): 521 | """Probability that a value from pmf1 equals a value from pmf2. 522 | 523 | pmf1: Pmf object 524 | pmf2: Pmf object 525 | 526 | returns: float probability 527 | """ 528 | outer = pmf_outer(pmf1, pmf2, np.equal) 529 | return outer.sum() 530 | 531 | 532 | def pmf_ne(pmf1, pmf2): 533 | """Probability that a value from pmf1 is <= than a value from pmf2. 534 | 535 | pmf1: Pmf object 536 | pmf2: Pmf object 537 | 538 | returns: float probability 539 | """ 540 | outer = pmf_outer(pmf1, pmf2, np.not_equal) 541 | return outer.sum() 542 | 543 | 544 | class Cdf(pd.Series): 545 | """Represents a Cumulative Distribution Function (CDF).""" 546 | 547 | def __init__(self, *args, **kwargs): 548 | """Initialize a Cdf. 549 | 550 | Note: this cleans up a weird Series behavior, which is 551 | that Series() and Series([]) yield different results. 552 | See: https://github.com/pandas-dev/pandas/issues/16737 553 | """ 554 | if args: 555 | super().__init__(*args, **kwargs) 556 | else: 557 | underride(kwargs, dtype=np.float64) 558 | super().__init__([], **kwargs) 559 | 560 | def copy(self, **kwargs): 561 | """Make a copy. 562 | 563 | returns: new Cdf 564 | """ 565 | return Cdf(self, **kwargs) 566 | 567 | @property 568 | def forward(self): 569 | interp = interp1d(self.qs, self.ps, 570 | kind='previous', 571 | copy=False, 572 | assume_sorted=True, 573 | bounds_error=False, 574 | fill_value=(0,1)) 575 | return interp 576 | 577 | @property 578 | def inverse(self): 579 | interp = interp1d(self.ps, self.qs, 580 | kind='next', 581 | copy=False, 582 | assume_sorted=True, 583 | bounds_error=False, 584 | fill_value=(self.qs[0], np.nan)) 585 | return interp 586 | 587 | # calling a Cdf like a function does forward lookup 588 | __call__ = forward 589 | 590 | # quantile is the same as an inverse lookup 591 | quantile = inverse 592 | 593 | @staticmethod 594 | def from_seq(seq, normalize=True, sort=True, **options): 595 | """Make a CDF from a sequence of values. 596 | 597 | seq: any kind of sequence 598 | normalize: whether to normalize the Cdf, default True 599 | sort: whether to sort the Cdf by values, default True 600 | options: passed to the pd.Series constructor 601 | 602 | returns: CDF object 603 | """ 604 | pmf = Pmf.from_seq(seq, normalize=False, sort=sort, **options) 605 | return pmf.make_cdf(normalize=normalize) 606 | 607 | @property 608 | def qs(self): 609 | """Get the quantities. 610 | 611 | returns: NumPy array 612 | """ 613 | return self.index.values 614 | 615 | @property 616 | def ps(self): 617 | """Get the probabilities. 618 | 619 | returns: NumPy array 620 | """ 621 | return self.values 622 | 623 | def _repr_html_(self): 624 | """Returns an HTML representation of the series. 625 | 626 | Mostly used for Jupyter notebooks. 627 | """ 628 | df = pd.DataFrame(dict(probs=self)) 629 | return df._repr_html_() 630 | 631 | def normalize(self): 632 | """Make the probabilities add up to 1 (modifies self). 633 | 634 | returns: normalizing constant 635 | """ 636 | total = self.ps[-1] 637 | self /= total 638 | return total 639 | 640 | def make_pmf(self, normalize=False): 641 | """Make a Pmf from the Cdf. 642 | 643 | returns: Cdf 644 | """ 645 | ps = self.ps 646 | diff = np.ediff1d(ps, to_begin=ps[0]) 647 | pmf = Pmf(pd.Series(diff, index=self.index.copy())) 648 | if normalize: 649 | pmf.normalize() 650 | return pmf 651 | 652 | def choice(self, *args, **kwargs): 653 | """Makes a random sample. 654 | 655 | Uses the probabilities as weights unless `p` is provided. 656 | 657 | args: same as np.random.choice 658 | options: same as np.random.choice 659 | 660 | returns: NumPy array 661 | """ 662 | # TODO: Make this more efficient by implementing the inverse CDF method. 663 | pmf = self.make_pmf() 664 | return pmf.choice(*args, *kwargs) 665 | 666 | def mean(self): 667 | """Expected value. 668 | 669 | returns: float 670 | """ 671 | return self.make_pmf().mean() 672 | 673 | def var(self): 674 | """Variance. 675 | 676 | returns: float 677 | """ 678 | return self.make_pmf().var() 679 | 680 | def std(self): 681 | """Standard deviation. 682 | 683 | returns: float 684 | """ 685 | return self.make_pmf().std() 686 | 687 | def median(self): 688 | """Median (50th percentile). 689 | 690 | returns: float 691 | """ 692 | return self.quantile(0.5) 693 | 694 | 695 | -------------------------------------------------------------------------------- /thinkplot.py: -------------------------------------------------------------------------------- 1 | """This file contains code for use with "Think Stats", 2 | by Allen B. Downey, available from greenteapress.com 3 | 4 | Copyright 2014 Allen B. Downey 5 | License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html 6 | """ 7 | 8 | from __future__ import print_function 9 | 10 | import math 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import pandas 15 | 16 | import warnings 17 | 18 | # customize some matplotlib attributes 19 | #matplotlib.rc('figure', figsize=(4, 3)) 20 | 21 | #matplotlib.rc('font', size=14.0) 22 | #matplotlib.rc('axes', labelsize=22.0, titlesize=22.0) 23 | #matplotlib.rc('legend', fontsize=20.0) 24 | 25 | #matplotlib.rc('xtick.major', size=6.0) 26 | #matplotlib.rc('xtick.minor', size=3.0) 27 | 28 | #matplotlib.rc('ytick.major', size=6.0) 29 | #matplotlib.rc('ytick.minor', size=3.0) 30 | 31 | 32 | class _Brewer(object): 33 | """Encapsulates a nice sequence of colors. 34 | 35 | Shades of blue that look good in color and can be distinguished 36 | in grayscale (up to a point). 37 | 38 | Borrowed from http://colorbrewer2.org/ 39 | """ 40 | color_iter = None 41 | 42 | colors = ['#f7fbff', '#deebf7', '#c6dbef', 43 | '#9ecae1', '#6baed6', '#4292c6', 44 | '#2171b5','#08519c','#08306b'][::-1] 45 | 46 | # lists that indicate which colors to use depending on how many are used 47 | which_colors = [[], 48 | [1], 49 | [1, 3], 50 | [0, 2, 4], 51 | [0, 2, 4, 6], 52 | [0, 2, 3, 5, 6], 53 | [0, 2, 3, 4, 5, 6], 54 | [0, 1, 2, 3, 4, 5, 6], 55 | [0, 1, 2, 3, 4, 5, 6, 7], 56 | [0, 1, 2, 3, 4, 5, 6, 7, 8], 57 | ] 58 | 59 | current_figure = None 60 | 61 | @classmethod 62 | def Colors(cls): 63 | """Returns the list of colors. 64 | """ 65 | return cls.colors 66 | 67 | @classmethod 68 | def ColorGenerator(cls, num): 69 | """Returns an iterator of color strings. 70 | 71 | n: how many colors will be used 72 | """ 73 | for i in cls.which_colors[num]: 74 | yield cls.colors[i] 75 | raise StopIteration('Ran out of colors in _Brewer.') 76 | 77 | @classmethod 78 | def InitIter(cls, num): 79 | """Initializes the color iterator with the given number of colors.""" 80 | cls.color_iter = cls.ColorGenerator(num) 81 | fig = plt.gcf() 82 | cls.current_figure = fig 83 | 84 | @classmethod 85 | def ClearIter(cls): 86 | """Sets the color iterator to None.""" 87 | cls.color_iter = None 88 | cls.current_figure = None 89 | 90 | @classmethod 91 | def GetIter(cls, num): 92 | """Gets the color iterator.""" 93 | fig = plt.gcf() 94 | if fig != cls.current_figure: 95 | cls.InitIter(num) 96 | cls.current_figure = fig 97 | 98 | if cls.color_iter is None: 99 | cls.InitIter(num) 100 | 101 | return cls.color_iter 102 | 103 | 104 | def _UnderrideColor(options): 105 | """If color is not in the options, chooses a color. 106 | """ 107 | if 'color' in options: 108 | return options 109 | 110 | # get the current color iterator; if there is none, init one 111 | color_iter = _Brewer.GetIter(5) 112 | 113 | try: 114 | options['color'] = next(color_iter) 115 | except StopIteration: 116 | # if you run out of colors, initialize the color iterator 117 | # and try again 118 | warnings.warn('Ran out of colors. Starting over.') 119 | _Brewer.ClearIter() 120 | _UnderrideColor(options) 121 | 122 | return options 123 | 124 | 125 | def PrePlot(num=None, rows=None, cols=None): 126 | """Takes hints about what's coming. 127 | 128 | num: number of lines that will be plotted 129 | rows: number of rows of subplots 130 | cols: number of columns of subplots 131 | """ 132 | if num: 133 | _Brewer.InitIter(num) 134 | 135 | if rows is None and cols is None: 136 | return 137 | 138 | if rows is not None and cols is None: 139 | cols = 1 140 | 141 | if cols is not None and rows is None: 142 | rows = 1 143 | 144 | # resize the image, depending on the number of rows and cols 145 | size_map = {(1, 1): (8, 6), 146 | (1, 2): (12, 6), 147 | (1, 3): (12, 6), 148 | (1, 4): (12, 5), 149 | (1, 5): (12, 4), 150 | (2, 2): (10, 10), 151 | (2, 3): (16, 10), 152 | (3, 1): (8, 10), 153 | (4, 1): (8, 12), 154 | } 155 | 156 | if (rows, cols) in size_map: 157 | fig = plt.gcf() 158 | fig.set_size_inches(*size_map[rows, cols]) 159 | 160 | # create the first subplot 161 | if rows > 1 or cols > 1: 162 | ax = plt.subplot(rows, cols, 1) 163 | global SUBPLOT_ROWS, SUBPLOT_COLS 164 | SUBPLOT_ROWS = rows 165 | SUBPLOT_COLS = cols 166 | else: 167 | ax = plt.gca() 168 | 169 | return ax 170 | 171 | def SubPlot(plot_number, rows=None, cols=None, **options): 172 | """Configures the number of subplots and changes the current plot. 173 | 174 | rows: int 175 | cols: int 176 | plot_number: int 177 | options: passed to subplot 178 | """ 179 | rows = rows or SUBPLOT_ROWS 180 | cols = cols or SUBPLOT_COLS 181 | return plt.subplot(rows, cols, plot_number, **options) 182 | 183 | 184 | def _Underride(d, **options): 185 | """Add key-value pairs to d only if key is not in d. 186 | 187 | If d is None, create a new dictionary. 188 | 189 | d: dictionary 190 | options: keyword args to add to d 191 | """ 192 | if d is None: 193 | d = {} 194 | 195 | for key, val in options.items(): 196 | d.setdefault(key, val) 197 | 198 | return d 199 | 200 | 201 | def Clf(): 202 | """Clears the figure and any hints that have been set.""" 203 | global LOC 204 | LOC = None 205 | _Brewer.ClearIter() 206 | plt.clf() 207 | fig = plt.gcf() 208 | fig.set_size_inches(8, 6) 209 | 210 | 211 | def Figure(**options): 212 | """Sets options for the current figure.""" 213 | _Underride(options, figsize=(6, 8)) 214 | plt.figure(**options) 215 | 216 | 217 | def Plot(obj, ys=None, style='', **options): 218 | """Plots a line. 219 | 220 | Args: 221 | obj: sequence of x values, or Series, or anything with Render() 222 | ys: sequence of y values 223 | style: style string passed along to plt.plot 224 | options: keyword args passed to plt.plot 225 | """ 226 | options = _UnderrideColor(options) 227 | label = getattr(obj, 'label', '_nolegend_') 228 | options = _Underride(options, linewidth=3, alpha=0.7, label=label) 229 | 230 | xs = obj 231 | if ys is None: 232 | if hasattr(obj, 'Render'): 233 | xs, ys = obj.Render() 234 | if isinstance(obj, pandas.Series): 235 | ys = obj.values 236 | xs = obj.index 237 | 238 | if ys is None: 239 | plt.plot(xs, style, **options) 240 | else: 241 | plt.plot(xs, ys, style, **options) 242 | 243 | 244 | def Vlines(xs, y1, y2, **options): 245 | """Plots a set of vertical lines. 246 | 247 | Args: 248 | xs: sequence of x values 249 | y1: sequence of y values 250 | y2: sequence of y values 251 | options: keyword args passed to plt.vlines 252 | """ 253 | options = _UnderrideColor(options) 254 | options = _Underride(options, linewidth=1, alpha=0.5) 255 | plt.vlines(xs, y1, y2, **options) 256 | 257 | 258 | def Hlines(ys, x1, x2, **options): 259 | """Plots a set of horizontal lines. 260 | 261 | Args: 262 | ys: sequence of y values 263 | x1: sequence of x values 264 | x2: sequence of x values 265 | options: keyword args passed to plt.vlines 266 | """ 267 | options = _UnderrideColor(options) 268 | options = _Underride(options, linewidth=1, alpha=0.5) 269 | plt.hlines(ys, x1, x2, **options) 270 | 271 | 272 | def FillBetween(xs, y1, y2=None, where=None, **options): 273 | """Fills the space between two lines. 274 | 275 | Args: 276 | xs: sequence of x values 277 | y1: sequence of y values 278 | y2: sequence of y values 279 | where: sequence of boolean 280 | options: keyword args passed to plt.fill_between 281 | """ 282 | options = _UnderrideColor(options) 283 | options = _Underride(options, linewidth=0, alpha=0.5) 284 | plt.fill_between(xs, y1, y2, where, **options) 285 | 286 | 287 | def Bar(xs, ys, **options): 288 | """Plots a line. 289 | 290 | Args: 291 | xs: sequence of x values 292 | ys: sequence of y values 293 | options: keyword args passed to plt.bar 294 | """ 295 | options = _UnderrideColor(options) 296 | options = _Underride(options, linewidth=0, alpha=0.6) 297 | plt.bar(xs, ys, **options) 298 | 299 | 300 | def Scatter(xs, ys=None, **options): 301 | """Makes a scatter plot. 302 | 303 | xs: x values 304 | ys: y values 305 | options: options passed to plt.scatter 306 | """ 307 | options = _Underride(options, color='blue', alpha=0.2, 308 | s=30, edgecolors='none') 309 | 310 | if ys is None and isinstance(xs, pandas.Series): 311 | ys = xs.values 312 | xs = xs.index 313 | 314 | plt.scatter(xs, ys, **options) 315 | 316 | 317 | def HexBin(xs, ys, **options): 318 | """Makes a scatter plot. 319 | 320 | xs: x values 321 | ys: y values 322 | options: options passed to plt.scatter 323 | """ 324 | options = _Underride(options, cmap=matplotlib.cm.Blues) 325 | plt.hexbin(xs, ys, **options) 326 | 327 | 328 | def Pdf(pdf, **options): 329 | """Plots a Pdf, Pmf, or Hist as a line. 330 | 331 | Args: 332 | pdf: Pdf, Pmf, or Hist object 333 | options: keyword args passed to plt.plot 334 | """ 335 | low, high = options.pop('low', None), options.pop('high', None) 336 | n = options.pop('n', 101) 337 | xs, ps = pdf.Render(low=low, high=high, n=n) 338 | options = _Underride(options, label=pdf.label) 339 | Plot(xs, ps, **options) 340 | 341 | 342 | def Pdfs(pdfs, **options): 343 | """Plots a sequence of PDFs. 344 | 345 | Options are passed along for all PDFs. If you want different 346 | options for each pdf, make multiple calls to Pdf. 347 | 348 | Args: 349 | pdfs: sequence of PDF objects 350 | options: keyword args passed to plt.plot 351 | """ 352 | for pdf in pdfs: 353 | Pdf(pdf, **options) 354 | 355 | 356 | def Hist(hist, **options): 357 | """Plots a Pmf or Hist with a bar plot. 358 | 359 | The default width of the bars is based on the minimum difference 360 | between values in the Hist. If that's too small, you can override 361 | it by providing a width keyword argument, in the same units 362 | as the values. 363 | 364 | Args: 365 | hist: Hist or Pmf object 366 | options: keyword args passed to plt.bar 367 | """ 368 | # find the minimum distance between adjacent values 369 | xs, ys = hist.Render() 370 | 371 | # see if the values support arithmetic 372 | try: 373 | xs[0] - xs[0] 374 | except TypeError: 375 | # if not, replace values with numbers 376 | labels = [str(x) for x in xs] 377 | xs = np.arange(len(xs)) 378 | plt.xticks(xs+0.5, labels) 379 | 380 | if 'width' not in options: 381 | try: 382 | options['width'] = 0.9 * np.diff(xs).min() 383 | except TypeError: 384 | warnings.warn("Hist: Can't compute bar width automatically." 385 | "Check for non-numeric types in Hist." 386 | "Or try providing width option." 387 | ) 388 | 389 | options = _Underride(options, label=hist.label) 390 | options = _Underride(options, align='center') 391 | if options['align'] == 'left': 392 | options['align'] = 'edge' 393 | elif options['align'] == 'right': 394 | options['align'] = 'edge' 395 | options['width'] *= -1 396 | 397 | Bar(xs, ys, **options) 398 | 399 | 400 | def Hists(hists, **options): 401 | """Plots two histograms as interleaved bar plots. 402 | 403 | Options are passed along for all PMFs. If you want different 404 | options for each pmf, make multiple calls to Pmf. 405 | 406 | Args: 407 | hists: list of two Hist or Pmf objects 408 | options: keyword args passed to plt.plot 409 | """ 410 | for hist in hists: 411 | Hist(hist, **options) 412 | 413 | 414 | def Pmf(pmf, **options): 415 | """Plots a Pmf or Hist as a line. 416 | 417 | Args: 418 | pmf: Hist or Pmf object 419 | options: keyword args passed to plt.plot 420 | """ 421 | xs, ys = pmf.Render() 422 | low, high = min(xs), max(xs) 423 | 424 | width = options.pop('width', None) 425 | if width is None: 426 | try: 427 | width = np.diff(xs).min() 428 | except TypeError: 429 | warnings.warn("Pmf: Can't compute bar width automatically." 430 | "Check for non-numeric types in Pmf." 431 | "Or try providing width option.") 432 | points = [] 433 | 434 | lastx = np.nan 435 | lasty = 0 436 | for x, y in zip(xs, ys): 437 | if (x - lastx) > 1e-5: 438 | points.append((lastx, 0)) 439 | points.append((x, 0)) 440 | 441 | points.append((x, lasty)) 442 | points.append((x, y)) 443 | points.append((x+width, y)) 444 | 445 | lastx = x + width 446 | lasty = y 447 | points.append((lastx, 0)) 448 | pxs, pys = zip(*points) 449 | 450 | align = options.pop('align', 'center') 451 | if align == 'center': 452 | pxs = np.array(pxs) - width/2.0 453 | if align == 'right': 454 | pxs = np.array(pxs) - width 455 | 456 | options = _Underride(options, label=pmf.label) 457 | Plot(pxs, pys, **options) 458 | 459 | 460 | def Pmfs(pmfs, **options): 461 | """Plots a sequence of PMFs. 462 | 463 | Options are passed along for all PMFs. If you want different 464 | options for each pmf, make multiple calls to Pmf. 465 | 466 | Args: 467 | pmfs: sequence of PMF objects 468 | options: keyword args passed to plt.plot 469 | """ 470 | for pmf in pmfs: 471 | Pmf(pmf, **options) 472 | 473 | 474 | def Diff(t): 475 | """Compute the differences between adjacent elements in a sequence. 476 | 477 | Args: 478 | t: sequence of number 479 | 480 | Returns: 481 | sequence of differences (length one less than t) 482 | """ 483 | diffs = [t[i+1] - t[i] for i in range(len(t)-1)] 484 | return diffs 485 | 486 | 487 | def Cdf(cdf, complement=False, transform=None, **options): 488 | """Plots a CDF as a line. 489 | 490 | Args: 491 | cdf: Cdf object 492 | complement: boolean, whether to plot the complementary CDF 493 | transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel' 494 | options: keyword args passed to plt.plot 495 | 496 | Returns: 497 | dictionary with the scale options that should be passed to 498 | Config, Show or Save. 499 | """ 500 | xs, ps = cdf.Render() 501 | xs = np.asarray(xs) 502 | ps = np.asarray(ps) 503 | 504 | scale = dict(xscale='linear', yscale='linear') 505 | 506 | for s in ['xscale', 'yscale']: 507 | if s in options: 508 | scale[s] = options.pop(s) 509 | 510 | if transform == 'exponential': 511 | complement = True 512 | scale['yscale'] = 'log' 513 | 514 | if transform == 'pareto': 515 | complement = True 516 | scale['yscale'] = 'log' 517 | scale['xscale'] = 'log' 518 | 519 | if complement: 520 | ps = [1.0-p for p in ps] 521 | 522 | if transform == 'weibull': 523 | xs = np.delete(xs, -1) 524 | ps = np.delete(ps, -1) 525 | ps = [-math.log(1.0-p) for p in ps] 526 | scale['xscale'] = 'log' 527 | scale['yscale'] = 'log' 528 | 529 | if transform == 'gumbel': 530 | xs = xp.delete(xs, 0) 531 | ps = np.delete(ps, 0) 532 | ps = [-math.log(p) for p in ps] 533 | scale['yscale'] = 'log' 534 | 535 | options = _Underride(options, label=cdf.label) 536 | Plot(xs, ps, **options) 537 | return scale 538 | 539 | 540 | def Cdfs(cdfs, complement=False, transform=None, **options): 541 | """Plots a sequence of CDFs. 542 | 543 | cdfs: sequence of CDF objects 544 | complement: boolean, whether to plot the complementary CDF 545 | transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel' 546 | options: keyword args passed to plt.plot 547 | """ 548 | for cdf in cdfs: 549 | Cdf(cdf, complement, transform, **options) 550 | 551 | 552 | def Contour(obj, pcolor=False, contour=True, imshow=False, **options): 553 | """Makes a contour plot. 554 | 555 | d: map from (x, y) to z, or object that provides GetDict 556 | pcolor: boolean, whether to make a pseudocolor plot 557 | contour: boolean, whether to make a contour plot 558 | imshow: boolean, whether to use plt.imshow 559 | options: keyword args passed to plt.pcolor and/or plt.contour 560 | """ 561 | try: 562 | d = obj.GetDict() 563 | except AttributeError: 564 | d = obj 565 | 566 | _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues) 567 | 568 | xs, ys = zip(*d.keys()) 569 | xs = sorted(set(xs)) 570 | ys = sorted(set(ys)) 571 | 572 | X, Y = np.meshgrid(xs, ys) 573 | func = lambda x, y: d.get((x, y), 0) 574 | func = np.vectorize(func) 575 | Z = func(X, Y) 576 | 577 | x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False) 578 | axes = plt.gca() 579 | axes.xaxis.set_major_formatter(x_formatter) 580 | 581 | if pcolor: 582 | plt.pcolormesh(X, Y, Z, **options) 583 | if contour: 584 | cs = plt.contour(X, Y, Z, **options) 585 | plt.clabel(cs, inline=1, fontsize=10) 586 | if imshow: 587 | extent = xs[0], xs[-1], ys[0], ys[-1] 588 | plt.imshow(Z, extent=extent, **options) 589 | 590 | 591 | def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options): 592 | """Makes a pseudocolor plot. 593 | 594 | xs: 595 | ys: 596 | zs: 597 | pcolor: boolean, whether to make a pseudocolor plot 598 | contour: boolean, whether to make a contour plot 599 | options: keyword args passed to plt.pcolor and/or plt.contour 600 | """ 601 | _Underride(options, linewidth=3, cmap=matplotlib.cm.Blues) 602 | 603 | X, Y = np.meshgrid(xs, ys) 604 | Z = zs 605 | 606 | x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False) 607 | axes = plt.gca() 608 | axes.xaxis.set_major_formatter(x_formatter) 609 | 610 | if pcolor: 611 | plt.pcolormesh(X, Y, Z, **options) 612 | 613 | if contour: 614 | cs = plt.contour(X, Y, Z, **options) 615 | plt.clabel(cs, inline=1, fontsize=10) 616 | 617 | 618 | def Text(x, y, s, **options): 619 | """Puts text in a figure. 620 | 621 | x: number 622 | y: number 623 | s: string 624 | options: keyword args passed to plt.text 625 | """ 626 | options = _Underride(options, 627 | fontsize=16, 628 | verticalalignment='top', 629 | horizontalalignment='left') 630 | plt.text(x, y, s, **options) 631 | 632 | 633 | LEGEND = True 634 | LOC = None 635 | 636 | def Config(**options): 637 | """Configures the plot. 638 | 639 | Pulls options out of the option dictionary and passes them to 640 | the corresponding plt functions. 641 | """ 642 | names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale', 643 | 'xticks', 'yticks', 'axis', 'xlim', 'ylim'] 644 | 645 | for name in names: 646 | if name in options: 647 | getattr(plt, name)(options[name]) 648 | 649 | global LEGEND 650 | LEGEND = options.get('legend', LEGEND) 651 | 652 | if LEGEND: 653 | global LOC 654 | LOC = options.get('loc', LOC) 655 | frameon = options.get('frameon', True) 656 | 657 | warnings.filterwarnings('error', category=UserWarning) 658 | try: 659 | plt.legend(loc=LOC, frameon=frameon) 660 | except UserWarning: 661 | pass 662 | warnings.filterwarnings('default', category=UserWarning) 663 | 664 | # x and y ticklabels can be made invisible 665 | val = options.get('xticklabels', None) 666 | if val is not None: 667 | if val == 'invisible': 668 | ax = plt.gca() 669 | labels = ax.get_xticklabels() 670 | plt.setp(labels, visible=False) 671 | 672 | val = options.get('yticklabels', None) 673 | if val is not None: 674 | if val == 'invisible': 675 | ax = plt.gca() 676 | labels = ax.get_yticklabels() 677 | plt.setp(labels, visible=False) 678 | 679 | 680 | def Show(**options): 681 | """Shows the plot. 682 | 683 | For options, see Config. 684 | 685 | options: keyword args used to invoke various plt functions 686 | """ 687 | clf = options.pop('clf', True) 688 | Config(**options) 689 | plt.show() 690 | if clf: 691 | Clf() 692 | 693 | 694 | def Plotly(**options): 695 | """Shows the plot. 696 | 697 | For options, see Config. 698 | 699 | options: keyword args used to invoke various plt functions 700 | """ 701 | clf = options.pop('clf', True) 702 | Config(**options) 703 | import plotly.plotly as plotly 704 | url = plotly.plot_mpl(plt.gcf()) 705 | if clf: 706 | Clf() 707 | return url 708 | 709 | 710 | def Save(root=None, formats=None, **options): 711 | """Saves the plot in the given formats and clears the figure. 712 | 713 | For options, see Config. 714 | 715 | Args: 716 | root: string filename root 717 | formats: list of string formats 718 | options: keyword args used to invoke various plt functions 719 | """ 720 | clf = options.pop('clf', True) 721 | 722 | save_options = {} 723 | for option in ['bbox_inches', 'pad_inches']: 724 | if option in options: 725 | save_options[option] = options.pop(option) 726 | 727 | Config(**options) 728 | 729 | if formats is None: 730 | formats = ['pdf', 'eps'] 731 | 732 | try: 733 | formats.remove('plotly') 734 | Plotly(clf=False) 735 | except ValueError: 736 | pass 737 | 738 | if root: 739 | for fmt in formats: 740 | SaveFormat(root, fmt, **save_options) 741 | if clf: 742 | Clf() 743 | 744 | 745 | def SaveFormat(root, fmt='eps', **options): 746 | """Writes the current figure to a file in the given format. 747 | 748 | Args: 749 | root: string filename root 750 | fmt: string format 751 | """ 752 | _Underride(options, dpi=300) 753 | filename = '%s.%s' % (root, fmt) 754 | print('Writing', filename) 755 | plt.savefig(filename, format=fmt, **options) 756 | 757 | 758 | # provide aliases for calling functions with lower-case names 759 | preplot = PrePlot 760 | subplot = SubPlot 761 | clf = Clf 762 | figure = Figure 763 | plot = Plot 764 | vlines = Vlines 765 | hlines = Hlines 766 | fill_between = FillBetween 767 | text = Text 768 | scatter = Scatter 769 | pmf = Pmf 770 | pmfs = Pmfs 771 | hist = Hist 772 | hists = Hists 773 | diff = Diff 774 | cdf = Cdf 775 | cdfs = Cdfs 776 | contour = Contour 777 | pcolor = Pcolor 778 | config = Config 779 | show = Show 780 | save = Save 781 | 782 | 783 | def main(): 784 | color_iter = _Brewer.ColorGenerator(7) 785 | for color in color_iter: 786 | print(color) 787 | 788 | 789 | if __name__ == '__main__': 790 | main() 791 | -------------------------------------------------------------------------------- /empyrical_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Pmf: Represents a Probability Mass Function (PMF). 4 | Cdf: Represents a Cumulative Distribution Function (CDF). 5 | 6 | Copyright 2019 Allen B. Downey 7 | 8 | MIT License: https://opensource.org/licenses/MIT 9 | """ 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | 16 | from scipy.interpolate import interp1d 17 | 18 | 19 | def underride(d, **options): 20 | """Add key-value pairs to d only if key is not in d. 21 | 22 | d: dictionary 23 | options: keyword args to add to d 24 | 25 | :return: modified d 26 | """ 27 | for key, val in options.items(): 28 | d.setdefault(key, val) 29 | 30 | return d 31 | 32 | 33 | class Pmf(pd.Series): 34 | """Represents a probability Mass Function (PMF).""" 35 | 36 | def __init__(self, *args, **kwargs): 37 | """Initialize a Pmf. 38 | 39 | Note: this cleans up a weird Series behavior, which is 40 | that Series() and Series([]) yield different results. 41 | See: https://github.com/pandas-dev/pandas/issues/16737 42 | """ 43 | if args: 44 | super().__init__(*args, **kwargs) 45 | else: 46 | underride(kwargs, dtype=np.float64) 47 | super().__init__([], **kwargs) 48 | 49 | def copy(self, deep=True): 50 | """Make a copy. 51 | 52 | :return: new Pmf 53 | """ 54 | return Pmf(self, copy=deep) 55 | 56 | def __getitem__(self, qs): 57 | """Look up qs and return ps.""" 58 | try: 59 | return super().__getitem__(qs) 60 | except (KeyError, ValueError, IndexError): 61 | return 0 62 | 63 | @property 64 | def qs(self): 65 | """Get the quantities. 66 | 67 | :return: NumPy array 68 | """ 69 | return self.index.values 70 | 71 | @property 72 | def ps(self): 73 | """Get the probabilities. 74 | 75 | :return: NumPy array 76 | """ 77 | return self.values 78 | 79 | def _repr_html_(self): 80 | """Returns an HTML representation of the series. 81 | 82 | Mostly used for Jupyter notebooks. 83 | """ 84 | df = pd.DataFrame(dict(probs=self)) 85 | return df._repr_html_() 86 | 87 | def normalize(self): 88 | """Make the probabilities add up to 1 (modifies self). 89 | 90 | :return: normalizing constant 91 | """ 92 | total = self.sum() 93 | self /= total 94 | return total 95 | 96 | def mean(self): 97 | """Computes expected value. 98 | 99 | :return: float 100 | """ 101 | #TODO: error if not normalized 102 | #TODO: error if the quantities are not numeric 103 | return np.sum(self.ps * self.qs) 104 | 105 | def median(self): 106 | """Median (50th percentile). 107 | 108 | :return: float 109 | """ 110 | return self.quantile(0.5) 111 | 112 | def quantile(self, ps, **kwargs): 113 | """Quantiles. 114 | 115 | Computes the inverse CDF of ps, that is, 116 | the values that correspond to the given probabilities. 117 | 118 | :return: float 119 | """ 120 | return self.make_cdf().quantile(ps, **kwargs) 121 | 122 | def var(self): 123 | """Variance of a PMF. 124 | 125 | :return: float 126 | """ 127 | m = self.mean() 128 | d = self.qs - m 129 | return np.sum(d**2 * self.ps) 130 | 131 | def std(self): 132 | """Standard deviation of a PMF. 133 | 134 | :return: float 135 | """ 136 | return np.sqrt(self.var()) 137 | 138 | def choice(self, *args, **kwargs): 139 | """Makes a random sample. 140 | 141 | Uses the probabilities as weights unless `p` is provided. 142 | 143 | args: same as np.random.choice 144 | kwargs: same as np.random.choice 145 | 146 | :return: NumPy array 147 | """ 148 | underride(kwargs, p=self.ps) 149 | return np.random.choice(self.qs, *args, **kwargs) 150 | 151 | def sample(self, *args, **kwargs): 152 | """Makes a random sample. 153 | 154 | Uses the probabilities as weights unless `weights` is provided. 155 | 156 | This function returns an array containing a sample of the quantities in this Pmf, 157 | which is different from Series.sample, which returns a Series with a sample of 158 | the rows in the original Series. 159 | 160 | args: same as Series.sample 161 | options: same as Series.sample 162 | 163 | :return: NumPy array 164 | """ 165 | series = pd.Series(self.qs) 166 | underride(kwargs, weights=self.ps) 167 | sample = series.sample(*args, **kwargs) 168 | return sample.values 169 | 170 | def plot(self, **options): 171 | """Plot the Pmf as a line. 172 | 173 | :param options: passed to plt.plot 174 | :return: 175 | """ 176 | underride(options, label=self.name) 177 | plt.plot(self.qs, self.ps, **options) 178 | 179 | def bar(self, **options): 180 | """Makes a bar plot. 181 | 182 | options: passed to plt.bar 183 | """ 184 | underride(options, label=self.name) 185 | plt.bar(self.qs, self.ps, **options) 186 | 187 | def add(self, x): 188 | """Computes the Pmf of the sum of values drawn from self and x. 189 | 190 | x: another Pmf or a scalar or a sequence 191 | 192 | :return: new Pmf 193 | """ 194 | if isinstance(x, Pmf): 195 | return pmf_conv(self, x, np.add.outer) 196 | else: 197 | return Pmf(self.ps, index=self.qs + x) 198 | 199 | __add__ = add 200 | __radd__ = add 201 | 202 | def sub(self, x): 203 | """Computes the Pmf of the diff of values drawn from self and other. 204 | 205 | x: another Pmf or a scalar or a sequence 206 | 207 | :return: new Pmf 208 | """ 209 | if isinstance(x, Pmf): 210 | return pmf_conv(self, x, np.subtract.outer) 211 | else: 212 | return Pmf(self.ps, index=self.qs - x) 213 | 214 | subtract = sub 215 | __sub__ = sub 216 | 217 | def rsub(self, x): 218 | """Computes the Pmf of the diff of values drawn from self and other. 219 | 220 | x: another Pmf or a scalar or a sequence 221 | 222 | :return: new Pmf 223 | """ 224 | if isinstance(x, Pmf): 225 | return pmf_conv(x, self, np.subtract.outer) 226 | else: 227 | return Pmf(self.ps, index=x - self.qs) 228 | 229 | __rsub__ = rsub 230 | 231 | def mul(self, x): 232 | """Computes the Pmf of the product of values drawn from self and x. 233 | 234 | x: another Pmf or a scalar or a sequence 235 | 236 | :return: new Pmf 237 | """ 238 | if isinstance(x, Pmf): 239 | return pmf_conv(self, x, np.multiply.outer) 240 | else: 241 | return Pmf(self.ps, index=self.qs * x) 242 | 243 | multiply = mul 244 | __mul__ = mul 245 | __rmul__ = mul 246 | 247 | def div(self, x): 248 | """Computes the Pmf of the ratio of values drawn from self and x. 249 | 250 | x: another Pmf or a scalar or a sequence 251 | 252 | :return: new Pmf 253 | """ 254 | if isinstance(x, Pmf): 255 | return pmf_conv(self, x, np.divide.outer) 256 | else: 257 | return Pmf(self.ps, index=self.qs / x) 258 | 259 | divide = div 260 | __div = div 261 | __truediv__ = div 262 | 263 | def rdiv(self, x): 264 | """Computes the Pmf of the ratio of values drawn from self and x. 265 | 266 | x: another Pmf or a scalar or a sequence 267 | 268 | :return: new Pmf 269 | """ 270 | if isinstance(x, Pmf): 271 | return pmf_conv(x, self, np.divide.outer) 272 | else: 273 | return Pmf(self.ps, index=x / self.qs) 274 | 275 | __rdiv__ = rdiv 276 | __rtruediv__ = rdiv 277 | 278 | def make_joint(self, other, **options): 279 | """Make joint distribution (assuming independence). 280 | 281 | :param self: 282 | :param other: 283 | :param options: passed to Pmf constructor 284 | 285 | :return: new Pmf 286 | """ 287 | qs = pd.MultiIndex.from_product([self.qs, other.qs]) 288 | ps = np.multiply.outer(self.ps, other.ps).flatten() 289 | return Pmf(ps, index=qs, **options) 290 | 291 | def marginal(self, i, name=None): 292 | """Gets the marginal distribution of the indicated variable. 293 | 294 | i: index of the variable we want 295 | name: string 296 | 297 | :return: Pmf 298 | """ 299 | # TODO: rewrite this using MultiIndex operations 300 | pmf = Pmf(name=name) 301 | for vs, p in self.items(): 302 | pmf[vs[i]] += p 303 | return pmf 304 | 305 | def conditional(self, i, j, val, name=None): 306 | """Gets the conditional distribution of the indicated variable. 307 | 308 | Distribution of vs[i], conditioned on vs[j] = val. 309 | 310 | i: index of the variable we want 311 | j: which variable is conditioned on 312 | val: the value the jth variable has to have 313 | name: string 314 | 315 | :return: Pmf 316 | """ 317 | # TODO: rewrite this using MultiIndex operations 318 | pmf = Pmf(name=name) 319 | for vs, p in self.items(): 320 | if vs[j] == val: 321 | pmf[vs[i]] += p 322 | 323 | pmf.normalize() 324 | return pmf 325 | 326 | def update(self, likelihood, data): 327 | """Bayesian update. 328 | 329 | likelihood: function that takes (data, hypo) and returns 330 | likelihood of data under hypo 331 | data: whatever format like_func understands 332 | 333 | :return: normalizing constant 334 | """ 335 | for hypo in self.qs: 336 | self[hypo] *= likelihood(data, hypo) 337 | 338 | return self.normalize() 339 | 340 | def max_prob(self): 341 | """Value with the highest probability. 342 | 343 | :return: the value with the highest probability 344 | """ 345 | return self.idxmax() 346 | 347 | def make_cdf(self, normalize=True): 348 | """Make a Cdf from the Pmf. 349 | 350 | It can be good to normalize the cdf even if the Pmf was normalized, 351 | to guarantee that the last element of `ps` is 1. 352 | 353 | :return: Cdf 354 | """ 355 | cdf = Cdf(self.cumsum()) 356 | if normalize: 357 | cdf.normalize() 358 | return cdf 359 | 360 | def quantile(self, ps): 361 | """Quantities corresponding to given probabilities. 362 | 363 | ps: sequence of probabilities 364 | 365 | return: sequence of quantities 366 | """ 367 | cdf = self.make_cdf() 368 | return cdf.quantile(ps) 369 | 370 | def credible_interval(self, p): 371 | """Credible interval containing the given probability. 372 | 373 | p: float 0-1 374 | 375 | :return: array of two quantities 376 | """ 377 | tail = (1-p) / 2 378 | ps = [tail, 1-tail] 379 | return self.quantile(ps) 380 | 381 | @staticmethod 382 | def from_seq(seq, normalize=True, sort=True, **options): 383 | """Make a PMF from a sequence of values. 384 | 385 | seq: any kind of sequence 386 | normalize: whether to normalize the Pmf, default True 387 | sort: whether to sort the Pmf by values, default True 388 | options: passed to the pd.Series constructor 389 | 390 | :return: Pmf object 391 | """ 392 | series = pd.Series(seq).value_counts(sort=False) 393 | 394 | options['copy'] = False 395 | pmf = Pmf(series, **options) 396 | 397 | if sort: 398 | pmf.sort_index(inplace=True) 399 | 400 | if normalize: 401 | pmf.normalize() 402 | 403 | return pmf 404 | 405 | # Comparison operators 406 | 407 | def gt(self, x): 408 | """Probability that a sample from this Pmf > x. 409 | 410 | x: number 411 | 412 | :return: float probability 413 | """ 414 | if isinstance(x, Pmf): 415 | return pmf_gt(self, x) 416 | else: 417 | return self[self.qs > x].sum() 418 | 419 | __gt__ = gt 420 | 421 | def lt(self, x): 422 | """Probability that a sample from this Pmf < x. 423 | 424 | x: number 425 | 426 | :return: float probability 427 | """ 428 | if isinstance(x, Pmf): 429 | return pmf_lt(self, x) 430 | else: 431 | return self[self.qs < x].sum() 432 | 433 | __lt__ = lt 434 | 435 | def ge(self, x): 436 | """Probability that a sample from this Pmf >= x. 437 | 438 | x: number 439 | 440 | :return: float probability 441 | """ 442 | if isinstance(x, Pmf): 443 | return pmf_ge(self, x) 444 | else: 445 | return self[self.qs >= x].sum() 446 | 447 | __ge__ = ge 448 | 449 | def le(self, x): 450 | """Probability that a sample from this Pmf <= x. 451 | 452 | x: number 453 | 454 | :return: float probability 455 | """ 456 | if isinstance(x, Pmf): 457 | return pmf_le(self, x) 458 | else: 459 | return self[self.qs <= x].sum() 460 | 461 | __le__ = le 462 | 463 | def eq(self, x): 464 | """Probability that a sample from this Pmf == x. 465 | 466 | x: number 467 | 468 | :return: float probability 469 | """ 470 | if isinstance(x, Pmf): 471 | return pmf_eq(self, x) 472 | else: 473 | return self[self.qs == x].sum() 474 | 475 | __eq__ = eq 476 | 477 | def ne(self, x): 478 | """Probability that a sample from this Pmf != x. 479 | 480 | x: number 481 | 482 | :return: float probability 483 | """ 484 | if isinstance(x, Pmf): 485 | return pmf_ne(self, x) 486 | else: 487 | return self[self.qs != x].sum() 488 | 489 | __ne__ = ne 490 | 491 | 492 | def pmf_conv(pmf1, pmf2, ufunc): 493 | """Convolve two PMFs. 494 | 495 | pmf1: 496 | pmf2: 497 | ufunc: elementwise function for arrays 498 | 499 | :return: new Pmf 500 | """ 501 | qs = ufunc(pmf1.qs, pmf2.qs).flatten() 502 | ps = np.multiply.outer(pmf1.ps, pmf2.ps).flatten() 503 | series = pd.Series(ps).groupby(qs).sum() 504 | return Pmf(series) 505 | 506 | 507 | def pmf_add(pmf1, pmf2): 508 | """Distribution of the sum. 509 | 510 | pmf1: 511 | pmf2: 512 | 513 | :return: new Pmf 514 | """ 515 | return pmf_conv(pmf1, pmf2, np.add.outer) 516 | 517 | 518 | def pmf_sub(pmf1, pmf2): 519 | """Distribution of the difference. 520 | 521 | pmf1: 522 | pmf2: 523 | 524 | :return: new Pmf 525 | """ 526 | return pmf_conv(pmf1, pmf2, np.subtract.outer) 527 | 528 | 529 | def pmf_mul(pmf1, pmf2): 530 | """Distribution of the product. 531 | 532 | pmf1: 533 | pmf2: 534 | 535 | :return: new Pmf 536 | """ 537 | return pmf_conv(pmf1, pmf2, np.multiply.outer) 538 | 539 | def pmf_div(pmf1, pmf2): 540 | """Distribution of the ratio. 541 | 542 | pmf1: 543 | pmf2: 544 | 545 | :return: new Pmf 546 | """ 547 | return pmf_conv(pmf1, pmf2, np.divide.outer) 548 | 549 | def pmf_outer(pmf1, pmf2, ufunc): 550 | """Computes the outer product of two PMFs. 551 | 552 | pmf1: 553 | pmf2: 554 | ufunc: function to apply to the qs 555 | 556 | :return: NumPy array 557 | """ 558 | qs = ufunc.outer(pmf1.qs, pmf2.qs) 559 | ps = np.multiply.outer(pmf1.ps, pmf2.ps) 560 | return qs * ps 561 | 562 | 563 | def pmf_gt(pmf1, pmf2): 564 | """Probability that a value from pmf1 is greater than a value from pmf2. 565 | 566 | pmf1: Pmf object 567 | pmf2: Pmf object 568 | 569 | :return: float probability 570 | """ 571 | outer = pmf_outer(pmf1, pmf2, np.greater) 572 | return outer.sum() 573 | 574 | 575 | def pmf_lt(pmf1, pmf2): 576 | """Probability that a value from pmf1 is less than a value from pmf2. 577 | 578 | pmf1: Pmf object 579 | pmf2: Pmf object 580 | 581 | :return: float probability 582 | """ 583 | outer = pmf_outer(pmf1, pmf2, np.less) 584 | return outer.sum() 585 | 586 | 587 | def pmf_ge(pmf1, pmf2): 588 | """Probability that a value from pmf1 is >= than a value from pmf2. 589 | 590 | pmf1: Pmf object 591 | pmf2: Pmf object 592 | 593 | :return: float probability 594 | """ 595 | outer = pmf_outer(pmf1, pmf2, np.greater_equal) 596 | return outer.sum() 597 | 598 | 599 | def pmf_le(pmf1, pmf2): 600 | """Probability that a value from pmf1 is <= than a value from pmf2. 601 | 602 | pmf1: Pmf object 603 | pmf2: Pmf object 604 | 605 | :return: float probability 606 | """ 607 | outer = pmf_outer(pmf1, pmf2, np.less_equal) 608 | return outer.sum() 609 | 610 | 611 | def pmf_eq(pmf1, pmf2): 612 | """Probability that a value from pmf1 equals a value from pmf2. 613 | 614 | pmf1: Pmf object 615 | pmf2: Pmf object 616 | 617 | :return: float probability 618 | """ 619 | outer = pmf_outer(pmf1, pmf2, np.equal) 620 | return outer.sum() 621 | 622 | 623 | def pmf_ne(pmf1, pmf2): 624 | """Probability that a value from pmf1 is <= than a value from pmf2. 625 | 626 | pmf1: Pmf object 627 | pmf2: Pmf object 628 | 629 | :return: float probability 630 | """ 631 | outer = pmf_outer(pmf1, pmf2, np.not_equal) 632 | return outer.sum() 633 | 634 | 635 | class Cdf(pd.Series): 636 | """Represents a Cumulative Distribution Function (CDF).""" 637 | 638 | def __init__(self, *args, **kwargs): 639 | """Initialize a Cdf. 640 | 641 | Note: this cleans up a weird Series behavior, which is 642 | that Series() and Series([]) yield different results. 643 | See: https://github.com/pandas-dev/pandas/issues/16737 644 | """ 645 | if args: 646 | super().__init__(*args, **kwargs) 647 | else: 648 | underride(kwargs, dtype=np.float64) 649 | super().__init__([], **kwargs) 650 | 651 | def copy(self, deep=True): 652 | """Make a copy. 653 | 654 | :return: new Pmf 655 | """ 656 | return Cdf(self, copy=deep) 657 | 658 | @staticmethod 659 | def from_seq(seq, normalize=True, sort=True, **options): 660 | """Make a CDF from a sequence of values. 661 | 662 | seq: any kind of sequence 663 | normalize: whether to normalize the Cdf, default True 664 | sort: whether to sort the Cdf by values, default True 665 | options: passed to the pd.Series constructor 666 | 667 | :return: CDF object 668 | """ 669 | pmf = Pmf.from_seq(seq, normalize=False, sort=sort, **options) 670 | return pmf.make_cdf(normalize=normalize) 671 | 672 | @property 673 | def qs(self): 674 | """Get the quantities. 675 | 676 | :return: NumPy array 677 | """ 678 | return self.index.values 679 | 680 | @property 681 | def ps(self): 682 | """Get the probabilities. 683 | 684 | :return: NumPy array 685 | """ 686 | return self.values 687 | 688 | def _repr_html_(self): 689 | """Returns an HTML representation of the series. 690 | 691 | Mostly used for Jupyter notebooks. 692 | """ 693 | df = pd.DataFrame(dict(probs=self)) 694 | return df._repr_html_() 695 | 696 | def plot(self, **options): 697 | """Plot the Cdf as a line. 698 | 699 | :param options: passed to plt.plot 700 | :return: 701 | """ 702 | underride(options, label=self.name) 703 | plt.plot(self.qs, self.ps, **options) 704 | 705 | def step(self, **options): 706 | """Plot the Cdf as a step function. 707 | 708 | :param options: passed to plt.step 709 | :return: 710 | """ 711 | underride(options, label=self.name, where='post') 712 | plt.step(self.qs, self.ps, **options) 713 | 714 | def normalize(self): 715 | """Make the probabilities add up to 1 (modifies self). 716 | 717 | :return: normalizing constant 718 | """ 719 | total = self.ps[-1] 720 | self /= total 721 | return total 722 | 723 | @property 724 | def forward(self, **kwargs): 725 | """Compute the forward Cdf 726 | 727 | :param kwargs: keyword arguments passed to interp1d 728 | 729 | :return array of probabilities 730 | """ 731 | 732 | underride(kwargs, kind='previous', 733 | copy=False, 734 | assume_sorted=True, 735 | bounds_error=False, 736 | fill_value=(0, 1)) 737 | 738 | interp = interp1d(self.qs, self.ps, **kwargs) 739 | return interp 740 | 741 | @property 742 | def inverse(self, **kwargs): 743 | """Compute the inverse Cdf 744 | 745 | :param kwargs: keyword arguments passed to interp1d 746 | 747 | :return array of quantities 748 | """ 749 | underride(kwargs, kind='next', 750 | copy=False, 751 | assume_sorted=True, 752 | bounds_error=False, 753 | fill_value=(self.qs[0], np.nan)) 754 | 755 | interp = interp1d(self.ps, self.qs, **kwargs) 756 | return interp 757 | 758 | # calling a Cdf like a function does forward lookup 759 | __call__ = forward 760 | 761 | # quantile is the same as an inverse lookup 762 | quantile = inverse 763 | 764 | def make_pmf(self, normalize=False): 765 | """Make a Pmf from the Cdf. 766 | 767 | :return: Cdf 768 | """ 769 | ps = self.ps 770 | diff = np.ediff1d(ps, to_begin=ps[0]) 771 | pmf = Pmf(pd.Series(diff, index=self.index.copy())) 772 | if normalize: 773 | pmf.normalize() 774 | return pmf 775 | 776 | def choice(self, *args, **kwargs): 777 | """Makes a random sample. 778 | 779 | Uses the probabilities as weights unless `p` is provided. 780 | 781 | args: same as np.random.choice 782 | options: same as np.random.choice 783 | 784 | :return: NumPy array 785 | """ 786 | # TODO: Make this more efficient by implementing the inverse CDF method. 787 | pmf = self.make_pmf() 788 | return pmf.choice(*args, **kwargs) 789 | 790 | def sample(self, *args, **kwargs): 791 | """Makes a random sample. 792 | 793 | Uses the probabilities as weights unless `weights` is provided. 794 | 795 | This function returns an array containing a sample of the quantities in this Pmf, 796 | which is different from Series.sample, which returns a Series with a sample of 797 | the rows in the original Series. 798 | 799 | args: same as Series.sample 800 | options: same as Series.sample 801 | 802 | :return: NumPy array 803 | """ 804 | # TODO: Make this more efficient by implementing the inverse CDF method. 805 | pmf = self.make_pmf() 806 | return pmf.sample(*args, **kwargs) 807 | 808 | def mean(self): 809 | """Expected value. 810 | 811 | :return: float 812 | """ 813 | return self.make_pmf().mean() 814 | 815 | def var(self): 816 | """Variance. 817 | 818 | :return: float 819 | """ 820 | return self.make_pmf().var() 821 | 822 | def std(self): 823 | """Standard deviation. 824 | 825 | :return: float 826 | """ 827 | return self.make_pmf().std() 828 | 829 | def median(self): 830 | """Median (50th percentile). 831 | 832 | :return: float 833 | """ 834 | return self.quantile(0.5) 835 | 836 | 837 | --------------------------------------------------------------------------------