├── Images ├── alpha.png ├── beta.png ├── data.png ├── fitted_line.png ├── latex_8f584943c6692f6d022403eb38917573.png ├── latex_dist.png ├── lp.png ├── sigma.png └── stan_output.png ├── PyStan.html ├── PyStan.ipynb ├── PyStan_plotting.py ├── README.md └── regression_model.pkl /Images/alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/alpha.png -------------------------------------------------------------------------------- /Images/beta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/beta.png -------------------------------------------------------------------------------- /Images/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/data.png -------------------------------------------------------------------------------- /Images/fitted_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/fitted_line.png -------------------------------------------------------------------------------- /Images/latex_8f584943c6692f6d022403eb38917573.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/latex_8f584943c6692f6d022403eb38917573.png -------------------------------------------------------------------------------- /Images/latex_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/latex_dist.png -------------------------------------------------------------------------------- /Images/lp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/lp.png -------------------------------------------------------------------------------- /Images/sigma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/sigma.png -------------------------------------------------------------------------------- /Images/stan_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/Images/stan_output.png -------------------------------------------------------------------------------- /PyStan_plotting.py: -------------------------------------------------------------------------------- 1 | import pystan 2 | import pickle 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import pandas as pd 7 | import numpy as np 8 | 9 | sns.set() # Nice plot aesthetic 10 | np.random.seed(101) 11 | 12 | # Nice plot parameters 13 | matplotlib.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 14 | ## for Palatino and other serif fonts use: 15 | # matplotlib.rc('font',**{'family':'serif','serif':['Palatino']}) 16 | matplotlib.rc('text', usetex=True) 17 | 18 | # Workflow parameter 19 | model_compile = True 20 | 21 | ## Stan Model ################################################################## 22 | 23 | model = """ 24 | data { 25 | int N; 26 | vector[N] x; 27 | vector[N] y; 28 | } 29 | parameters { 30 | real alpha; 31 | real beta; 32 | real sigma; 33 | } 34 | model { 35 | y ~ normal(alpha + beta * x, sigma); 36 | } 37 | """ 38 | 39 | ## Data and Sampling ########################################################### 40 | 41 | # Parameters to be inferred 42 | alpha = 4.0 43 | beta = 0.5 44 | sigma = 1.0 45 | 46 | # Generate and plot data 47 | x = 10 * np.random.rand(100) 48 | y = alpha + beta * x 49 | y = np.random.normal(y, scale=sigma) 50 | plt.scatter(x, y) 51 | 52 | plt.xlabel('$x$') 53 | plt.ylabel('$y$') 54 | plt.title('Scatter Plot of Data') 55 | 56 | plt.show() 57 | 58 | # Put our data in a dictionary 59 | data = {'N': len(x), 'x': x, 'y': y} 60 | 61 | if model_compile: 62 | # Compile the model 63 | sm = pystan.StanModel(model_code=model) 64 | # Save the model 65 | with open('regression_model.pkl', 'wb') as f: 66 | pickle.dump(sm, f) 67 | else: 68 | sm = pickle.load(open('regression_model.pkl', 'rb')) 69 | 70 | # Train the model and generate samples 71 | fit = sm.sampling(data=data, iter=1000, chains=4, warmup=500, thin=1, seed=101, 72 | verbose=True) 73 | print(fit) 74 | 75 | ## Diagnostics ################################################################# 76 | 77 | summary_dict = fit.summary() 78 | df = pd.DataFrame(summary_dict['summary'], 79 | columns=summary_dict['summary_colnames'], 80 | index=summary_dict['summary_rownames']) 81 | 82 | alpha_mean, beta_mean = df['mean']['alpha'], df['mean']['beta'] 83 | 84 | # Extracting traces 85 | alpha = fit['alpha'] 86 | beta = fit['beta'] 87 | sigma = fit['sigma'] 88 | lp = fit['lp__'] 89 | 90 | # Plotting regression line 91 | x_min, x_max = -0.5, 10.5 92 | x_plot = np.linspace(x_min, x_max, 100) 93 | 94 | # Plot a subset of sampled regression lines 95 | for i in np.random.randint(0, len(alpha), 1000): 96 | plt.plot(x_plot, alpha[i] + beta[i] * x_plot, color='lightsteelblue', 97 | alpha=0.005) 98 | 99 | # Plot mean regression line 100 | plt.plot(x_plot, alpha_mean + beta_mean * x_plot) 101 | plt.scatter(x, y) 102 | 103 | plt.xlabel('$x$') 104 | plt.ylabel('$y$') 105 | plt.title('Fitted Regression Line') 106 | plt.xlim(x_min, x_max) 107 | plt.show() 108 | 109 | 110 | def plot_trace(param, param_name='parameter'): 111 | """Plot the trace and posterior of a parameter.""" 112 | 113 | # Summary statistics 114 | mean = np.mean(param) 115 | median = np.median(param) 116 | cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5) 117 | 118 | # Plotting 119 | plt.subplot(2,1,1) 120 | plt.plot(param) 121 | plt.xlabel('samples') 122 | plt.ylabel(param_name) 123 | plt.axhline(mean, color='r', lw=2, linestyle='--') 124 | plt.axhline(median, color='c', lw=2, linestyle='--') 125 | plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2) 126 | plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2) 127 | plt.title('Trace and Posterior Distribution for {}'.format(param_name)) 128 | 129 | plt.subplot(2,1,2) 130 | plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True) 131 | plt.xlabel(param_name) 132 | plt.ylabel('density') 133 | plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean') 134 | plt.axvline(median, color='c', lw=2, linestyle='--',label='median') 135 | plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label=r'95\% CI') 136 | plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2) 137 | 138 | plt.gcf().tight_layout() 139 | plt.legend() 140 | 141 | plot_trace(alpha, r'$\alpha$') 142 | plt.show() 143 | plot_trace(beta, r'$\beta$') 144 | plt.show() 145 | plot_trace(sigma, r'$\sigma$') 146 | plt.show() 147 | plot_trace(lp, r'lp\_\_') 148 | plt.show() 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An-Introduction-to-Bayesian-Inference-in-PyStan 2 | Code for blog post on Bayesian inference in PyStan published in [Towards Data Science](https://towardsdatascience.com/an-introduction-to-bayesian-inference-in-pystan-c27078e58d53). 3 | -------------------------------------------------------------------------------- /regression_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mwestt/An-Introduction-to-Bayesian-Inference-in-PyStan/e60b40a2dde4374c595d3984dbaa466f1c691388/regression_model.pkl --------------------------------------------------------------------------------