├── .gitignore ├── Probabilistic ├── ExpectationMaximisationCoinToss.ipynb ├── Importance_sampling.ipynb ├── NormalTransformations.ipynb ├── lagavin_monte_carlo.ipynb ├── lagavin_regression.ipynb ├── lagavin_requirements.txt └── test_jax_cuda.ipynb ├── README.md ├── computer_vision ├── WebcamDetector.ipynb └── WebcamInNotebook.ipynb ├── control ├── CartPole_LQR_Control.ipynb └── PID.ipynb ├── cs_algorithms ├── binary_search.ipynb ├── binary_search_tree.ipynb ├── binary_search_tree.py ├── mergesort.ipynb ├── priority_queue.ipynb ├── priority_queue.py ├── programming_challanges │ ├── construct_tree_from_edge_list.ipynb │ └── find_longest_substring.ipynb ├── quicksort.ipynb ├── red_black_tree.py └── tree_traversal.ipynb ├── dataset_specific ├── FAA │ ├── FAA_Exploration.ipynb │ ├── FAA_regression.ipynb │ └── README.md ├── README.md └── housing │ ├── XGBoost_regression.ipynb │ ├── box_cox_tests.py │ ├── extended_linear_regression.ipynb │ ├── naive_linear_regression.ipynb │ └── preprocess.py ├── ml_algorithms ├── Autoencoder.ipynb ├── BayesianOptimization.ipynb ├── CollaborativeFiltering.ipynb ├── GAN_basic.ipynb ├── HMM_Python.ipynb ├── KMeans.ipynb ├── LDA.ipynb ├── LearningToRank.ipynb ├── Mean.ipynb ├── NeuralNetInformationTheory.ipynb ├── PCA_intuition.ipynb ├── ToRead.ipynb ├── VAE.ipynb ├── VI.ipynb ├── WGAN.ipynb └── diffusion │ ├── fm.ipynb │ ├── fm_bimodal_guass_1d.ipynb │ ├── fm_halfmoon.ipynb │ └── fm_requirements.txt ├── nlp_text ├── README.md ├── SimpleWikipedia_Gensim.ipynb ├── Spacy.ipynb └── WineExploration.ipynb ├── probabilistic_programming ├── Changepoint detection.ipynb ├── Clustering_3clusters_mixture.ipynb ├── Clustering_simple_2clusters_categorical.ipynb ├── CoinFlip.ipynb ├── MultivariateGaussian.ipynb ├── README.md └── numpyro │ ├── GMM-2D.ipynb │ ├── GMM-discrete.ipynb │ ├── GMM.ipynb │ ├── LKJ_Cholesky_Covariance.ipynb │ ├── LKJ_Prior.ipynb │ ├── MixtureSameFamily.ipynb │ ├── Poisson_unnormalized.ipynb │ ├── SVI │ ├── SVI_Part_01.ipynb │ └── linear_regression_SVI.ipynb │ ├── env │ ├── README.md │ └── conda_env.yml │ ├── linear_regression.ipynb │ ├── multivariate_gaussian_mixture_model.ipynb │ ├── numpyro_getting_started.ipynb │ └── test-batch-shape.ipynb ├── survival └── Lifelines_quickstart.ipynb └── visualizations ├── Derivative anim.ipynb └── Gradient Descent.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints/ 2 | *.idea/ 3 | *.pyc 4 | *.bz2 5 | *.mm 6 | *.csv 7 | *.index 8 | *.pb 9 | *.gz 10 | *.tar 11 | *.tar.gz 12 | *.jpeg 13 | *.jpg 14 | *.png 15 | -------------------------------------------------------------------------------- /Probabilistic/ExpectationMaximisationCoinToss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Expectation Maximisation coin-toss. - WIP\n", 8 | "\n", 9 | "Illustration of the expectation-maximisation (EM) algorithm on a coin-toss problem. Based on Nature's [\"What is the expectation maximization algorithm?\"](http://ai.stanford.edu/~chuongdo/papers/em_tutorial.pdf).\n", 10 | "\n", 11 | "consider a coin-flipping experiment in which we are given a pair\n", 12 | "of coins $A$ and $B$ of unknown biases. On any given flip, coin $A$ will land on heads with probability $\\theta_A$ and tails with probability $1 - \\theta_A$ and similarly for coin $B$ which has bias $\\theta_B$.\n", 13 | "\n", 14 | "Lets say that we run an experiment where we repeat the following procedure five times:\n", 15 | "\n", 16 | "* Randomly choose one of the two coins out of a bag (we don't know which coin we take, each coin can be taken with equal probablity).\n", 17 | "* Perform $m$ independent coin tosses with the selected coin and count how many times heads is observed.\n", 18 | "* Put the selected coin back in the bag.\n", 19 | "* Repeat this $n$ times.\n", 20 | "\n", 21 | "Our goal now is to estimate $\\theta_A$ and $\\theta_B$ from this experiment. The problem is that we have incomplete data, we don't know which coin we picked, the selection of coin $A$ or $B$ is a latent variable.\n", 22 | "\n", 23 | "This notebook will illustrate how we can estimate $\\theta_A$ and $\\theta_B$ with help of the EM algorithm. We will refer to both of them together as the model parameters $\\theta = \\{\\theta_A, \\theta_B\\}$.\n", 24 | "\n", 25 | "* https://am207.github.io/2017/wiki/EM.html\n", 26 | "* https://stats.stackexchange.com/questions/72774/numerical-example-to-understand-expectation-maximization\n", 27 | "* https://math.stackexchange.com/questions/25111/how-does-expectation-maximization-work\n", 28 | "* https://math.stackexchange.com/questions/81004/how-does-expectation-maximization-work-in-coin-flipping-problem\n", 29 | "* https://people.duke.edu/~ccc14/sta-663/EMAlgorithm.html\n", 30 | "* https://am207.github.io/2017/wiki/EM.html\n", 31 | "* https://media.nature.com/original/nature-assets/nbt/journal/v26/n8/extref/nbt1406-S1.pdf\n", 32 | "* https://mk-minchul.github.io/EM/\n", 33 | "* log-likelihood:\n", 34 | " * https://blog.metaflow.fr/ml-notes-why-the-log-likelihood-24f7b6c40f83\n", 35 | " * https://stats.stackexchange.com/questions/174481/why-to-optimize-max-log-probability-instead-of-probability\n", 36 | "* https://media.nature.com/original/nature-assets/nbt/journal/v26/n8/extref/nbt1406-S1.pdf\n", 37 | "* http://www.rmki.kfki.hu/~banmi/elte/bishop_em.pdf" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import warnings\n", 47 | "import numpy as np\n", 48 | "import matplotlib.pyplot as plt\n", 49 | "import seaborn as sns\n", 50 | "\n", 51 | "import scipy.stats as stats\n", 52 | "\n", 53 | "sns.set_style('darkgrid')\n", 54 | "%matplotlib inline\n", 55 | "warnings.filterwarnings('ignore')\n", 56 | "np.random.seed(42)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "Each experiment will have observations the number of times heads was seen for each coin picked. Lets represent this as $X = (x_1, ... , x_n)$ for $n$ the number of times a coin was chosen from the bag. $x_i$ is the number of times heads was observed for the $i$-th picked coin.\n", 64 | "\n", 65 | "Note that for each coin $i$ drawn we don't actually know the true identity ($A$ or $B$). We will represent this unknown as the latent variable $Z = (z_i, ... , z_n)$ with $z_i \\in \\{A, B\\}$.\n", 66 | "\n", 67 | "Note that this means that $x_i$ is dependent on $z_i$.\n", 68 | "\n", 69 | "![Latent variable graph](https://i.imgur.com/jphWJbX.png)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stdout", 79 | "output_type": "stream", 80 | "text": [ 81 | "zs: [ 0.8 0.3 0.3 0.3 0.8 0.8 0.8 0.3 0.3 0.3 0.8 0.3 0.3 0.8 0.8\n", 82 | " 0.8 0.8 0.3 0.8 0.8 0.3 0.8 0.8 0.8 0.8 0.3 0.8 0.3 0.3 0.8\n", 83 | " 0.3 0.8 0.8 0.3 0.3 0.3 0.8 0.8 0.3 0.8 0.8 0.8 0.8 0.3 0.8\n", 84 | " 0.3 0.8 0.3 0.3 0.8]\n", 85 | "xs: [ 5 4 5 5 8 6 10 2 1 2 8 2 4 9 9 8 9 4 10 5 4 9 10 7 7\n", 86 | " 4 7 1 2 9 5 8 9 1 2 2 7 8 5 8 9 7 7 3 7 3 8 3 0 9]\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "m = 10 # Number of flips per coin drawn\n", 92 | "n = 50 # Number of times coin is drawn from bag\n", 93 | "theta_A = 0.8 # Bias of coin A towards landing on heads\n", 94 | "theta_B = 0.3 # Bias of coin B towards landing on heads\n", 95 | "theta = (theta_A, theta_B)\n", 96 | "p_A = 0.5 # Probability of picking A\n", 97 | "p_B = 0.5 # Probability of picking B\n", 98 | "p = (p_A, p_B)\n", 99 | "\n", 100 | "# Choose between A or B (note that is is latent and not known when we run EM)\n", 101 | "# Note that z should be either A or B but we'll represent those by there bias here\n", 102 | "zs = np.random.choice(theta, size=n, p=p, replace=True)\n", 103 | "print('zs: ', zs)\n", 104 | "\n", 105 | "# Observed number of head for each coin drawn\n", 106 | "xs = np.random.binomial(n=m, p=zs)\n", 107 | "print('xs: ', xs)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "The [expectation-maximization (EM) algorithm](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) will help us estimate the values for latent variables $\\theta_A$ and $\\theta_B$. It will do this by trying to find the maximum likelihood of the parameters $\\theta$ in the latent-variable model in an iterative way. The algorithm starts with an initial guess of the parameters $\\theta_A$ and $\\theta_B$ and then follows an alternating 2-step procedure until convergence:\n", 115 | "\n", 116 | "1. **E-step**: Given the current guess of $\\theta_A$ and $\\theta_B$ and the observations $X$, determine for each $x_i$ how likely each parameter ($\\theta_A$ or $\\theta_B$) was to have generated the sequence of observed heads. We can then assign weights to each sample according to how likely it was generated by coin $A$ or $B$.\n", 117 | "\n", 118 | "2. **M-step**: Given the weighted data from the previous step, update the guesses for $\\theta_A$ and $\\theta_B$ by maximizing their likelihoods.\n", 119 | "\n", 120 | "\n", 121 | "![EM algorithm](https://i.imgur.com/Taa1gcv.png)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "\n", 129 | "The [expectation-maximization (EM) algorithm](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) will try to optimize the likelihood of the dataset $X$ given the parameters $\\theta$. We'll write this as:\n", 130 | "\n", 131 | "$$ \\max_{\\theta} p(X \\mid \\theta) = \\max_{\\theta} \\prod_{i=1}^N p(x_i \\mid \\theta) = \\max_{\\theta} \\prod_{i=1}^N \\sum_c^{\\{A,B\\}} p(x_i, z_i=c \\mid \\theta)$$\n", 132 | "\n", 133 | "\n", 134 | "\n", 135 | "The expectation maximization algorithm is a refinement on this basic idea. Rather than picking the single most likely completion of the missing coin assignments on each iteration, the expectation maximization algorithm computes probabilities for each possible completion of the missing data, using the current parameters θˆ(t). These probabilities are used to create a weighted training set consisting of all possible completions of the data. Finally, a modified version of maximum likelihood estimation that deals with weighted training examples provides new parameter estimates, θ ˆ(t+1). By using weighted training examples rather than choosing the single best completion, the expectation maximization algorithm accounts for the confidence of the model in each completion of the data\n", 136 | "\n", 137 | "\n", 138 | "In summary, the expectation maximization algorithm alternates between the steps of guessing a probability distribution over completions of missing data given the current model (known as the E-step) and then reestimating the model parameters using these completions (known as the M-step).\n", 139 | "\n", 140 | "\n", 141 | "The name ‘E-step’ comes from the fact that one does not usually need to form the probability distribution over completions explicitly, but rather need only compute ‘expected’ sufficient statistics over these completions\n", 142 | "\n", 143 | "the name ‘M-step’ comes from the fact that model reestimation can be thought of as ‘maximization’ of the expected log-likelihood of the data\n", 144 | "\n", 145 | "\n", 146 | "\n", 147 | "\n", 148 | "More pragmatically speaking, the EM algorithm is an iterative method that alternates between computing a conditional expectation and solving a maximization problem, hence the name expectation maximization\n", 149 | "\n", 150 | "\n", 151 | "We do not know z, so instead of using the true values of z we're going to use their expectations. In particular we wil compute \\theta by maximizing its likelihood under the expected values of z." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "## E-step (Expectation)\n", 166 | "\n", 167 | "\n", 168 | "\n", 169 | "In the E-step, a [likelihood function](https://en.wikipedia.org/wiki/Likelihood_function) of latent variables $z_i$ given the data $x_i$ is computed using the current estimates of $\\theta$. This likelihood function describes the plausibility of the latent variables $z_i$ being assigned to $A$ or $B$. For each $z_i$ we can write this as:\n", 170 | "\n", 171 | "$$ p(z_i=c \\mid x_i, \\theta) \\quad c \\in \\{A, B\\}$$\n", 172 | "\n", 173 | "The\n", 174 | "\n", 175 | "Since $\\theta$ is fixed and simplifying the notation we write this as $p(c \\mid x_i)$. We can then use the [Bayes rule](https://en.wikipedia.org/wiki/Bayes%27_theorem ) to write:\n", 176 | "\n", 177 | "$$ p(c \\mid x_i) = \\frac{p(x_i \\mid c) p(c)}{p(x_i)} $$\n", 178 | "\n", 179 | "Now $p(c)=0.5$ since we pick one out of 2 coins out of the bag. \n", 180 | "\n", 181 | "$p(x_i \\mid c)$ is the probabiliy of $x_i$ times head out of $m$ flips when coin $c$ was drawn. It dependents only on $\\theta_c$ and follows the [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) by definition ($m$ coinflips with $x_i$ heads).\n", 182 | "\n", 183 | "$$p(x_i \\mid c) = {m \\choose x_i}\\,\\theta_c^{x_i}(1-\\theta_c)^{m-x_i}$$\n", 184 | "\n", 185 | "\n", 186 | "$p(x_i)$ is the distribution of observing the number of heads in the sequence without knowing which coin is selected and is equal to $p(x_i \\mid A)p(A) + p(x_i \\mid B)p(B)$.\n", 187 | "\n", 188 | "Now, since the [binomial coeffient](https://en.wikipedia.org/wiki/Binomial_coefficient) appears in all terms of our fraction we can remove it and write the likelihood function as:\n", 189 | "\n", 190 | "$$ p(c \\mid x_i) = \\frac{ \\theta_c^{x_i}(1-\\theta_c)^{m-x_i} p(c)}{\\theta_A^{x_i}(1-\\theta_A)^{m-x_i} p(A) + \\theta_B^{x_i}(1-\\theta_B)^{m-x_i} p(B)} $$\n", 191 | "\n", 192 | "\n", 193 | "### Example\n", 194 | "\n", 195 | "Lets say we start out with initial parameter guesses $\\theta_A = 0.6$ and $\\theta_B = 0.5$, we pick one coin at random ($p=0.5$) and observe 8 times heads and 2 times tails. If we will in the equations we find that:\n", 196 | "\n", 197 | "$$\n", 198 | "\\theta_A^{x_i}(1-\\theta_A)^{m - x_i} p(A) = 0.6^8*(1-0.6)^2*0.5 ≈ 0.0013 \\\\\n", 199 | "\\theta_B^{x_i}(1-\\theta_B)^{m - x_i} p(B) = 0.5^8*(1-0.5)^2*0.5 ≈ 0.0004 \\\\\n", 200 | "\\Downarrow \\\\\n", 201 | "p(A \\mid x_i=8) = \\frac{0.0013}{0.0013 + 0.0004} ≈ 0.73\\\\\n", 202 | "p(B \\mid x_i=8) = \\frac{0.0004}{0.0013 + 0.0004} ≈ 0.27\n", 203 | "$$" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 3, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "# E-step: calculate probability distributions over possible completions\n", 213 | "\n", 214 | "def e_step(tA, tB):\n", 215 | " # Weighted data over n chosen coins\n", 216 | " w_A_H = 0\n", 217 | " w_A_T = 0\n", 218 | " w_B_H = 0\n", 219 | " w_B_T = 0\n", 220 | "\n", 221 | " # For each chosen coin\n", 222 | " for x in xs:\n", 223 | " coef_A = ((tA**x) * ((1-tA)**(m-x))) * p_A\n", 224 | " coef_B = ((tB**x) * ((1-tB)**(m-x))) * p_B\n", 225 | "\n", 226 | " coef_AB = coef_A + coef_B\n", 227 | "\n", 228 | " p_A_x = coef_A / coef_AB\n", 229 | " p_B_x = coef_B / coef_AB\n", 230 | "\n", 231 | " w_A_H += p_A_x * x\n", 232 | " w_A_T += p_A_x * (m - x)\n", 233 | " w_B_H += p_B_x * x\n", 234 | " w_B_T += p_B_x * (m - x)\n", 235 | " return w_A_H, w_A_T, w_B_H, w_B_T" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 4, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# M -step\n", 245 | "def m_step(w_A_H, w_A_T, w_B_H, w_B_T):\n", 246 | " tA = w_A_H / (w_A_H + w_A_T)\n", 247 | " tB = w_B_H / (w_B_H + w_B_T)\n", 248 | " return tA, tB" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 5, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "Step 0: A=0.51, B=0.5\n", 261 | "Step 1: A=0.5902116444208838, B=0.5573355860971709\n", 262 | "Step 2: A=0.6281270259605728, B=0.5197970793429936\n", 263 | "Step 3: A=0.7224739053517539, B=0.421905516120847\n", 264 | "Step 4: A=0.7973640474115157, B=0.33034231342770287\n", 265 | "Step 5: A=0.8097749133059613, B=0.3138660059289479\n", 266 | "Step 6: A=0.8108856254648229, B=0.3126081748766472\n", 267 | "Step 7: A=0.811010957899099, B=0.3125615147577201\n", 268 | "Step 8: A=0.8110385124408093, B=0.3125781746882331\n", 269 | "Step 9: A=0.811048367275363, B=0.31258799858636593\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "tA, tB = 0.51, 0.5\n", 275 | "\n", 276 | "for i in range(10):\n", 277 | " print('Step {}: A={}, B={}'.format(i, tA, tB))\n", 278 | " w_A_H, w_A_T, w_B_H, w_B_T = e_step(tA, tB)\n", 279 | " tA, tB = m_step(w_A_H, w_A_T, w_B_H, w_B_T)" 280 | ] 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "display_name": "Python 3", 286 | "language": "python", 287 | "name": "python3" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.6.6" 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 2 304 | } 305 | -------------------------------------------------------------------------------- /Probabilistic/NormalTransformations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Transformations of normal distribution\n", 8 | "\n", 9 | "* \"Any continuous distribution can be turned into a normal distribution through a process called Gaussianization\"\n", 10 | " - https://stats.stackexchange.com/a/137083/26888\n", 11 | " - http://papers.nips.cc/paper/1856-gaussianization.pdf\n", 12 | "* \"a transformation of variables is actually a change of measure\"\n", 13 | " - https://stats.stackexchange.com/a/294871/26888\n", 14 | " - http://www.statsathome.com/2017/06/26/measure-theory-made-rediculously-simple/\n", 15 | " - https://vannevar.ece.uw.edu/techsite/papers/documents/UWEETR-2006-0008.pdf\n", 16 | "* Affine transform of normal distribution is again a normal distribution\n", 17 | "* [Related distributions](https://en.wikipedia.org/wiki/Log-normal_distribution)\n", 18 | " - [Log-normal](https://en.wikipedia.org/wiki/Log-normal_distribution)\n", 19 | " - [Folder normal](https://en.wikipedia.org/wiki/Folded_normal_distribution) and [half-normal](https://en.wikipedia.org/wiki/Half-normal_distribution)\n", 20 | " - [Chi-distribution](https://en.wikipedia.org/wiki/Chi_distribution)\n", 21 | " - [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution)\n", 22 | " - Chi-square disribution can lead to gamma distribution and Beta distribution.\n", 23 | " - [Levy distribution](https://en.wikipedia.org/wiki/L%C3%A9vy_distribution)\n", 24 | " - [Cauchy distribution](https://en.wikipedia.org/wiki/Cauchy_distribution)\n", 25 | " - [Rayleigh distribution](https://en.wikipedia.org/wiki/Rayleigh_distribution)\n", 26 | " - [Student's t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution)\n", 27 | " - [F-distribution](https://en.wikipedia.org/wiki/F-distribution)\n", 28 | " - F-distribution could lead to Beta distribution\n", 29 | "* Approximations\n", 30 | " - [Logistic distribution](https://en.wikipedia.org/wiki/Logistic_distribution) resembles the normal distribution in shape but has heavier tails. \n", 31 | " \n", 32 | " \n", 33 | "Gamma distribution !!! https://en.wikipedia.org/wiki/Gamma_distribution\n", 34 | "\n", 35 | "https://www.johndcook.com/blog/conjugate_prior_diagram/" 36 | ] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "Python 3", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.7.0" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 2 60 | } 61 | -------------------------------------------------------------------------------- /Probabilistic/lagavin_requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel 2 | numpy<2 3 | scipy 4 | # JAX 5 | jax[cuda12] 6 | # CUDA Dependencies 7 | # nvidia-cuda-runtime-cu12==12.1.* 8 | # nvidia-cudnn-cu12 9 | # Plotting 10 | matplotlib 11 | seaborn 12 | tqdm 13 | ipywidgets 14 | -------------------------------------------------------------------------------- /Probabilistic/test_jax_cuda.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!nvidia-smi" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!nvcc --version" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Avoid JAX from using up all the GPU memory\n", 28 | "import os\n", 29 | "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 4, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import jax\n", 39 | "import jax.numpy as jnp\n", 40 | "from jax import random\n", 41 | "from jax.lib import xla_bridge" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "print(f\"{jax.__version__=}\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "backend = str(xla_bridge.get_backend().platform)\n", 60 | "print(f\"JAX backend: {backend}\")\n", 61 | "assert backend == \"gpu\"" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "key = random.PRNGKey(0)\n", 71 | "x = random.normal(key, (10,))\n", 72 | "print(f\"{x=}\")\n", 73 | "print(f\"{x.device=}\")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "tags": [] 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "y = jnp.dot(x, x.T).block_until_ready() # Should run on the GPU\n", 85 | "print(f\"{y.device=}\")\n", 86 | "\n", 87 | "\n", 88 | "%timeit y = jnp.dot(x, x.T).block_until_ready() # Should run on the GPUa" 89 | ] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "Python 3 (ipykernel)", 95 | "language": "python", 96 | "name": "python3" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.11.10" 109 | }, 110 | "vscode": { 111 | "interpreter": { 112 | "hash": "219a60bd7c2cef5a225b6a627b5206a9c3c0d26466c6e9be19e95afd9fc964d1" 113 | } 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 4 118 | } 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Collection of Jupyter notebooks 2 | 3 | Topics: 4 | 5 | * Probabilistic programming 6 | * NLP & text analytics 7 | * Dataset explorations 8 | * Algorithms 9 | * Visualizations 10 | -------------------------------------------------------------------------------- /computer_vision/WebcamDetector.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Webcam bounding box detection in notebook\n", 8 | "\n", 9 | "Use OpenCV to read webcam feed, use tensorflow do perform bounding box detection and use ipywidgets to display videostream.\n", 10 | "\n", 11 | "The TensorFlow detection model is currently the [`ssdlite_mobilenet_v2_coco`](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) pretrained model from the Tensorflow [object detection API](https://github.com/tensorflow/models/tree/master/research/object_detection).\n", 12 | "\n", 13 | "This notebook is inspired by the [object detection demo notebook](https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "import sys\n", 24 | "import time\n", 25 | "import tarfile\n", 26 | "import zipfile\n", 27 | "import six.moves.urllib as urllib\n", 28 | "from collections import defaultdict\n", 29 | "from io import StringIO\n", 30 | "import threading\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "import tensorflow as tf\n", 34 | "import cv2\n", 35 | "\n", 36 | "from ipywidgets import widgets\n", 37 | "from IPython.display import display" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Download the model from the tensorflow repo" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# Threhold used for prediction\n", 54 | "PREDICT_THRESHOLD = 0.3\n", 55 | "\n", 56 | "# What model to download.\n", 57 | "MODEL_NAME = 'ssdlite_mobilenet_v2_coco_2018_05_09'\n", 58 | "MODEL_FILE = MODEL_NAME + '.tar.gz'\n", 59 | "DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'\n", 60 | "PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def download_model():\n", 70 | " \"\"\"\n", 71 | " Download pretrained model from the tensorflow repo\n", 72 | " \"\"\"\n", 73 | " opener = urllib.request.URLopener()\n", 74 | " opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)\n", 75 | " tar_file = tarfile.open(MODEL_FILE)\n", 76 | " for file in tar_file.getmembers():\n", 77 | " file_name = os.path.basename(file.name)\n", 78 | " if 'frozen_inference_graph.pb' in file_name:\n", 79 | " tar_file.extract(file, os.getcwd())\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "print('Downloading model')\n", 89 | "download_model()\n", 90 | "print('Model downloaded')" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "## Model loading functionality" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def get_model_graph(model_path):\n", 107 | " \"\"\"\n", 108 | " Load the downloaded Tensorflow model into memory.\n", 109 | " \"\"\"\n", 110 | " detection_graph = tf.Graph()\n", 111 | " with detection_graph.as_default():\n", 112 | " od_graph_def = tf.GraphDef()\n", 113 | " with tf.gfile.GFile(model_path, 'rb') as fid:\n", 114 | " serialized_graph = fid.read()\n", 115 | " od_graph_def.ParseFromString(serialized_graph)\n", 116 | " tf.import_graph_def(od_graph_def, name='')\n", 117 | " return detection_graph\n", 118 | "\n", 119 | "\n", 120 | "def get_tf_tensors(graph):\n", 121 | " \"\"\"\n", 122 | " Get handles to input and output tensors.\n", 123 | " \"\"\"\n", 124 | " ops = graph.get_operations()\n", 125 | " all_tensor_names = {output.name for op in ops for output in op.outputs}\n", 126 | " tensor_dict = {}\n", 127 | " for key in ['detection_boxes', 'detection_scores', 'detection_classes']:\n", 128 | " tensor_name = key + ':0'\n", 129 | " if tensor_name in all_tensor_names:\n", 130 | " tensor_dict[key] = graph.get_tensor_by_name(\n", 131 | " tensor_name)\n", 132 | " input_image_tensor = graph.get_tensor_by_name('image_tensor:0')\n", 133 | " return tensor_dict, input_image_tensor\n", 134 | "\n", 135 | "\n", 136 | "def get_graph_tensors(model_path):\n", 137 | " \"\"\"\n", 138 | " Load model into memory and get the inputs.\n", 139 | " \"\"\"\n", 140 | " graph = get_model_graph(model_path)\n", 141 | " tensor_dict, input_image_tensor = get_tf_tensors(graph)\n", 142 | " return graph, tensor_dict, input_image_tensor" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "class WebcamBoundingBoxThread(threading.Thread):\n", 152 | " \"\"\"\n", 153 | " Background thread to read image from webcam, detect objects, and update\n", 154 | " interactive display.\n", 155 | " \"\"\"\n", 156 | " def __init__(self, interactive_img, interactive_framerate_text, model_path):\n", 157 | " super(WebcamBoundingBoxThread, self).__init__()\n", 158 | " self._stop_event = threading.Event()\n", 159 | " self.model_path = model_path\n", 160 | " self.interactive_img = interactive_img\n", 161 | " self.interactive_framerate_text = interactive_framerate_text\n", 162 | " \n", 163 | " def setup(self):\n", 164 | " \"\"\"\n", 165 | " Setup tensorflow graph and webcam connection.\n", 166 | " \"\"\"\n", 167 | " # Setup Tensorflow detector\n", 168 | " (self.detection_graph, self.tensor_dict, \n", 169 | " self.input_image_tensor) = get_graph_tensors(self.model_path)\n", 170 | " # Setup camera capture\n", 171 | " self.camera = cv2.VideoCapture(0)\n", 172 | "\n", 173 | " def stop(self):\n", 174 | " \"\"\"\n", 175 | " Stop thread.\n", 176 | " \"\"\"\n", 177 | " self._stop_event.set()\n", 178 | "\n", 179 | " def stopped(self):\n", 180 | " \"\"\"\n", 181 | " True iff tread is stopped.\n", 182 | " \"\"\"\n", 183 | " return self._stop_event.is_set()\n", 184 | " \n", 185 | " def run_tf_inference(self, image):\n", 186 | " \"\"\"\n", 187 | " Run tf inference to detect bounding boxes\n", 188 | " \"\"\"\n", 189 | " output_dict = self.session.run(\n", 190 | " self.tensor_dict,\n", 191 | " feed_dict={self.input_image_tensor: np.expand_dims(image, 0)})\n", 192 | " # all outputs are float32 numpy arrays, so convert types as appropriate\n", 193 | " output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)\n", 194 | " output_dict['detection_boxes'] = output_dict['detection_boxes'][0]\n", 195 | " output_dict['detection_scores'] = output_dict['detection_scores'][0]\n", 196 | " return output_dict\n", 197 | " \n", 198 | " def process_frame(self, frame):\n", 199 | " \"\"\"\n", 200 | " Process a single frame for bounding box detection.\n", 201 | " \"\"\"\n", 202 | " image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", 203 | " output_dict = self.run_tf_inference(image)\n", 204 | " for i in range(100):\n", 205 | " # Assume predictions are ordered by probability\n", 206 | " if output_dict['detection_scores'][i] < PREDICT_THRESHOLD:\n", 207 | " break\n", 208 | " ymin, xmin, ymax, xmax = output_dict['detection_boxes'][i,:]\n", 209 | " ymin_pix = int(ymin*image.shape[0])\n", 210 | " xmin_pix = int(xmin*image.shape[1])\n", 211 | " ymax_pix = int(ymax*image.shape[0])\n", 212 | " xmax_pix = int(xmax*image.shape[1])\n", 213 | " cv2.rectangle(image, (xmin_pix,ymin_pix), (xmax_pix,ymax_pix), (0,255,0), 3)\n", 214 | " return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)\n", 215 | " \n", 216 | " def run(self):\n", 217 | " \"\"\"\n", 218 | " Start thread.\n", 219 | " Creates loop that reads frame from webcam, detects bounding boxes,\n", 220 | " and updates interactive image until thread is stopped.\n", 221 | " \"\"\"\n", 222 | " self.setup()\n", 223 | " is_capturing = self.camera.isOpened()\n", 224 | " start_time = time.time()\n", 225 | " with self.detection_graph.as_default():\n", 226 | " self.session = tf.Session()\n", 227 | " with self.session as sess:\n", 228 | " while is_capturing and not self.stopped():\n", 229 | " start_time = time.time()\n", 230 | " is_capturing, frame = self.camera.read()\n", 231 | " processed_frame = self.process_frame(frame)\n", 232 | " self.interactive_img.value = cv2.imencode('.png', processed_frame)[1].tostring()\n", 233 | " self.interactive_framerate_text.value = '{:.2f}'.format(1/(time.time() - start_time))\n", 234 | " self.camera.release()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "print('Start')\n", 244 | "# Create interactive image\n", 245 | "interactive_img = widgets.Image(\n", 246 | " value=b'',\n", 247 | " format='png',\n", 248 | " width=800,\n", 249 | " height=600,\n", 250 | ")\n", 251 | "\n", 252 | "# Create interactive text to display framerate\n", 253 | "interactive_framerate_text = widgets.Text(\n", 254 | " value='0',\n", 255 | " placeholder='0',\n", 256 | " description='Fps:',\n", 257 | " disabled=False\n", 258 | ")\n", 259 | "\n", 260 | "\n", 261 | "# Create thread to update interactive image with webcam\n", 262 | "print('Setup thread')\n", 263 | "thread = WebcamBoundingBoxThread(\n", 264 | " interactive_img, interactive_framerate_text, PATH_TO_CKPT)\n", 265 | "thread.daemon = True\n", 266 | "print('Display image')\n", 267 | "display(interactive_img)\n", 268 | "display(interactive_framerate_text)\n", 269 | "print('Starting thread')\n", 270 | "thread.start()\n", 271 | "\n", 272 | "\n", 273 | "# Stop thread upon exit\n", 274 | "print('Running loop')\n", 275 | "while True:\n", 276 | " try:\n", 277 | " time.sleep(1)\n", 278 | " except:\n", 279 | " thread.stop()\n", 280 | " thread.join()\n", 281 | " break\n", 282 | "\n", 283 | " \n", 284 | "print('Finish')" 285 | ] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.6.5" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 2 309 | } 310 | -------------------------------------------------------------------------------- /computer_vision/WebcamInNotebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Show webcam feed in notebook\n", 8 | "\n", 9 | "Use OpenCV to read webcam feed, use ipywidgets to display videostream." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import time\n", 19 | "import cv2\n", 20 | "import threading\n", 21 | "from ipywidgets import widgets\n", 22 | "from IPython.display import display" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "class WebcamThread(threading.Thread):\n", 32 | " \"\"\"\n", 33 | " Thread to read frame by frame from webcam stream and update the\n", 34 | " interactive image widget.\n", 35 | " \"\"\"\n", 36 | " def __init__(self, interactive_img):\n", 37 | " \"\"\"\n", 38 | " Create thread.\n", 39 | " \n", 40 | " Args:\n", 41 | " interactive_img (Widget): Interactive widget to display image\n", 42 | " \"\"\"\n", 43 | " super(WebcamThread, self).__init__()\n", 44 | " self._stop_event = threading.Event()\n", 45 | " self.interactive_img = interactive_img\n", 46 | " self.camera = cv2.VideoCapture(0)\n", 47 | "\n", 48 | " def stop(self):\n", 49 | " \"\"\"\n", 50 | " Stop thread.\n", 51 | " \"\"\"\n", 52 | " self._stop_event.set()\n", 53 | "\n", 54 | " def stopped(self):\n", 55 | " \"\"\"\n", 56 | " True iff tread is stopped.\n", 57 | " \"\"\"\n", 58 | " return self._stop_event.is_set()\n", 59 | " \n", 60 | " def run(self):\n", 61 | " \"\"\"\n", 62 | " Start thread.\n", 63 | " Creates loop that reads frame from webcam and updates interactive image\n", 64 | " until thread is stopped.\n", 65 | " \"\"\"\n", 66 | " is_capturing = self.camera.isOpened()\n", 67 | " while is_capturing and not self.stopped():\n", 68 | " is_capturing, frame = self.camera.read()\n", 69 | " interactive_img.value = cv2.imencode(\n", 70 | " '.png', frame)[1].tostring()\n", 71 | " time.sleep(0.01)\n", 72 | " self.camera.release()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "scrolled": false 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# Create interactive image\n", 84 | "interactive_img = widgets.Image(\n", 85 | " value=b'',\n", 86 | " format='png',\n", 87 | " width=800,\n", 88 | " height=600,\n", 89 | ")\n", 90 | "\n", 91 | "# Create thread to update interactive image with webcam\n", 92 | "thread = WebcamThread(interactive_img)\n", 93 | "display(interactive_img)\n", 94 | "thread.start()\n", 95 | "\n", 96 | "# Stop thread upon exit\n", 97 | "while True:\n", 98 | " try:\n", 99 | " time.sleep(1)\n", 100 | " except:\n", 101 | " thread.stop()\n", 102 | " thread.join()\n", 103 | " break\n", 104 | "\n", 105 | "print('Finish')" 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python 3", 112 | "language": "python", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.6.5" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 2 130 | } 131 | -------------------------------------------------------------------------------- /cs_algorithms/binary_search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Binary search" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "-1\n", 20 | "[0, 1, 2, 3, 4, 5, 6, 7]\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# Data to test\n", 26 | "# 0 1 2 3 4 5 6 7\n", 27 | "ls = [1, 2, 3, 6, 7, 9, 13, 16]\n", 28 | "\n", 29 | "def binary_search(elem: int, ls: list) -> int:\n", 30 | " \"\"\"\n", 31 | " Search for the given elem in the given list and return the index.\n", 32 | " Return -1 is the elem was not found.\n", 33 | " \"\"\"\n", 34 | " low = 0\n", 35 | " high = len(ls)-1\n", 36 | " while True:\n", 37 | " middle = low + ((high - low) // 2)\n", 38 | " if ls[middle] == elem:\n", 39 | " return middle\n", 40 | " if low == high:\n", 41 | " return -1\n", 42 | " elif ls[middle] > elem:\n", 43 | " high = middle - 1\n", 44 | " elif ls[middle] < elem:\n", 45 | " low = middle + 1\n", 46 | "\n", 47 | "\n", 48 | "print(binary_search(4, ls))\n", 49 | "print([binary_search(i, ls) for i in ls])" 50 | ] 51 | } 52 | ], 53 | "metadata": { 54 | "kernelspec": { 55 | "display_name": "Python 3", 56 | "language": "python", 57 | "name": "python3" 58 | }, 59 | "language_info": { 60 | "codemirror_mode": { 61 | "name": "ipython", 62 | "version": 3 63 | }, 64 | "file_extension": ".py", 65 | "mimetype": "text/x-python", 66 | "name": "python", 67 | "nbconvert_exporter": "python", 68 | "pygments_lexer": "ipython3", 69 | "version": "3.7.6" 70 | } 71 | }, 72 | "nbformat": 4, 73 | "nbformat_minor": 4 74 | } 75 | -------------------------------------------------------------------------------- /cs_algorithms/binary_search_tree.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Illustrate implementations of binary search tree and red-black tree\n", 8 | "\n", 9 | "Implementation are in same folder as this notebook" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "# Import tree implementations from this folder\n", 22 | "from binary_search_tree import BST\n", 23 | "from red_black_tree import RedBlackTree\n", 24 | "import random" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "None\n", 37 | "{11}\n", 38 | "{{9}11}\n", 39 | "{{{1}9}11}\n", 40 | "{{{1}9}11{17}}\n", 41 | "{{{1}9}11{{16}17}}\n", 42 | "{{{1}9}11{{16}17{19}}}\n", 43 | "{{{1}9}11{{16}17{{18}19}}}\n", 44 | "{{{1}9}11{{{14}16}17{{18}19}}}\n", 45 | "min: 1\n", 46 | "max: 19\n", 47 | "del: 11 result: {{{1}9}14{{16}17{{18}19}}}\n", 48 | "del: 9 result: {{1}14{{16}17{{18}19}}}\n", 49 | "del: 1 result: {14{{16}17{{18}19}}}\n", 50 | "del: 17 result: {14{{16}18{19}}}\n", 51 | "del: 16 result: {14{18{19}}}\n", 52 | "del: 19 result: {14{18}}\n", 53 | "del: 18 result: {14}\n", 54 | "del: 14 result: None\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "tree = BST()\n", 60 | "print(tree)\n", 61 | "\n", 62 | "sample = random.sample(range(1, 20), 8)\n", 63 | "for x in sample:\n", 64 | " tree.put(x)\n", 65 | " print(tree)\n", 66 | " \n", 67 | "print('min: ', tree.min())\n", 68 | "print('max: ', tree.max())\n", 69 | "\n", 70 | "# while not tree.root is None:\n", 71 | "# tree.delete_min()\n", 72 | "# print tree\n", 73 | "\n", 74 | "for i in sample:\n", 75 | " tree.delete(i)\n", 76 | " print('del: ', i , ' result: ', tree)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "None\n", 89 | "x: 15 tree: {15}\n", 90 | "x: 18 tree: {{15}18}\n", 91 | "x: 17 tree: {{15}17{18}}\n", 92 | "x: 10 tree: {{10}15{17{18}}}\n", 93 | "x: 5 tree: {{5}10{15{17{18}}}}\n", 94 | "x: 4 tree: {{4}5{10{15{17{18}}}}}\n", 95 | "x: 7 tree: {{4}5{{7}10{15{17{18}}}}}\n", 96 | "x: 2 tree: {{2}4{5{{7}10{15{17{18}}}}}}\n", 97 | "min: 2\n", 98 | "max: 18\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "tree = RedBlackTree()\n", 104 | "print(tree)\n", 105 | "\n", 106 | "sample = random.sample(range(1, 20), 8)\n", 107 | "for x in sample:\n", 108 | " tree.put(x)\n", 109 | " print('x: ', x, 'tree: ', tree)\n", 110 | " \n", 111 | "print('min: ', tree.min())\n", 112 | "print('max: ', tree.max())" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.4" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 1 137 | } 138 | -------------------------------------------------------------------------------- /cs_algorithms/binary_search_tree.py: -------------------------------------------------------------------------------- 1 | class Node(object): 2 | """ 3 | Class to represent a node in a binary search tree 4 | """ 5 | def __init__(self, key, left=None, right=None): 6 | """ 7 | Initialise the node 8 | """ 9 | self.left = left 10 | self.right = right 11 | self.key = key 12 | 13 | def get(self, get_key): 14 | """ 15 | Get the node with the given key. Return None if not found. 16 | """ 17 | if self.key == get_key: 18 | return self.key 19 | elif self.key < get_key and not self.right is None: 20 | return self.right.get(get_key) 21 | elif self.key > get_key and not self.left is None: 22 | return self.left.get(get_key) 23 | else: 24 | return None 25 | 26 | def put(self, new_key): 27 | """ 28 | Put the given key in the BST starting at this node. 29 | """ 30 | if self.key < new_key: 31 | if self.right is None: 32 | self.right = Node(new_key) 33 | else: 34 | self.right.put(new_key) 35 | elif self.key > new_key: 36 | if self.left is None: 37 | self.left = Node(new_key) 38 | else: 39 | self.left.put(new_key) 40 | 41 | def min(self): 42 | """ 43 | Return the minimum in the BST under this node 44 | """ 45 | if self.left is None: 46 | return self.key 47 | else: 48 | return self.left.min() 49 | 50 | def max(self): 51 | """ 52 | Return the maximum in the BST under this node 53 | """ 54 | if self.right is None: 55 | return self.key 56 | else: 57 | return self.right.max() 58 | 59 | def __str__(self): 60 | """ 61 | Represent the node as a string. 62 | """ 63 | s ='{' 64 | if not self.left is None: 65 | s += str(self.left) 66 | s += str(self.key) 67 | if not self.right is None: 68 | s += str(self.right) 69 | s += '}' 70 | return s 71 | 72 | 73 | class BST(object): 74 | """ 75 | Class to represent a binary search tree 76 | """ 77 | def __init__(self): 78 | self.root = None 79 | 80 | def get(self, key): 81 | return self.root.get(key) 82 | 83 | def put(self, key): 84 | if self.root is None: 85 | self.root = Node(key) 86 | else: 87 | self.root.put(key) 88 | 89 | def min(self): 90 | if self.root is None: 91 | return None 92 | return self.root.min() 93 | 94 | def max(self): 95 | if self.root is None: 96 | return None 97 | return self.root.max() 98 | 99 | def delete_min(self): 100 | # replace the root with the root with the minimum removed 101 | self.root = BST.delete_min_rec(self.root) 102 | 103 | @staticmethod 104 | def delete_min_rec(node): 105 | # If the current left node is empty, the current node is the minimum 106 | # return the right of the current node to replace the current node in the previous call. 107 | if node.left is None: 108 | return node.right 109 | # The current left node should be the left node with the minimum deleted 110 | node.left = BST.delete_min_rec(node.left) 111 | # Return the current node to replace the current node with the minimum removed 112 | return node 113 | 114 | def delete(self, key): 115 | # The tree with the key element removed is the same tree with the element removed 116 | self.root = BST.delete_rec(self.root, key) 117 | 118 | @staticmethod 119 | def delete_rec(n, key): 120 | if n is None: 121 | return None 122 | if key < n.key: 123 | # If the deleted key is on the left, replace the left part of the tree 124 | # with the left part were key is deleted. 125 | n.left = BST.delete_rec(n.left, key) 126 | elif key > n.key: 127 | # If the deleted key is on the right, replace the right part of the tree 128 | # with the right part were key is deleted. 129 | n.right = BST.delete_rec(n.right, key) 130 | else: # key == n.key 131 | # Remove the key at this position 132 | if n.right is None: 133 | # If the right is empty, just replace the current node with its left part 134 | return n.left 135 | if n.left is None: 136 | # If the left is empty, just replace the current node with its right part 137 | return n.right 138 | # Left or right is not empty 139 | # all nodes from n.left < key, all nodes from n.right > key 140 | # so new key needs to be the minimum key from n.right to keep this property true. 141 | # Create new node as minimum of t.right (minimum has never any left nodes) 142 | new_n = Node(n.right.min()) 143 | # Remove this minimum from n.right and set the result as the right of the new node 144 | new_n.right = BST.delete_min_rec(n.right) 145 | # set the left node of the current node (minimum of right, minimum has no left nodes) 146 | new_n.left = n.left 147 | n = new_n 148 | return n 149 | 150 | def __str__(self): 151 | """ 152 | Represent the tree as a string. 153 | """ 154 | return str(self.root) -------------------------------------------------------------------------------- /cs_algorithms/mergesort.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Merge sort" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "original = [1]\n", 20 | "sorted = [1]\n", 21 | "\n", 22 | "original = [1, 2, 3]\n", 23 | "sorted = [1, 2, 3]\n", 24 | "\n", 25 | "original = [3, 2, 1]\n", 26 | "sorted = [1, 2, 3]\n", 27 | "\n", 28 | "original = [2, 1, 4, 3]\n", 29 | "sorted = [1, 2, 3, 4]\n", 30 | "\n", 31 | "original = [2, 7, 6, 8, 3, 5, 4, 1]\n", 32 | "sorted = [1, 2, 3, 4, 5, 6, 7, 8]\n", 33 | "\n", 34 | "original = [1, 8, 8, 8, 5]\n", 35 | "sorted = [1, 5, 8, 8, 8]\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "def merge(left: list, right: list, output: list):\n", 41 | " l = r = o = 0\n", 42 | " while o < len(output):\n", 43 | " if l >= len(left):\n", 44 | " output[o] = right[r]\n", 45 | " o += 1\n", 46 | " r += 1\n", 47 | " continue\n", 48 | " if r >= len(right):\n", 49 | " output[o] = left[l]\n", 50 | " o += 1\n", 51 | " l += 1\n", 52 | " continue\n", 53 | " if left[l] <= right[r]:\n", 54 | " output[o] = left[l]\n", 55 | " o += 1\n", 56 | " l += 1\n", 57 | " continue\n", 58 | " if left[l] > right[r]:\n", 59 | " output[o] = right[r]\n", 60 | " o += 1\n", 61 | " r += 1\n", 62 | " continue\n", 63 | " \n", 64 | " \n", 65 | "def mergesort(ls: list):\n", 66 | " if len(ls) == 1:\n", 67 | " return ls # list of length 1 is sorted by default\n", 68 | " mid_idx = len(ls) // 2\n", 69 | " left = mergesort(ls[:mid_idx])\n", 70 | " right = mergesort(ls[mid_idx:])\n", 71 | " merge(left, right, ls)\n", 72 | " return ls\n", 73 | "\n", 74 | "\n", 75 | "# Test\n", 76 | "ls = [1]\n", 77 | "print(f'original = {ls}')\n", 78 | "mergesort(ls)\n", 79 | "print(f'sorted = {ls}')\n", 80 | "print('')\n", 81 | "ls = [1, 2, 3]\n", 82 | "print(f'original = {ls}')\n", 83 | "mergesort(ls)\n", 84 | "print(f'sorted = {ls}')\n", 85 | "print('')\n", 86 | "ls = [3, 2, 1]\n", 87 | "print(f'original = {ls}')\n", 88 | "mergesort(ls)\n", 89 | "print(f'sorted = {ls}')\n", 90 | "print('')\n", 91 | "ls = [2, 1, 4, 3]\n", 92 | "print(f'original = {ls}')\n", 93 | "mergesort(ls)\n", 94 | "print(f'sorted = {ls}')\n", 95 | "print('')\n", 96 | "ls = [2, 7, 6, 8, 3, 5, 4, 1]\n", 97 | "print(f'original = {ls}')\n", 98 | "mergesort(ls)\n", 99 | "print(f'sorted = {ls}')\n", 100 | "print('')\n", 101 | "ls = [1, 8, 8, 8, 5]\n", 102 | "print(f'original = {ls}')\n", 103 | "mergesort(ls)\n", 104 | "print(f'sorted = {ls}')" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.7.6" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 4 129 | } 130 | -------------------------------------------------------------------------------- /cs_algorithms/priority_queue.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Illustrate implementation of priority queue\n", 8 | "\n", 9 | "Implementation is in python file in the same folder as this notebook." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "from priority_queue import BinaryHeap\n", 22 | "import random" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "13 [None, 13]\n", 35 | "13 [None, 13, 8]\n", 36 | "13 [None, 13, 8, 2]\n", 37 | "19 [None, 19, 13, 2, 8]\n", 38 | "19 [None, 19, 13, 2, 8, 12]\n", 39 | "19 [None, 19, 13, 11, 8, 12, 2]\n", 40 | "19 [None, 19, 13, 11, 8, 12, 2, 7]\n", 41 | "19 [None, 19, 15, 11, 13, 12, 2, 7, 8]\n", 42 | "\n", 43 | "start emptying the heap\n", 44 | "19 [None, 15, 13, 11, 8, 12, 2, 7]\n", 45 | "15 [None, 13, 12, 11, 8, 7, 2]\n", 46 | "13 [None, 12, 8, 11, 2, 7]\n", 47 | "12 [None, 11, 8, 7, 2]\n", 48 | "11 [None, 8, 2, 7]\n", 49 | "8 [None, 7, 2]\n", 50 | "7 [None, 2]\n", 51 | "2 [None]\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "pq = BinaryHeap()\n", 57 | "\n", 58 | "sample = random.sample(range(1, 20), 8)\n", 59 | "\n", 60 | "for s in sample:\n", 61 | " pq.insert(s)\n", 62 | " print(pq.max(), pq.heap)\n", 63 | "\n", 64 | "print('\\nstart emptying the heap')\n", 65 | "while not pq.is_empty():\n", 66 | " m = pq.delete_max()\n", 67 | " print(m, pq.heap)" 68 | ] 69 | } 70 | ], 71 | "metadata": { 72 | "kernelspec": { 73 | "display_name": "Python 3", 74 | "language": "python", 75 | "name": "python3" 76 | }, 77 | "language_info": { 78 | "codemirror_mode": { 79 | "name": "ipython", 80 | "version": 3 81 | }, 82 | "file_extension": ".py", 83 | "mimetype": "text/x-python", 84 | "name": "python", 85 | "nbconvert_exporter": "python", 86 | "pygments_lexer": "ipython3", 87 | "version": "3.6.4" 88 | } 89 | }, 90 | "nbformat": 4, 91 | "nbformat_minor": 1 92 | } 93 | -------------------------------------------------------------------------------- /cs_algorithms/priority_queue.py: -------------------------------------------------------------------------------- 1 | class BinaryHeap(): 2 | """ 3 | Binary Heap datastructure. 4 | Each node is larger than or equal to the keys in that node's two children (if any). 5 | The largest key is found at the root. 6 | 7 | The parent of the node at index i can be found at index floor(i/2). 8 | The 2 children of a node at index i can be found at index 2i and 2i+1 9 | 10 | The heap is indexed by the algorithm starting from index 1. 11 | """ 12 | def __init__(self): 13 | # Set the heap 14 | self.heap = [None] 15 | 16 | def size(self): 17 | return len(self.heap)-1 18 | 19 | def swim(self, i): 20 | """ 21 | Bottom-up reheapify (swim). 22 | Fix heap violation when a node's key becomes LARGER than that node's PARENT key. 23 | """ 24 | # Keep exchanging the parent at floor(i/2) if this parent is smaller than the current child at i, 25 | # and while i is still in the range of the heap list 26 | while i > 1 and self.heap[i//2] < self.heap[i]: 27 | # Exchange parent with child if parent < child 28 | self.heap[i//2], self.heap[i] = self.heap[i], self.heap[i//2] 29 | # Update position in heap 30 | i //= 2 31 | 32 | def sink(self, i): 33 | """ 34 | Top-down reheapify (sink). 35 | Fix heap violation when a node's key becomes SMALLER than than one or both of that 36 | node's CHILDREN's keys. 37 | """ 38 | # Keep exchanging the larger child at 2i or 2i+1 (j) with the current parent at i, 39 | # while the child index is still withing the heap, and the parent is still smaller than the child. 40 | while 2*i <= self.size(): 41 | # Get the child at position j=2i 42 | j = 2*i 43 | if j < self.size() and self.heap[j] < self.heap[j+1]: 44 | # Set j to j+1 if the child at j+1 is larger than the child at j. 45 | # The child that needs to be exchanged with the parent needs to be 46 | # larger than the other child. 47 | j += 1 48 | if not self.heap[i] < self.heap[j]: 49 | # Stop if the parent at i is larger or equal to the child at j 50 | break 51 | # Exchange parent with larger child 52 | self.heap[i], self.heap[j] = self.heap[j], self.heap[i] 53 | # Update i to the child because we need to check the new child at j 54 | # in the next iteration. 55 | i = j 56 | 57 | def insert(self, key): 58 | """ 59 | Insert the given key into the heap. 60 | """ 61 | # Insert the key at the end of the list 62 | self.heap.append(key) 63 | # Fix the possible heap violation 64 | self.swim(self.size()) 65 | 66 | def max(self): 67 | """ 68 | Return the maximum element in the heap 69 | """ 70 | # The maximum element is at index 1. (index 0 is not used) 71 | return self.heap[1] 72 | 73 | def delete_max(self): 74 | """ 75 | Delete and return the maximal element in the heap 76 | """ 77 | max = self.max() 78 | n = self.size() 79 | # Exchange the maximum with the last elemnt in the heap 80 | self.heap[1], self.heap[n] = self.heap[n], self.heap[1] 81 | # Remove the last (maximum now) element 82 | self.heap.pop() 83 | # Fix the heap violation 84 | self.sink(1) 85 | return max 86 | 87 | def is_empty(self): 88 | """ 89 | Return true if and only if this heap is empty. 90 | """ 91 | return self.size() == 0 92 | -------------------------------------------------------------------------------- /cs_algorithms/programming_challanges/construct_tree_from_edge_list.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Construct a Tree Given Its Edges\n", 8 | "\n", 9 | "Example from: http://xahlee.info/python/python_construct_tree_from_edge.html\n", 10 | "\n", 11 | "Problem: given a list of edges of a tree: [child, parent], construct the tree. Here's a sample input in Python nested list syntax: [[0, 2], [3, 0], [1, 4], [2, 4]]." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from collections import defaultdict\n", 21 | "import json" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Edge is [child, parent] list\n", 31 | "edges = [[0, 6], [17, 5], [2, 7], [4, 14], [12, 9], [15, 5], \n", 32 | " [11, 1], [14, 8], [16, 6], [5, 1], [10, 7], [6, 10], [8, 2], [13, 1], \n", 33 | " [1, 12], [7, 1], [3, 2], [19, 12], [18, 19]]" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def construct_tree(ls_of_edges):\n", 43 | " \"\"\"\n", 44 | " Construct a tree given a list of edges\n", 45 | " \"\"\"\n", 46 | " # Represent the tree as a dictionary.\n", 47 | " # Each node is represented as a dictionary that holds references to its children.\n", 48 | " # Use a defaultdict, so that upon the first access of each element\n", 49 | " # that element refers to an empty dictonary\n", 50 | " # https://docs.python.org/2/library/collections.html#collections.defaultdict\n", 51 | " tree = defaultdict(dict)\n", 52 | " # To find the root hold a set of parents and a set of children\n", 53 | " child_set = set()\n", 54 | " parent_set = set()\n", 55 | " # fill the dictionary\n", 56 | " for child, parent in ls_of_edges:\n", 57 | " # The parent holds the child nodes.\n", 58 | " tree[parent][child] = tree[child]\n", 59 | " # Get all the children and parents from the list of edges\n", 60 | " child_set.add(child)\n", 61 | " parent_set.add(parent)\n", 62 | " # Get and return the root\n", 63 | " root = parent_set.difference(child_set).pop()\n", 64 | " # Get the tree under the root and append the root as the root node\n", 65 | " return {9: tree[root]}" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "{\n", 78 | " \"9\": {\n", 79 | " \"12\": {\n", 80 | " \"1\": {\n", 81 | " \"11\": {},\n", 82 | " \"5\": {\n", 83 | " \"17\": {},\n", 84 | " \"15\": {}\n", 85 | " },\n", 86 | " \"13\": {},\n", 87 | " \"7\": {\n", 88 | " \"2\": {\n", 89 | " \"8\": {\n", 90 | " \"14\": {\n", 91 | " \"4\": {}\n", 92 | " }\n", 93 | " },\n", 94 | " \"3\": {}\n", 95 | " },\n", 96 | " \"10\": {\n", 97 | " \"6\": {\n", 98 | " \"0\": {},\n", 99 | " \"16\": {}\n", 100 | " }\n", 101 | " }\n", 102 | " }\n", 103 | " },\n", 104 | " \"19\": {\n", 105 | " \"18\": {}\n", 106 | " }\n", 107 | " }\n", 108 | " }\n", 109 | "}\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "print(json.dumps(construct_tree(edges), indent=1))" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.6.4" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 1 139 | } 140 | -------------------------------------------------------------------------------- /cs_algorithms/programming_challanges/find_longest_substring.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Find the longest substring\n", 8 | "\n", 9 | "Given string like this: `ccaababbaccbabbcc`. \n", 10 | "Find the longest contiguous substring made up of only two distinct chars." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "given_string = 'ccaababbaccbabbcc'" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def find_longest_substring_distinct(s):\n", 29 | " \"\"\"\n", 30 | " Find the longest substring made up of only 2 distint chars in s.\n", 31 | " \"\"\"\n", 32 | " # max_length holds a tuple (max_length, start_idx, stop_idx)\n", 33 | " max_length = (0,0,0)\n", 34 | " left_idx = 0\n", 35 | " right_idx = 0\n", 36 | " # dictionary to hold the character count between the left and the right index\n", 37 | " char_dict = {s[0]:1}\n", 38 | " # iterate over string and update the right index\n", 39 | " while right_idx < len(s)-1:\n", 40 | " if len(char_dict) <= 2:\n", 41 | " # update max length if the current substring is longer than the\n", 42 | " # current maxlength substring\n", 43 | " new_length = right_idx - left_idx + 1\n", 44 | " if new_length > max_length[0]:\n", 45 | " max_length = (new_length, left_idx, right_idx)\n", 46 | " # update the right index and the char_dict\n", 47 | " right_idx += 1\n", 48 | " right_char = s[right_idx]\n", 49 | " if right_char in char_dict:\n", 50 | " char_dict[right_char] += 1\n", 51 | " else:\n", 52 | " char_dict[right_char] = 1\n", 53 | " else: # len(char_dict) > 2\n", 54 | " # start updating the left index\n", 55 | " left_char = s[left_idx]\n", 56 | " char_dict[left_char] -= 1\n", 57 | " if char_dict[left_char] == 0:\n", 58 | " del char_dict[left_char]\n", 59 | " left_idx += 1\n", 60 | "# print(left_idx, right_idx, s[left_idx:right_idx+1], len(char_dict), char_dict)\n", 61 | " return s[max_length[1]:max_length[2]+1], max_length[0], max_length[1], max_length[2]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "('aababba', 7, 2, 8)\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "print(find_longest_substring_distinct(given_string))" 79 | ] 80 | } 81 | ], 82 | "metadata": { 83 | "kernelspec": { 84 | "display_name": "Python 3", 85 | "language": "python", 86 | "name": "python3" 87 | }, 88 | "language_info": { 89 | "codemirror_mode": { 90 | "name": "ipython", 91 | "version": 3 92 | }, 93 | "file_extension": ".py", 94 | "mimetype": "text/x-python", 95 | "name": "python", 96 | "nbconvert_exporter": "python", 97 | "pygments_lexer": "ipython3", 98 | "version": "3.6.4" 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 1 103 | } 104 | -------------------------------------------------------------------------------- /cs_algorithms/quicksort.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quicksort" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "scrolled": false 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "original = [1]\n", 22 | "sorted = [1]\n", 23 | "\n", 24 | "original = [1, 2, 3]\n", 25 | "sorted = [1, 2, 3]\n", 26 | "\n", 27 | "original = [3, 2, 1]\n", 28 | "sorted = [1, 2, 3]\n", 29 | "\n", 30 | "original = [2, 1, 4, 3]\n", 31 | "sorted = [1, 2, 3, 4]\n", 32 | "\n", 33 | "original = [2, 7, 6, 8, 3, 5, 4, 1]\n", 34 | "sorted = [1, 2, 3, 4, 5, 6, 7, 8]\n", 35 | "\n", 36 | "original = [1, 8, 8, 8, 5]\n", 37 | "sorted = [1, 5, 8, 8, 8]\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "def quicksort(ls: list) -> list:\n", 43 | " \"\"\"\n", 44 | " Sort the given list inplace.\n", 45 | " \"\"\"\n", 46 | " quicksort_rec(ls, 0, len(ls)-1)\n", 47 | "\n", 48 | " \n", 49 | "def quicksort_rec(ls: list, start: int, stop: int) -> list:\n", 50 | " \"\"\"\n", 51 | " Sort the given list between start and stop inplace\n", 52 | " \"\"\"\n", 53 | " if start >= stop:\n", 54 | " return\n", 55 | " else:\n", 56 | " # partition\n", 57 | " part_idx = partition(ls, start, stop)\n", 58 | " # sort left and right\n", 59 | " quicksort_rec(ls, start, part_idx)\n", 60 | " quicksort_rec(ls, part_idx+1, stop)\n", 61 | "\n", 62 | " \n", 63 | "def partition(ls: list, start: int, stop: int) -> int:\n", 64 | " \"\"\"\n", 65 | " Partition the given list between start and stop. left <= partition element <= right\n", 66 | " \n", 67 | " The partition point `part_elem` is chosen to be the first element ls[start]\n", 68 | " (note this could also be randomly chosen).\n", 69 | " The partition is created so that all elements <= `part_elem` are to the left of `part_elem`\n", 70 | " and all > `part_elem` are to the right of `part_elem`.\n", 71 | " \n", 72 | " Return the partition index that divides the partitions.\n", 73 | " \"\"\"\n", 74 | " if start == stop:\n", 75 | " return start # If only one element start is the partion index (no real partioning is happening)\n", 76 | " idx_part = start\n", 77 | " part_elem = ls[start] # Partition element is chosen as first element\n", 78 | " idx_right = stop\n", 79 | " while idx_part < idx_right:\n", 80 | " if ls[idx_part+1] < ls[idx_part]:\n", 81 | " # Element next to partition element is smaller so swap\n", 82 | " ls[idx_part], ls[idx_part+1] = ls[idx_part+1], ls[idx_part]\n", 83 | " idx_part += 1\n", 84 | " else:\n", 85 | " # ls[idx_part+1] >= ls[idx_part]\n", 86 | " # element next to partition element is larger so keep and try to swap with a right element\n", 87 | " if ls[idx_right] >= part_elem:\n", 88 | " # Element to the right is larger than partion element so keep it\n", 89 | " idx_right -= 1\n", 90 | " else:\n", 91 | " # ls[idx_right] < part_elem:\n", 92 | " # Element to to the right should be swapped\n", 93 | " # Swap with element next to partition\n", 94 | " ls[idx_part+1], ls[idx_right] = ls[idx_right], ls[idx_part+1]\n", 95 | " idx_right -= 1\n", 96 | " return idx_part\n", 97 | "\n", 98 | "\n", 99 | "# Test\n", 100 | "ls = [1]\n", 101 | "print(f'original = {ls}')\n", 102 | "quicksort(ls)\n", 103 | "print(f'sorted = {ls}')\n", 104 | "print('')\n", 105 | "ls = [1, 2, 3]\n", 106 | "print(f'original = {ls}')\n", 107 | "quicksort(ls)\n", 108 | "print(f'sorted = {ls}')\n", 109 | "print('')\n", 110 | "ls = [3, 2, 1]\n", 111 | "print(f'original = {ls}')\n", 112 | "quicksort(ls)\n", 113 | "print(f'sorted = {ls}')\n", 114 | "print('')\n", 115 | "ls = [2, 1, 4, 3]\n", 116 | "print(f'original = {ls}')\n", 117 | "quicksort(ls)\n", 118 | "print(f'sorted = {ls}')\n", 119 | "print('')\n", 120 | "ls = [2, 7, 6, 8, 3, 5, 4, 1]\n", 121 | "print(f'original = {ls}')\n", 122 | "quicksort(ls)\n", 123 | "print(f'sorted = {ls}')\n", 124 | "print('')\n", 125 | "ls = [1, 8, 8, 8, 5]\n", 126 | "print(f'original = {ls}')\n", 127 | "quicksort(ls)\n", 128 | "print(f'sorted = {ls}')" 129 | ] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.7.6" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 1 153 | } 154 | -------------------------------------------------------------------------------- /cs_algorithms/red_black_tree.py: -------------------------------------------------------------------------------- 1 | class Node(object): 2 | """ 3 | Class to represent a node in a binary search tree 4 | """ 5 | def __init__(self, key, left=None, right=None, is_red=False): 6 | """ 7 | Initialise the node 8 | """ 9 | self.left = left 10 | self.right = right 11 | self.key = key 12 | self.is_red = is_red 13 | 14 | def get(self, get_key): 15 | """ 16 | Get the node with the given key. Return None if not found. 17 | """ 18 | if self.key == get_key: 19 | return self.key 20 | elif self.key < get_key and not self.right is None: 21 | return self.right.get(get_key) 22 | elif self.key > get_key and not self.left is None: 23 | return self.left.get(get_key) 24 | else: 25 | return None 26 | 27 | def min(self): 28 | """ 29 | Return the minimum in the BST under this node 30 | """ 31 | if self.left is None: 32 | return self.key 33 | else: 34 | return self.left.min() 35 | 36 | def max(self): 37 | """ 38 | Return the maximum in the BST under this node 39 | """ 40 | if self.right is None: 41 | return self.key 42 | else: 43 | return self.right.max() 44 | 45 | def __str__(self): 46 | """ 47 | Represent the node as a string. 48 | """ 49 | s ='{' 50 | if not self.left is None: 51 | s += str(self.left) 52 | s += str(self.key) 53 | if not self.right is None: 54 | s += str(self.right) 55 | s += '}' 56 | return s 57 | 58 | 59 | class RedBlackTree(object): 60 | """ 61 | Class to represent a binary search tree 62 | """ 63 | def __init__(self): 64 | self.root = None 65 | 66 | def get(self, key): 67 | return self.root.get(key) 68 | 69 | def put(self, key): 70 | self.root =RedBlackTree.put_rec(self.root, key) 71 | self.root.is_red = False 72 | 73 | @staticmethod 74 | def put_rec(n, key): 75 | if n is None: 76 | # If the current node to insert the key is None return a new red Node 77 | # to be inserted on the parent 78 | return Node(key, is_red=True) 79 | if key < n.key: 80 | # Key is smaller and needs to be inserted on the left 81 | # The new left subtree is the current left side with the new key added 82 | n.left = RedBlackTree.put_rec(n.left, key) 83 | elif key > n.key: 84 | # Key is larger and needs to be inserted on the right 85 | # The new right subtree is the current right side with the new key added 86 | n.right = RedBlackTree.put_rec(n.right, key) 87 | # Perform rotations 88 | if RedBlackTree.is_red(n.right) and not RedBlackTree.is_red(n.left): 89 | # If the right link is red, and the left link is black: 90 | # rotate the red link to the left and reset the current node as the rotated link. 91 | n = RedBlackTree.rotate_left(n) 92 | if RedBlackTree.is_red(n.left) and RedBlackTree.is_red(n.left.left): 93 | # If the left tree contains to following red links: 94 | # rotate right so that the middle node become the top node. 95 | n = RedBlackTree.rotate_right(n) 96 | if RedBlackTree.is_red(n.left) and RedBlackTree.is_red(n.right): 97 | # Flip colors if both left and right links are red. 98 | RedBlackTree.flip_colors(n) 99 | # return the updated node 100 | return n 101 | 102 | @staticmethod 103 | def is_red(n): 104 | if n is None: 105 | return False 106 | return n.is_red 107 | 108 | @staticmethod 109 | def rotate_left(n): 110 | """ 111 | Right red link of node n needs to be rotated to the left. 112 | """ 113 | # get the node to the right that has the red color 114 | r = n.right 115 | if not r.is_red: 116 | raise ValueError('Red link expected on the right during rotate_left') 117 | # Put the left side of r (right side of n) on the right side of n 118 | # Move the left side of r through the red link to n 119 | n.right = r.left 120 | # The left side of r becomes n, r is tilted above n 121 | r.left = n 122 | # r will take the color of n 123 | r.is_red = n.is_red 124 | # n become red, the color of r 125 | n.is_red = True 126 | # return r as the new node that should be linked to the parent of n 127 | return r 128 | 129 | @staticmethod 130 | def rotate_right(n): 131 | """ 132 | Left red link of node n needs to be rotated to the right. 133 | """ 134 | # get the node to the lef that has the red color 135 | l = n.left 136 | if not l.is_red: 137 | raise ValueError('Red link expected on the left during rotate_right') 138 | # Put the right side of l (left side of n) on the left side of n 139 | # Move the right side of l through the red link to n 140 | n.left = l.right 141 | # The right side of l becomes n, l is tilted above n 142 | l.right = n 143 | # l will take the color of n 144 | l.is_red = n.is_red 145 | # n become red, the color of l 146 | l.is_red = True 147 | # return l as the new node that should be linked to the parent of n 148 | return l 149 | 150 | @staticmethod 151 | def flip_colors(n): 152 | """ 153 | Flip colors of children to black, and color if this node to red 154 | """ 155 | n.is_red = True 156 | n.left.is_red = False 157 | n.right.is_red = False 158 | 159 | def min(self): 160 | if self.root is None: 161 | return None 162 | return self.root.min() 163 | 164 | def max(self): 165 | if self.root is None: 166 | return None 167 | return self.root.max() 168 | 169 | def __str__(self): 170 | """ 171 | Represent the tree as a string. 172 | """ 173 | return str(self.root) -------------------------------------------------------------------------------- /cs_algorithms/tree_traversal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tree traversal algorithms implementations\n", 8 | "\n", 9 | "Follow examples from wikipedia: http://en.wikipedia.org/wiki/Tree_traversal" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "from __future__ import print_function\n", 22 | "from collections import deque" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "tree = ('F',\n", 32 | " ('B',\n", 33 | " ('A', None, None),\n", 34 | " ('D',\n", 35 | " ('C', None, None),\n", 36 | " ('E', None, None))),\n", 37 | " ('G', \n", 38 | " None, \n", 39 | " ('I', \n", 40 | " ('H', None, None),\n", 41 | " None)))" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "F, B, A, D, C, E, G, I, H, " 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "# Print Pre-order: F, B, A, D, C, E, G, I, H\n", 59 | "\n", 60 | "def print_pre_order(tree):\n", 61 | " if tree: # if not None\n", 62 | " print(tree[0], end=\", \") # Print current node\n", 63 | " print_pre_order(tree[1]) # Print left\n", 64 | " print_pre_order(tree[2]) # Print right\n", 65 | " \n", 66 | "print_pre_order(tree)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "A, B, C, D, E, F, G, H, I, " 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# Print In-order: A, B, C, D, E, F, G, H, I\n", 84 | "\n", 85 | "def print_in_order(tree):\n", 86 | " if tree: # if not None\n", 87 | " print_in_order(tree[1]) # Print left\n", 88 | " print(tree[0], end=\", \") # Print current node\n", 89 | " print_in_order(tree[2]) # Print right\n", 90 | " \n", 91 | "print_in_order(tree)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "A, C, E, D, B, H, I, G, F, " 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "# Print Post-order: A, C, E, D, B, H, I, G, F\n", 109 | "\n", 110 | "def print_post_order(tree):\n", 111 | " if tree: # if not None\n", 112 | " print_post_order(tree[1]) # Print left\n", 113 | " print_post_order(tree[2]) # Print right\n", 114 | " print(tree[0], end=\", \") # Print current node\n", 115 | " \n", 116 | "print_post_order(tree)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "F, B, G, A, D, I, C, E, H, " 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# Print Level-order: F, B, G, A, D, I, C, E, H\n", 134 | "\n", 135 | "def print_queue(queue):\n", 136 | " node = queue.popleft()\n", 137 | " if node: # if not None\n", 138 | " print(node[0], end=\", \") # print first node in queue\n", 139 | " # add rest to queue\n", 140 | " queue.append(node[1])\n", 141 | " queue.append(node[2])\n", 142 | "\n", 143 | "def print_level_order(tree):\n", 144 | " queue = deque() # Queue to hold the bread first search\n", 145 | " queue.append(tree)\n", 146 | " while queue:\n", 147 | " print_queue(queue)\n", 148 | " \n", 149 | "print_level_order(tree)" 150 | ] 151 | } 152 | ], 153 | "metadata": { 154 | "kernelspec": { 155 | "display_name": "Python 3", 156 | "language": "python", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.6.4" 170 | } 171 | }, 172 | "nbformat": 4, 173 | "nbformat_minor": 1 174 | } 175 | -------------------------------------------------------------------------------- /dataset_specific/FAA/README.md: -------------------------------------------------------------------------------- 1 | # FAA data: Predict delays 2 | 3 | Predicting flight delays with regression methods on FAA data. 4 | 5 | Linear regression + XGBoost regression run on 4M-5M datapoints using sparse variables. 6 | -------------------------------------------------------------------------------- /dataset_specific/README.md: -------------------------------------------------------------------------------- 1 | # Dataset experiments 2 | 3 | * Kaggle House prices 4 | * FAA Flight delays 5 | * 6 | -------------------------------------------------------------------------------- /dataset_specific/housing/box_cox_tests.py: -------------------------------------------------------------------------------- 1 | # Box-cox transform for all the independent variables 2 | # Not used atm 3 | 4 | def find_transform(df): 5 | """Find box-cox shift and lambda transform parameters.""" 6 | features_transform_dict = {} 7 | for feature_name in features: 8 | vals = df[feature_name].values 9 | # Ensure positive before box cox 10 | m_val = min(vals) 11 | if m_val > 0: 12 | m_val = 0 13 | shift = abs(m_val) + 1 14 | vals = vals + shift 15 | # Find box cox 16 | _, lmbda = stats.boxcox(vals) 17 | features_transform_dict[feature_name] = (shift, lmbda) 18 | return features_transform_dict 19 | 20 | def transform(df, features_transform_dict): 21 | """Tranforms box-cos according to the parameter found before.""" 22 | df = df.copy() 23 | for feature_name in features: 24 | shift, lmbda = features_transform_dict[feature_name] 25 | df[feature_name] = df[feature_name] + shift 26 | df[feature_name] = stats.boxcox(df[feature_name], lmbda=lmbda) 27 | return df 28 | 29 | 30 | # Merge all data and find transformation of features. 31 | df_all = pd.concat([df_train[features], df_test[features]], axis=0) 32 | print('df_all: ', len(df_all)) 33 | features_transform_dict = find_transform(df_all[features]) 34 | -------------------------------------------------------------------------------- /dataset_specific/housing/preprocess.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | 4 | 5 | def combine_categories(df, col1, col2, name): 6 | """Combine categories if a row can have multiple categories of a certain type.""" 7 | # Find unique categories 8 | uniques = set(pd.unique(df[col1])) | set(pd.unique(df[col2])) 9 | # Merge different columns 10 | all_dummies = pd.get_dummies(df[[col1, col2]], dummy_na=True) 11 | for condition in uniques: 12 | if type(condition) == float and math.isnan(condition): 13 | continue 14 | c1 = col1 + '_' + condition 15 | c2 = col2 + '_' + condition 16 | c_combined = name + condition 17 | if c1 in all_dummies and c2 in all_dummies: 18 | df[c_combined] = all_dummies[c1] | all_dummies[c2] 19 | elif c1 in all_dummies: 20 | df[c_combined] = all_dummies[c1] 21 | elif c2 in all_dummies: 22 | df[c_combined] = all_dummies[c2] 23 | del df[col1] 24 | del df[col2] 25 | 26 | 27 | def preprocess(df, columns_needed=None): 28 | if columns_needed is None: 29 | columns_needed = [] 30 | 31 | ### MSSubClass is integer but should be categorical (integer values don't have meaning) 32 | df['MSSubClass'] = df['MSSubClass'].astype('int').astype('category') 33 | 34 | # Alley has NaN variable that actually have meaning 35 | df['Alley'].fillna('NoAlley', inplace=True) 36 | assert df['Alley'].notnull().all() 37 | 38 | # LotShape is an ordinal variable 39 | assert df['LotShape'].notnull().all() 40 | df['LotShape'].replace({'Reg': 0, 'IR1': 1, 'IR2': 2, 'IR3': 3}, inplace=True) 41 | 42 | # Utilities is complex categorical 43 | df['Utilities_Electricity'] = df['Utilities'].apply( 44 | lambda x: 1 if x in ['ELO', 'NoSeWa', 'NoSewr', 'AllPub'] else 0) 45 | df['Utilities_Gas'] = df['Utilities'].apply( 46 | lambda x: 1 if x in ['NoSeWa', 'NoSewr', 'AllPub'] else 0) 47 | df['Utilities_Water'] = df['Utilities'].apply( 48 | lambda x: 1 if x in ['NoSewr', 'AllPub'] else 0) 49 | df['Utilities_SepticTank'] = df['Utilities'].apply( 50 | lambda x: 1 if x in ['AllPub'] else 0) 51 | del df['Utilities'] 52 | 53 | # LandSlope is ordinal 54 | assert df['LandSlope'].notnull().all() 55 | df['LandSlope'].replace({'Gtl': 0, 'Mod': 1, 'Sev': 2}, inplace=True) 56 | 57 | # Neighborhood is a categorical 58 | assert df['Neighborhood'].notnull().all() 59 | df['Neighborhood'] = df['Neighborhood'].astype('category') 60 | 61 | # Condition1 and Condition2 are similar categoricals 62 | combine_categories(df, 'Condition1', 'Condition2', 'Condition') 63 | 64 | # Exterior1st and Exterior2nd are similar categoricals 65 | combine_categories(df, 'Exterior1st', 'Exterior2nd', 'Exterior') 66 | 67 | # ExterQual is an ordinal variable 68 | df['ExterQual'].fillna(-1, inplace=True) 69 | assert df['ExterQual'].notnull().all() 70 | df['ExterQual'].replace({'Po': 0, 'Fa': 1, 'TA': 2, 'Gd': 3, 'Ex': 4}, inplace=True) 71 | 72 | # ExterCond is an ordinal variable 73 | df['ExterCond'].fillna(-1, inplace=True) 74 | assert df['ExterCond'].notnull().all() 75 | df['ExterCond'].replace({'Po': 0, 'Fa': 1, 'TA': 2, 'Gd': 3, 'Ex': 4}, inplace=True) 76 | 77 | # BsmtQual is an ordinal variable 78 | df['BsmtQual'].fillna('NA', inplace=True) 79 | assert df['BsmtQual'].notnull().all() 80 | df['BsmtQual'].replace({'NA':0 , 'Po': 1, 'Fa': 2, 'TA': 3, 'Gd': 4, 'Ex': 5}, inplace=True) 81 | 82 | # BsmtCond is an ordinal variable 83 | df['BsmtCond'].fillna('NA', inplace=True) 84 | assert df['BsmtCond'].notnull().all() 85 | df['BsmtCond'].replace({'NA':0 , 'Po': 1, 'Fa': 2, 'TA': 3, 'Gd': 4, 'Ex': 5}, inplace=True) 86 | 87 | # BsmtExposure is an ordinal variable 88 | df['BsmtExposure'].fillna('NA', inplace=True) 89 | assert df['BsmtExposure'].notnull().all() 90 | df['BsmtExposure'].replace({'NA':0 , 'No': 1, 'Mn': 2, 'Av': 3, 'Gd': 4}, inplace=True) 91 | 92 | # BsmtFinType1 is an ordinal variable 93 | df['BsmtFinType1'].fillna('NA', inplace=True) 94 | assert df['BsmtFinType1'].notnull().all() 95 | df['BsmtFinType1'].replace({'NA':0 , 'Unf': 1, 'LwQ': 2, 'Rec': 3, 'BLQ': 4, 'ALQ': 5, 'GLQ':6}, inplace=True) 96 | 97 | # BsmtFinType2 is an ordinal variable 98 | df['BsmtFinType2'].fillna('NA', inplace=True) 99 | assert df['BsmtFinType2'].notnull().all() 100 | df['BsmtFinType2'].replace({'NA':0 , 'Unf': 1, 'LwQ': 2, 'Rec': 3, 'BLQ': 4, 'ALQ': 5, 'GLQ':6}, inplace=True) 101 | 102 | # HeatingQC is an ordinal variable 103 | df['HeatingQC'].fillna(-1, inplace=True) 104 | assert df['HeatingQC'].notnull().all() 105 | df['HeatingQC'].replace({'Po': 0, 'Fa': 1, 'TA': 2, 'Gd': 3, 'Ex': 4}, inplace=True) 106 | 107 | # CentralAir is a binary variable 108 | df['CentralAir'].fillna(-1, inplace=True) 109 | assert df['CentralAir'].notnull().all() 110 | df['CentralAir'].replace({'N': 0, 'Y': 1}, inplace=True) 111 | 112 | # KitchenQual is an ordinal variable 113 | df['KitchenQual'].fillna(-1, inplace=True) 114 | assert df['KitchenQual'].notnull().all() 115 | df['KitchenQual'].replace({'Po': 0, 'Fa': 1, 'TA': 2, 'Gd': 3, 'Ex': 4}, inplace=True) 116 | 117 | # Functional is an ordinal variable 118 | df['Functional'].fillna(-1, inplace=True) 119 | assert df['Functional'].notnull().all() 120 | df['Functional'].replace( 121 | {'Sal': 0, 'Sev': 1, 'Maj2': 2, 'Maj1': 3, 'Mod': 4, 'Min2': 5, 'Min1': 6, 'Typ': 7}, inplace=True) 122 | 123 | # FireplaceQu is an ordinal variable 124 | df['FireplaceQu'].fillna('NA', inplace=True) 125 | assert df['FireplaceQu'].notnull().all() 126 | df['FireplaceQu'].replace({'NA':0 , 'Po': 1, 'Fa': 2, 'TA': 3, 'Gd': 4, 'Ex': 5}, inplace=True) 127 | 128 | # GarageFinish is an ordinal variable 129 | df['GarageFinish'].fillna('NA', inplace=True) 130 | assert df['GarageFinish'].notnull().all() 131 | df['GarageFinish'].replace({'NA': 0, 'Unf': 1, 'RFn': 2, 'Fin': 3}, inplace=True) 132 | 133 | # FireplaceQu is an ordinal variable 134 | df['GarageQual'].fillna('NA', inplace=True) 135 | assert df['GarageQual'].notnull().all() 136 | df['GarageQual'].replace({'NA':0 , 'Po': 1, 'Fa': 2, 'TA': 3, 'Gd': 4, 'Ex': 5}, inplace=True) 137 | 138 | # GarageCond is an ordinal variable 139 | df['GarageCond'].fillna('NA', inplace=True) 140 | assert df['GarageCond'].notnull().all() 141 | df['GarageCond'].replace({'NA':0 , 'Po': 1, 'Fa': 2, 'TA': 3, 'Gd': 4, 'Ex': 5}, inplace=True) 142 | 143 | # CentralAir is an ordinal variable 144 | df['PavedDrive'].fillna(-1, inplace=True) 145 | assert df['PavedDrive'].notnull().all() 146 | df['PavedDrive'].replace({'N': 0, 'P': 1, 'Y': 2}, inplace=True) 147 | 148 | # PoolQC is an ordinal variable 149 | df['PoolQC'].fillna('NA', inplace=True) 150 | assert df['PoolQC'].notnull().all() 151 | df['PoolQC'].replace({'NA':0 , 'Fa': 1, 'TA': 2, 'Gd': 3, 'Ex': 4}, inplace=True) 152 | 153 | # Fence is an ordinal variable 154 | df['Fence'].fillna('NA', inplace=True) 155 | assert df['Fence'].notnull().all() 156 | df['Fence'].replace({'NA': 0, 'MnWw': 1, 'GdWo': 2, 'MnPrv': 3, 'GdPrv': 4}, inplace=True) 157 | 158 | # Combine YrSold and MoSold into more or less continous variable 159 | df['YrSold'] = df['YrSold'] + (df['MoSold'] - 1) / 12. 160 | # Still convert MoSold to categorical to keep seasonality effect 161 | # (note: to fully do this it should be encoded into a circular 2D variable) 162 | df['MoSold'] = df['MoSold'].astype('int').astype('category') 163 | 164 | # Assume that LotFrontage==NaN means there is no street connected directly 165 | df['LotFrontage'].fillna(0, inplace=True) 166 | 167 | # No veneer area when veneer is not present 168 | df['MasVnrArea'].fillna(0, inplace=True) 169 | 170 | # No Garage means no year build 171 | df['GarageYrBlt'].fillna(0, inplace=True) 172 | 173 | # No Basement 174 | df['BsmtFinSF1'].fillna(0, inplace=True) 175 | df['BsmtFinSF2'].fillna(0, inplace=True) 176 | df['BsmtUnfSF'].fillna(0, inplace=True) 177 | df['TotalBsmtSF'].fillna(0, inplace=True) 178 | df['BsmtFullBath'].fillna(0, inplace=True) 179 | df['BsmtHalfBath'].fillna(0, inplace=True) 180 | # No Garage 181 | df['GarageCars'].fillna(0, inplace=True) 182 | df['GarageArea'].fillna(0, inplace=True) 183 | 184 | df = pd.get_dummies(df, dummy_na=True) 185 | 186 | missing_columns = set(columns_needed) - set(df.columns) 187 | if missing_columns: 188 | print('Columns {} are missing, adding them.'.format(missing_columns)) 189 | for col in missing_columns: 190 | df[col] = 0 191 | 192 | assert df.notnull().all().all(), 'Nan s in {}'.format( 193 | df.columns[df.isnull().any()].tolist()) 194 | return df 195 | -------------------------------------------------------------------------------- /ml_algorithms/BayesianOptimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Bayesian Optimization\n", 8 | "\n", 9 | "- https://www.cs.ubc.ca/~nando/540-2013/lectures/l7.pdf\n", 10 | "- https://www.youtube.com/watch?v=vz3D36VXefI&index=10&list=PLE6Wd9FR--EdyJ5lbFl8UuGjecvVw66F6\n", 11 | "- https://www.iro.umontreal.ca/~bengioy/cifar/NCAP2014-summerschool/slides/Ryan_adams_140814_bayesopt_ncap.pdf\n", 12 | "- http://gpss.cc/gpmc17/slides/LancasterMasterclass_1.pdf\n", 13 | "- https://www.cs.ox.ac.uk/people/nando.defreitas/publications/BayesOptLoop.pdf\n", 14 | "- https://papers.nips.cc/paper/4522-practical-bayesian-optimization-of-machine-learning-algorithms.pdf\n", 15 | "- https://www.coursera.org/lecture/bayesian-methods-in-machine-learning/bayesian-optimization-iRLaF\n", 16 | "- https://towardsdatascience.com/shallow-understanding-on-bayesian-optimization-324b6c1f7083\n", 17 | "- https://github.com/fmfn/BayesianOptimization\n" 18 | ] 19 | } 20 | ], 21 | "metadata": { 22 | "kernelspec": { 23 | "display_name": "Python 3", 24 | "language": "python", 25 | "name": "python3" 26 | }, 27 | "language_info": { 28 | "codemirror_mode": { 29 | "name": "ipython", 30 | "version": 3 31 | }, 32 | "file_extension": ".py", 33 | "mimetype": "text/x-python", 34 | "name": "python", 35 | "nbconvert_exporter": "python", 36 | "pygments_lexer": "ipython3", 37 | "version": "3.7.3" 38 | } 39 | }, 40 | "nbformat": 4, 41 | "nbformat_minor": 2 42 | } 43 | -------------------------------------------------------------------------------- /ml_algorithms/NeuralNetInformationTheory.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Occams Razor\n", 8 | "\n", 9 | "## Neural networks, Occam's razor, and heuristic search [Blogpost]\n", 10 | "\n", 11 | "https://dselsam.github.io/neural-networks-occams-razor/\n", 12 | "\n", 13 | "#### Quotes\n", 14 | "\n", 15 | "\"Neural networks sacrifice Occam’s razor in exchange for the ability to efficiently find parameters that fit the training data.\"\n", 16 | "\n", 17 | "\"In general, the best way to generalize from a finite set of training data involves finding the shortest computer program that (efficiently) explains the training data.\"\n", 18 | "\n", 19 | "\"the effects of such regularization are subtle and do not correspond to what we really want, which is to prefer parameters that behave like short computer programs. Indeed, l2 regularization does not help learn the identity function in the example above.\"" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Occam's Razor [Paper]\n", 27 | "\n", 28 | "https://papers.nips.cc/paper/1925-occams-razor.pdf\n", 29 | "\n", 30 | "Paper focusses on Occam's Razor for Bayesian models. It illustrated that in a Bayesian approach, Occam's razor is always at work.\n", 31 | "\n", 32 | "#### Quotes\n", 33 | "\n", 34 | "\"One might think that one has to build a prior over models which explicitly favours simpler models. But as we will see, Occam's Razor is in fact embodied in the application of Bayesian theory.\"\n", 35 | "\n", 36 | "\"From a non-Bayesian perspective, arguments are put forward for adjusting model complexity in the light of limited training data, to avoid over-fitting.\"\n", 37 | "\n", 38 | "\"If the model complexity is either too low or too high performance on an independent test set will suffer, giving rise to a characteristic Occam's Hill. Typically an estimator of the generalization error or an independent validation set is used to control the model complexity.\"\n", 39 | "\n", 40 | "\"Typically an estimator of the generalization error or an independent validation set is used to control the model complexity.\"\n", 41 | "\n", 42 | "\"One of the central quantities in Bayesian learning is the evidence, the probability of the data given the model P(Y | Mi) computed as the integral over the parameters W of the likelihood times the prior.\"\n", 43 | "\n", 44 | "\"In non-parametric Bayesian models there is no statistical reason to constrain models, as\n", 45 | "long as our prior reflects our beliefs.\"\n", 46 | "\n", 47 | "\"The ratio of prior to posterior volumes is the Occam Factor, which may be interpreted as a penalty to pay for fitting parameters.\"" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Sensitivity and Generalization in Neural Networks: an Empirical Study [Paper]\n", 55 | "\n", 56 | "https://arxiv.org/abs/1802.08760\n", 57 | "\n", 58 | "* Literatature study on complexity metrics in neural networks and measuring generalizatoin\n", 59 | "* Measure two aspects of sensitivity\n", 60 | " 1. How does the output of the network change as the input is perturbed within the linear region? -> Done with the help of \"Jacobian norm\" that measures the sensitivity of the class probabilities with respect to the samples of interest. This is done by calculating the average gradient of the predicted class for a set of test samples.\n", 61 | " 2. How likely is the linear region to change in response to change in the input? -> Done by using piecewise non-linearities like relu and assigning a code to each linear piece. A concatenation of all codes in the neural network describes the linear region of an input. Then different trajectories are defined. The transitions of linear regions in these trajectories are then measured.\n", 62 | "\n", 63 | "They show that the best achieved generalization is not necessarily the simplest model.\n", 64 | "\n", 65 | "Intersting paper!\n", 66 | "\n", 67 | "\n", 68 | "#### Quotes\n", 69 | "\n", 70 | "\"we investigate this tension between complexity and generalization through an extensive empirical exploration of two natural metrics of complexity related to sensitivity to input perturbations\"\n", 71 | "\n", 72 | "\"We find that trained neural networks are more robust to input perturbations in the vicinity of the training data manifold, as measured by the norm of the input-output Jacobian of the network, and that it correlates well with generalization.\"\n", 73 | "\n", 74 | "\"not only do large networks demonstrate good test performance, but larger networks often generalize better, counter to what would be expected from classical measures, such as VC dimension\"\n", 75 | "\n", 76 | "\"the norm of the input-output Jacobian, correlates with generalization in a very wide variety of scenarios\"\n", 77 | "\n", 78 | "\"Then we can measure two aspects of sensitivity by answering\n", 79 | "1. How does the output of the network change as the input is perturbed within the linear region?\n", 80 | "2. How likely is the linear region to change in response to change in the input?\"\n", 81 | "\n", 82 | "\"We find that, according to both the Jacobian norm and transitions metrics, functions exhibit much more robust behavior around the training data\"\n", 83 | "\n", 84 | "\"We conjecture large networks to have access to a larger space of robust solutions due to solving a highly-underdetermined system when fitting a dataset, while small models converge to more extreme weight values due to being overconstrained by the data.\"" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.7.3" 105 | } 106 | }, 107 | "nbformat": 4, 108 | "nbformat_minor": 2 109 | } 110 | -------------------------------------------------------------------------------- /ml_algorithms/ToRead.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# To Read" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "\n", 15 | "## GANS\n", 16 | "\n", 17 | "\"The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency.\"\n", 18 | "\"In the space of arbitraryfunctions G and D, a unique solution exists, with G recovering the training data distribution and D equal to 1/2 everywhere.\"\n", 19 | "\"The generator G implicitly defines a probability distribution pg as the distribution of the samples G(z) obtained when z ∼ pz.\" - Generative Adversarial Nets paper: https://arxiv.org/pdf/1406.2661.pdf\n", 20 | "\n", 21 | "\n", 22 | "- \n", 23 | "[Generalization and Equilibrium in Generative Adversarial Nets (GANs)](https://arxiv.org/abs/1703.00573)\n", 24 | "- [Do GANs actually learn the distribution? An empirical study](https://arxiv.org/abs/1706.08224)\n", 25 | "- Presentation: [Understanding Generative Adversarial Networks](http://www.gatsby.ucl.ac.uk/~balaji/Understanding-GANs.pdf)\n", 26 | "- [Adversarially Learned Inference](https://ishmaelbelghazi.github.io/ALI/)\n", 27 | "- [GAN Applications](https://github.com/nashory/gans-awesome-applications)\n", 28 | "- [From GAN to WGAN](https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html)\n", 29 | "- [Generative Adversarial Nets in TensorFlow](https://blog.evjang.com/2016/06/generative-adversarial-nets-in.html)\n", 30 | "- [gan_simple_nb.ipynb](https://github.com/AYLIEN/jupyter-notebooks/blob/master/research/gan/gan_simple_nb.ipynb)\n", 31 | "- [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks)\n", 32 | "- [Building a simple Generative Adversarial Network (GAN) using TensorFlow](https://blog.paperspace.com/implementing-gans-in-tensorflow/)\n", 33 | "- [GAN-1D-distribution-fitting](https://github.com/jankrepl/GAN-1D-distribution-fitting)\n", 34 | "- An Intuitive Guide to Optimal Transport [part 1](https://www.mindcodec.com/an-intuitive-guide-to-optimal-transport-for-machine-learning/) [part 2](https://www.mindcodec.com/an-intuitive-guide-to-optimal-transport-part-ii-the-wasserstein-gan-made-easy/)\n", 35 | "- [GANs are Broken in More than One Way: The Numerics of GANs](https://www.inference.vc/my-notes-on-the-numerics-of-gans/)\n", 36 | "- [GAN — Why it is so hard to train Generative Adversarial Networks!](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)\n", 37 | "- [Mode collapse in GANs](http://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)\n", 38 | "- [Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch)](https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f)\n", 39 | "- [A Primer on Optimal Transport](https://nips.cc/Conferences/2017/Schedule?showEvent=8736)\n", 40 | "- https://www.youtube.com/watch?v=31mqB4yGgQY&t=1s&index=73&list=WL" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Batch norm\n", 48 | "- [An Intuitive Explanation of Why Batch Normalization Really Works (Normalization in Deep Learning Part 1)](http://mlexplained.com/2018/01/10/an-intuitive-explanation-of-why-batch-normalization-really-works-normalization-in-deep-learning-part-1/)\n", 49 | "- [Intuit and Implement: Batch Normalization](https://towardsdatascience.com/intuit-and-implement-batch-normalization-c05480333c5b)\n", 50 | "- [\n", 51 | "How Does Batch Normalization Help Optimization? (No, It Is Not About Internal Covariate Shift)](https://arxiv.org/abs/1805.11604)\n", 52 | "- [Batch normalization in Neural Networks](https://towardsdatascience.com/batch-normalization-in-neural-networks-1ac91516821c)\n", 53 | "- [Batch Normalization — What the hey?\n", 54 | "](https://gab41.lab41.org/batch-normalization-what-the-hey-d480039a9e3b)\n", 55 | "- [Understanding the backward pass through Batch Normalization Layer](http://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html)\n", 56 | "- [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)\n", 57 | "- [Intro to optimization in deep learning: Busting the myth about batch normalization](https://blog.paperspace.com/busting-the-myths-about-batch-normalization/)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Variational inference\n", 65 | "\n", 66 | "- https://www.inference.vc/variational-inference-with-implicit-probabilistic-models-part-1-2/\n", 67 | "- http://pyro.ai/examples/svi_part_i.html\n", 68 | "- https://lingpipe-blog.com/2013/03/25/mean-field-variational-inference-made-easy/\n", 69 | "- https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf\n", 70 | "- https://www.cs.jhu.edu/~jason/tutorials/variational.html\n", 71 | "- http://www.robots.ox.ac.uk/~sjrob/Pubs/vbTutorialFinal.pdf\n", 72 | "- https://rpubs.com/cakapourani/variational-bayes-gmm\n", 73 | "- https://zhiyzuo.github.io/VI/\n", 74 | "- https://xyang35.github.io/2017/04/14/variational-lower-bound/\n", 75 | "- http://bjlkeng.github.io/posts/variational-bayes-and-the-mean-field-approximation/\n", 76 | "- http://kronosapiens.github.io/blog/2015/11/22/understanding-variational-inference.html\n", 77 | "- https://blog.evjang.com/2016/08/variational-bayes.html\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Misc\n", 85 | "\n", 86 | "- [Neural Processes as distributions over functions](https://kasparmartens.rbind.io/post/np/)\n", 87 | "- [Hamiltonian Descent Methods](https://arxiv.org/abs/1809.05042)" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "kernelspec": { 93 | "display_name": "Python 3", 94 | "language": "python", 95 | "name": "python3" 96 | }, 97 | "language_info": { 98 | "codemirror_mode": { 99 | "name": "ipython", 100 | "version": 3 101 | }, 102 | "file_extension": ".py", 103 | "mimetype": "text/x-python", 104 | "name": "python", 105 | "nbconvert_exporter": "python", 106 | "pygments_lexer": "ipython3", 107 | "version": "3.7.3" 108 | } 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /ml_algorithms/VI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Variational inference\n", 8 | "\n", 9 | "- https://en.wikipedia.org/wiki/Variational_Bayesian_methods\n", 10 | "- http://www.cs.jhu.edu/~jason/tutorials/variational.html\n", 11 | "- http://www.robots.ox.ac.uk/~sjrob/Pubs/fox_vbtut.pdf\n", 12 | "- https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf\n", 13 | "- https://ermongroup.github.io/cs228-notes/inference/variational/\n", 14 | "- https://arxiv.org/pdf/1601.00670.pdf\n", 15 | "- http://edwardlib.org/tutorials/variational-inference\n", 16 | "- http://pyro.ai/examples/svi_part_i.html\n", 17 | "- http://www.stat.columbia.edu/~gelman/research/unpublished/bbvb.pdf" 18 | ] 19 | } 20 | ], 21 | "metadata": { 22 | "kernelspec": { 23 | "display_name": "Python 3", 24 | "language": "python", 25 | "name": "python3" 26 | }, 27 | "language_info": { 28 | "codemirror_mode": { 29 | "name": "ipython", 30 | "version": 3 31 | }, 32 | "file_extension": ".py", 33 | "mimetype": "text/x-python", 34 | "name": "python", 35 | "nbconvert_exporter": "python", 36 | "pygments_lexer": "ipython3", 37 | "version": "3.7.0" 38 | } 39 | }, 40 | "nbformat": 4, 41 | "nbformat_minor": 2 42 | } 43 | -------------------------------------------------------------------------------- /ml_algorithms/WGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Wasserstein GAN" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "* https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html\n", 15 | "* https://vincentherrmann.github.io/blog/wasserstein/\n", 16 | "* https://www.alexirpan.com/2017/02/22/wasserstein-gan.html\n", 17 | "* https://medium.com/@jonathan_hui/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490\n", 18 | "* https://mindcodec.ai/2018/09/19/an-intuitive-guide-to-optimal-transport-part-i-formulating-the-problem/\n", 19 | "* https://mindcodec.ai/2018/09/23/an-intuitive-guide-to-optimal-transport-part-ii-the-wasserstein-gan-made-easy/\n", 20 | "* https://arxiv.org/abs/1701.07875\n", 21 | "* https://arxiv.org/abs/1704.00028\n", 22 | "* https://github.com/martinarjovsky/WassersteinGAN\n", 23 | "* https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/\n", 24 | "* https://paper.dropbox.com/doc/Wasserstein-GAN--AcUxXxcIZ3w7oqWmSn69D98vAg-GvU0p2V9ThzdwY3BbhoP7\n", 25 | "* https://github.com/Yangyangii/GAN-Tutorial/tree/master/Notebooks\n", 26 | "* https://nips.cc/Conferences/2017/Schedule?showEvent=8736" 27 | ] 28 | } 29 | ], 30 | "metadata": { 31 | "kernelspec": { 32 | "display_name": "Python 3", 33 | "language": "python", 34 | "name": "python3" 35 | }, 36 | "language_info": { 37 | "codemirror_mode": { 38 | "name": "ipython", 39 | "version": 3 40 | }, 41 | "file_extension": ".py", 42 | "mimetype": "text/x-python", 43 | "name": "python", 44 | "nbconvert_exporter": "python", 45 | "pygments_lexer": "ipython3", 46 | "version": "3.7.3" 47 | } 48 | }, 49 | "nbformat": 4, 50 | "nbformat_minor": 2 51 | } 52 | -------------------------------------------------------------------------------- /ml_algorithms/diffusion/fm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\"\n", 10 | "Minimal PyTorch implementation of Flow Matching with Optimal Transport (OT) for generative modeling, based on \"Flow Matching for Generative Modeling\" by Lipman et al. (2023).\n", 11 | "\n", 12 | "This code provides a simplified example using the half-moons dataset and Euler method for sampling.\n", 13 | "\"\"\"\n", 14 | "\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "import torch.optim as optim\n", 20 | "from sklearn.datasets import make_moons\n", 21 | "from tqdm import tqdm\n", 22 | "from typing import Tuple\n", 23 | "\n", 24 | "# Device configuration\n", 25 | "device: torch.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 26 | "\n", 27 | "# Hyperparameters\n", 28 | "data_size: int = 10_000\n", 29 | "batch_size: int = 256\n", 30 | "input_dim: int = 2\n", 31 | "hidden_dim: int = 64\n", 32 | "n_iterations: int = 10_000 # Total number of iterations.\n", 33 | "sigma_min: float = (\n", 34 | " 1e-5 # For OT path (Eq. 22). Controls variance at t=1 (should be small).\n", 35 | ")\n", 36 | "lr: float = 1e-3\n", 37 | "\n", 38 | "# Data #######################################################################\n", 39 | "# Generate half moon data as toy dataset\n", 40 | "data, _ = make_moons(data_size, noise=0.05)\n", 41 | "data = torch.tensor(data, dtype=torch.float32).to(device)\n", 42 | "data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n", 43 | "\n", 44 | "\n", 45 | "# Model ######################################################################\n", 46 | "# Define the vector field network v_t\n", 47 | "class VectorFieldNetwork(nn.Module):\n", 48 | " \"\"\"\n", 49 | " Parametric model for the time-dependent vector field v_t(x; θ) (Section 3 in Lipman et al.).\n", 50 | "\n", 51 | " This network represents the function v_t(x; θ), which is the Continuous Normalizing Flow (CNF) vector field.\n", 52 | " It is used in the Flow Matching objective (Eq. 5 in Lipman et al.).\n", 53 | " The network outputs the *value* of the vector field at a given time t and point x.\n", 54 | " \"\"\"\n", 55 | "\n", 56 | " def __init__(self, input_dim: int, hidden_dim: int):\n", 57 | " super().__init__()\n", 58 | " # MLP architecture (no specific architecture defined in the paper)\n", 59 | " self.net: nn.Sequential = nn.Sequential(\n", 60 | " nn.Linear(input_dim + 1, hidden_dim), # +1 for time embedding\n", 61 | " nn.ReLU(),\n", 62 | " nn.Linear(hidden_dim, hidden_dim),\n", 63 | " nn.ReLU(),\n", 64 | " nn.Linear(hidden_dim, input_dim),\n", 65 | " )\n", 66 | "\n", 67 | " def time_embedding(self, t: torch.Tensor) -> torch.Tensor:\n", 68 | " \"\"\"\n", 69 | " Embeds time into a higher dimensional space.\n", 70 | " \"\"\"\n", 71 | " return t.unsqueeze(-1) # Simple embedding, no frequencies used\n", 72 | "\n", 73 | " def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:\n", 74 | " \"\"\"\n", 75 | " Computes the vector field v_t(x) at time t and point x.\n", 76 | " \"\"\"\n", 77 | " t_emb: torch.Tensor = self.time_embedding(t)\n", 78 | " tx: torch.Tensor = torch.cat([t_emb, x], dim=-1)\n", 79 | " return self.net(tx) # Output is v_t(x; θ) (Eq. 5 in Lipman et al.)\n", 80 | "\n", 81 | "\n", 82 | "def ot_path(\n", 83 | " x_0: torch.Tensor,\n", 84 | " x_1: torch.Tensor,\n", 85 | " t: torch.Tensor,\n", 86 | " sigma_min: float,\n", 87 | ") -> Tuple[torch.Tensor, torch.Tensor]:\n", 88 | " \"\"\"\n", 89 | " Computes the Optimal Transport (OT) path (ψ_t) and target vector field (u_t).\n", 90 | "\n", 91 | " This function implements the conditional flow ψ_t (Eq. 22 in Lipman et al.) and the corresponding target vector field u_t (Eq. 23 in Lipman et al.) for the Optimal Transport (OT) formulation in Flow Matching. ψ_t defines a straight-line interpolation path between a sample x_0 from the prior distribution and a target data point x_1. The target vector field u_t represents the ideal vector field that the learned vector field v_t should approximate.\n", 92 | "\n", 93 | " Args:\n", 94 | " x_0: Samples from the prior distribution (typically a standard Gaussian).\n", 95 | " x_1: Target data points.\n", 96 | " t: Time points along the OT path (between 0 and 1).\n", 97 | " sigma_min: Small constant controlling the variance at t=1 (see Eq. 22).\n", 98 | "\n", 99 | " Returns:\n", 100 | " A tuple containing the OT path (ψ_t) and the target vector field (u_t).\n", 101 | " \"\"\"\n", 102 | " t = t.unsqueeze(-1)\n", 103 | " # ψ_t (Eq. 22 in Lipman et al.) Linear interpolation between prior and target.\n", 104 | " psi_t: torch.Tensor = (1 - (1 - sigma_min) * t) * x_0 + (t * x_1)\n", 105 | " # Target vector field for OT (u_t) (Eq. 23 in Lipman et al.)\n", 106 | " target_v: torch.Tensor = x_1 - (1 - sigma_min) * x_0\n", 107 | " return psi_t, target_v\n", 108 | "\n", 109 | "\n", 110 | "def compute_cfm_loss(\n", 111 | " v_net: VectorFieldNetwork,\n", 112 | " x_0: torch.Tensor,\n", 113 | " x_1: torch.Tensor,\n", 114 | " t: torch.Tensor,\n", 115 | " sigma_min: float,\n", 116 | ") -> torch.Tensor:\n", 117 | " \"\"\"\n", 118 | " Computes the Conditional Flow Matching (CFM) loss for Optimal Transport.\n", 119 | "\n", 120 | " This function computes the CFM loss (Eq. 23 in Lipman et al.) using the Optimal Transport (OT) formulation. It compares the predicted vector field v_t(ψ_t(x_0, x_1, t)) (output of the v_net) with the target vector field u_t derived from the OT path. The loss is calculated as the mean squared error (MSE) between the predicted and target vector fields. Minimizing this loss encourages the learned vector field to accurately represent the OT flow.\n", 121 | "\n", 122 | " Args:\n", 123 | " v_net: The vector field network v_t(x; θ).\n", 124 | " x_0: Samples from the prior distribution.\n", 125 | " x_1: Target data points.\n", 126 | " t: Time points along the OT path.\n", 127 | " sigma_min: Small constant controlling the variance at t=1.\n", 128 | "\n", 129 | " Returns:\n", 130 | " The CFM loss as a scalar tensor.\n", 131 | " \"\"\"\n", 132 | " psi_t, target_v = ot_path(x_0, x_1, t, sigma_min)\n", 133 | " v: torch.Tensor = v_net(t, psi_t) # v_t(ψ_t(x_0, x_1, t))\n", 134 | " # CFM loss (Eq. 23 in Lipman et al.), simplified as MSE loss\n", 135 | " loss: torch.Tensor = ((v - target_v) ** 2).mean()\n", 136 | " return loss\n", 137 | "\n", 138 | "\n", 139 | "# Training ###################################################################\n", 140 | "# Initialize the vector field network v_t and optimizer\n", 141 | "v_net = VectorFieldNetwork(input_dim, hidden_dim).to(device)\n", 142 | "optimizer = optim.Adam(v_net.parameters(), lr=lr)\n", 143 | "\n", 144 | "# Training loop\n", 145 | "progress_bar = tqdm(range(n_iterations), desc=\"Training\", unit=\"iteration\")\n", 146 | "losses = [] # List to store loss values\n", 147 | "for iteration in progress_bar:\n", 148 | " x_1 = next(iter(data_loader)) # Sample data x_1 ~ q(x_1)\n", 149 | " t = torch.rand(x_1.shape[0], device=device) # Sample time t ~ U[0, 1]\n", 150 | " x_0 = torch.randn_like(x_1) # Sample prior noise x_0 ~ N(0, I)\n", 151 | "\n", 152 | " loss = compute_cfm_loss(v_net, x_0, x_1, t, sigma_min) # CFM loss\n", 153 | "\n", 154 | " # Optimization step\n", 155 | " optimizer.zero_grad()\n", 156 | " loss.backward()\n", 157 | " optimizer.step()\n", 158 | "\n", 159 | " losses.append(loss.item()) # Store the loss value\n", 160 | " progress_bar.set_postfix({\"Loss\": f\"{loss.item():.4f}\"})\n", 161 | "\n", 162 | "# Plot loss curve after training\n", 163 | "fig, ax = plt.subplots(figsize=(10, 4))\n", 164 | "ax.plot(losses)\n", 165 | "ax.set_xlabel(\"Iteration\")\n", 166 | "ax.set_ylabel(\"Loss\")\n", 167 | "ax.set_title(\"Training Loss Curve\")\n", 168 | "ax.grid(True) # Add grid\n", 169 | "plt.show()\n", 170 | "\n", 171 | "\n", 172 | "# Sampling ###################################################################\n", 173 | "def sample(\n", 174 | " n_samples: int,\n", 175 | " vector_field: VectorFieldNetwork,\n", 176 | " dt: float = 0.01, # Step size for Euler integration\n", 177 | ") -> torch.Tensor:\n", 178 | " \"\"\"\n", 179 | " Generates samples by integrating the learned vector field.\n", 180 | "\n", 181 | " This function generates samples by numerically integrating the learned vector field `v_t(x)` from t=0 to t=1. It uses the Euler method (a simple first-order numerical integration method) to approximate the solution to the ODE defined by the vector field (Eq. 1 in Lipman et al. with a simplified discretization). The integration starts from random samples drawn from the prior distribution (typically a standard Gaussian).\n", 182 | "\n", 183 | " Args:\n", 184 | " n_samples: The number of samples to generate.\n", 185 | " vector_field: The learned vector field network v_t(x; θ).\n", 186 | " dt: The integration step size.\n", 187 | "\n", 188 | " Returns:\n", 189 | " A tensor of generated samples.\n", 190 | " \"\"\"\n", 191 | " x = torch.randn(n_samples, input_dim).to(device) # Sample x_0 ~ N(0, I)\n", 192 | " for t in torch.arange(0, 1, dt): # Euler integration from t=0 to t=1\n", 193 | " t_batch = t.expand(n_samples).to(device)\n", 194 | " x = (\n", 195 | " x + vector_field(t_batch, x) * dt\n", 196 | " ) # Euler update (Eq. 28 in Lipman et al. with simplified discretization)\n", 197 | " return x\n", 198 | "\n", 199 | "\n", 200 | "# Sample and plot generated data\n", 201 | "generated_samples = sample(data_size, v_net)\n", 202 | "generated_samples = generated_samples.cpu().detach().numpy()\n", 203 | "\n", 204 | "\n", 205 | "# Visualization ###############################################################\n", 206 | "# Plot the generated samples\n", 207 | "fig, ax = plt.subplots(figsize=(10, 6))\n", 208 | "ax.scatter(\n", 209 | " data[:, 0].cpu().numpy(), data[:, 1].cpu().numpy(), label=\"Real Data\", alpha=0.1\n", 210 | ")\n", 211 | "ax.scatter(\n", 212 | " generated_samples[:, 0], generated_samples[:, 1], label=\"Generated Data\", alpha=0.1\n", 213 | ")\n", 214 | "ax.set_title(\"Flow Matching with Optimal Transport - Generated Samples\")\n", 215 | "ax.legend()\n", 216 | "plt.show()" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "pytorch", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.11.10" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /ml_algorithms/diffusion/fm_requirements.txt: -------------------------------------------------------------------------------- 1 | ipython 2 | ipykernel 3 | # Plotting 4 | matplotlib 5 | seaborn 6 | tqdm 7 | ipywidgets 8 | # Numpy 9 | numpy 10 | scipy 11 | scikit-learn 12 | # PyTorch 13 | torch>=2.4 14 | torchaudio 15 | torchvision 16 | # CUDA Dependencies 17 | nvidia-cuda-runtime-cu12 18 | nvidia-cudnn-cu12 19 | -------------------------------------------------------------------------------- /nlp_text/README.md: -------------------------------------------------------------------------------- 1 | # Text analytics experiments 2 | 3 | Notebooks containing experiments with topic modelling, NLP, etc. 4 | -------------------------------------------------------------------------------- /probabilistic_programming/README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Programming Examples 2 | 3 | Examples of probabilistic programming using tools such as [PyMC](http://docs.pymc.io/index.html). 4 | 5 | Usefull sources: 6 | * [PyMC docs](http://docs.pymc.io/) 7 | * [PyMC3 Github](https://github.com/pymc-devs/pymc3) 8 | * [PyMC Paper](https://arxiv.org/abs/1507.08050) 9 | * [PyMC3 tutorial from Duke University](http://people.duke.edu/~ccc14/sta-663-2016/16C_PyMC3.html) 10 | * [Bayesian Methods for Hackers](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers) 11 | * 12 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/GMM-discrete.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "25b95485-4c74-4702-99e1-033029306fe6", 6 | "metadata": {}, 7 | "source": [ 8 | "# NumPyro Gaussian Mixture Model with discrete sampling\n", 9 | "\n", 10 | "This notebook illustrates how to build a GMM in NumPyro using discrete sampling from the categorical distribution.\n", 11 | "\n", 12 | "Note that in practice you probably want to usa a marginalized mixture model, as is [illustrated in PyMC3 here](https://docs.pymc.io/notebooks/marginalized_gaussian_mixture_model.html)." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "0208b2af-1e72-4cfa-a196-4df2593b16df", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import sys\n", 23 | "import warnings\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "\n", 27 | "import jax\n", 28 | "import jax.numpy as jnp\n", 29 | "\n", 30 | "import numpyro\n", 31 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 32 | "import numpyro.distributions as dist\n", 33 | "\n", 34 | "import matplotlib\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "from matplotlib import cm # Colormaps\n", 37 | "import seaborn as sns\n", 38 | "import arviz as az\n", 39 | "\n", 40 | "from tqdm import tqdm_notebook as tqdm" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "dfff55a6-cbf3-46e1-a102-1937e4a96865", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "%load_ext watermark\n", 51 | "%watermark --iversions" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "b0c9a67e-9890-485d-b4ea-8e78272eb715", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "np.random.seed(42)\n", 62 | "rng_key = jax.random.PRNGKey(42)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "27374bda-5172-4b00-ab57-aa2b9f031f84", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "np.random.seed(42)\n", 73 | "\n", 74 | "n = 2500 # Total number of samples\n", 75 | "k = 3 # Number of clusters\n", 76 | "p_real = np.array([0.2, 0.3, 0.5]) # Probability of choosing each cluster\n", 77 | "mus_real = np.array([-1., 1., 4.]) # Mu of clusters\n", 78 | "sigmas_real = np.array([0.2, 0.9, 0.5]) # Sigma of clusters\n", 79 | "clusters = np.random.choice(k, size=n, p=p_real)\n", 80 | "data = np.random.normal(mus_real[clusters], sigmas_real[clusters], size=n)\n", 81 | "\n", 82 | "print(f'{n} samples in total from {k} clusters. data: {data.shape}')\n", 83 | "sns.histplot(data, kde=True)\n", 84 | "plt.show()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "d4d92ca3-e036-4f0e-9f1e-92fff30cbab1", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def gmm_model(data, k):\n", 95 | " # Prior for cluster probabilities\n", 96 | " # Diriclet([1,1,1]) is like uniform distribution over all clusters\n", 97 | " selection_prob = numpyro.sample('selection_prob', dist.Dirichlet(concentration=jnp.ones(k)))\n", 98 | " # Prior on cluster means\n", 99 | " with numpyro.plate('k_plate', k):\n", 100 | " mu = numpyro.sample('mu', dist.Normal(loc=0., scale=10.))\n", 101 | " sigma = numpyro.sample('scale', dist.HalfCauchy(scale=10))\n", 102 | " # Data needs to have it's onwn plate due to the categorical\n", 103 | " with numpyro.plate('data', len(data)):\n", 104 | " cluster_idx = numpyro.sample('cluster_idx', dist.Categorical(selection_prob))\n", 105 | " numpyro.sample('x', dist.Normal(loc=mu[cluster_idx], scale=sigma[cluster_idx]), obs=data)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "fe4047b6-c536-4d83-a14d-3219ee3fa2e0", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "rng_key = jax.random.PRNGKey(0)\n", 116 | "\n", 117 | "num_warmup, num_samples = 1000, 2000\n", 118 | "\n", 119 | "# Run NUTS.\n", 120 | "kernel = NUTS(gmm_model)\n", 121 | "mcmc = MCMC(\n", 122 | " kernel,\n", 123 | " num_warmup=num_warmup,\n", 124 | " num_samples=num_samples,\n", 125 | ")\n", 126 | "mcmc.run(rng_key, data=data, k=k)\n", 127 | "mcmc.print_summary()\n", 128 | "posterior_samples = mcmc.get_samples()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "9c2b6819-d0a7-4e2f-9b9b-4a1edb8ff57b", 134 | "metadata": {}, 135 | "source": [ 136 | "Make sure to `infer_discrete` as [mentionned here](https://github.com/pyro-ppl/numpyro/issues/1121#issuecomment-897363003)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "5f04ce76-01fc-4f52-9512-e5888fc673d0", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "posterior_predictive = Predictive(gmm_model, posterior_samples, infer_discrete=True)\n", 147 | "posterior_predictions = posterior_predictive(rng_key, k=k, data=data)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "b834be75-b7f8-452d-9850-4b8ce75b6507", 153 | "metadata": {}, 154 | "source": [ 155 | "Make sure to ad the selected indices (discrete samples) to the MCMC samples, as is [mentionned here](https://github.com/pyro-ppl/numpyro/issues/1121#issuecomment-897363003)." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "8948599a-20a5-420c-8aa2-af013cae85cd", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# Add \"cluster_idx\" values to mcmc samples\n", 166 | "posterior_samples[\"cluster_idx\"] = posterior_predictions[\"cluster_idx\"]" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "6fb68b55-5a27-4241-beab-bf10170ba501", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "\n", 177 | "\n", 178 | "inference_data = az.from_numpyro(\n", 179 | " posterior=mcmc,\n", 180 | " posterior_predictive=posterior_predictions,\n", 181 | ")\n", 182 | "display(inference_data)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "c179084b-c131-423e-a2d1-2218447dd1d3", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "data = az.from_numpyro(mcmc)\n", 193 | "az.plot_trace(data, compact=True)\n", 194 | "plt.show()" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3 (ipykernel)", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.9.6" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 5 219 | } 220 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/GMM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "72089cf0-9420-4b89-9a88-4a3994029bbd", 6 | "metadata": {}, 7 | "source": [ 8 | "# Gaussian Mixture Model (GMM)\n", 9 | "\n", 10 | "Based on:\n", 11 | "- https://docs.pymc.io/notebooks/marginalized_gaussian_mixture_model.html\n", 12 | "\n", 13 | "NumPyro Marginalized Mixture model:\n", 14 | "- https://forum.pyro.ai/t/sample-from-the-mixture-same-family-distribution/3178/\n", 15 | "\n", 16 | "More info on marginalized mixture models:\n", 17 | "- https://www.youtube.com/watch?v=KOIudAB6vJ0\n", 18 | "- https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "ba6d44a9-1389-4e76-a155-2c76f375f2c2", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Imports\n", 29 | "%matplotlib inline\n", 30 | "%config InlineBackend.figure_format = 'svg'\n", 31 | "\n", 32 | "import sys\n", 33 | "import warnings\n", 34 | "\n", 35 | "import numpy as np\n", 36 | "\n", 37 | "import jax\n", 38 | "import jax.numpy as jnp\n", 39 | "\n", 40 | "import numpyro\n", 41 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 42 | "import numpyro.distributions as dist\n", 43 | "\n", 44 | "import matplotlib\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "from matplotlib import cm # Colormaps\n", 47 | "import seaborn as sns\n", 48 | "import arviz as az\n", 49 | "\n", 50 | "from tqdm import tqdm_notebook as tqdm" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "c3072eea-aedb-4bff-a56a-950c35dbc26c", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "sns.set_style('darkgrid')\n", 61 | "az.rcParams['stats.hdi_prob'] = 0.90\n", 62 | "az.style.use(\"arviz-darkgrid\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "1366a6f2-8eb0-4810-b796-33b5dbd04407", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "numpyro.set_platform('cpu')\n", 73 | "numpyro.set_host_device_count(8)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "b0c9a67e-9890-485d-b4ea-8e78272eb715", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "np.random.seed(42)\n", 84 | "rng_key = jax.random.PRNGKey(42)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "27374bda-5172-4b00-ab57-aa2b9f031f84", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "np.random.seed(42)\n", 95 | "\n", 96 | "n = 2500 # Total number of samples\n", 97 | "k = 3 # Number of clusters\n", 98 | "p_real = np.array([0.2, 0.3, 0.5]) # Probability of choosing each cluster\n", 99 | "mus_real = np.array([-1., 1., 4.]) # Mu of clusters\n", 100 | "sigmas_real = np.array([0.2, 0.9, 0.5]) # Sigma of clusters\n", 101 | "clusters = np.random.choice(k, size=n, p=p_real)\n", 102 | "x_data = np.random.normal(mus_real[clusters], sigmas_real[clusters], size=n)\n", 103 | "\n", 104 | "print(f'{n} samples in total from {k} clusters. x_data: {x_data.shape}')\n", 105 | "fig, ax = plt.subplots(1, 1, figsize=(6, 3))\n", 106 | "sns.histplot(x_data, kde=True, ax=ax)\n", 107 | "ax.set_xlabel('x')\n", 108 | "plt.show()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "4130d6cb-7284-4342-929e-febb974bd50b", 114 | "metadata": {}, 115 | "source": [ 116 | "A natural parameterization of the Gaussian mixture model is as the latent variable model\n", 117 | "\n", 118 | "$$\n", 119 | "\\begin{split}\\begin{align*}\n", 120 | "\\mu_1, \\ldots, \\mu_k & \\sim \\mathcal{N}(0, \\sigma^2) \\\\\n", 121 | "\\sigma_1, \\ldots, \\sigma_k & \\sim \\text{HalfCauchy}(b) \\\\\n", 122 | "w & \\sim \\text{Dirichlet}(\\alpha_1, \\ldots. \\alpha_k) \\\\\n", 123 | "z \\mid w & \\sim \\text{Categorial}(w) \\\\\n", 124 | "x \\mid z & \\sim \\mathcal{N}(\\mu_z, \\sigma_z).\n", 125 | "\\end{align*}\\end{split}\n", 126 | "$$\n", 127 | "\n", 128 | "The disadvantage of this is that that sampling the posterior relies on sampling from the discrete categorical variable $z$. And thus we need to create with a lot of different elements in order not to get stuck during sampling.\n", 129 | "\n", 130 | "An alternative is to try to marginalise out the categorical $z$ to sample from a single [mixture distribution](https://en.wikipedia.org/wiki/Mixture_distribution) at the end:\n", 131 | "\n", 132 | "$$\n", 133 | "\\begin{split}\\begin{align*}\n", 134 | "\\mu_1, \\ldots, \\mu_k & \\sim \\mathcal{N}(0, \\sigma^2) \\\\\n", 135 | "\\sigma_1, \\ldots, \\sigma_k & \\sim \\text{HalfCauchy}(b) \\\\\n", 136 | "w & \\sim \\text{Dirichlet}(\\alpha_1, \\ldots. \\alpha_k) \\\\\n", 137 | "f(x \\mid w) & = \\sum_{i = 1}^k w_i \\mathcal{N}(x \\mid \\mu_i, \\sigma_z)\n", 138 | "\\end{align*}\\end{split}\n", 139 | "$$\n", 140 | "\n", 141 | "with\n", 142 | "\n", 143 | "$$\n", 144 | "N(x \\mid \\mu, \\sigma^2) = \\frac{1}{\\sqrt{2 \\pi} \\sigma} \\exp\\left(-\\frac{1}{2 \\sigma^2} (x - \\mu)^2\\right)\n", 145 | "$$" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "id": "d729f68e-cd1a-472a-bbb2-544b2c52340c", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "class MixtureGaussian(dist.Distribution):\n", 156 | " def __init__(self, loc, scale, mixing_probs, validate_args=None):\n", 157 | " expand_shape = jax.lax.broadcast_shapes(\n", 158 | " jnp.shape(loc), jnp.shape(scale), jnp.shape(mixing_probs)\n", 159 | " )\n", 160 | " self._gaussian = dist.Normal(loc=loc, scale=scale).expand(expand_shape)\n", 161 | " self._categorical = dist.Categorical(jnp.broadcast_to(mixing_probs, expand_shape))\n", 162 | " super(MixtureGaussian, self).__init__(batch_shape=expand_shape[:-1], validate_args=validate_args)\n", 163 | "\n", 164 | " def sample(self, key, sample_shape=()):\n", 165 | " key, key_idx = jax.random.split(key)\n", 166 | " samples = self._gaussian.sample(key, sample_shape)\n", 167 | " ind = self._categorical.sample(key_idx, sample_shape)\n", 168 | " return jnp.take_along_axis(samples, ind[..., None], -1)[..., 0]\n", 169 | "\n", 170 | " def log_prob(self, value):\n", 171 | " print(f\"\\nlog_prob(value={value.shape})\")\n", 172 | " value_reshaped = value[..., None]\n", 173 | " print(\"value_reshaped: \", value_reshaped.shape)\n", 174 | " probs_mixture = self._gaussian.log_prob(value[..., None])\n", 175 | " print(\"probs_mixture: \", probs_mixture.shape)\n", 176 | " sum_probs = self._categorical.logits + probs_mixture\n", 177 | " print(\"sum_probs: \", sum_probs.shape)\n", 178 | " lse = jax.nn.logsumexp(sum_probs, axis=-1)\n", 179 | " print(\"lse: \", lse.shape)\n", 180 | " return lse" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "6368c908-e0a2-49cb-bae5-c3f027b23b8b", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "def gmm_model(k, x=None):\n", 191 | " # Prior for cluster probabilities\n", 192 | " prob_cluster = numpyro.sample('prob_cluster', dist.Dirichlet(concentration=jnp.ones(k)))\n", 193 | " # Prior on cluster means\n", 194 | " with numpyro.plate('k_plate', k):\n", 195 | " loc = numpyro.sample('loc', dist.Normal(loc=0., scale=10.))\n", 196 | " sigma = numpyro.sample('scale', dist.HalfCauchy(scale=10))\n", 197 | " print(\"loc: \", loc.shape)\n", 198 | " print(\"sigma: \", sigma.shape)\n", 199 | " numpyro.sample('x', MixtureGaussian(loc=loc, scale=sigma, mixing_probs=prob_cluster), obs=x)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "fe4047b6-c536-4d83-a14d-3219ee3fa2e0", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "rng_key = jax.random.PRNGKey(42)\n", 210 | "\n", 211 | "num_warmup, num_samples = 1000, 2000\n", 212 | "\n", 213 | "# Run NUTS.\n", 214 | "kernel = NUTS(gmm_model)\n", 215 | "mcmc = MCMC(\n", 216 | " kernel,\n", 217 | " num_warmup=num_warmup,\n", 218 | " num_samples=num_samples,\n", 219 | ")\n", 220 | "mcmc.run(rng_key, x=x_data, k=k)\n", 221 | "mcmc.print_summary()\n", 222 | "posterior_samples = mcmc.get_samples()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "c6870be3-61ba-47ac-b35d-e72648d567e4", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "rng_key = jax.random.PRNGKey(42)\n", 233 | "\n", 234 | "x_posterior = np.linspace(min(x_data)-1, max(x_data)+1, 100)\n", 235 | "\n", 236 | "posterior_predictive = Predictive(gmm_model, posterior_samples=posterior_samples)\n", 237 | "posterior_predictions = posterior_predictive(rng_key, k=k)\n", 238 | "print('Posterior predictions: ', posterior_predictions['x'].shape)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "id": "5bdd8b6b-0ec5-4e02-9f71-5ce6e8a28eb9", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "fig, axes = plt.subplots(3, 2, figsize=(10, 6))\n", 249 | "inference_data = az.from_numpyro(\n", 250 | " posterior=mcmc,\n", 251 | " posterior_predictive=posterior_predictions,\n", 252 | " coords={\"cluster\": np.arange(k)},\n", 253 | " dims={\"loc\": [\"cluster\"], \"scale\": [\"cluster\"], \"prob_cluster\": [\"cluster\"]}\n", 254 | ")\n", 255 | "display(inference_data)\n", 256 | "\n", 257 | "az.plot_trace(inference_data, compact=True, axes=axes)\n", 258 | "plt.suptitle('Trace plots', fontsize=18)\n", 259 | "plt.show()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "id": "0acedc27-7cff-4833-acb5-49b23b0cb8ba", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "fig, axes = plt.subplots(3, 3, figsize=(12, 8))\n", 270 | "az.plot_posterior(inference_data, var_names=['loc', 'scale', 'prob_cluster'], kind='hist', ax=axes)\n", 271 | "plt.suptitle('Posterior plots', fontsize=18)\n", 272 | "plt.show()" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "f50a4b0a-dace-4b23-b20f-4a810f7669cd", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "\n", 283 | "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", 284 | "sns.histplot(x_data, kde=True, label='original', ax=ax, color=sns.color_palette(\"tab10\")[0])\n", 285 | "sns.histplot(posterior_predictions['x'], kde=True, label='posterior', ax=ax, color=sns.color_palette(\"tab10\")[1])\n", 286 | "ax.set_title(\"Posterior predictive vs Original\")\n", 287 | "ax.set_xlabel(\"x\")\n", 288 | "ax.legend()\n", 289 | "plt.show()" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3 (ipykernel)", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.9.6" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 5 314 | } 315 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/LKJ_Cholesky_Covariance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f6ee01ec-da34-433c-88e9-8172209196ae", 6 | "metadata": {}, 7 | "source": [ 8 | "# LKJ Cholesky Covariance Priors for Multivariate Normal Models\n", 9 | "\n", 10 | "Replication of the PyMC3 notebook on LKJ Cholesky Covariance priors: https://docs.pymc.io/notebooks/LKJ.html\n", 11 | "\n", 12 | "More info:\n", 13 | "- [PyMC3 LKJCholeskyCov](https://docs.pymc.io/api/distributions/multivariate.html#pymc3.distributions.multivariate.LKJCholeskyCov)\n", 14 | "- [NumPyro LKJCholesky](http://num.pyro.ai/en/stable/distributions.html#lkjcholesky)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "3b084404-e85b-4059-a74b-b37ddf4b332d", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "%matplotlib inline\n", 25 | "%config InlineBackend.figure_format='svg'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "6352ad1e-2898-47a5-839b-229ed56a3dfc", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import sys\n", 36 | "import warnings\n", 37 | "\n", 38 | "import numpy as np\n", 39 | "import scipy\n", 40 | "import scipy.stats\n", 41 | "\n", 42 | "import jax\n", 43 | "import jax.numpy as jnp\n", 44 | "\n", 45 | "import numpyro\n", 46 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 47 | "import numpyro.distributions as dist\n", 48 | "\n", 49 | "import matplotlib\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "from matplotlib import cm # Colormaps\n", 52 | "from matplotlib.patches import Ellipse\n", 53 | "import seaborn as sns\n", 54 | "import arviz as az\n", 55 | "\n", 56 | "from tqdm import tqdm_notebook as tqdm" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "7905a261-3d2d-414d-aabe-aecaa2be1a3e", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "sns.set_style('darkgrid')\n", 67 | "az.rcParams['stats.hdi_prob'] = 0.90\n", 68 | "az.style.use(\"arviz-darkgrid\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "c240ae1d-f9e1-4ed7-bba1-6ae81925cb0f", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "numpyro.set_platform('cpu')\n", 79 | "numpyro.set_host_device_count(8)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "c8e57c76-fd2e-49c7-82ea-4f5c9b4d7741", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "np.random.seed(42)\n", 90 | "rng_key = jax.random.PRNGKey(42)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "85f10991-2874-40ed-bbab-d7f3fae02693", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "blue, orange, *_ = sns.color_palette(\"tab10\")" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "135ff97d-2570-4374-8cd7-80ecc8d9b76d", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "RANDOM_SEED = 8924\n", 111 | "np.random.seed(3264602) # from random.org\n", 112 | "\n", 113 | "N = 10000\n", 114 | "\n", 115 | "μ_actual = np.array([1.0, -2.0])\n", 116 | "sigmas_actual = np.array([0.7, 1.5])\n", 117 | "Rho_actual = np.matrix([[1.0, -0.4], [-0.4, 1.0]])\n", 118 | "\n", 119 | "Σ_actual = np.diag(sigmas_actual) * Rho_actual * np.diag(sigmas_actual)\n", 120 | "\n", 121 | "x = np.random.multivariate_normal(μ_actual, Σ_actual, size=N)\n", 122 | "Σ_actual" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "21969fe2-0aa0-40da-97cc-a536c790ee81", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "var, U = np.linalg.eig(Σ_actual)\n", 133 | "angle = 180.0 / np.pi * np.arccos(np.abs(U[0, 0]))\n", 134 | "\n", 135 | "fig, ax = plt.subplots(figsize=(8, 6))\n", 136 | "\n", 137 | "blue, _, red, *_ = sns.color_palette()\n", 138 | "\n", 139 | "e = Ellipse(μ_actual, 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1]), angle=angle)\n", 140 | "e.set_alpha(0.5)\n", 141 | "e.set_facecolor(blue)\n", 142 | "e.set_zorder(10)\n", 143 | "ax.add_artist(e)\n", 144 | "\n", 145 | "ax.scatter(x[:, 0], x[:, 1], c=\"k\", alpha=0.05, zorder=11)\n", 146 | "\n", 147 | "rect = plt.Rectangle((0, 0), 1, 1, fc=blue, alpha=0.5)\n", 148 | "ax.legend([rect], [\"95% density region\"], loc=2);" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "19f4313b-b7f3-4705-9c0a-2d7654f8da21", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "def model(obs):\n", 159 | " chol_stds = numpyro.sample(\"chol_stds\", dist.Exponential(rate=jnp.ones(2)))\n", 160 | " lkj_chol = numpyro.sample(\"lkj_chol\", dist.LKJCholesky(dimension=2, concentration=2.0))\n", 161 | " chol_corr = numpyro.deterministic(\"chol_corr\", lkj_chol@lkj_chol.T)\n", 162 | " # Create cholesky matrix by scaling lkj_chol matrix with standard deviations\n", 163 | "# chol = numpyro.deterministic(\"chol\", jnp.matmul(jnp.diag(chol_stds), lkj_chol))\n", 164 | " chol = numpyro.deterministic(\"chol\", chol_stds[..., None] * lkj_chol)\n", 165 | "\n", 166 | " μ = numpyro.sample(\"μ\", dist.Normal(loc=jnp.zeros(2), scale=jnp.ones(2)*1.5))\n", 167 | " mvn = dist.MultivariateNormal(loc=μ, scale_tril=chol)\n", 168 | " cov = numpyro.deterministic(\"cov\", mvn.covariance_matrix)\n", 169 | " obs = numpyro.sample(\"obs\", mvn, obs=obs)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "26da476f-54e8-43b4-9167-a5d26cb65d22", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "rng_key = jax.random.PRNGKey(42)\n", 180 | "\n", 181 | "num_warmup, num_samples = 1000, 1000\n", 182 | "\n", 183 | "# Run NUTS.\n", 184 | "kernel = NUTS(model)\n", 185 | "mcmc = MCMC(\n", 186 | " kernel,\n", 187 | " num_warmup=num_warmup,\n", 188 | " num_samples=num_samples,\n", 189 | " num_chains=4,\n", 190 | " chain_method='parallel',\n", 191 | " \n", 192 | ")\n", 193 | "mcmc.run(rng_key, obs=x)\n", 194 | "posterior_samples = mcmc.get_samples()" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "4aceaa53-3454-4e94-8035-31022ebd8b14", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "import warnings\n", 205 | "warnings.filterwarnings('ignore')" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "f1b344e2-57ba-444b-9e31-36c172bd16e3", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "az.summary(mcmc, var_names=[\"~lkj_chol\", \"~chol\"], round_to=2)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "9ef69ee6-f7e4-4c57-a7b0-db850049cd88", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "rng_key = jax.random.PRNGKey(42)\n", 226 | "\n", 227 | "posterior_predictive = Predictive(model, posterior_samples=posterior_samples)\n", 228 | "posterior_predictions = posterior_predictive(rng_key, obs=x)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "c452e10d-e5fb-4bcc-a9bf-a01857f339b7", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "inference_data = az.from_numpyro(\n", 239 | " posterior=mcmc,\n", 240 | " posterior_predictive=posterior_predictions,\n", 241 | " coords={\"xy\": jnp.arange(2)},\n", 242 | " dims={\"μ\": [\"xy\"], \"chol_stds\": [\"xy\"], \"chol\": [\"xy\", \"xy\"], \"lkj_chol\": [\"xy\", \"xy\"], \"cov\": [\"xy\", \"xy\"], \"chol_corr\": [\"xy\", \"xy\"]}\n", 243 | ")\n", 244 | "display(inference_data)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "0633abbb-724f-4cb0-a84a-f299a20cba00", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "az.plot_trace(\n", 255 | " inference_data,\n", 256 | " compact=True,\n", 257 | " var_names=[\"~lkj_chol\", \"~chol\"],\n", 258 | " lines=[\n", 259 | " (\"μ\", {}, μ_actual),\n", 260 | " (\"chol_stds\", {}, sigmas_actual),\n", 261 | " (\"chol_corr\", {}, Rho_actual),\n", 262 | " (\"cov\", {}, Σ_actual),\n", 263 | " ],\n", 264 | ")\n", 265 | "plt.suptitle('Trace plots', fontsize=18)\n", 266 | "plt.show()" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3 (ipykernel)", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.9.6" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 5 291 | } 292 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/SVI/SVI_Part_01.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ee48511e-03db-419c-b609-65d53de6054e", 6 | "metadata": {}, 7 | "source": [ 8 | "# NumPyro SVI (Part 1 & 2)\n", 9 | "\n", 10 | "Based on Pyro tutorial on SVI: \n", 11 | "- Part 1: http://pyro.ai/examples/svi_part_i.html\n", 12 | " - About SVI\n", 13 | "- Part 2: http://pyro.ai/examples/svi_part_ii.html\n", 14 | " - About conditional independence via plates\n", 15 | "\n", 16 | "NumPyro SVI documentation:\n", 17 | "- http://num.pyro.ai/en/stable/svi.html" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "7e611633-7719-49d9-a3ec-36fcb0888911", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Imports\n", 28 | "%matplotlib inline\n", 29 | "%config InlineBackend.figure_format = 'svg'" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "0c162427-b1b2-4839-83e2-98a50dba304d", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import sys\n", 40 | "import warnings\n", 41 | "\n", 42 | "import numpy as np\n", 43 | "\n", 44 | "import jax\n", 45 | "import jax.numpy as jnp\n", 46 | "from jax.experimental import optimizers\n", 47 | "\n", 48 | "import numpyro\n", 49 | "from numpyro.infer import SVI, Trace_ELBO, Predictive\n", 50 | "import numpyro.distributions as dist\n", 51 | "\n", 52 | "import matplotlib\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "from matplotlib import cm\n", 55 | "import seaborn as sns\n", 56 | "import arviz as az\n", 57 | "from tqdm import tqdm_notebook as tqdm" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "d77dd4f7-03a0-4a2b-9168-3aeb117b87cf", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "sns.set_style('darkgrid')\n", 68 | "az.rcParams['stats.hdi_prob'] = 0.90\n", 69 | "az.style.use(\"arviz-darkgrid\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "db4cbea1-2fea-4c41-96d9-a4a02e22f8fb", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "numpyro.set_platform('cpu')\n", 80 | "numpyro.set_host_device_count(8)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "32386488-d01e-4a5b-b0f5-ec14d1d631bf", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "rng_key = jax.random.PRNGKey(42)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "763e61e7-c703-430a-9363-90b78c7eafc7", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# create some data with 7 observed heads and 3 observed tails\n", 101 | "data = jnp.concatenate([jnp.ones(7), jnp.zeros(3)])\n", 102 | "data" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "c9e819d0-07d0-4e2d-83ef-497b213bc51b", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def model(data):\n", 113 | " # define the hyperparameters that control the beta prior\n", 114 | " alpha0 = 10.\n", 115 | " beta0 = 10.\n", 116 | " # sample f from the beta prior\n", 117 | " f = numpyro.sample(\"latent_fairness\", dist.Beta(alpha0, beta0))\n", 118 | " # loop over the observed data\n", 119 | " with numpyro.plate(\"N\", data.shape[0]):\n", 120 | " # observe datapoint i using the bernoulli likelihood\n", 121 | " numpyro.sample(\"obs\", dist.Bernoulli(f), obs=data)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "6a381b4d-1ec8-4039-85bf-fc655573b390", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def guide(data):\n", 132 | " # register the two variational parameters with NumPyro\n", 133 | " # - both parameters will have initial value set in `numpyro.param`.\n", 134 | " # - because we invoke constraints.positive, the optimizer\n", 135 | " # will take gradients on the unconstrained parameters\n", 136 | " # (which are related to the constrained parameters by a log)\n", 137 | " alpha_q = numpyro.param(\n", 138 | " \"alpha_q\", 15.,\n", 139 | " constraint=dist.constraints.positive\n", 140 | " )\n", 141 | " beta_q = numpyro.param(\n", 142 | " \"beta_q\", 15.,\n", 143 | " constraint=dist.constraints.positive\n", 144 | " )\n", 145 | " # sample latent_fairness from the distribution Beta(alpha_q, beta_q)\n", 146 | " numpyro.sample(\"latent_fairness\", dist.Beta(alpha_q, beta_q))" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "a0e2a121-92b6-47bb-8e6f-6187b528d6d8", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "%%time\n", 157 | "\n", 158 | "# setup the optimizer\n", 159 | "optimizer = numpyro.optim.Adam(step_size=2e-4)\n", 160 | "\n", 161 | "# setup the inference algorithm\n", 162 | "svi = SVI(\n", 163 | " model=model,\n", 164 | " guide=guide,\n", 165 | " optim=optimizer,\n", 166 | " loss=Trace_ELBO()\n", 167 | ")\n", 168 | "\n", 169 | "# Run\n", 170 | "svi_result = svi.run(\n", 171 | " jax.random.PRNGKey(0),\n", 172 | " num_steps=5000,\n", 173 | " data=data\n", 174 | ")" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "9c8c0c93-e682-4649-a056-b7a776c9886b", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "fig, ax = plt.subplots(1, 1, figsize=(7, 3))\n", 185 | "ax.plot(svi_result.losses)\n", 186 | "ax.set_title(\"losses\")\n", 187 | "plt.show()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "1625acb3-9cfc-4514-ba4c-39d62672f518", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# grab the learned variational parameters\n", 198 | "svi_result.params\n", 199 | "alpha_q = svi_result.params[\"alpha_q\"]\n", 200 | "print(\"alpha_q: \", float(alpha_q))\n", 201 | "beta_q = svi_result.params[\"beta_q\"]\n", 202 | "print(\"beta_q: \", float(beta_q))\n", 203 | " \n", 204 | "# here we use some facts about the beta distribution\n", 205 | "# compute the inferred mean of the coin's fairness\n", 206 | "inferred_mean = alpha_q / (alpha_q + beta_q)\n", 207 | "# compute inferred standard deviation\n", 208 | "factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))\n", 209 | "inferred_std = inferred_mean * jnp.sqrt(factor)\n", 210 | "print(\"based on the data and our prior belief, the fairness \"\n", 211 | " f\"of the coin is {inferred_mean:.3f} +- {inferred_std:.3f}\"\n", 212 | ")" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "770c99e0-945e-4cb1-b103-c7449a84b7db", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "predictive = Predictive(\n", 223 | " guide,\n", 224 | " params=svi_result.params,\n", 225 | " num_samples=2500\n", 226 | ")\n", 227 | "samples = predictive(jax.random.PRNGKey(0), data)\n", 228 | "sns.histplot(samples['latent_fairness'])\n", 229 | "plt.show()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "id": "bd67bbbc-74a3-44e0-8e05-8c3555b4bbb0", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "634ba479-8b99-4ff6-9a50-5842ef712ae2", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "3b16d261-e2fe-4454-82f2-f190c0341514", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "5566094a-dc57-4e09-8347-246d090918dd", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "47df365a-e830-4e9a-a68b-1da06c3c43a8", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "id": "a8176cd0-0b49-49ae-8a64-eb4f3522775f", 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "44304439-4a89-4e1b-a6f4-539260f0f223", 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "642fbd6a-2a2a-432c-8bf6-8aa6714e155b", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "2f1d4716-3d2f-46a3-aa64-b7433b210777", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "7b9c2128-4426-4279-b6d2-3249d12a041a", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "id": "63556154-7ffd-4816-8c3c-89873d3b3ce0", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "45154dbb-6fc3-49b0-b704-aec7ce408d91", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [] 327 | } 328 | ], 329 | "metadata": { 330 | "kernelspec": { 331 | "display_name": "Python 3 (ipykernel)", 332 | "language": "python", 333 | "name": "python3" 334 | }, 335 | "language_info": { 336 | "codemirror_mode": { 337 | "name": "ipython", 338 | "version": 3 339 | }, 340 | "file_extension": ".py", 341 | "mimetype": "text/x-python", 342 | "name": "python", 343 | "nbconvert_exporter": "python", 344 | "pygments_lexer": "ipython3", 345 | "version": "3.9.7" 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 5 350 | } 351 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/SVI/linear_regression_SVI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8c6e0fd9-e61e-4c03-a09b-6aa83ebe1a80", 6 | "metadata": {}, 7 | "source": [ 8 | "# Linear regression with SVI\n", 9 | "\n", 10 | "Resources:\n", 11 | "- http://pyro.ai/examples/svi_part_i.html\n", 12 | "- http://pyro.ai/examples/svi_part_ii.html\n", 13 | "- http://pyro.ai/examples/svi_part_iii.html\n", 14 | "- http://pyro.ai/examples/svi_part_iv.html\n", 15 | "- http://num.pyro.ai/en/stable/svi.html\n", 16 | "\n", 17 | "Discourse thread:\n", 18 | "- https://forum.pyro.ai/t/large-variance-in-svi-losses-to-be-expected/3435" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "b6c94525-1d50-4229-a7c1-42ff4a44f1d8", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Imports\n", 29 | "%matplotlib inline\n", 30 | "%config InlineBackend.figure_format = 'svg'" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "liable-infection", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import sys\n", 41 | "import warnings\n", 42 | "\n", 43 | "import numpy as np\n", 44 | "\n", 45 | "import jax\n", 46 | "import jax.numpy as jnp\n", 47 | "from jax.experimental import optimizers\n", 48 | "\n", 49 | "import numpyro\n", 50 | "from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive\n", 51 | "import numpyro.distributions as dist\n", 52 | "from numpyro import handlers\n", 53 | "\n", 54 | "import matplotlib\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "import seaborn as sns" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "be26a56c-189e-46ba-a79b-35993966248d", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "sns.set_style('darkgrid')" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "rapid-masters", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "numpyro.set_platform('cpu')\n", 77 | "numpyro.set_host_device_count(8)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "054cc38d-1c44-49ec-b9ec-00497e49fb1f", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "ground_truth_params = {\n", 88 | " \"slope\" : 2.32,\n", 89 | " \"intercept\": 4.11,\n", 90 | " \"noise_std\": 0.5\n", 91 | "}" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "9fd0a863-3b4f-47b1-98e8-d3070fb97954", 97 | "metadata": {}, 98 | "source": [ 99 | "# Create Dataset" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "experienced-patch", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# Define the data\n", 110 | "np.random.seed(42)\n", 111 | "# Generate random data\n", 112 | "n = 51 # Number of samples\n", 113 | "# Linear relation\n", 114 | "slope_true = ground_truth_params[\"slope\"]\n", 115 | "intercept_true = ground_truth_params[\"intercept\"]\n", 116 | "fn = lambda x_: x_ * slope_true + intercept_true\n", 117 | "# Noise\n", 118 | "err = ground_truth_params[\"noise_std\"] * np.random.randn(n) # Noise\n", 119 | "# Features and output\n", 120 | "x_data = np.random.uniform(-1., 1., n) # Independent variable x\n", 121 | "y_data = fn(x_data) + err # Dependent variable\n", 122 | "\n", 123 | "# Show data\n", 124 | "plt.figure(figsize=(7, 4), dpi=100)\n", 125 | "plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')\n", 126 | "x_bound = (float(x_data.min()), float(x_data.max()))\n", 127 | "plt.plot(\n", 128 | " x_bound, [fn(x_bound[0]), fn(x_bound[1])], color='black', linestyle='-',\n", 129 | " label=f'$y = {intercept_true:.2f} + {slope_true:.2f} x$')\n", 130 | "plt.xlim(x_bound)\n", 131 | "plt.xlabel('$x$')\n", 132 | "plt.ylabel('$y$')\n", 133 | "plt.title('Noisy data samples from linear line')\n", 134 | "plt.legend()\n", 135 | "plt.show()\n", 136 | "#" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "efe91237-cd00-4b9e-8867-9a7c569bc94e", 142 | "metadata": {}, 143 | "source": [ 144 | "# Define model and variational distribution" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "chief-craft", 150 | "metadata": {}, 151 | "source": [ 152 | "$$\n", 153 | "\\mu_i = \\text{intercept} + \\text{slope} * x_i \\\\\n", 154 | "y_i \\sim \\mathcal{N}(\\mu_i, \\sigma) \\quad (i = 1, \\ldots, n)\n", 155 | "$$" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "45eb9ce3-1472-403c-897f-89f92e107a6e", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def model(x, y):\n", 166 | " slope = numpyro.sample('slope', dist.Normal(0., 10.))\n", 167 | " intercept = numpyro.sample('intercept', dist.Normal(0., 10.))\n", 168 | " noise_std = numpyro.sample('noise_std', dist.Exponential(1.))\n", 169 | " with numpyro.plate('obs', x.shape[0]):\n", 170 | " y_loc = numpyro.deterministic('y_loc', intercept + slope * x)\n", 171 | " numpyro.sample('y', dist.Normal(y_loc, noise_std), obs=y)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "2ee034fc-cbd2-4fee-89b1-b7878dc66b25", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "def guide(x, y):\n", 182 | " slope_loc = numpyro.param(\"slope_loc\", 0.)\n", 183 | " slope_scale = numpyro.param(\"slope_scale\", 0.01, constraint=dist.constraints.positive)\n", 184 | " slope = numpyro.sample('slope', dist.Normal(slope_loc, slope_scale))\n", 185 | " intercept_loc = numpyro.param(\"intercept_loc\", 0.)\n", 186 | " intercept_scale = numpyro.param(\"intercept_scale\", 0.01, constraint=dist.constraints.positive)\n", 187 | " intercept = numpyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))\n", 188 | " noise_std_log_loc = numpyro.param(\"noise_std_log_loc\", 0.1)\n", 189 | " noise_std_scale = numpyro.param(\"noise_std_scale\", 0.01, constraint=dist.constraints.positive)\n", 190 | " noise_std = numpyro.sample('noise_std', dist.LogNormal(noise_std_log_loc, noise_std_scale))" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "id": "ba2433f7-4e35-4bb5-b67a-a5a95dd68a13", 196 | "metadata": {}, 197 | "source": [ 198 | "Use `LogNormal` for guide distribution, since using `Exponential` leads to high variance.\n", 199 | "Exponential guid that leads to high variance:\n", 200 | "```\n", 201 | "noise_std_rate = numpyro.param(\"noise_std_rate\", 1., constraint=dist.constraints.positive)\n", 202 | "noise_std = numpyro.sample('noise_std', dist.Exponential(noise_std_rate))\n", 203 | "```" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "a8264ece-4a89-42a3-860a-c93e22e7b597", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "_s = dist.LogNormal(loc=-0.72457, scale=0.13578562).sample(key=jax.random.PRNGKey(42), sample_shape=(1000,))\n", 214 | "sns.histplot(_s)\n", 215 | "plt.show()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "d4f45418-4dda-467d-8b0f-9f2336707efb", 221 | "metadata": {}, 222 | "source": [ 223 | "## Fit SVI" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "95870cb8-2363-4e31-b451-ff5105fcf947", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# Learning rate schedule\n", 234 | "def cosine_annealing(lr_min, lr_max, num_steps, i):\n", 235 | " return lr_min + 0.5 * (lr_max - lr_min) * (1 + jnp.cos(jnp.pi * i / num_steps))\n", 236 | "\n", 237 | "\n", 238 | "num_steps = 5000\n", 239 | "lr_max = 2e-3\n", 240 | "lr_min = 1e-4\n", 241 | "\n", 242 | "iterations = jnp.arange(num_steps)\n", 243 | "lr_steps = cosine_annealing(lr_min, lr_max, num_steps, iterations)\n", 244 | "\n", 245 | "\n", 246 | "def lr_schedule(idx):\n", 247 | " return lr_steps[idx]" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "6e881e2e-d75b-405e-84ae-db803d8c3878", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# Use clipped Optimizer to deal with unstable gradients\n", 258 | "# http://num.pyro.ai/en/stable/optimizers.html#clippedadam\n", 259 | "optimizer = numpyro.optim.ClippedAdam(step_size=lr_schedule, clip_norm=1.0)\n", 260 | "\n", 261 | "# setup the inference algorithm\n", 262 | "svi = SVI(\n", 263 | " model=model,\n", 264 | " guide=guide,\n", 265 | " optim=optimizer,\n", 266 | " loss=TraceMeanField_ELBO(num_particles=1)\n", 267 | ")\n", 268 | "\n", 269 | "# Run\n", 270 | "svi_result = svi.run(\n", 271 | " jax.random.PRNGKey(0),\n", 272 | " num_steps=5000,\n", 273 | " x=x_data,\n", 274 | " y=y_data\n", 275 | ")" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "id": "267f7d0a-90c5-426a-bacd-23235f04f850", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "fig, ax = plt.subplots(1, 1, figsize=(7, 3))\n", 286 | "ax.plot(svi_result.losses)\n", 287 | "ax.set_title(\"losses\")\n", 288 | "ax.set_yscale(\"symlog\")\n", 289 | "plt.show()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "84ba71f9-83c2-4d71-b90b-b8f7f0632543", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "svi_predictive = Predictive(\n", 300 | " guide,\n", 301 | " params=svi_result.params,\n", 302 | " num_samples=2000\n", 303 | ")\n", 304 | "posterior_samples = svi_predictive(\n", 305 | " jax.random.PRNGKey(0),\n", 306 | " x=x_data,\n", 307 | " y=y_data\n", 308 | ")" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "id": "84370b92-46f7-4202-a3ec-b883f669d2ad", 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "fig, axs = plt.subplots(1, len(posterior_samples), figsize=(12, 4))\n", 319 | "\n", 320 | "for ax, (param_name, param_samples) in zip(axs, posterior_samples.items()):\n", 321 | " d = sns.histplot(param_samples, kde=True, stat='probability', ax=ax)\n", 322 | " ax.set_xlabel(param_name)\n", 323 | " ax.set_title(f\"Samples from {param_name!r}\")\n", 324 | " ax.axvline(np.mean(param_samples), color=\"black\", label=\"mean\")\n", 325 | " ax.axvline( ground_truth_params[param_name], color=\"red\", label=\"true\")\n", 326 | "fig.legend(*ax.get_legend_handles_labels(), bbox_to_anchor=(0., 0.7, 1.0, -.0))\n", 327 | "plt.show()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "id": "999a8a48-c613-4446-bf2b-3be0bd7c67af", 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "for param_name, param_samples in posterior_samples.items():\n", 338 | " param_gt = ground_truth_params[param_name]\n", 339 | " param_mean = np.mean(param_samples)\n", 340 | " param_std = np.std(param_samples)\n", 341 | " param_median = np.median(param_samples)\n", 342 | " param_quantile_low, param_quantile_high = np.quantile(param_samples, (.025, .975))\n", 343 | " print(f\"{param_name:>13}: true={param_gt:.2f}\\t median={param_median:.2f}\\t 95%-interval: {param_quantile_low:+.2f} - {param_quantile_high:+.2f}\\t \"\n", 344 | " f\"mean:{param_mean:.2f}±{param_std:.2f}\")" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "id": "b0404fd1-c84e-40ab-b14c-29364da05fcb", 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "mean_slope = np.mean(posterior_samples[\"slope\"])\n", 355 | "mean_intercept = np.mean(posterior_samples[\"intercept\"])\n", 356 | "y_mean_pred = jnp.array([-1., 1]) * mean_slope + mean_intercept\n", 357 | "\n", 358 | "\n", 359 | "# Show mean fit vs data\n", 360 | "plt.figure(figsize=(7, 4), dpi=100)\n", 361 | "plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')\n", 362 | "x_bound = (float(x_data.min()), float(x_data.max()))\n", 363 | "plt.plot(\n", 364 | " x_bound, [fn(x_bound[0]), fn(x_bound[1])], color='black', linestyle='-',\n", 365 | " label='true'\n", 366 | ")\n", 367 | "plt.plot(\n", 368 | " x_bound, y_mean_pred, color='red', linestyle='-',\n", 369 | " label=f'$pred = {mean_intercept:.2f} + {mean_slope:.2f} x$'\n", 370 | ")\n", 371 | "plt.xlim(x_bound)\n", 372 | "plt.xlabel('$x$')\n", 373 | "plt.ylabel('$y$')\n", 374 | "plt.title('Mean fit vs ground-truth data')\n", 375 | "plt.legend()\n", 376 | "plt.show()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "id": "b7958a67-2c3a-4ea0-8e2d-98e2fb171931", 382 | "metadata": {}, 383 | "source": [ 384 | "## Posterior predictions" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "0f6dfb3d-d3cc-4c62-a57e-342ab035ee9e", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "def plot_predictions(x_samples, predictions, name):\n", 395 | " x_bound = (float(x_samples.min()), float(x_samples.max()))\n", 396 | " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 8))\n", 397 | " # Plot prior parameters\n", 398 | " y_mu_mean = jnp.mean(predictions['y_loc'], 0)\n", 399 | " y_mu_pct = jnp.percentile(predictions['y_loc'], q=np.array([5., 95., 0.5, 99.5]), axis=0)\n", 400 | " for i in range(min(10, predictions['y_loc'].shape[0])):\n", 401 | " yi = predictions['y_loc'][i]\n", 402 | " label=None\n", 403 | " if i == 0:\n", 404 | " label = 'samples'\n", 405 | " ax1.plot(x_samples, yi, color='tab:gray', linestyle='-', alpha=0.5, label=label)\n", 406 | " ax1.plot(x_samples, y_mu_mean, color='tab:blue', linestyle='-', label='mean($\\mu_y$)', linewidth=2)\n", 407 | " ax1.fill_between(x_samples, y_mu_pct[0], y_mu_pct[1], color='tab:blue', alpha=0.2, label='$\\mu_y \\; 90\\%$')\n", 408 | " ax1.fill_between(x_samples, y_mu_pct[2], y_mu_pct[3], color='tab:blue', alpha=0.1, label='$\\mu_y \\; 99\\%$')\n", 409 | " ax1.set_xlim(x_bound)\n", 410 | " ax1.set_xlabel('$x$')\n", 411 | " ax1.set_ylabel('$y$')\n", 412 | " ax1.set_title(f'{name} parameter distribution')\n", 413 | " ax1.legend(loc='lower right')\n", 414 | "\n", 415 | " # Plot prior predictions\n", 416 | " y_mean = jnp.mean(predictions['y'], 0)\n", 417 | " y_pct = jnp.percentile(predictions['y'], q=np.array([5., 95., 0.5, 99.5]), axis=0)\n", 418 | " # Plot samples\n", 419 | " for i in range(min(100, predictions['y'].shape[0])):\n", 420 | " yi = predictions['y'][i]\n", 421 | " label=None\n", 422 | " if i == 0:\n", 423 | " label = 'samples'\n", 424 | " ax2.plot(x_samples, yi, color='tab:blue', marker='o', alpha=0.03, label=label)\n", 425 | " ax2.plot(x_samples, y_mean, 'k-', label='mean($y$)')\n", 426 | " ax2.fill_between(x_samples, y_pct[0], y_pct[1], color='k', alpha=0.2, label='$y \\; 90\\%$')\n", 427 | " ax2.fill_between(x_samples, y_pct[2], y_pct[3], color='k', alpha=0.1, label='$y \\; 99\\%$')\n", 428 | " ax2.set_xlim(x_bound)\n", 429 | " ax2.set_xlabel('$x$')\n", 430 | " ax2.set_ylabel('$y$')\n", 431 | " ax2.set_title(f'{name} predictive distribution')\n", 432 | " ax2.legend(loc='lower right')\n", 433 | " plt.tight_layout()\n", 434 | " plt.show()" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "2ac4f24a-cd65-4c7b-85a9-b2f79c885bcb", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "# Get posterior predictive samples\n", 445 | "# https://forum.pyro.ai/t/svi-version-of-mcmc-get-samples/3069/4\n", 446 | "posterior_predictive = Predictive(\n", 447 | " model=model,\n", 448 | " guide=guide,\n", 449 | " params=svi_result.params,\n", 450 | " num_samples=1000\n", 451 | ")\n", 452 | "\n", 453 | "x_samples = np.linspace(-1.5, 1.5, 100)\n", 454 | "posterior_predictions = posterior_predictive(jax.random.PRNGKey(42), x=x_samples, y=None)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "id": "8c6cc3fe-923b-423c-ac7b-d8ed074d5991", 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "plot_predictions(x_samples, posterior_predictions, 'Posterior')" 465 | ] 466 | } 467 | ], 468 | "metadata": { 469 | "kernelspec": { 470 | "display_name": "Python 3 (ipykernel)", 471 | "language": "python", 472 | "name": "python3" 473 | }, 474 | "language_info": { 475 | "codemirror_mode": { 476 | "name": "ipython", 477 | "version": 3 478 | }, 479 | "file_extension": ".py", 480 | "mimetype": "text/x-python", 481 | "name": "python", 482 | "nbconvert_exporter": "python", 483 | "pygments_lexer": "ipython3", 484 | "version": "3.9.7" 485 | } 486 | }, 487 | "nbformat": 4, 488 | "nbformat_minor": 5 489 | } 490 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/env/README.md: -------------------------------------------------------------------------------- 1 | # Python Environment 2 | 3 | ## Local conda environment 4 | 5 | The Miniconda environment can be build and activated locally if you have conda installed by: 6 | ``` 7 | conda env create --file ./env/conda_env.yml 8 | conda activate numpyro 9 | ``` 10 | 11 | And can be cleaned up afterwards with: 12 | ``` 13 | conda deactivate && conda remove --name numpyro --all 14 | ``` 15 | 16 | 17 | ### Updating 18 | 19 | To list outdated packages you can use `pip list --outdated` and `conda update --all --dry-run --channel conda-forge` in the activated environment. 20 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/env/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: numpyro 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - bokeh 6 | - datashader 7 | - holoviews 8 | - jupyter_bokeh 9 | - jupyterlab 10 | - matplotlib 11 | - nodejs 12 | - numpy 13 | - pip 14 | - python>=3.9 15 | - pyviz_comms 16 | - seaborn 17 | - scipy 18 | - tqdm 19 | - xarray 20 | - pip: 21 | - arviz 22 | - numpyro 23 | - watermark 24 | - jax 25 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/linear_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "partial-beijing", 6 | "metadata": {}, 7 | "source": [ 8 | "# Linear regression\n", 9 | "\n", 10 | "based on https://peterroelants.github.io/posts/linear-regression-four-ways/\n", 11 | "\n", 12 | "\n", 13 | "- https://towardsdatascience.com/introduction-to-bayesian-linear-regression-e66e60791ea7\n", 14 | "- http://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html\n", 15 | "- https://www.hellocybernetics.tech/entry/2020/02/23/034551\n", 16 | "- https://arviz-devs.github.io/arviz/user_guide/numpyro_refitting.html\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "b6c94525-1d50-4229-a7c1-42ff4a44f1d8", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Imports\n", 27 | "%matplotlib inline\n", 28 | "%config InlineBackend.figure_format = 'svg'" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "liable-infection", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import sys\n", 39 | "import warnings\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "import jax\n", 44 | "import jax.numpy as jnp\n", 45 | "\n", 46 | "import numpyro\n", 47 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 48 | "import numpyro.distributions as dist\n", 49 | "\n", 50 | "import matplotlib\n", 51 | "import matplotlib.pyplot as plt\n", 52 | "from matplotlib import cm # Colormaps\n", 53 | "import seaborn as sns\n", 54 | "import arviz as az\n", 55 | "from tqdm import tqdm_notebook as tqdm\n", 56 | "from IPython.display import display\n", 57 | "\n", 58 | "sns.set_style('darkgrid')\n", 59 | "az.rcParams['stats.hdi_prob'] = 0.90\n", 60 | "az.style.use(\"arviz-darkgrid\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "rapid-masters", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "numpyro.set_platform('cpu')\n", 71 | "numpyro.set_host_device_count(8)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "final-examination", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "rng_key = jax.random.PRNGKey(42)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "experienced-patch", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Define the data\n", 92 | "rng_key, rng_key_ = jax.random.split(rng_key)\n", 93 | "np.random.seed(rng_key_)\n", 94 | "# Generate random data\n", 95 | "n = 50 # Number of samples\n", 96 | "# Underlying linear relation\n", 97 | "m = 2.32 # slope\n", 98 | "b = 4.11 # bias\n", 99 | "fn = lambda x_: x_ * m + b\n", 100 | "# Noise\n", 101 | "e_std = 0.5 # Standard deviation of the noise\n", 102 | "err = e_std * np.random.randn(n) # Noise\n", 103 | "# Features and output\n", 104 | "x_data = np.random.uniform(-1, 1, n) # Independent variable x\n", 105 | "y_data = fn(x_data) + err # Dependent variable\n", 106 | "\n", 107 | "# Show data\n", 108 | "plt.figure(figsize=(7, 4), dpi=100)\n", 109 | "plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')\n", 110 | "plt.plot(\n", 111 | " [-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',\n", 112 | " label=f'$y = {b:.2f} + {m:.2f} x$')\n", 113 | "plt.xlim((-1, 1))\n", 114 | "plt.xlabel('$x$')\n", 115 | "plt.ylabel('$y$')\n", 116 | "plt.title('Noisy data samples from linear line')\n", 117 | "plt.legend()\n", 118 | "plt.show()\n", 119 | "#" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "recent-waterproof", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def plot_predictions(x_samples, predictions, name):\n", 130 | " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 8), dpi=100)\n", 131 | " # Plot prior parameters\n", 132 | " y_mu_mean = jnp.mean(predictions['y_mu'], 0)\n", 133 | " y_mu_pct = jnp.percentile(predictions['y_mu'], q=np.array([10., 90., 1., 99.]), axis=0)\n", 134 | " for i in range(min(10, predictions['y_mu'].shape[0])):\n", 135 | " yi = predictions['y_mu'][i]\n", 136 | " label=None\n", 137 | " if i == 0:\n", 138 | " label = 'samples'\n", 139 | " ax1.plot(x_samples, yi, color='tab:gray', linestyle='-', alpha=0.5, label=label)\n", 140 | " ax1.plot(x_samples, y_mu_mean, color='tab:blue', linestyle='-', label='mean($\\mu_y$)', linewidth=2)\n", 141 | " ax1.fill_between(x_samples, y_mu_pct[0], y_mu_pct[1], color='tab:blue', alpha=0.2, label='$\\mu_y \\; 90\\%$')\n", 142 | " ax1.fill_between(x_samples, y_mu_pct[2], y_mu_pct[3], color='tab:blue', alpha=0.1, label='$\\mu_y \\; 99\\%$')\n", 143 | " ax1.set_xlim((-1, 1))\n", 144 | " ax1.set_xlabel('$x$')\n", 145 | " ax1.set_ylabel('$y$')\n", 146 | " ax1.set_title(f'{name} parameter distribution')\n", 147 | " ax1.legend(loc='lower right')\n", 148 | "\n", 149 | " # Plot prior predictions\n", 150 | " y_mean = jnp.mean(predictions['y'], 0)\n", 151 | " y_pct = jnp.percentile(predictions['y'], q=np.array([10., 90., 1., 99.]), axis=0)\n", 152 | " # Plot samples\n", 153 | " for i in range(min(200, predictions['y'].shape[0])):\n", 154 | " yi = predictions['y'][i]\n", 155 | " label=None\n", 156 | " if i == 0:\n", 157 | " label = 'samples'\n", 158 | " ax2.plot(x_samples, yi, color='tab:blue', marker='o', alpha=0.02, label=label)\n", 159 | " ax2.plot(x_samples, y_mean, 'k-', label='mean($y$)')\n", 160 | " ax2.fill_between(x_samples, y_pct[0], y_pct[1], color='k', alpha=0.2, label='$y \\; 90\\%$')\n", 161 | " ax2.fill_between(x_samples, y_pct[2], y_pct[3], color='k', alpha=0.1, label='$y \\; 99\\%$')\n", 162 | " ax2.set_xlim((-1, 1))\n", 163 | " ax2.set_xlabel('$x$')\n", 164 | " ax2.set_ylabel('$y$')\n", 165 | " ax2.set_title(f'{name} predictive distribution')\n", 166 | " ax2.legend(loc='lower right')\n", 167 | " plt.show()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "chief-craft", 173 | "metadata": {}, 174 | "source": [ 175 | "$$\n", 176 | "y_i \\sim \\mathcal{N}(\\theta_0 + \\theta_1 x_i, \\sigma^2) \\quad (i = 1, \\ldots, n)\n", 177 | "$$" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "structural-settle", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "def model(x, y):\n", 188 | " theta_0 = numpyro.sample('theta_0', dist.Normal(0., 10.))\n", 189 | " theta_1 = numpyro.sample('theta_1', dist.Normal(0., 10.))\n", 190 | " y_sigma = numpyro.sample('y_sigma', dist.Exponential(1.))\n", 191 | " y_mu = numpyro.deterministic('y_mu', theta_0 + theta_1 * x)\n", 192 | " numpyro.sample('y', dist.Normal(y_mu, y_sigma), obs=y)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "id": "unnecessary-frost", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "rng_key, rng_key_ = jax.random.split(rng_key)\n", 203 | "\n", 204 | "x_samples = np.linspace(-1, 1, 100)\n", 205 | "num_prior_predictive_samples = 1000\n", 206 | "prior_predictive = Predictive(model, num_samples=num_prior_predictive_samples)\n", 207 | "prior_predictions = prior_predictive(rng_key_, x=x_samples, y=None)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "republican-country", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "plot_predictions(x_samples, prior_predictions, 'Prior')" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "scheduled-stations", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "rng_key, rng_key_ = jax.random.split(rng_key)\n", 228 | "\n", 229 | "num_warmup, num_samples = 1000, 2000\n", 230 | "\n", 231 | "# Run NUTS.\n", 232 | "kernel = NUTS(model)\n", 233 | "mcmc = MCMC(\n", 234 | " kernel,\n", 235 | " num_warmup=num_warmup,\n", 236 | " num_samples=num_samples,\n", 237 | " num_chains=4,\n", 238 | " chain_method='parallel'\n", 239 | ")\n", 240 | "mcmc.run(rng_key_, x=x_data, y=y_data)\n", 241 | "mcmc.print_summary()\n", 242 | "mcmc_samples = mcmc.get_samples()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "innocent-stupid", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "az_posterior = az.from_numpyro(posterior=mcmc)\n", 253 | "display(az_posterior)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "legendary-narrative", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", 264 | "az.plot_posterior(az_posterior, var_names=['theta_0', 'theta_1', 'y_sigma'], ax=ax)\n", 265 | "plt.suptitle('Posterior plots', fontsize=18)\n", 266 | "plt.show()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "married-rebate", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "fig, axes = plt.subplots(4, 2, figsize=(12, 8))\n", 277 | "az.plot_trace(az_posterior, compact=True, axes=axes)\n", 278 | "plt.suptitle('Trace plots', fontsize=18)\n", 279 | "plt.show()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "clean-graphic", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "# fig, ax = plt.subplots(1, 3, figsize=(14, 4), dpi=70)\n", 290 | "# az.plot_posterior(az_posterior, var_names=['theta_0', 'theta_1', 'y_sigma'], ax=ax)\n", 291 | "# plt.suptitle('Posterior plots', fontsize=18)\n", 292 | "# plt.tight_layout()\n", 293 | "# plt.show()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "answering-compact", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "rng_key, rng_key_ = jax.random.split(rng_key)\n", 304 | "\n", 305 | "x_samples = np.linspace(-1, 1, 100)\n", 306 | "posterior_predictive = Predictive(model, posterior_samples=mcmc_samples)\n", 307 | "posterior_predictions = posterior_predictive(rng_key_, x=x_samples, y=None)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "id": "fatal-buffer", 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "plot_predictions(x_samples, posterior_predictions, 'Posterior')" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "776ba40a-55e7-493c-b964-4dd7c7ee3fcc", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [] 327 | } 328 | ], 329 | "metadata": { 330 | "kernelspec": { 331 | "display_name": "Python 3 (ipykernel)", 332 | "language": "python", 333 | "name": "python3" 334 | }, 335 | "language_info": { 336 | "codemirror_mode": { 337 | "name": "ipython", 338 | "version": 3 339 | }, 340 | "file_extension": ".py", 341 | "mimetype": "text/x-python", 342 | "name": "python", 343 | "nbconvert_exporter": "python", 344 | "pygments_lexer": "ipython3", 345 | "version": "3.9.7" 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 5 350 | } 351 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/multivariate_gaussian_mixture_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e34a6de4", 6 | "metadata": {}, 7 | "source": [ 8 | "# 2D GMM" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "7e64d3c8", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%matplotlib inline\n", 19 | "%config InlineBackend.figure_format = 'svg'" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "763aaffe", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import numpy as np\n", 30 | "\n", 31 | "import jax\n", 32 | "import jax.numpy as jnp\n", 33 | "\n", 34 | "import numpyro\n", 35 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 36 | "import numpyro.distributions as dist\n", 37 | "\n", 38 | "import matplotlib\n", 39 | "import matplotlib.pyplot as plt\n", 40 | "from matplotlib import cm\n", 41 | "import seaborn as sns\n", 42 | "import arviz as az\n", 43 | "\n", 44 | "from tqdm import tqdm_notebook as tqdm" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "7e48bf9e", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "sns.set_style('darkgrid')\n", 55 | "az.style.use(\"arviz-darkgrid\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "d1e6e09b", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "numpyro.set_platform('cpu')\n", 66 | "numpyro.set_host_device_count(4)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "eb594649", 72 | "metadata": {}, 73 | "source": [ 74 | "## Create data" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "d23f51ff", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "np.random.seed(42)\n", 85 | "\n", 86 | "n = 2500 # Total number of samples\n", 87 | "k = 3 # Number of clusters\n", 88 | "\n", 89 | "# Probability of choosing each cluster\n", 90 | "true_mixture_probs = np.array([0.2, 0.5, 0.3]) \n", 91 | "assert np.isclose(true_mixture_probs.sum(), 1.)\n", 92 | "\n", 93 | "# Mean of clusters\n", 94 | "true_locs = np.array([ \n", 95 | " [-1.2, 1.5],\n", 96 | " [2.0, 2.],\n", 97 | " [-1, 4.]\n", 98 | "])\n", 99 | "\n", 100 | "# Correlation between x and y in clusters\n", 101 | "true_corrs = np.array([-0.85, 0.0, 0.85])\n", 102 | "\n", 103 | "# Correlation matrix\n", 104 | "true_corr_mats = np.stack([np.array([[1., true_corrs[i]], [true_corrs[i], 1.]]) for i in range(k)])\n", 105 | "print(\"true_corr_mats: \", true_corr_mats.shape)\n", 106 | "# Scales, or standard deviation in x&y directions of clusters\n", 107 | "true_scales = np.array([\n", 108 | " [0.9, 1.6],\n", 109 | " [1.0, 1.0],\n", 110 | " [1.4, 0.8],\n", 111 | "])\n", 112 | "print(\"true_scales: \", true_scales.shape)\n", 113 | "# Covariance matrix\n", 114 | "true_cov = np.einsum('ki,kj,kij->kij', true_scales, true_scales, true_corr_mats)\n", 115 | "\n", 116 | "# Sample mixture component indices\n", 117 | "true_mixture_idxs = np.random.choice(np.arange(k), p=true_mixture_probs, size=n)\n", 118 | "\n", 119 | "# Sample observations\n", 120 | "obs_data = np.vstack([\n", 121 | " np.random.multivariate_normal(true_locs[idx], true_cov[idx])\n", 122 | " for idx in true_mixture_idxs\n", 123 | "])\n", 124 | "assert obs_data.shape == (n, 2)\n", 125 | "\n", 126 | "cmap = {\n", 127 | " i: sns.color_palette(\"tab10\")[i]\n", 128 | " for i in range(k)\n", 129 | "}\n", 130 | "\n", 131 | "# Show observations\n", 132 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 133 | "for i in range(k):\n", 134 | " c_idx = (true_mixture_idxs == i)\n", 135 | " ax.plot(obs_data[c_idx, 0], obs_data[c_idx, 1], 'o', alpha=0.3, color=cmap[i], label=i)\n", 136 | "ax.set_aspect('equal')\n", 137 | "ax.set_title('Observations')\n", 138 | "ax.set_xlabel('x')\n", 139 | "ax.set_ylabel('y')\n", 140 | "ax.set_aspect('equal')\n", 141 | "ax.legend()\n", 142 | "plt.show()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "b8552233", 148 | "metadata": {}, 149 | "source": [ 150 | "## Mixture distribution" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "06a57a95", 156 | "metadata": {}, 157 | "source": [ 158 | "### Gaussian Mixture Model" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "fccdc12a", 164 | "metadata": {}, 165 | "source": [ 166 | "TODO:\n", 167 | "- Check if scales multiplies with corr_lower indeed give the cholesky matrix of the covariance with the same shapes... (I doubt so)\n", 168 | " - More info: https://www2.stat.duke.edu/courses/Spring12/sta104.1/Lectures/Lec22.pdf" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "id": "17c3c21f", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "@jax.vmap\n", 179 | "def create_chol_lower(scale, corr_lower):\n", 180 | " return scale[..., None] * corr_lower\n", 181 | "\n", 182 | "\n", 183 | "def gmm_model(d: int, k: int, obs=None):\n", 184 | " \"\"\"\n", 185 | " :param d: Dimension of Gaussian.\n", 186 | " :param k: Number of mixtures\n", 187 | " :param obs: Observations\n", 188 | " \"\"\"\n", 189 | " # Prior for cluster probabilities\n", 190 | " mixing_prob = numpyro.sample('mixing_probabilities', dist.Dirichlet(concentration=jnp.ones((k, ))))\n", 191 | " # Prior on cluster means\n", 192 | " with numpyro.plate('mixture_plate', k, dim=-2):\n", 193 | " scales = numpyro.sample(\"scales\", dist.HalfCauchy(scale=jnp.ones(d)*2))\n", 194 | " locs = numpyro.sample('locs', dist.Cauchy(loc=jnp.zeros(d), scale=jnp.ones(d)*2))\n", 195 | " # Prior on correlation trough LKJ prior\n", 196 | " with numpyro.plate('mixture_plate', k, dim=-1):\n", 197 | " corr_lower = numpyro.sample(\"corr_lower\", dist.LKJCholesky(dimension=d, concentration=1.))\n", 198 | " # Extract correlation for later analysis\n", 199 | " corrs = numpyro.deterministic(\"correlations\", corr_lower[:, 1, 0])\n", 200 | " # Mixing distribution\n", 201 | " mixing_dist = dist.Categorical(probs=mixing_prob)\n", 202 | " # Mixture components\n", 203 | " lower_cholesky = create_chol_lower(scales, corr_lower)\n", 204 | " component_dist = dist.MultivariateNormal(loc=locs, scale_tril=lower_cholesky)\n", 205 | " # Mixture distribution\n", 206 | " gmm_dist = dist.MixtureSameFamily(mixing_distribution=mixing_dist, component_distribution=component_dist)\n", 207 | " numpyro.sample('obs', gmm_dist, obs=obs)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "5d449bcf", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "rng_key = jax.random.PRNGKey(42)\n", 218 | "\n", 219 | "num_warmup, num_samples = 1000, 3000\n", 220 | "\n", 221 | "# Run NUTS.\n", 222 | "kernel = NUTS(gmm_model)\n", 223 | "mcmc = MCMC(\n", 224 | " kernel,\n", 225 | " num_warmup=num_warmup,\n", 226 | " num_samples=num_samples,\n", 227 | " num_chains=4,\n", 228 | " chain_method='parallel',\n", 229 | ")\n", 230 | "mcmc.run(rng_key, d=2, k=3, obs=obs_data)\n", 231 | "posterior_samples = mcmc.get_samples()" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "05ec156b", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "az.summary(mcmc, var_names=[\"~corr_lower\"], round_to=2)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "e896b172", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "rng_key = jax.random.PRNGKey(42)\n", 252 | "\n", 253 | "posterior_predictive = Predictive(gmm_model, posterior_samples=posterior_samples, batch_ndims=0)\n", 254 | "posterior_predictions = posterior_predictive(rng_key, d=2, k=3, obs=None)\n", 255 | "print('Posterior predictions: ', posterior_predictions['obs'].shape)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "id": "61ffe94a", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "inference_data = az.from_numpyro(\n", 266 | " posterior=mcmc,\n", 267 | " posterior_predictive=posterior_predictions,\n", 268 | " coords={\"mixture\": np.arange(k), \"dim\": np.arange(2)},\n", 269 | " dims={\"locs\": [\"mixture\", \"dim\"], \"scales\": [\"mixture\", \"dim\"], \"mixing_probabilities\": [\"mixture\"]}\n", 270 | ")\n", 271 | "display(inference_data)\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "654f5fad", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "az.plot_trace(\n", 282 | " inference_data,\n", 283 | " compact=True,\n", 284 | " var_names=[\"~corr_lower\"],\n", 285 | " lines=[\n", 286 | " (\"correlations\", {}, true_corrs),\n", 287 | " (\"locs\", {}, true_locs),\n", 288 | " (\"scales\", {}, true_scales),\n", 289 | " (\"mixing_probabilities\", {}, true_mixture_probs)\n", 290 | " ],\n", 291 | ")\n", 292 | "plt.suptitle('Trace plots', fontsize=18)\n", 293 | "plt.show()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "a460fb8b", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n", 304 | "ax.plot(posterior_predictions[\"obs\"][:, 0], posterior_predictions[\"obs\"][:, 1], 'o', alpha=0.1)\n", 305 | "ax.set_aspect('equal')\n", 306 | "ax.set_title('Posterior predicted samples')\n", 307 | "ax.set_xlabel('x')\n", 308 | "ax.set_ylabel('y')\n", 309 | "plt.show()" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "id": "b73546d2", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [] 319 | } 320 | ], 321 | "metadata": { 322 | "kernelspec": { 323 | "display_name": "Python 3 (ipykernel)", 324 | "language": "python", 325 | "name": "python3" 326 | }, 327 | "language_info": { 328 | "codemirror_mode": { 329 | "name": "ipython", 330 | "version": 3 331 | }, 332 | "file_extension": ".py", 333 | "mimetype": "text/x-python", 334 | "name": "python", 335 | "nbconvert_exporter": "python", 336 | "pygments_lexer": "ipython3", 337 | "version": "3.9.6" 338 | } 339 | }, 340 | "nbformat": 4, 341 | "nbformat_minor": 5 342 | } 343 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/numpyro_getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "plastic-limit", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import numpyro\n", 12 | "\n", 13 | "import numpyro.distributions as dist\n", 14 | "from jax import random\n", 15 | "\n", 16 | "from numpyro.infer import MCMC, NUTS\n", 17 | "from numpyro.infer.reparam import TransformReparam\n", 18 | "from numpyro.infer import Predictive" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "perceived-leadership", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "J = 8\n", 29 | "\n", 30 | "y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", 31 | "\n", 32 | "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "ancient-webster", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Eight Schools example\n", 43 | "def eight_schools(J, sigma, y=None):\n", 44 | " mu = numpyro.sample('mu', dist.Normal(0, 5))\n", 45 | " tau = numpyro.sample('tau', dist.HalfCauchy(5))\n", 46 | " with numpyro.plate('J', J):\n", 47 | " theta = numpyro.sample('theta', dist.Normal(mu, tau))\n", 48 | " numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "noticed-rachel", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "nuts_kernel = NUTS(eight_schools)\n", 59 | "\n", 60 | "mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)\n", 61 | "\n", 62 | "rng_key = random.PRNGKey(0)\n", 63 | "\n", 64 | "mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))\n", 65 | "\n", 66 | "\n", 67 | "mcmc.print_summary()\n", 68 | "\n", 69 | "pe = mcmc.get_extra_fields()['potential_energy']\n", 70 | "print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) " 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "unique-military", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Eight Schools example - Non-centered Reparametrization\n", 81 | "def eight_schools_noncentered(J, sigma, y=None):\n", 82 | " mu = numpyro.sample('mu', dist.Normal(0, 5))\n", 83 | " tau = numpyro.sample('tau', dist.HalfCauchy(5))\n", 84 | " with numpyro.plate('J', J):\n", 85 | " with numpyro.handlers.reparam(config={'theta': TransformReparam()}):\n", 86 | " theta = numpyro.sample(\n", 87 | " 'theta',\n", 88 | " dist.TransformedDistribution(dist.Normal(0., 1.),\n", 89 | " dist.transforms.AffineTransform(mu, tau)))\n", 90 | " numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)\n", 91 | "\n", 92 | "\n", 93 | "\n", 94 | "nuts_kernel = NUTS(eight_schools_noncentered)\n", 95 | "\n", 96 | "mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)\n", 97 | "\n", 98 | "rng_key = random.PRNGKey(0)\n", 99 | "\n", 100 | "mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))\n", 101 | "\n", 102 | "mcmc.print_summary(exclude_deterministic=False) \n", 103 | "pe = mcmc.get_extra_fields()['potential_energy']\n", 104 | "# Compare with the earlier value\n", 105 | "print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) " 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "capable-alpha", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# New School\n", 116 | "def new_school():\n", 117 | " mu = numpyro.sample('mu', dist.Normal(0, 5))\n", 118 | " tau = numpyro.sample('tau', dist.HalfCauchy(5))\n", 119 | " return numpyro.sample('obs', dist.Normal(mu, tau))\n", 120 | "\n", 121 | "\n", 122 | "\n", 123 | "predictive = Predictive(new_school, mcmc.get_samples())\n", 124 | "\n", 125 | "samples_predictive = predictive(random.PRNGKey(1))\n", 126 | "print(np.mean(samples_predictive['obs'])) " 127 | ] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3 (ipykernel)", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.9.6" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 5 151 | } 152 | -------------------------------------------------------------------------------- /probabilistic_programming/numpyro/test-batch-shape.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "48a5735f-a1b0-49e9-bb61-80e9b3b545f4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "import warnings\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "import jax\n", 16 | "import jax.numpy as jnp\n", 17 | "from jax import lax\n", 18 | "\n", 19 | "import numpyro\n", 20 | "from numpyro.infer import MCMC, NUTS, Predictive\n", 21 | "import numpyro.distributions as dist\n", 22 | "\n", 23 | "import matplotlib\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from matplotlib import cm # Colormaps\n", 26 | "import seaborn as sns\n", 27 | "import arviz as az\n", 28 | "\n", 29 | "from tqdm import tqdm_notebook as tqdm" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "faa9d512-a3d1-40e4-9dbc-3b2a55f84778", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "sns.set_style('darkgrid')\n", 40 | "az.rcParams['stats.hdi_prob'] = 0.90\n", 41 | "az.style.use(\"arviz-darkgrid\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "b3fc99d6-c7cb-454c-87fc-526e007d6192", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "numpyro.set_platform('cpu')\n", 52 | "numpyro.set_host_device_count(8)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "0f3e3909-54d0-4f3d-a18e-a0e8346974f6", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "np.random.seed(42)\n", 63 | "rng_key = jax.random.PRNGKey(42)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "827455c5-b6a0-44f9-a11a-97fa7e7c4040", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "normal = dist.Normal(loc=0., scale=1.)\n", 74 | "print(\"normal.shape(): \", normal.shape())\n", 75 | "print(\"normal.batch_shape: \", normal.batch_shape)\n", 76 | "print(\"normal.event_shape: \", normal.event_shape)\n", 77 | "print(\"normal.event_dim: \", normal.event_dim)\n", 78 | "print('')\n", 79 | "samples = normal.sample(rng_key, (100,))\n", 80 | "print('samples: ', samples.shape)\n", 81 | "print('')\n", 82 | "lp = normal.log_prob(samples)\n", 83 | "print('lp: ', lp.shape)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "f8d96425-0cf8-410d-a1da-de64851d42cf", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "k = 3\n", 94 | "loc = jnp.zeros(k)\n", 95 | "print(\"loc: \", loc.shape)\n", 96 | "scale = jnp.ones(k)\n", 97 | "print(\"scale: \", scale.shape)\n", 98 | "print('')\n", 99 | "normal = dist.Normal(loc=loc, scale=scale)\n", 100 | "print(\"normal.shape(): \", normal.shape())\n", 101 | "print(\"normal.batch_shape: \", normal.batch_shape)\n", 102 | "print(\"normal.event_shape: \", normal.event_shape)\n", 103 | "print(\"normal.event_dim: \", normal.event_dim)\n", 104 | "print('')\n", 105 | "samples = normal.sample(rng_key, (100,))\n", 106 | "print('samples: ', samples.shape)\n", 107 | "print('')\n", 108 | "lp = normal.log_prob(samples)\n", 109 | "print('lp: ', lp.shape)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "3eb964ee-a0fe-49e0-a921-d31d1fa1defc", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "k = 3\n", 120 | "s = 4\n", 121 | "loc = jnp.vstack([jnp.zeros(k) / k for _ in range(s)])\n", 122 | "print(\"loc: \", loc.shape)\n", 123 | "scale =jnp.vstack([jnp.ones(k) / k for _ in range(s)])\n", 124 | "print(\"scale: \", scale.shape)\n", 125 | "print('')\n", 126 | "normal = dist.Normal(loc=loc, scale=scale)\n", 127 | "print(\"normal.shape(): \", normal.shape())\n", 128 | "print(\"normal.batch_shape: \", normal.batch_shape)\n", 129 | "print(\"normal.event_shape: \", normal.event_shape)\n", 130 | "print(\"normal.event_dim: \", normal.event_dim)\n", 131 | "print('')\n", 132 | "samples = normal.sample(rng_key, (100,))\n", 133 | "print('samples: ', samples.shape)\n", 134 | "print('')\n", 135 | "lp = normal.log_prob(samples)\n", 136 | "print('lp: ', lp.shape)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "2b62acd3-e83c-4351-9d08-6bbf1024a728", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "d = 2\n", 147 | "loc = jnp.zeros((d,))\n", 148 | "print(\"loc: \", loc.shape)\n", 149 | "cov_matrix = jnp.eye(d, d)\n", 150 | "print(\"cov_matrix: \", cov_matrix.shape)\n", 151 | "print('')\n", 152 | "normal = dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)\n", 153 | "print(\"normal.shape(): \", normal.shape())\n", 154 | "print(\"normal.batch_shape: \", normal.batch_shape)\n", 155 | "print(\"normal.event_shape: \", normal.event_shape)\n", 156 | "print(\"normal.event_dim: \", normal.event_dim)\n", 157 | "print('')\n", 158 | "samples = normal.sample(rng_key, (100,))\n", 159 | "print('samples: ', samples.shape)\n", 160 | "print('')\n", 161 | "lp = normal.log_prob(samples)\n", 162 | "print('lp: ', lp.shape)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "196b7b59-3cd3-44c8-91c1-46c784c91b4d", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "k = 3\n", 173 | "d = 2\n", 174 | "loc = jnp.zeros((k, d))\n", 175 | "print(\"loc: \", loc.shape)\n", 176 | "cov_matrix = jnp.repeat(jnp.expand_dims(jnp.eye(d, d), 0), k, axis=0)\n", 177 | "print(\"cov_matrix: \", cov_matrix.shape)\n", 178 | "print('')\n", 179 | "normal = dist.MultivariateNormal(loc=loc, covariance_matrix=cov_matrix)\n", 180 | "print(\"normal.shape(): \", normal.shape())\n", 181 | "print(\"normal.batch_shape: \", normal.batch_shape)\n", 182 | "print(\"normal.event_shape: \", normal.event_shape)\n", 183 | "print(\"normal.event_dim: \", normal.event_dim)\n", 184 | "print('')\n", 185 | "samples = normal.sample(rng_key, (100,))\n", 186 | "print('samples: ', samples.shape)\n", 187 | "print('')\n", 188 | "lp = normal.log_prob(samples)\n", 189 | "print('lp: ', lp.shape)" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "Python 3 (ipykernel)", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.9.6" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 5 214 | } 215 | --------------------------------------------------------------------------------