├── README.md ├── bayesian-statistics ├── Modelling_M&Ms_AHW.ipynb ├── Modelling_M&Ms_Solutions_AHW2022.ipynb ├── dhuppenkothen_ahw_bayesianstats.pdf ├── mm_counts.csv └── mm_smallsample.jpg ├── day2_ml_tutorial ├── 01-intro-machine-learning.ipynb ├── 01-intro-machine-learning_answers.ipynb ├── 02-deep-learning.ipynb ├── 02-deep-learning_answers.ipynb └── README.md └── monday_ml_tutorial ├── .gitignore ├── README.md ├── deepmerge ├── download_data.py ├── prepare_data.py └── train_cnn.py ├── original_notebook.ipynb ├── requirements.txt └── trim_fits.py /README.md: -------------------------------------------------------------------------------- 1 | # Astro Hack Week 2022 2 | 3 | This Repository contains the materials for Astro Hack Week 2022. 4 | -------------------------------------------------------------------------------- /bayesian-statistics/Modelling_M&Ms_AHW.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Modeling M&M Colour Distributions \n", 8 | "\n", 9 | "In this exercise, we will model M&M colour distributions using Bayesian inference. \n", 10 | "\n", 11 | "Here's the basic problem: You have a bag of M&Ms, and you want to know how many blue ones are in it. Easy, right? You open the bag, and count all the blue ones! But what if you have a [10 pound bag](https://www.mymms.com/product/bulk-mms-candy-10-lb.do?sortby=ourPicksAscend&refType=&from=fn&ecList=7&ecCategory=100601) of M&Ms? Counting them all would take you ages, wouldn't it? Or what if you had several bags. The exact number of blue M&Ms as well as the total number of M&Ms in each bag might vary! So really, what we need is a model of the *average* number of blue M&Ms per bag. We don't just need any model, but we need a *statistical* model that describes the number of blue M&Ms we would get out of a bag of M&Ms given some underlying true fraction of blue M&Ms. \n", 12 | "\n", 13 | "#### Imports\n", 14 | "\n", 15 | "This exercise initially only requires `numpy`, `scipy` and `matplotlib` for plotting, along with some `ipython widgets` if you want to run the interactive plots. Below, you might also want to use `pymc3` for sampling the hierarchical model." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install matplotlib\n", 25 | "!pip install seaborn\n", 26 | "!pip install ipywidgets\n", 27 | "!pip install numpy\n", 28 | "!pip install scipy" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "%matplotlib inline\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "from ipywidgets import interact, interactive, fixed, interact_manual\n", 40 | "import ipywidgets as widgets\n", 41 | "\n", 42 | "# prettier plotting;\n", 43 | "# comment out if seaborn is not installed\n", 44 | "import seaborn as sns\n", 45 | "sns.set_style(\"whitegrid\")\n", 46 | "\n", 47 | "import numpy as np\n", 48 | "import scipy.stats\n", 49 | "import scipy.special" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## The Binomial Likelihood\n", 57 | "\n", 58 | "In statistics, the statistical model for how one draws observations given some underlying process is called the *likelihood*. \n", 59 | "\n", 60 | "In our case, we have two observations: the total number of M&Ms in a bag $N$, and the number of blue M&Ms out of that total number, $k$. \n", 61 | "There are only two things the M&Ms can be in our (simplified) model: blue and not-blue. It is worth noting at this point that virtually any model used to make sense of data is always a *simplification* of the true underlying process. In reality, M&Ms come in six different colours. We have *simplified* this to just two. This is fine as long as blue M&Ms are all we care about. If we suddenly also cared about green M&Ms, we'd need to make our model more complicated to account for this (more on this later)!\n", 62 | "\n", 63 | "Back to our blue M&Ms. Every time you draw an M&M out of a bag, you get one of two outcomes: blue or not-blue. In a more statistical language, a draw of an M&M out of a bag is called a *trial*, drawing a blue M&M is called a *success* (a not-blue M&M is called a *failure*). You can do this $N$ times and then record the number of successes, $k$. \n", 64 | "\n", 65 | "Assuming that there is some underlying fraction $q$ of blue M&Ms being produced and put into bags, then for every time you draw an $M&M$ out of the bag, you will draw a blue one with probability $q$ and a not-blue one with probability $(1-q)$ (since these are our only two, mutually exclusive options, and all probabilities must sum up to $1$). \n", 66 | "\n", 67 | "$$\n", 68 | "p(k | N, q) = {N \\choose k} q^k (1-q)^(N-k) \\; .\n", 69 | "$$\n", 70 | "\n", 71 | "Let's talk about how to read this equation. On the right side is a probability $p$, and it's the probability of getting $k$ blue M&Ms out of a bag with $N$ total M&Ms, and an underlying fraction of $q$ blue M&Ms per total. The $|$ symbol always denotes the term *given*, which implies *truths* about the world, or things we know. In this case, we *know* that we've drawn $N$ M&Ms out of the bag, and we're trying to figure out how probable it is that $k$ will be blue, given some true underlying rate $q$. Note that here, we assume that we actually *know* what the true number of blue M&Ms per bag should be, but in reality, we don't! \n", 72 | "Keep this in the back of your head, we'll get back to it in a little while!\n", 73 | "\n", 74 | "On the left-hand side is the definition of the probability distribution we are interested in. The probability of drawing $k$ blue M&Ms is $q^k$ (if the draws are all independent). Then we have $N-k$ not-blue M&Ms left, and the probability of drawing those is $(1-q)^{N-k}$. The ${N choose k}$ term in the front of the expression comes from the fact that $q^k (1-q)^{N-k}$ is the probability of one *specific* sequence. For example, you could have drawn something like \\[blue, blue, not-blue, blue, not-blue, not-blue, not-blue, not-blue, blue\\], which is a specific sequence. But we don't really care about whether you draw a blue or not-blue first, all we care about is the total number of blue M&Ms out of the total. The first term corrects the expression for all possible permutations of sequences that could produce $k$ blue M&Ms out of $N$ total. \n", 75 | "\n", 76 | "This expression, as a whole, is called the *binomial distribution*, and is the likelihood we're going to use.\n", 77 | "\n", 78 | "**Exercise 1**: Without looking at any M&Ms, take a guess for what a reasonable value might be for $q$. Then look at the image below (where I've recorded some M&M data for you).\n", 79 | "Calculate the probability of that number of blue M&Ms out of your set of 30 trials, given the value of $q$ you've chosen. How large is that probability? \n", 80 | "\n", 81 | "Here's a picture of the M&Ms:\n", 82 | "\n", 83 | "![a handful of M&Ms, sorted by colours](mm_smallsample.jpg)\n", 84 | "\n", 85 | "If you count carefully, you'll find that the total number of M&Ms in the image is $n = 30$. The number of blue M&Ms is $k = 10$. Let's add those numbers below.\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "##################################################################\n", 95 | "## ADD YOUR VALUES BELOW\n", 96 | "##################################################################n = # number of draws out of bag\n", 97 | "k = # add the number of blue M&Ms you drew to this variable\n", 98 | "\n", 99 | "q = # add the value you chose for q here" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "Now we need to write down the equation of the binomial distribution. I've done this for you so you can play with it rather than get lost in coding exercises:\n", 107 | "\n", 108 | "**Hint**: The function `scipy.special.comb` allows you to calculate the combinatorial pre-factor of the binomial distribution:" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "from scipy.special import comb" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def binomial_distribution(n, k, q):\n", 127 | " \"\"\"\n", 128 | " Calculate the probability of $k$ successes out \n", 129 | " of $n$ trials, given an underlying success rate $q$.\n", 130 | " \n", 131 | " Parameters\n", 132 | " ----------\n", 133 | " n : int\n", 134 | " The total number of trials\n", 135 | " \n", 136 | " k : int\n", 137 | " The number of successful draws out of $n$\n", 138 | " \n", 139 | " q : float, [0,1]\n", 140 | " The success rate\n", 141 | " \n", 142 | " Returns\n", 143 | " -------\n", 144 | " prob : float [0,1]\n", 145 | " The binomial probability of $k$ draws out \n", 146 | " of $n$ trials\n", 147 | " \"\"\"\n", 148 | " \n", 149 | " bin_fac = comb(n,k)\n", 150 | " \n", 151 | " first_prob = q ** k\n", 152 | " second_prob = (1. - q) ** (n - k)\n", 153 | " \n", 154 | " prob = bin_fac * first_prob * second_prob\n", 155 | "\n", 156 | " return prob" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "Let's use this function to calculate the probability of our M&Ms above, given the value we assumed for the fraction of blue M&Ms:" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# calculate the binomial probability \n", 173 | "print(\"Probability: \" + str(binomial_distribution(n, k, q)))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "For illustration, let's also imagine we have a friend at the M&M factory, who has told us that 50% of all M&Ms produced are blue. Calculate the probability for all possible outcomes (from no blue M&Ms in the bag, all the way to all M&Ms are blue), and plot the result.\n", 181 | "\n", 182 | "**Exercise 2**: How probable is drawing $k=10$ blue M&Ms out of our bag of $n=30$ if the fraction of blue M&Ms produced is indeed $q=0.5$? How many blue M&Ms would you *expect* to draw?\n", 183 | "\n", 184 | "For fun, I've created for you an interactive version of the plot you've just created, where you can change the total number of M&Ms, the observed fraction of blue M&Ms and the true underlying fraction of blue M&Ms produced at the factory.\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "fraction_slider = widgets.FloatSlider(\n", 194 | " value=0.15,\n", 195 | " min=0.0,\n", 196 | " max=1.0,\n", 197 | " step=0.01,\n", 198 | " description='Fraction of blue M&Ms:',\n", 199 | " readout_format='.2f',\n", 200 | ")\n", 201 | "@widgets.interact(\n", 202 | " all_mms=\"100\", blue_mms=\"10\", fraction=(0.0, 1.0, 0.05), max_plot=(0,200,1))\n", 203 | "def plot(all_mms=100, blue_mms=10, fraction=0.15, max_plot=200):\n", 204 | "\n", 205 | " nexp = 1000\n", 206 | " \n", 207 | " all_mms = int(all_mms)\n", 208 | "\n", 209 | " blue_mms = int(blue_mms)\n", 210 | "\n", 211 | " dist = scipy.stats.binom(all_mms, fraction)\n", 212 | "\n", 213 | " x = np.arange(0,all_mms, 1.0)\n", 214 | "\n", 215 | " #x = np.arange(0,all_mms, 1.0)\n", 216 | " rvs = dist.rvs(size=1000)\n", 217 | "\n", 218 | " pmf, bins = np.histogram(rvs, bins=all_mms, range=[0,all_mms])\n", 219 | " \n", 220 | " fig, ax = plt.subplots(1, 1, figsize=(8,6))\n", 221 | "\n", 222 | "\n", 223 | " ax.vlines(x, 0, pmf, colors='k', linestyles='-', lw=1,\n", 224 | " label='frozen pmf')\n", 225 | " ax.axvline(blue_mms, color=\"red\", lw=2)\n", 226 | " ax.set_ylabel(\"Probability mass function\", fontsize=20)\n", 227 | " ax.set_xlabel(\"Number of blue M&Ms\", fontsize=20)\n", 228 | " ax.set_xlim(0, max_plot)\n", 229 | " ax.set_ylim(0, np.max(pmf)+5)\n", 230 | "\n" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## Calculating the Likelihood Function\n", 238 | "\n", 239 | "There's a fundamental problem with the process above: it assumes we know $q$ (it's on the right side of the \"|\") and that $k$ is a random variable. But this is often not what we observe in reality! Often, we can make *observations* of the process (here: $k$ blue M&Ms out of $N$), and we care about the *parameters* of the underlying model (here: the success rate $q$). In reality, we probably don't have a friend at the M&M factory who can tell us the true fractions of colours produced! What we really want to know is not $p(k | N, q)$, but $p(q | N, k)$. It is important to realize that these two are not the same! For an illustrative example, consider a simpler case. Consider that you're given the information that it is raining outside. What can you conclude about the cloud cover overhead? Conversely, imagine you're being told the it is cloudy. Can you conclude with equal probability that it is also raining? \n", 240 | "\n", 241 | "So, if they're not the same, do we get from $p(k | N, q)$, which we've shown we can calculate, to $p(q | N, k)$? In principle, nothing stops you from measuring your $k$ successes out of $N$ trials, and then calculating $p(k | N, q)$ for different values of $q$. However, there is a reason this is called a likelihood *function*: it is *not* a probability distribution of the parameter $q$, because $q$ is on the right-hand side of the \"|\" sign. It is fixed, known, assumed to be true. The binomial probability is a probability distribution in $k$, not $q$. This may sound subtle, but has huge consequences, one of them being that $p(k | N, q)$ as a function of $q$ does not integrate to 1, like a proper probability distribution.\n", 242 | "\n", 243 | "Let's calculate the likelihood function for our problem. For this, we keep the number of blue M&Ms $k$ and the number of total M&Ms $N$ fixed, but vary the fraction of blue M&Ms produced at the factory $q$.\n", 244 | "\n", 245 | "**Exercise 3**: Remember that your friend told you that the fraction of M&Ms produced at the factory is $q = 0.5$. Do you think that your data is very probable given that knowledge? What if you only had 3 blue M&Ms in the package? Plot likelihood functions for both values of $k$ and compare them.\n" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "##################################################################\n", 255 | "## ADD YOUR VALUES BELOW\n", 256 | "##################################################################\n", 257 | "\n", 258 | "n = # total number of M&Ms\n", 259 | "k = # blue M&Ms in bag\n", 260 | "\n", 261 | "##################################################################\n", 262 | "\n", 263 | "\n", 264 | "# generate a list of all possible values of $q$, going \n", 265 | "# from 0 to 1:\n", 266 | "qtrial = np.arange(0, 1, 0.001)\n", 267 | "\n", 268 | "# make a list of zeros to store the probabilities for \n", 269 | "# each value of q in\n", 270 | "prob = np.zeros_like(qtrial)\n", 271 | "\n", 272 | "# loop through all values of k and calculate the binomial probability\n", 273 | "# for this case\n", 274 | "for i,q in enumerate(qtrial):\n", 275 | " prob[i] = binomial_distribution(n, k, q) \n", 276 | " \n", 277 | "# plot the results\n", 278 | "fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 279 | "ax.plot(qtrial, prob, lw=2, color=\"black\")\n", 280 | "ax.set_xlabel(r\"fraction of blue M&Ms $q$\")\n", 281 | "ax.set_ylabel(r\"$p(k|N,q)$\")\n", 282 | "\n", 283 | "# second distribution\n", 284 | "\n", 285 | "##################################################################\n", 286 | "## ADD YOUR VALUES BELOW\n", 287 | "##################################################################\n", 288 | "k2 = # blue M&Ms in bag\n", 289 | "##################################################################\n", 290 | "\n", 291 | "\n", 292 | "qtrial = np.arange(0, 1, 0.001)\n", 293 | "prob2 = np.zeros_like(qtrial)\n", 294 | "\n", 295 | "for i,q in enumerate(qtrial):\n", 296 | " prob2[i] = binomial_distribution(n, k2, q) \n", 297 | "\n", 298 | "ax.plot(qtrial, prob2, lw=2, color=\"orange\")\n", 299 | "\n", 300 | "\n", 301 | "plt.tight_layout()\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "\n", 309 | "\n", 310 | "### Going from Likelihood to Posterior\n", 311 | "\n", 312 | "It's not that the likelihood isn't useful: it often gives you a pretty good guess which parameter might do a good job of producing the data you've observed. But the crux is that it doesn't tell you what you want to know, because it is *not* a probability of the parameter $q$, but of the outcomes $k$. \n", 313 | "\n", 314 | "So can we calculate the actual probability we're interested in, $p(q | N, k)$? \n", 315 | "\n", 316 | "Well, this is where Bayes' theorem comes in handy. Bayes' theorem can be derived directly from some fundamental rules of probability, most importantly the *joint* probability distribution of two variables:\n", 317 | "\n", 318 | "$$\n", 319 | "P(A, B) = P(A|B)P(B) = P(B|A)P(A)\n", 320 | "$$\n", 321 | "\n", 322 | "for some generic random variables $A$ and $B$ (e.g. whether it's raining outside, and whether it is cloudy or sunny). What does this term say? \n", 323 | "\n", 324 | "Let's make a little table for the four possible outcomes (cloudy/sunny, rain/no rain):\n", 325 | "\n", 326 | "| categories | rain (r) | no rain (nr) |\n", 327 | "|------------|----------|--------------|\n", 328 | "| cloudy (c) | 0.1 | 0.4 |\n", 329 | "| sunny (s) | 0.0 | 0.5 |\n", 330 | "\n", 331 | "This table expresses the *joint* probabilities of all possible outcomes. For example, the joint probability of it currently being cloudy without rain is $p(\\mathrm{pr}, \\mathrm{f}) = 0.4$. The probability that the it's both sunny and raining is zero (where would the rain come from without clouds?).\n", 332 | "\n", 333 | "What does this have to do with our case above? Well, what's the joint probability of it being both cloudy and not raining? $p(\\mathrm{c}, \\mathrm{nr}) = 0.4$ given our table above. \n", 334 | "\n", 335 | "Let's ask a harder question: what's the probability of a it being cloudy, $p(\\mathrm{c})$? To answer this question, it doesn't matter whether it's raining or not raining, so we just have to sum up both columns, $p(\\mathrm{c}) = 0.1 + 0.4 = 0.5$. In reality, our variables are often continuous, so this often requires an integral instead of a simple sum. \n", 336 | "\n", 337 | "Let's ask something a little bit more complex: what's the probability that it's raining, *given* that it is cloudy, $p(\\mathrm{r} | \\mathrm{c}) = 0.1$? Note that this is *not* the same as the *joint* probability. In the latter case, I don't know it's cloudy, and I'm trying to calculate the probability that it is both cloudy and raining. In the case we're currently looking at, I already *know* that it's cloudy (maybe I've looked out of the window), and I'm curious whether I might be able to get to work without getting wet. So I already have one piece of information (it's cloudy). Because I already know this, the whole row labelled \"sunny\" no longer matters, and I only have two cases left (rain and not-rain). However, the sum of those two options is $0.5$, and I said earlier that probabilities must sum up to 1! So we'll need to re-normalize the probability to sum up to 1: \n", 338 | "\n", 339 | "$$\n", 340 | "p(\\mathrm{r} | \\mathrm{c}) = \\frac{p(\\mathrm{r}, \\mathrm{c})}{p(\\mathrm{c})} \n", 341 | "$$\n", 342 | "\n", 343 | "So the probability that it is raining given that it is cloudy is $0.1/(0.1 + 0.5) = 0.2$. \n", 344 | "\n", 345 | "If you move $p(\\mathrm{c})$ on the other side, you get an expression for the joint probability:\n", 346 | "\n", 347 | "$$\n", 348 | "p(\\mathrm{r} , \\mathrm{c}) = p(\\mathrm{r} | \\mathrm{c})p(\\mathrm{c})\n", 349 | "$$\n", 350 | "\n", 351 | "Note that you can turn that expression around: the joint probability for it being both cloudy and raining is:\n", 352 | "\n", 353 | "$$\n", 354 | "p(\\mathrm{r} , \\mathrm{c}) = p(\\mathrm{c} | \\mathrm{r})p(\\mathrm{r})\n", 355 | "$$\n", 356 | "\n", 357 | "You can put these two together, and you've got Bayes rule as stated above (I'm going to go back to the generic variables $A$ and $B$ for this):\n", 358 | "\n", 359 | "$$\n", 360 | "P(A | B) = \\frac{p(B|A)P(A)}{P(B)} \\, .\n", 361 | "$$\n", 362 | "\n", 363 | "And this is Bayes' rule! This particular theorem has many more implications than simply tallying up probabilities as we've done in the example above. In particular, there are fundamental philosophical differences between Bayesian statistics and its alternative--often also called frequentist statistics--in how one sees probabilities. In Bayesian statistics, almost anything can be a random variable, and Bayesians see probabilities as encoding our uncertainty or lack of knowledge about the world. Frequentists tend to have a more literal view of the world, and interpret probabilities as frequencies of truly random events, e.g. rolls of dice. \n", 364 | "\n", 365 | "## The Posterior Probability Distribution\n", 366 | "\n", 367 | "What does all of this have to do with our M&Ms? Well, above, we have basically written down the blue-print for how to get from $p(k | N, q)$ to $p(q | N, k)$. We can stick these particular variables in our equation above:\n", 368 | "\n", 369 | "$$\n", 370 | "p(q | N, k) = \\frac{p(k | N, q) p(q)}{p(k)} \\, .\n", 371 | "$$\n", 372 | "\n", 373 | "In theory, this tells us exactly how to calculate the probability of the *parameter* $q$ that we're looking for, given that we've drawn $k$ blue M&Ms out of a bag with $N$ M&Ms total. $p(q | N, k)$ is generally called the *posterior probability distribution*. We're not done, though. In particular, we've written down the equation, and we know how to calculate $p(k | N, q)$, but what are $p(q)$ and $p(x)$? \n", 374 | "\n", 375 | "I've made a big production above that $p(k | N, q)$ isn't normalized, and that this is important. The product $p(k | N, q) p(q)$ is still not normalized, but $p(q | N, k)$, so $p(k)$ is effectively a normalizing constant:\n", 376 | "\n", 377 | "$$\n", 378 | "p(k) = \\int{p(k | N, q) p(q) dq} \n", 379 | "$$\n", 380 | "\n", 381 | "such that the whole probability distribution integrates to 1. In practice, this is the probability of observing the data $k$ times $p(q)$, integrated over all possible values of $q$. This is also called the *marginal likelihood* or *evidence*. While this no longer depends on $q$, this doesn't mean it has no assumptions. For example, we've assumed above that our data can be modelled by a binomial distribution. This may not be true, and we should probably have included another variable $B$ on the given side of our probabilities to indicate there's an underlying assumption there, e.g. $p(k | B)$. Notice that this looks a lot like a likelihood? Well it is, but it is now the likelihood of observing the data given the generalized assumption that the data were drawn from any binomial distribution. If we had another model for the data, say a Poisson distribution, we could also calculate $p(k | P)$ (where $P$ stands for the Poisson model) and compare the two. This is why the marginal likelihood is often used for *model comparison*.\n", 382 | "\n", 383 | "In this tutorial, we're not going to worry about comparing different kinds of (statistical) models, but instead worry about estimating the parameter $q$. For this the normalizing constant $p(k)$ is exactly that, a constant, the same for all possible values of $q$. if we're only interested in the relative probabilities of a specific $q_0$ to a different $q_1$, we can ignore that constant and write\n", 384 | "\n", 385 | "$$\n", 386 | "p(q | N, k) \\propto p(k | N, q) p(q) \\, .\n", 387 | "$$\n", 388 | "\n", 389 | "which is going to make our lives a whole lot easier, because $p(k)$ is often very hard to compute in practice\n", 390 | "\n", 391 | "## Priors\n", 392 | "\n", 393 | "The one thing we *haven't* talked about yet is $p(q)$. You'll notice that this is a probability distribution of $q$ only, without the data playing a role. This is often called the **prior probability distribution**, and it encodes whatever prior knowledge you might have about this parameter before you've looked at the data. For example, you might know that there are six colours in a package of M&M, so you might conclude that it's extremely unlikely that $q=1$, i.e. that there are only blue M&Ms in your bag. \n", 394 | "\n", 395 | "**Exercise 4**: Think about what you know about M&Ms. Do you have any prior knowledge about how many M&Ms there might be in a bag? Think about the number you picked for $q$ when you calculated the binomial probability earlier. Why did you pick that value? Can you draw a distribution of your prior expectations for all values of $q$ between 0 and 1? \n", 396 | "\n", 397 | "**Careful**: Of course, you've already seen some of the data! This is generally not how you go about things, so you're going to have to pretend you haven't. \n", 398 | "\n", 399 | "Note that I've said earlier that $p(q)$ is a probability *distribution*, so it has to be more than one value. It has to encode your knowledge about $q$ for all possible values of $q$, which can in principle be anywhere in the range between 0 and 1. One simple choice is to make all possible values of $q$ equally likely, but we've already said earlier that this is probably not a good assumption, because we don't think our bag will be all blue M&Ms. In general, this kind of prior is called a *uniform distribution*, and while it may seem like the choice that is least affected by what you know, this is in practice *not* always true! We won't go into the details here of why or when this is the case, but be mindful that this is something you might have to think about in practice. \n", 400 | "There's another thing to be aware of with the uniform distribution: it makes a very, very strong assumption about what values $q$ is allowed to take. If we set the prior to be uniform between $0$ and $1$, this would be an okay choice, because these are all the values $q$ can take in practice. However, imagine you picked a prior for $q$ between 0.5 and 1. You have just assumed that $q$ can *never, ever* be lower than 0.5, *no matter* what your data tell you! This is a really strong assumption to make, and you'd better be really sure that it's a reasonable one! \n", 401 | "In practice, it's often better to choose distributions that fall off sharply, but retain some (small, but not impossible) prior probability in all theoretically allowed values, unless you're absolutely confident that these values cannot be true.\n", 402 | "\n", 403 | "### Conjugate Priors\n", 404 | "\n", 405 | "So in principle, you could choose any distribution for $q$. Maybe you eat a package of M&Ms every day, so you have a pretty good feeling for $q$. You could choose a normal distribution around your estimated value of $q = \\mu$, assign some narrow width $\\sigma$ to the distribution, and you'd have a perfectly good prior: $p(k | \\mu, \\sigma) \\sim \\mathcal{N}(\\mu, \\sigma)$. Note that in this case $\\mu$ and $\\sigma$ define the shape of the prior distribution, and are called **hyperparameters**. They're given (i.e we've set them in advance), so they're on the right-hand side of the \"|\".\n", 406 | "\n", 407 | "One issue with this is that you don't just want to calculate $p(k | \\mu, \\sigma)$, but $\\propto p(k | N, q) p(q | \\mu, \\sigma)$, and there's no guarantee that the latter will be an analytically solveable equation for any choice of $p(k)$. However, for most likelihood functions, there do exist functions that you can use as priors that will lead to analytical expressions for the posterior. These are called **conjugate priors** and are a good choice when you don't have much prior information about your parameter $q$ and/or the conjugate prior matches the shape of what you *do* know about $q$. \n", 408 | "\n", 409 | "The conjugate prior for the binomial distribution is the [beta distribution](https://en.wikipedia.org/wiki/Beta_distribution). This distribution has two parameters, $\\alpha$ and $\\beta$ and is defined as \n", 410 | "\n", 411 | "$$\n", 412 | "p(x | \\alpha ,\\beta ) =\\mathrm{constant} \\cdot x^{\\alpha -1}(1-x)^{\\beta -1} \\; .\n", 413 | "$$\n", 414 | "\n", 415 | "It can take many different shapes.\n", 416 | "\n", 417 | "**Exercise 5a**: Calculate the prior probability density for three different values of $\\alpha$ and $\\beta$, and plot the result. How does the shape of the distribution change for different values of the two parameters? Which combination of parameters makes a good prior for $q$ in your opinion?\n", 418 | "\n", 419 | "**Hint**: You don't have to write your own version of the beta-distribution. The `scipy.stats` package contains a large list of distributions ready-made for you, including the [beta distribution](https://docs.scipy.org/doc/scipy-0.19.1/reference/generated/scipy.stats.beta.html).\n" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 429 | "qtrial = np.linspace(0, 1, 500) # trial values of q\n", 430 | "\n", 431 | "##################################################################\n", 432 | "## ADD YOUR VALUES BELOW \n", 433 | "##################################################################\n", 434 | "alpha = # set the values for alpha\n", 435 | "beta = # set the values for beta\n", 436 | "##################################################################\n", 437 | "\n", 438 | "for a, b in zip(alpha, beta):\n", 439 | " # set up the probability distribution\n", 440 | " beta_dist = scipy.stats.beta(a, b)\n", 441 | "\n", 442 | " # calculate the probability density for qtrial\n", 443 | " beta_pdf = beta_dist.pdf(qtrial)\n", 444 | " \n", 445 | " # plot the results\n", 446 | " ax.plot(qtrial, beta_pdf, lw=2, label=r\"$\\alpha=%.1f$, $\\beta=%.1f$\"%(a,b))\n", 447 | " ax.set_label(r\"Parameter $q$\")\n", 448 | " ax.set_ylabel(r\"Prior distribution $p(q)$\")\n", 449 | " ax.set_title(\"beta-distribution prior\")\n", 450 | " \n", 451 | "plt.legend()" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "As with the previous problem, I've written you a little interactive widget to explore how this distribution changes as a function of its parameters. \n", 459 | "\n", 460 | "**Exercise 5b**: How does the shape of the distribution change for different values of the two parameters? Which combination of parameters makes a good prior for $q$ in your opinion?" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "alpha_slider = widgets.FloatSlider(\n", 470 | " value=2,\n", 471 | " min=0,\n", 472 | " max=100,\n", 473 | " step=0.01,\n", 474 | " description='Alpha',\n", 475 | " readout_format='.2f',\n", 476 | ")\n", 477 | "\n", 478 | "beta_slider = widgets.FloatSlider(\n", 479 | " value=2,\n", 480 | " min=0,\n", 481 | " max=100,\n", 482 | " step=0.01,\n", 483 | " description='Beta',\n", 484 | " readout_format='.2f',\n", 485 | ")\n", 486 | "\n", 487 | "\n", 488 | "@widgets.interact(\n", 489 | " alpha=(0,10, 0.01), beta=(0, 10, 0.01))\n", 490 | "def plot(alpha=2, beta=2):\n", 491 | "\n", 492 | " alpha = np.float(alpha)\n", 493 | " beta = np.float(beta)\n", 494 | "\n", 495 | " qtrial = np.linspace(0, 1, 500) # trial values of q\n", 496 | "\n", 497 | " # set up the probability distribution\n", 498 | " beta_dist = scipy.stats.beta(alpha, beta)\n", 499 | "\n", 500 | " # calculate the probability density for qtrial\n", 501 | " beta_pdf = beta_dist.pdf(qtrial)\n", 502 | " \n", 503 | " fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 504 | " # plot the results\n", 505 | " ax.plot(qtrial, beta_pdf, lw=2, label=r\"$\\alpha=%.1f$, $\\beta=%.1f$\"%(alpha,beta))\n", 506 | " ax.set_label(r\"Parameter $q$\")\n", 507 | " ax.set_ylabel(r\"Prior distribution $p(q)$\")\n", 508 | " ax.set_title(\"beta-distribution prior\")\n", 509 | "\n" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "\n", 517 | "Chosen values for the hyper-parameters\n", 518 | "* $\\alpha = $\n", 519 | "* $\\beta = $\n", 520 | "\n", 521 | "**Exercise 6**: Share your results with the class and discuss. Did you all pick similar values? Are your choices significantly different to those of others? How do your assumptions differ?\n", 522 | "\n", 523 | "It is important to notice that there is no one correct choice for a prior, because by its very definition, it depends on your *prior knowledge* about the problem you're trying to solve! Someone who has eaten M&Ms regularly since childhood might have a different knowledge about the fraction of blue M&Ms in a bag than someone who has never had any before today! This may at first seem like a disadvantage, because making different assumptions about $q$ seems like it's not very objective, and science is supposed to be objective, right?\n", 524 | "\n", 525 | "Well, it's not that easy, because the idea that science is absolutely objective is itself a fallacy. Whenever we write down a model for observations, we *always* make assumptions (as for example, we pointed out explicitly above with the binomial model), and those assumptions can differ from researcher to researcher and change over time. \n", 526 | "A lack of explicit prior probability distribution does *not* equal a lack of assumptions. The assumptions might not be explicit, but they exist. An advantage of Bayesian statistics is that it requires you to state your assumptions explicitly, which means the can be examined and discussed like anything else we do. \n", 527 | "\n", 528 | "### Calculating the Posterior\n", 529 | "\n", 530 | "Okay, now we've got all of our components in place, which means we can calculate our posterior probability density. And there are some good news: because we've chosen a conjugate prior for our likelihood, the posterior is analytical. In fact, the posterior to a binomial likelihood and a beta-prior is also a beta-distribution,\n", 531 | "\n", 532 | "$$\n", 533 | "p(q | k, N) = \\mathrm{Beta}(\\alpha+k,\\beta+N-k)\n", 534 | "$$\n", 535 | "\n", 536 | "**Exercise 7**: Now it's finally time to open your bags of M&Ms and count your M&Ms to find the total number of M&Ms $N$ and the number of blue M&Ms $k$ for your data set! Use these numbers and your results of the previous exercise to calculate both the prior for your chosen values of $\\alpha$ and $\\beta$ and the posterior and plot them in the same figure. How has the posterior changed from your prior? " 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": null, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "def mm_posterior(n, k, alpha=2.0, beta=2.0, nqtrial=500):\n", 546 | " \"\"\"\n", 547 | " A function to calculate the posterior for M&Ms\n", 548 | " \n", 549 | " Parameters\n", 550 | " ----------\n", 551 | " n : int\n", 552 | " The total number of M&Ms in your sample\n", 553 | " \n", 554 | " k : int\n", 555 | " The number of blue M&Ms in the sample\n", 556 | " \n", 557 | " alpha, beta : float\n", 558 | " Hyperparameters for the beta-distribution\n", 559 | " prior on q\n", 560 | " \n", 561 | " nqtrial : int, default 500\n", 562 | " The number of steps in $q$ for which to \n", 563 | " compute the posterior distribution\n", 564 | " \n", 565 | " Returns\n", 566 | " -------\n", 567 | " \n", 568 | " mm_post_pdf = numpy.ndarray\n", 569 | " An array with posterior probabilities for all values \n", 570 | " in `qtrial`, given values for `n`, `k`, `alpha` and `beta`\n", 571 | " \n", 572 | " \"\"\"\n", 573 | " \n", 574 | " qtrial = np.linspace(0, 1, nqtrial) # trial values of q\n", 575 | "\n", 576 | "\n", 577 | " # set up the probability distribution\n", 578 | " beta_prior = scipy.stats.beta(alpha, beta)\n", 579 | "\n", 580 | " # calculate the probability density for qtrial\n", 581 | " beta_prior_pdf = beta_prior.pdf(qtrial)\n", 582 | "\n", 583 | " ### Now let's calculate the posterior\n", 584 | "\n", 585 | " a_post = alpha + k # alpha + k\n", 586 | " b_post = beta + n - k # beta + N - k\n", 587 | "\n", 588 | " print(\"The alpha parameter of the posterior is: \" + str(a_post))\n", 589 | " print(\"The beta parameter of the posterior is: \" + str(b_post))\n", 590 | "\n", 591 | " # set up the probability distribution \n", 592 | " beta_posterior = scipy.stats.beta(a_post, b_post)\n", 593 | "\n", 594 | " # calculate PDF\n", 595 | " mm_post_pdf = beta_posterior.pdf(qtrial)\n", 596 | " \n", 597 | " return mm_post_pdf" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "##################################################################\n", 607 | "## RECORD YOUR M&M VALUES BELOW\n", 608 | "##################################################################\n", 609 | "n = # total number of M&Ms in your bag\n", 610 | "k = # number of blue M&Ms in your bag\n", 611 | "\n", 612 | "alpha = # final value for alpha\n", 613 | "beta = # final value for beta\n", 614 | "\n", 615 | "##################################################################\n", 616 | "\n", 617 | "# set up beta distribution for the prior\n", 618 | "beta_dist = scipy.stats.beta(alpha, beta)\n", 619 | "\n", 620 | "# calculate the probability density for qtrial\n", 621 | "beta_prior_pdf = beta_dist.pdf(qtrial)\n", 622 | " \n", 623 | "# calculate posterior\n", 624 | "beta_post_pdf = mm_posterior(n, k, alpha, beta)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "# plot the results\n", 634 | "fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 635 | "\n", 636 | "ax.plot(qtrial, beta_prior_pdf, lw=2, color=\"black\", label=\"prior\")\n", 637 | "ax.plot(qtrial, beta_post_pdf, lw=2, color=\"black\", \n", 638 | " linestyle=\"dashed\", label=\"posterior\")\n", 639 | "\n", 640 | "ax.set_label(r\"Parameter $q$\")\n", 641 | "ax.set_ylabel(r\"Prior distribution $p(q)$\")\n", 642 | "ax.set_title(\"beta-distribution prior\")\n", 643 | "\n", 644 | "ax.legend()\n", 645 | "plt.tight_layout()" 646 | ] 647 | }, 648 | { 649 | "cell_type": "markdown", 650 | "metadata": {}, 651 | "source": [ 652 | "**Exercise 8**: Imagine that you'd chosen values for $\\alpha$ and $\\beta$ that are very unlikely to be true (e.g. a distribution that rises towards $q=1$. Repeat the comparison between prior and posterior above with these unlikely values. Does the different prior affect the results? How? \n", 653 | "\n", 654 | "**Important Note**: The above exercise, i.e. to change the prior and go back to re-calculate the posterior, is an academic exercise only! In practice, you **cannot** go back and change your prior once you've looked at your data and calculated your posterior! The prior *only* encodes knowledge about $q$ *before* you looked at the data. If you look at the data, then change your prior and calculate the posterior again, you've effectively used the data twice! In practice, this will lead you to be unreasonably overconfident in your results. Once you've looked at your data, your only real solution is to gather more data and use the posterior from your current analysis as a prior for the future (more M&Ms! Oh No! :) ). \n" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "##################################################################\n", 664 | "## ADD YOUR CODE BELOW\n", 665 | "##################################################################\n", 666 | "\n", 667 | "\n", 668 | "\n", 669 | "\n", 670 | "\n" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": null, 676 | "metadata": {}, 677 | "outputs": [], 678 | "source": [] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": null, 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "# plot the results\n", 687 | "fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 688 | "\n", 689 | "ax.plot(qtrial, beta_prior_pdf, lw=2, color=\"black\", label=\"prior\")\n", 690 | "ax.plot(qtrial, beta_post_pdf, lw=2, color=\"black\", \n", 691 | " linestyle=\"dashed\", label=\"posterior\")\n", 692 | "\n", 693 | "ax.set_label(r\"Parameter $q$\")\n", 694 | "ax.set_ylabel(r\"Prior distribution $p(q)$\")\n", 695 | "ax.set_title(\"beta-distribution prior\")\n", 696 | "\n", 697 | "ax.legend()\n", 698 | "plt.tight_layout()" 699 | ] 700 | }, 701 | { 702 | "cell_type": "markdown", 703 | "metadata": {}, 704 | "source": [ 705 | "### Adding More Information\n", 706 | "\n", 707 | "These are the results for one package of M&Ms. Can we actually make this better? Yes, because you have classmates all around you who also have counted blue M&Ms! \n", 708 | "\n", 709 | "**Exercise 9**: Tally up the total number of blue M&Ms counted by everyone in the class, and the total number of M&Ms from everyone. Then use the new numbers for $k$ and $N$ to calculate and plot the posterior as well as the prior.\n", 710 | "\n" 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": null, 716 | "metadata": {}, 717 | "outputs": [], 718 | "source": [ 719 | "##################################################################\n", 720 | "## ADD YOUR CODE BELOW\n", 721 | "##################################################################\n", 722 | "\n", 723 | "\n", 724 | "\n", 725 | "\n", 726 | "\n", 727 | "\n" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "# plot the results\n", 737 | "fig, ax = plt.subplots(1, 1, figsize=(6,4))\n", 738 | "\n", 739 | "ax.plot(qtrial, beta_prior_pdf, lw=2, color=\"black\", label=\"prior\")\n", 740 | "ax.plot(qtrial, beta_post_pdf, lw=2, color=\"black\", \n", 741 | " linestyle=\"dashed\", label=\"posterior\")\n", 742 | "\n", 743 | "ax.set_label(r\"Parameter $q$\")\n", 744 | "ax.set_ylabel(r\"Prior distribution $p(q)$\")\n", 745 | "ax.set_title(\"beta-distribution prior\")\n", 746 | "\n", 747 | "ax.legend()\n", 748 | "plt.tight_layout()" 749 | ] 750 | }, 751 | { 752 | "cell_type": "code", 753 | "execution_count": null, 754 | "metadata": {}, 755 | "outputs": [], 756 | "source": [] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "metadata": {}, 769 | "outputs": [], 770 | "source": [] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": {}, 775 | "source": [ 776 | "## Adding the Full Data Set\n", 777 | "\n", 778 | "If you don't have access to the full data set of the entire group, you can use the following code to load the data:" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": null, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "import pandas as pd" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": null, 793 | "metadata": {}, 794 | "outputs": [], 795 | "source": [ 796 | "df = pd.read_csv(\"../data/mm_counts.csv\", sep=\",\")\n", 797 | "df.head()" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": null, 803 | "metadata": {}, 804 | "outputs": [], 805 | "source": [ 806 | "# only peanut M&Ms\n", 807 | "df = df[df.iloc[:,1] == 'Peanut (yellow packaging)']" 808 | ] 809 | }, 810 | { 811 | "cell_type": "markdown", 812 | "metadata": {}, 813 | "source": [ 814 | "Let's calculate the sums for all samples of each colour, and also the sum of all colours for each row (i.e. the total number of M&Ms per row):" 815 | ] 816 | }, 817 | { 818 | "cell_type": "code", 819 | "execution_count": null, 820 | "metadata": {}, 821 | "outputs": [], 822 | "source": [ 823 | "sum_rows = df[df.columns[2:8]].sum(axis=0)\n", 824 | "sum_rows" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": null, 830 | "metadata": {}, 831 | "outputs": [], 832 | "source": [ 833 | "sum_cols = df[df.columns[2:8]].sum(axis=1)\n", 834 | "sum_cols" 835 | ] 836 | }, 837 | { 838 | "cell_type": "markdown", 839 | "metadata": {}, 840 | "source": [ 841 | "The total number of M&Ms among all of them is the sum of all M&Ms in each row:" 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "execution_count": null, 847 | "metadata": {}, 848 | "outputs": [], 849 | "source": [ 850 | "total_n = sum_cols.sum()\n", 851 | "print(\"The entire data set contains %i M&Ms.\"%total_n)" 852 | ] 853 | }, 854 | { 855 | "cell_type": "markdown", 856 | "metadata": {}, 857 | "source": [ 858 | "We can also extract the total number of blue M&Ms:" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": {}, 865 | "outputs": [], 866 | "source": [ 867 | "blue_question = \"How many *blue* M&Ms were in your bag?\"\n", 868 | "\n", 869 | "total_k = sum_rows[blue_question]\n", 870 | "print(\"There are %i blue M&Ms in our data.\"%total_k)" 871 | ] 872 | }, 873 | { 874 | "cell_type": "markdown", 875 | "metadata": {}, 876 | "source": [ 877 | "**Exercise 10**: Now calculate the posterior for these values of $n$ and $k$, and compare it to the posterior you got on your own smaller dataset. What changed?" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": null, 883 | "metadata": {}, 884 | "outputs": [], 885 | "source": [] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": null, 890 | "metadata": {}, 891 | "outputs": [], 892 | "source": [] 893 | }, 894 | { 895 | "cell_type": "markdown", 896 | "metadata": {}, 897 | "source": [ 898 | "## Further Reading: Markov Chain Monte Carlo\n", 899 | "\n", 900 | "### Or: What to do when your posterior is not analytical.\n", 901 | "\n", 902 | "In practice, you will often end up in situations where conjugate priors are not a good choice, and your posterior will not be analytical. What you do in this case depends on what you want to know. For example, you might only be interested in the *most probable* value of $q$. In this case, optimization algorithms are often a good choice. This is sometimes also your only option for example if the likelihood function is very expensive to calculate.\n", 903 | "\n", 904 | "However, often the posterior probability can be complex, and trying to find the most probable value isn't good enough. Imagine you had a probability distribution with two roughly equally tall peaks, at 0.2 and 0.8, with a valley of low probability in between. An optimization algorithm will always end up in one of the two peaks, and will give you a single value, but you might never find out about the other peak!\n", 905 | "\n", 906 | "So what can we do? If you need to map out the probability distribution as a whole, there are several approaches. The simplest and straightforward is to make a grid in $q$, e.g. of a 100 points, and calculate $p(q | N, k)$ for each of those points, and then plot the result. Easy, right? This works well for problems with very low dimensions, like ours, where we only have a single parameter. What if you don't have a single parameter, but 50? You now need to sample 100 points in each of those 50 dimensions, meaning you need $100^{50}$ points. If your posterior takes a microsecond to calculate, you'll still need longer than the age of the universe to calculate all of those points! This is clearly impossible.\n", 907 | "\n", 908 | "So can we do something smarter than a grid? Yes! In fact, we can find clever ways to jump through parameter space in such a way that we'll evaluate our posterior often in regions where the posterior is large, and less often in regions where the posterior is low. There are a whole range of different algorithms that can do this, but **Markov Chain Monte Carlo (MCMC)** is the most common and most popular one. \n", 909 | "\n", 910 | "A tutorial on MCMC goes beyond the scope of this tutorial. For a good starting point, have a look at [Hogg & Foreman-Mackey (2017)](https://arxiv.org/abs/1710.06068).\n", 911 | "\n" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": null, 917 | "metadata": {}, 918 | "outputs": [], 919 | "source": [] 920 | }, 921 | { 922 | "cell_type": "code", 923 | "execution_count": null, 924 | "metadata": {}, 925 | "outputs": [], 926 | "source": [] 927 | }, 928 | { 929 | "cell_type": "markdown", 930 | "metadata": {}, 931 | "source": [ 932 | "## Advanced: A model for all M&M colours\n", 933 | "\n", 934 | "Let's make our problem above more complicated. What if we're not just interested in the blue M&Ms, but want to know the distribution of all six colour? Well, where the binomial distribution only considered *success* and *failure*, there is a generalization to this distribution that considers *multiple categorical outcomes* (in our case six colours). In this case, we don't have a single $k$ given $N$ trials, but multiple $\\mathbf{k} = \\{k_1, k_2, ..., k_l\\}$ for $l$ possible outcomes. In our case, $l=6$, and each $k_i$ stands for a single colour (e.g. $k_0 = \\mathrm{blue}$,$k_1 = \\mathrm{green}$, ...). Similarly, we now have a vector $\\mathbf{q} = \\{q_1, q_2, ..., q_l\\}$ for the underlying true fraction of each colour. \n", 935 | "\n", 936 | "This generalization is the [multinomial distribution](https://en.wikipedia.org/wiki/Multinomial_distribution), defined as:\n", 937 | "\n", 938 | "$$\n", 939 | " p(\\mathbf{k} | \\mathbf{q}, N)= \n", 940 | "\\begin{cases}\n", 941 | " \\frac{N!}{k_1! k_2! ... k_l!}q_1^{k_1}q_2^{k_2} ... q_l^{k_l},& \\text{when } \\sum_{i=1}^{l}k_i=N \\\\\n", 942 | " 0, & \\text{otherwise}\n", 943 | "\\end{cases}\n", 944 | "$$\n", 945 | "\n", 946 | "Our measurements are now the number of M&Ms for each colour. Our parameters are the underlying fractions $q_i$ for each colour. We now have a six-dimensional measurement, and six parameters for our new model.\n", 947 | "\n", 948 | "**Exercise**: Define a six-element vector with your prior expectations for what you think the different $q_i$ should be. Do you think all colours are represented equally?\n" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": null, 954 | "metadata": {}, 955 | "outputs": [], 956 | "source": [ 957 | "q_blue = # fraction of blue M&Ms\n", 958 | "q_green = # fraction of green M&Ms\n", 959 | "q_red = # fraction of red M&Ms\n", 960 | "q_yellow = # fraction of yellow M&Ms\n", 961 | "q_orange = # fraction of orange M&Ms\n", 962 | "q_brown = # fraction of brown M&Ms\n", 963 | "\n", 964 | "q_all = np.array([q_blue, q_green, q_red, \n", 965 | " q_yellow, q_orange, q_brown])" 966 | ] 967 | }, 968 | { 969 | "cell_type": "markdown", 970 | "metadata": {}, 971 | "source": [ 972 | "Now tally up all the colours in your package of M&Ms and write down the result:" 973 | ] 974 | }, 975 | { 976 | "cell_type": "code", 977 | "execution_count": null, 978 | "metadata": {}, 979 | "outputs": [], 980 | "source": [ 981 | "k_blue = # blue M&Ms\n", 982 | "k_green = # green M&Ms\n", 983 | "k_red = # red M&Ms\n", 984 | "k_yellow = # yellow M&Ms\n", 985 | "k_orange = # orange M&Ms\n", 986 | "k_brown = # brown M&Ms\n", 987 | "\n", 988 | "# all measurements together\n", 989 | "k_all = np.array([k_blue, k_green, k_red, \n", 990 | " k_yellow, k_orange, k_brown])\n", 991 | "\n", 992 | "n_total = # total number of M&Ms in package" 993 | ] 994 | }, 995 | { 996 | "cell_type": "markdown", 997 | "metadata": {}, 998 | "source": [ 999 | "Let's calculate the multinomial probability for these measurements:" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": null, 1005 | "metadata": {}, 1006 | "outputs": [], 1007 | "source": [ 1008 | "# define the distribution\n", 1009 | "mult = scipy.stats.multinomial(n=n_total, p=q_all)\n", 1010 | "\n", 1011 | "# calculate the probability for our measurements:\n", 1012 | "print(\"multinomial probability: \" + str(mult.pdf(k_all)))" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "markdown", 1017 | "metadata": {}, 1018 | "source": [ 1019 | "Ideally, we'd like to calculate a posterior probability for this, too, so we'll need to define a prior for $\\mathbf{q}$. The conjugate prior for the multinomial distribution is a [*Dirichlet distribution*](https://en.wikipedia.org/wiki/Dirichlet_distribution), the multivariate generalization of the beta-ditribution. The Dirichlet distribution appears fairly often in problems with categorical variables and is very useful to know. A nice conceptual introduction can be found [here](http://blog.bogatron.net/blog/2014/02/02/visualizing-dirichlet-distributions/).\n", 1020 | "\n", 1021 | "For our 6 different categories (colours), the Dirichlet distribution has six parameters called *concentration parameters*, $\\mathbf{\\alpha} = \\{\\alpha_1, \\alpha_2, ..., \\alpha_l\\} \\, , \\, \\alpha_i > 0$. \n", 1022 | "Note that it is only defined on the interval $(0,1)$, and also only in the region where $\\sum_{i=1}^{l}q_i = 1$ (remember: our $q_i$ are relative fractions of colour $i$, and if we take all fractions for all colours, they must make up all of our M&Ms).\n", 1023 | "\n", 1024 | "Of course, `scipy.stats` also has an implementation of the [Dirichlet distribution](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html#scipy.stats.dirichlet).\n", 1025 | "\n", 1026 | "In practice, the PDF of the Dirichlet distribution is a bit tricky to plot, because of the way that your trial values of $q_i$ need to sum up to $1$. You can look at [this illustration](https://en.wikipedia.org/wiki/Dirichlet_distribution#/media/File:Dirichlet-3d-panel.png) to see how the PDF changes for different values of $\\alpha$. \n", 1027 | "\n", 1028 | "**Exercise**: Let's plot the PDF of a Dirichlet distribution with two categories, i.e. two concentration parameters $\\alpha_i$, and plot the results. Repeat for different values and combinations of $\\alpha_1$ and $\\alpha_2$. How does the distribution change? What do you think are reasonable values for the different values of $\\alpha_i$?" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": null, 1034 | "metadata": {}, 1035 | "outputs": [], 1036 | "source": [ 1037 | "alpha1 = # add your guess for alpha1\n", 1038 | "alpha2 = # add your guess for alpha2\n", 1039 | "alpha = [alpha1, alpha2] # add\n", 1040 | "\n", 1041 | "# define the dirichlet distribution\n", 1042 | "dirichlet = scipy.stats.dirichlet(alpha=alpha)\n" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": null, 1048 | "metadata": {}, 1049 | "outputs": [], 1050 | "source": [ 1051 | "x1 = np.linspace(0, 1, 1000)\n", 1052 | "x2 = 1.0 - x1" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": null, 1058 | "metadata": {}, 1059 | "outputs": [], 1060 | "source": [ 1061 | "pdf = dirichlet.pdf([x1, x2])" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "code", 1066 | "execution_count": null, 1067 | "metadata": {}, 1068 | "outputs": [], 1069 | "source": [ 1070 | "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8,4))\n", 1071 | "\n", 1072 | "ax1.plot(x1, pdf, lw=2, color=\"black\")\n", 1073 | "ax1.set_xlim(0,1)\n", 1074 | "ax1.set_xlabel(r\"$q_1$\")\n", 1075 | "\n", 1076 | "ax2.plot(x2, pdf, lw=2, color=\"black\")\n", 1077 | "ax2.set_xlim(0,1)\n", 1078 | "ax2.set_xlabel(r\"$q_2$\")" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "markdown", 1083 | "metadata": {}, 1084 | "source": [ 1085 | "Now we can set up our posterior inference! \n", 1086 | "\n", 1087 | "First, set your concentration parameters for each of the values of $\\alpha_i$:" 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "code", 1092 | "execution_count": null, 1093 | "metadata": {}, 1094 | "outputs": [], 1095 | "source": [ 1096 | "alpha_blue = # blue M&Ms concentration parameter\n", 1097 | "alpha_green = # green M&Ms concentration parameter\n", 1098 | "alpha_red = # red M&Ms concentration parameter\n", 1099 | "alpha_yellow = # yellow M&Ms concentration parameter\n", 1100 | "alpha_orange = # orange M&Ms concentration parameter\n", 1101 | "alpha_brown = # brown M&Ms concentration parameter\n", 1102 | "\n", 1103 | "# all parameters together\n", 1104 | "alpha_all = np.array([alpha_blue, alpha_green, alpha_red, \n", 1105 | " alpha_yellow, alpha_orange, alpha_brown])" 1106 | ] 1107 | }, 1108 | { 1109 | "cell_type": "markdown", 1110 | "metadata": {}, 1111 | "source": [ 1112 | "The posterior distribution of a multinomial likelihood with a Dirichlet prior is also a Dirichlet distribution, with concentration parameter $\\mathbf{\\alpha}_{\\mathrm{posterior}} = \\mathbf{\\alpha}_{\\mathrm{prior}} + \\mathbf{k}$:" 1113 | ] 1114 | }, 1115 | { 1116 | "cell_type": "code", 1117 | "execution_count": null, 1118 | "metadata": {}, 1119 | "outputs": [], 1120 | "source": [ 1121 | "alpha_post = alpha_all + k_all\n", 1122 | "\n", 1123 | "dir_post = scipy.stats.dirichlet(alpha=alpha_post)" 1124 | ] 1125 | }, 1126 | { 1127 | "cell_type": "markdown", 1128 | "metadata": {}, 1129 | "source": [ 1130 | "For a 6-dimensional distribution, it's much harder to think about where the Dirichlet distribution is even defined (it should lie on a 5-dimensional volume in six dimensions for which $\\sum_{i=1}^{l}q_i = 1$. Instead of calculating the posterior for a grid of values for $q_i$, we're just going to draw samples directly from the posterior distribution and then plot them:" 1131 | ] 1132 | }, 1133 | { 1134 | "cell_type": "code", 1135 | "execution_count": null, 1136 | "metadata": {}, 1137 | "outputs": [], 1138 | "source": [ 1139 | "# generate random samples from the posterior\n", 1140 | "post_rvs = dir_post.rvs(size=1000000)" 1141 | ] 1142 | }, 1143 | { 1144 | "cell_type": "code", 1145 | "execution_count": null, 1146 | "metadata": {}, 1147 | "outputs": [], 1148 | "source": [ 1149 | "# plot marginal distributions\n", 1150 | "fig, axes = plt.subplots(2, 3, figsize=(8,6), sharex=True, sharey=True)\n", 1151 | "\n", 1152 | "# flatten the array of axis objects\n", 1153 | "axes = np.hstack(axes)\n", 1154 | "\n", 1155 | "# we have six colours, so we're going to loop over each\n", 1156 | "for i in range(6):\n", 1157 | " axes[i].hist(post_rvs[:,i], bins=50, histtype=\"stepfilled\", \n", 1158 | " color=\"black\", alpha=0.4, density=True)\n", 1159 | " axes[i].set_xlabel(r\"$q_%i$\"%i)\n", 1160 | " # set the y-axis labels only on the left-most plots\n", 1161 | " if i == 0 or i == 3:\n", 1162 | " axes[i].set_ylabel(\"posterior pdf\")\n", 1163 | "\n", 1164 | "# automatically improve spacings between subplots\n", 1165 | "plt.tight_layout()" 1166 | ] 1167 | }, 1168 | { 1169 | "cell_type": "markdown", 1170 | "metadata": {}, 1171 | "source": [ 1172 | "## Bayesian Hierarchical Modelling\n", 1173 | "\n", 1174 | "Now that we can model all colours at once, it's time to let you in on a little secret: M&Ms in the US are produced by two different factories, one in Tennesse and one in New Jersey. The interesting part is that they produce different distributions of colours! Why? Nobody is really sure (except probably the Mars Company, which makes M&Ms).\n", 1175 | "\n", 1176 | "So each of you has their own package of M&Ms, and you've all recorded the number of different colours and calculated the posterior distribution for your parameters, but now you'd like to figure out which factory your M&Ms came from. However, while you know that the two different factories make different colour distributions, you don't know the distributions each makes, and you also don't know which factory your particular package came from! This seems like an insurmountable lack of knowledge, but fear not! Through the power of sharing information between you, you'll be able to figure all of that out.\n", 1177 | "\n", 1178 | "In the previous examples, you pooled your information for the entire class in order to improve your posterior. However, we glossed over the fact that your packages did not come from the same factory! How can we take better account of that fact? Through Bayesian hierarchical modelling! \n", 1179 | "\n", 1180 | "In the previous models you've built you had a prior distribution on your parameters, and the hyperparameters of these prior distribution were fixed. They were numbers you chose based on your prior information and intuition of the problem. In a hierarchical model, the parameters describing the prior are *not* fixed, but something that we *infer* along with the parameters of the colour distributions. Instead of describing prior knowledge, they describe the *population* of data sets, in our case the *population* of bags. \n", 1181 | "\n", 1182 | "We go from a model like this\n", 1183 | "\n", 1184 | "$$\n", 1185 | "p(q | \\{\\mathbf{k}_b\\}_{b=1}^{B}) \\propto p(q | \\alpha) \\prod_{b=1}^{B} p(\\mathbf{k}_b | q)\n", 1186 | "$$\n", 1187 | "\n", 1188 | "where $\\alpha$ were fixed hyperparameters, to adding one more layer of parameters:\n", 1189 | "\n", 1190 | "$$\n", 1191 | "p(\\{q\\}_{b=1}^{B}, \\alpha | \\{\\mathbf{k}_b\\}_{b=1}^{B}) \\propto p(\\alpha | \\beta) \\prod_{b=1}^{B}p(\\mathbf{k}_b | q_b) p(q_b | \\alpha)\n", 1192 | "$$\n", 1193 | "\n", 1194 | "where now $\\theta$ isn't shared anymore among the individual data sets (i.e. bags of M&Ms), and we're inferring the population parameters $\\alpha$ along with the $\\theta$ for each bag of M&Ms.\n", 1195 | "\n", 1196 | "In our case, the difference to our previous model is that we now have *two* colour distributions--one for each factory--and that each bag comes from one of those factories based on some unknown mixture distribution.\n", 1197 | "\n", 1198 | "How can we write that down? Well, we are going to introduce a new variable $\\theta$ to describe the probability distribution of a bag of M&Ms coming from the New Jersey factory as opposed to the Tennessee factory. And we're also going to give each bag a new variable $z_b$ drawn from the discrete distribution for $\\theta$ which describes the assignment of each individual bag to come from a certain factory. There we have our hierarchy in the model: Each bag has a probability distribution of coming from the NJ or TN factory, and together, these probabilities are drawn from a prior describing the overall proportions of *all* bags coming from either factory. We're going to infer both together.\n", 1199 | "\n", 1200 | "The rest of the model doesn't really change, except that we need a prior for $\\theta$. Much like our initial example, where we only had two possible outcomes, we only have two factories, so our prior in this case is also a beta-distribution." 1201 | ] 1202 | }, 1203 | { 1204 | "cell_type": "code", 1205 | "execution_count": null, 1206 | "metadata": {}, 1207 | "outputs": [], 1208 | "source": [] 1209 | } 1210 | ], 1211 | "metadata": { 1212 | "kernelspec": { 1213 | "display_name": "Python 3", 1214 | "language": "python", 1215 | "name": "python3" 1216 | }, 1217 | "language_info": { 1218 | "codemirror_mode": { 1219 | "name": "ipython", 1220 | "version": 3 1221 | }, 1222 | "file_extension": ".py", 1223 | "mimetype": "text/x-python", 1224 | "name": "python", 1225 | "nbconvert_exporter": "python", 1226 | "pygments_lexer": "ipython3", 1227 | "version": "3.8.5" 1228 | } 1229 | }, 1230 | "nbformat": 4, 1231 | "nbformat_minor": 2 1232 | } 1233 | -------------------------------------------------------------------------------- /bayesian-statistics/dhuppenkothen_ahw_bayesianstats.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstroHackWeek/AstroHackWeek2022/2343873130be3d42a947300bfaecab1716419ae6/bayesian-statistics/dhuppenkothen_ahw_bayesianstats.pdf -------------------------------------------------------------------------------- /bayesian-statistics/mm_counts.csv: -------------------------------------------------------------------------------- 1 | Timestamp,What kind of M&Ms do you have?,How many *red* M&Ms were in your bag?,How many *orange* M&Ms were in your bag?,How many *yellow* M&Ms were in your bag?,How many *green* M&Ms were in your bag?,How many *blue* M&Ms were in your bag?,How many *brown* M&Ms were in your bag?,"If there are other colours in your M&Ms, please tell us the colour and the number below:",In what country did you buy your M&Ms? ,"Can you let us know the last three letters of the factory code? In the example below, this would be ""CLV""",Total number of M&Ms 2 | 27/06/2019 11:32:12,Milk chocolate (brown packaging),5,11,4,6,20,8,,USA,CLV,54 3 | 01/07/2019 12:08:13,Milk chocolate (brown packaging),6,14,6,8,14,8,,United States,CLV,56 4 | 02/07/2019 14:07:06,Milk chocolate (brown packaging),7,13,8,5,14,9,,USA,CLV,56 5 | 22/07/2019 01:55:48,Milk chocolate (brown packaging),7,12,9,13,11,4,No,America,CLV,56 6 | 22/07/2019 01:56:37,Milk chocolate (brown packaging),4,11,7,11,12,10,,USA,CLV,55 7 | 22/07/2019 01:55:52,Milk chocolate (brown packaging),2,14,5,13,13,8,,United States,CLV,55 8 | 22/07/2019 01:56:12,Milk chocolate (brown packaging),5,12,6,10,19,3,,United States,CLV,55 9 | 22/07/2019 01:56:51,Milk chocolate (brown packaging),3,12,10,15,11,4,,USA,CLV,55 10 | 22/07/2019 01:56:41,Milk chocolate (brown packaging),4,10,5,14,11,8,,US,CLV,52 11 | 22/07/2019 01:57:07,Milk chocolate (brown packaging),3,15,12,9,11,6,,US,CLV,56 12 | 22/07/2019 01:57:56,Milk chocolate (brown packaging),3,14,6,13,14,5,,USA,CLV,55 13 | 22/07/2019 01:57:57,Milk chocolate (brown packaging),8,12,6,11,12,7,0,united states,clv,56 14 | 22/07/2019 01:57:12,Milk chocolate (brown packaging),2,9,12,11,17,5,,USA,CLV,56 15 | 22/07/2019 01:58:00,Milk chocolate (brown packaging),4,10,11,10,12,2,,USA,CLV,49 16 | 22/07/2019 01:58:05,Milk chocolate (brown packaging),6,9,8,11,14,9,,USA,CLV,57 17 | 22/07/2019 01:58:18,Milk chocolate (brown packaging),3,14,10,10,11,7,we have a half orange m&m,us,CLV,55 18 | 22/07/2019 01:58:13,Milk chocolate (brown packaging),6,12,9,8,18,5,No other colors,US,CLV,58 19 | 22/07/2019 01:58:14,Milk chocolate (brown packaging),3,11,7,12,3,3,no,USA,CLV,39 20 | 22/07/2019 01:59:56,Milk chocolate (brown packaging),2,9,11,11,15,6,,USA,CLV,54 21 | 22/07/2019 01:58:25,Milk chocolate (brown packaging),2,14,10,6,17,7,,US,CLV,56 22 | 22/07/2019 02:01:13,Milk chocolate (brown packaging),5,9,6,8,19,6,,US,CLV,53 23 | 22/07/2019 02:01:38,Milk chocolate (brown packaging),10,13,5,14,9,4,,USA,CLV,55 24 | 22/07/2019 02:01:57,Milk chocolate (brown packaging),6,18,5,11,12,4,,USA,CLV,56 25 | 22/07/2019 02:40:23,Milk chocolate (brown packaging),4,9,10,10,18,3,,USA,CLV,54 26 | 24/07/2019 15:00:45,Milk chocolate (brown packaging),9,12,7,6,12,6,,USA,CLV,52 27 | 30/07/2019 11:21:52,Milk chocolate (brown packaging),3,11,6,15,16,3,,USA,CLV,54 28 | 16/08/2019 12:53:58,Milk chocolate (brown packaging),4,6,6,20,15,5,,USA,CLV,56 29 | 26/08/2019 10:34:15,Milk chocolate (brown packaging),4,20,5,10,11,6,,USA,CLV,56 30 | 09/09/2019 13:40:48,Milk chocolate (brown packaging),9,14,7,10,11,4,,USA,CLV,55 31 | 06/11/2019 09:38:54,Peanut (yellow packaging),1,2,8,2,4,1,,Canada,TOL,18 32 | 06/11/2019 10:15:10,Peanut (yellow packaging),5,3,2,1,4,4,,Spain,TOP,19 33 | 07/11/2019 14:13:34,Peanut (yellow packaging),4,0,5,1,3,5,,Canada,TOP,18 34 | 11/11/2019 11:19:28,Peanut (yellow packaging),1,2,3,5,1,5,,Canada,TOP,17 35 | 14/11/2019 12:19:08,Peanut (yellow packaging),3,5,3,1,3,3,,Canada,TOP,18 36 | 18/11/2019 02:34:51,Peanut (yellow packaging),0,5,5,3,4,2,,Italy,HAG,19 37 | 03/03/2020 21:28:18,Peanut (yellow packaging),2,10,10,11,10,9,none,Canada,HKP,52 38 | 04/03/2020 08:00:28,Milk chocolate (brown packaging),20,26,19,39,21,15,,Canada,CLV,140 39 | 04/03/2020 08:01:46,Peanut (yellow packaging),3,14,6,10,14,4,,Canada,HKP,51 40 | 04/03/2020 08:03:10,Milk chocolate (brown packaging),17,25,19,24,30,22,,Canada,CLV,137 41 | 04/03/2020 08:04:22,Peanut (yellow packaging),4,12,11,7,11,1,,Canada,HKP,46 42 | 04/03/2020 09:57:29,Milk chocolate (brown packaging),21,32,15,26,27,15,,Canada,CLV,136 43 | 04/03/2020 13:14:22,Peanut (yellow packaging),5,11,8,10,11,5,,Canada,HKP,50 44 | 22/07/2021 05:25:22,Peanut (yellow packaging),6,11,16,14,9,7,,Germany,HAG,63 45 | ,,,,,,,,,,,0 46 | ,,,,,,,,,,,0 47 | ,,,,,,,,,,,0 48 | ,,,,,,,,,,,0 49 | ,,,,,,,,,,,0 50 | ,,,,,,,,,,,0 51 | ,,,,,,,,,,,0 52 | ,,,,,,,,,,,0 53 | ,,,,,,,,,,,0 54 | ,,,,,,,,,,,0 55 | ,,,,,,,,,,,0 56 | ,,,,,,,,,,,0 57 | ,,,,,,,,,,,0 58 | ,,,,,,,,,,,0 59 | ,,,,,,,,,,,0 60 | ,,,,,,,,,,,0 61 | ,,,,,,,,,,,0 62 | ,,,,,,,,,,,0 63 | ,,,,,,,,,,,0 64 | ,,,,,,,,,,,0 65 | ,,,,,,,,,,,0 66 | ,,,,,,,,,,,0 67 | ,,,,,,,,,,,0 68 | ,,,,,,,,,,,0 69 | ,,,,,,,,,,,0 70 | ,,,,,,,,,,,0 71 | ,,,,,,,,,,,0 72 | ,,,,,,,,,,,0 73 | ,,,,,,,,,,,0 74 | ,,,,,,,,,,,0 75 | ,,,,,,,,,,,0 76 | ,,,,,,,,,,,0 77 | ,,,,,,,,,,,0 78 | ,,,,,,,,,,,0 79 | ,,,,,,,,,,,0 80 | ,,,,,,,,,,,0 81 | ,,,,,,,,,,,0 82 | ,,,,,,,,,,,0 83 | ,,,,,,,,,,,0 84 | ,,,,,,,,,,,0 85 | ,,,,,,,,,,,0 86 | ,,,,,,,,,,,0 87 | ,,,,,,,,,,,0 88 | ,,,,,,,,,,,0 89 | ,,,,,,,,,,,0 -------------------------------------------------------------------------------- /bayesian-statistics/mm_smallsample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AstroHackWeek/AstroHackWeek2022/2343873130be3d42a947300bfaecab1716419ae6/bayesian-statistics/mm_smallsample.jpg -------------------------------------------------------------------------------- /day2_ml_tutorial/README.md: -------------------------------------------------------------------------------- 1 | # Day 2 tutorial: Practical machine learning for astronomers 2 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jwuphysics/AstroHackWeek2022) 3 | 4 | This repository contains two machine learning tutorials. We will learn how to apply several machine learning algorithms, ranging from polynomial regression and random forests to deep convolutional neural networks, in order to answer a science qeustion (see e.g. [Wu 2020](https://ui.adsabs.harvard.edu/abs/2020ApJ...900..142W/abstract)). These notebooks will also help you gain familiarity with the `scikit-learn` and `pytorch`/`fastai` Python packages for machine learning. 5 | 6 | Presented by John F. Wu ([@jwuphysics](https://github.com/jwuphysics)). 7 | 8 | # Getting started 9 | 10 | ## Running on Google Colab (*recommended*) 11 | Open the Colab notebooks for the [introductory machine learning (part 1)](https://colab.research.google.com/github/jwuphysics/AstroHackWeek2022/blob/main/day2/01-intro-machine-learning.ipynb) and the [deep learning (part 2)](https://colab.research.google.com/github/jwuphysics/AstroHackWeek2022/blob/main/day2/02-deep-learning.ipynb) sessions. 12 | 13 | ## Running locally (*not recommended*) 14 | If you want to run these notebooks locally, then you should clone the repository and set up a conda environment with the necessary packages (`numpy`, `scipy`, `matplotlib`, `pandas`, `scikit-learn`, `pytorch`, `fastai`). The installation process might depend on (a) whether you have an NVIDIA graphics card, and (b) what version of CUDA your system is running. To avoid these complications, just use the Colab notebook! 15 | 16 | # Part 1 - Introductory Machine learning 17 | [![Colab - Part 1](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jwuphysics/AstroHackWeek2022/blob/main/day2/01-intro-machine-learning.ipynb) 18 | 19 | 0. Before getting started 20 | 1. Can we predict a galaxy's neutral hydrogen (HI) content? 21 | - Examine data with pandas 22 | - A very simplified glossary for xGASS 23 | - Examine and clean features 24 | - Visualize correlations 25 | 2. Polynomial regression 26 | - Multivariate linear regression 27 | - Train-test split 28 | - Cross-validation 29 | - Quadratic and higher-order polynomial models 30 | - Overfitting 31 | 3. Decision trees 32 | - How scikit-learn does it 33 | - Random forests 34 | - Optimize hyperparameters 35 | - Feature importances 36 | 37 | # Part 2 - Deep learning 38 | [![Colab - Part 2](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jwuphysics/AstroHackWeek2022/blob/main/day2/02-deep-learning.ipynb) 39 | 40 | 1. Introducing the science problem (again) 41 | 2. Solving the task with a CNN 42 | 3. Understanding convolutions 43 | 4. Other neural network ingredients 44 | 5. A simple CNN model in action (forward) 45 | 6. Optimization 46 | 7. A simple CNN model in action (forward + backward) 47 | 8. Hyperparameters 48 | 49 | # Other resources 50 | 51 | - [Data analysis recipes: Fitting a model to data](https://arxiv.org/abs/1008.4686) 52 | - [Scikit-learn - Machine learning with Python](https://scikit-learn.org/stable/) 53 | - [Fast.ai - Practical Deep Learning for Coders](https://course.fast.ai/) 54 | - [Training a deep neural network on astronomical data from scratch](https://jwuphysics.github.io/blog/galaxies/astrophysics/deep%20learning/computer%20vision/fastai/2020/05/26/training-a-deep-cnn.html) 55 | -------------------------------------------------------------------------------- /monday_ml_tutorial/.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.npy 3 | *.fits 4 | 5 | deepmerge/models 6 | deepmerge/prepared_data 7 | wandb 8 | -------------------------------------------------------------------------------- /monday_ml_tutorial/README.md: -------------------------------------------------------------------------------- 1 | # Example DVC Pipeline 2 | 3 | This folder demonstrates how to set up a [DVC pipeline](https://dvc.org/doc/start/data-management/data-versioning) with existing code. 4 | 5 | 1. Copy your code out of notebooks etc. and into scripts. 6 | 2. Execute each stage with `dvc run -n {name} -d {dependency} -o {output} {command}` 7 | 3. Check the pipeline with `dvc dag` 8 | 4. Commit `dvc.yaml` and `dvc.lock` files to git 9 | 10 | Then you can make edits to your code or data and re-run the pipeline with `dvc repro`. DVC will automatically re-run only the stages that need re-running. 11 | 12 | I've helped out with Step 1, converting to scripts, by script-ifying the HelloUniverse / Ciprianovic "DeepMerge" notebook Michelle introduced in Part 1. 13 | 14 | ## Running the Scripts 15 | 16 | Install the Python requirements 17 | 18 | pip install -r requirements.txt 19 | 20 | Download the data (images and labels, as a single FITS file) 21 | 22 | python deepmerge/download_data.py deepmerge/latest_data.fits 23 | 24 | Preprocess the data and save to deepmerge/prepared_data/*.npy 25 | 26 | python deepmerge/prepare_data.py 27 | 28 | Train a CNN on the preprocessed data, save to deepmerge/models/latest 29 | 30 | python deepmerge/train_cnn.py 31 | 32 | ## Setting up DVC Pipeline 33 | 34 | This bit is your job! Remember: describe the scripts as a series of steps, each with a name, dependencies, and outputs. 35 | 36 | dvc init --subdir 37 | 38 | # TODO for you 39 | 40 | I recommend the `dvc run` [docs](https://dvc.org/doc/command-reference/run#run) and the versioning [tutorial](https://dvc.org/doc/use-cases/versioning-data-and-models/tutorial). 41 | 42 | For extra bonus points, add Weights and Biases tracking in train_cnn.py. See the [TensorFlow guide](https://docs.wandb.ai/guides/integrations/tensorflow) 43 | -------------------------------------------------------------------------------- /monday_ml_tutorial/deepmerge/download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import urllib.request 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser( 9 | 'Download data', 10 | 'Download simulated merger images and labels', 11 | ) 12 | parser.add_argument('fits_loc', default='deepmerge/latest_data.fits') 13 | parser.add_argument('--noise', dest='noise', default=False, action='store_true') 14 | 15 | args = parser.parse_args() 16 | 17 | if args.noise: 18 | # version = 'noise' 19 | raise NotImplementedError 20 | else: 21 | # version = 'pristine' 22 | fits_url = 'https://dl.dropboxusercontent.com/s/5wt3ctqx3xlqul8/latest_data.fits' 23 | # demo version, equivalent to first 3k rows of 24 | # 'https://archive.stsci.edu/hlsps/deepmerge/hlsp_deepmerge_hst-jwst_acs-wfc3-nircam_illustris-z2_f814w-f160w-f356w_v1_sim-'+version+'.fits' 25 | 26 | print('Downloading from {} to {} - please wait'.format(fits_url, args.fits_loc)) 27 | urllib.request.urlretrieve(fits_url, args.fits_loc) 28 | -------------------------------------------------------------------------------- /monday_ml_tutorial/deepmerge/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from astropy.io import fits 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.preprocessing import StandardScaler 7 | 8 | 9 | def split(X, y, random_state): 10 | 11 | 12 | # First split off 30% of the data for validation+testing 13 | X_train, X_split, y_train, y_split = train_test_split(X, y, test_size=0.3, random_state=random_state, shuffle=True) 14 | 15 | # Then divide this subset into training and testing sets 16 | X_valid, X_test, y_valid, y_test = train_test_split(X_split, y_split, test_size=0.666, random_state=random_state, shuffle=True) 17 | 18 | return X_train, y_train, X_valid, y_valid, X_test, y_test 19 | 20 | 21 | def normalise(X): 22 | # divide each channel by channel mean 23 | channel_means = X.mean(axis=(0, 1, 2)) 24 | return X / channel_means 25 | 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | fits_loc = 'deepmerge/latest_data.fits' 31 | 32 | hdu = fits.open(fits_loc) 33 | X = hdu[0].data.transpose(0, 2, 3, 1) # channels last 34 | y = hdu[1].data 35 | 36 | 37 | X = np.asarray(X).astype('float32') 38 | y = np.asarray(y).astype('float32') 39 | 40 | X = normalise(X) 41 | print('Normalised to: ', X.mean(axis=(0, 1, 2))) 42 | 43 | random_state = 42 44 | X_train, y_train, X_valid, y_valid, X_test, y_test = split(X, y, random_state) 45 | 46 | prepared_data_dir = 'deepmerge/prepared_data' 47 | os.mkdir(prepared_data_dir) 48 | 49 | for name, data in [ 50 | ('X_train', X_train), 51 | ('X_valid', X_valid), 52 | ('X_test', X_test), 53 | ('y_train', y_train), 54 | ('y_valid', y_valid), 55 | ('y_test', y_test) 56 | ]: 57 | path = os.path.join(prepared_data_dir, name + '.npy') 58 | with open(path, 'w') as f: 59 | np.save(path, data) -------------------------------------------------------------------------------- /monday_ml_tutorial/deepmerge/train_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from keras.models import Model 5 | from keras.layers import Input, Flatten, Dense, Dropout, BatchNormalization 6 | from keras.layers.convolutional import Convolution2D, MaxPooling2D 7 | from keras.regularizers import l2 8 | from keras.callbacks import ModelCheckpoint, EarlyStopping 9 | 10 | # for bonus points: 11 | # import wandb 12 | # and either 13 | # from keras.callbacks import TensorBoard 14 | # or 15 | # from wandb.keras import WandbCallback 16 | # and the rest is up to you 17 | 18 | 19 | def get_cnn_architecture(imsize=75, channels=3): 20 | input_shape = (imsize, imsize, channels) 21 | 22 | x_in = Input(shape=input_shape) 23 | c0 = Convolution2D(8, (5, 5), activation='relu', strides=(1, 1), padding='same')(x_in) 24 | b0 = BatchNormalization()(c0) 25 | d0 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b0) 26 | e0 = Dropout(0.5)(d0) 27 | 28 | c1 = Convolution2D(16, (3, 3), activation='relu', strides=(1, 1), padding='same')(e0) 29 | b1 = BatchNormalization()(c1) 30 | d1 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b1) 31 | e1 = Dropout(0.5)(d1) 32 | 33 | c2 = Convolution2D(32, (3, 3), activation='relu', strides=(1, 1), padding='same')(e1) 34 | b2 = BatchNormalization()(c2) 35 | d2 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b2) 36 | e2 = Dropout(0.5)(d2) 37 | 38 | f = Flatten()(e2) 39 | z0 = Dense(64, activation='softmax', kernel_regularizer=l2(0.0001))(f) 40 | z1 = Dense(32, activation='softmax', kernel_regularizer=l2(0.0001))(z0) 41 | y_out = Dense(1, activation='sigmoid')(z1) 42 | 43 | return Model(inputs=x_in, outputs=y_out) 44 | 45 | 46 | def get_prepared_data(data_dir='deepmerge/prepared_data'): 47 | return [load_by_name(name, data_dir) for name in [ 48 | 'X_train', 49 | 'X_valid', 50 | 'X_test', 51 | 'y_train', 52 | 'y_valid', 53 | 'y_test' 54 | ]] 55 | 56 | def load_by_name(name, data_dir): 57 | path = os.path.join(data_dir, name + '.npy') 58 | with open(path, 'r') as f: 59 | return np.load(path) 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | # extra super duper bonus points - refactor out the hyperparameters 65 | # then add them to the wandb init's config= arg 66 | cnn = get_cnn_architecture() 67 | prepared_data = get_prepared_data() 68 | X_train, X_valid, X_test, y_train, y_valid, y_test = prepared_data 69 | 70 | optimizer = 'adam' 71 | fit_metrics = ['accuracy'] 72 | loss = 'binary_crossentropy' 73 | cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics) 74 | 75 | nb_epoch = 2 76 | batch_size = 128 77 | shuffle = True 78 | 79 | # model checkpoints will be saved here (only the best) 80 | # directory "../latest", checkpoint name "model" 81 | model_dir = 'deepmerge/models/latest' 82 | os.makedirs(model_dir) 83 | 84 | model_name = 'model' 85 | model_path = os.path.join(model_dir, model_name) 86 | 87 | # TODO may want some wandb stuff here 88 | 89 | default_callbacks = [ 90 | # TODO for wandb, you may want to add some extra callbacks here... 91 | ] 92 | fit_callbacks = default_callbacks + [ 93 | ModelCheckpoint(model_path, save_best_only=True, save_weights_only=True), 94 | EarlyStopping(patience=5, restore_best_weights=True) 95 | ] 96 | 97 | # Train 98 | history = cnn.fit(X_train, y_train, 99 | batch_size=batch_size, 100 | epochs=nb_epoch, 101 | validation_data=(X_valid, y_valid), 102 | shuffle=shuffle, 103 | callbacks=fit_callbacks, 104 | verbose=1) 105 | 106 | cnn.evaluate(X_test, y_test, callbacks=default_callbacks) 107 | -------------------------------------------------------------------------------- /monday_ml_tutorial/original_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d2ad3bde", 6 | "metadata": {}, 7 | "source": [ 8 | "\n", 9 | "# Classifying JWST-HST galaxy mergers with CNNs" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "3a7f9875", 15 | "metadata": { 16 | "slideshow": { 17 | "slide_type": "skip" 18 | } 19 | }, 20 | "source": [ 21 | "***" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "ebe2a9d0", 27 | "metadata": {}, 28 | "source": [ 29 | "## Learning Goals\n", 30 | "\n", 31 | "\n", 32 | "**In this tutorial, you will see an example of building, compiling, and training a CNN on simulated astronomical data.**\n", 33 | "By the end of this tutorial you will have a working example of a simple Convolutional Neural Network (CNN) in `Keras`. " 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "324aa021", 39 | "metadata": {}, 40 | "source": [ 41 | "## Introduction\n", 42 | "CNNs are a class of machine learning (ML) algorithms that can extract information from images.\n", 43 | "In this notebook, you will walk through the basic steps of applying a CNN to data:\n", 44 | "1. Load the data and visualize a sample of the data.\n", 45 | "2. Divide the data into training, validation, and testing sets.\n", 46 | "3. Build a CNN in `Keras`.\n", 47 | "4. Compile the CNN.\n", 48 | "5. Train the CNN to perform a classification task.\n", 49 | "6. Evaluate the results.\n", 50 | "\n", 51 | "CNNs can be applied to a wide range of image recognition tasks, including classification and regression.\n", 52 | "In this tutorial, we will build, compile, and train CNN to classify whether a galaxy has undergone a merger, using simulated Hubble Space Telescope images of galaxies.\n", 53 | "This work is based on the public data and code from DeepMerge (Ciprijanovic et al. 2020). \n", 54 | "\n", 55 | "**NOTE:** *The DeepMerge team has [publicly-available code](https://github.com/deepskies/deepmerge-public) for demonstrating the architecture and optimal performace of the model, which we encourage you to check out! The goal of this notebook is to step through the model building and training process.*\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "618ee480", 61 | "metadata": {}, 62 | "source": [ 63 | "## Imports\n", 64 | "This notebook uses the following packages:\n", 65 | "- `numpy` to handle array functions\n", 66 | "- `astropy` for downloading and accessing FITS files\n", 67 | "- `matplotlib.pyplot` for plotting data\n", 68 | "- `keras` and `tensorflow` for building the CNN\n", 69 | "- `sklearn` for some utility functions\n", 70 | "\n", 71 | "If you do not have these packages installed, you can install them using [`pip`](https://pip.pypa.io/en/stable/) or [`conda`](https://docs.conda.io/en/latest/)." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "519891d4", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# arrays\n", 82 | "import numpy as np\n", 83 | "\n", 84 | "# fits\n", 85 | "from astropy.io import fits\n", 86 | "from astropy.utils.data import download_file\n", 87 | "from astropy.visualization import simple_norm\n", 88 | "\n", 89 | "# plotting\n", 90 | "from matplotlib import pyplot as plt\n", 91 | "\n", 92 | "# keras\n", 93 | "from keras.models import Model\n", 94 | "from keras.layers import Input, Flatten, Dense, Activation, Dropout, BatchNormalization\n", 95 | "from keras.layers.convolutional import Convolution2D, MaxPooling2D\n", 96 | "from keras.regularizers import l2\n", 97 | "from keras.callbacks import EarlyStopping\n", 98 | "\n", 99 | "# sklearn (for machine learning)\n", 100 | "from sklearn.model_selection import train_test_split\n", 101 | "from sklearn import metrics\n", 102 | "\n", 103 | "# from IPython import get_ipython\n", 104 | "# get_ipython().run_line_magic('matplotlib', 'notebook')" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "62a3fcb1", 110 | "metadata": {}, 111 | "source": [ 112 | "### 1. Load the data and visualize a sample of the data\n", 113 | "\n", 114 | "Load the simulated galaxy observations (3-band images) and merger probabilities (output labels).\n", 115 | "\n", 116 | "In total, there are 15,426 simulated images, each in three filters (F814W from the Advanced Camera for Surveys and F160W from the Wide Field Camera 3 on the Hubble Space Telescope (HST), and F160W and F356W from Near Infrared Camera on the James Webb Space Telescope (JWST)), retrieved and augmented from synthetic observations of the Illustris cosmological simulation. The sample includes 8120 galaxy mergers and 7306 non-mergers. Two versions of the sample are available, with and without realistic observational and experimental noise (\"pristine\" and \"noisy\"). The sample construction and augmentation process for the HST images is described in detail in [Ciprijanovic et al. 2020](https://doi.org/10.1016/j.ascom.2020.100390), and is identical for the mock JWST images. \n", 117 | "\n", 118 | "These datasets are hosted at the Mikulski Archive for Space Telescopes as an the [DEEPMERGE](https://archive.stsci.edu/doi/resolve/resolve.html?doi=10.17909/t9-vqk6-pc80) high-level science product (HLSP). \n", 119 | "\n", 120 | "The CNN will be trained to distinguish between merging and non-merging galaxies. " 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "c8a2527c", 126 | "metadata": {}, 127 | "source": [ 128 | "#### Load the data" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "9944bd61", 134 | "metadata": {}, 135 | "source": [ 136 | "The simulated images are stored in FITS format. We refer you to the [Astropy Documentation](https://docs.astropy.org/en/stable/io/fits/index.html) for further information about this format. " 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "182a55ee", 142 | "metadata": {}, 143 | "source": [ 144 | "For this example, we will download the \"pristine\" set of galaxy images, i.e., those without added observational noise. To select the \"noisy\" sample, change the version below. Alternatively, you can download data files from the [DEEPMERGE](https://stdatu.stsci.edu/hlsp/deepmerge) website." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "c01626c8", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "version = 'pristine'" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "04f9869c", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "%%time\n", 165 | "file_url = 'https://archive.stsci.edu/hlsps/deepmerge/hlsp_deepmerge_hst-jwst_acs-wfc3-nircam_illustris-z2_f814w-f160w-f356w_v1_sim-'+version+'.fits'\n", 166 | "c(download_file(file_url, cache=True, show_progress=True))" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "18e3a0dd", 172 | "metadata": {}, 173 | "source": [ 174 | "Explore the header of the file for information about its contents" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "73588651", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "hdu[0].header" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "e4d36518", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "hdu[1].header" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "cc589f00", 200 | "metadata": {}, 201 | "source": [ 202 | "The file includes a primary header card with overall information, an image card with the simulated images, and a bintable with the merger labels for the images (1=merger, 0=non-merger)." 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "aeaef9ba", 208 | "metadata": {}, 209 | "source": [ 210 | "#### Plot example images\n", 211 | "\n", 212 | "For a random selection of images, plot the images and their corresponding labels:" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "a41380cd", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "hdu[0].data.shape" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "2e79adba", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "# set the random seed to get the same random set of images each time, or comment it out to get different ones!\n", 233 | "# np.random.seed(206265)\n", 234 | "\n", 235 | "# select 16 random image indices:\n", 236 | "example_ids = np.random.choice(hdu[1].data.shape[0], 16)\n", 237 | "# pull the F160W image (index=1) from the simulated dataset for these selections\n", 238 | "examples = [hdu[0].data[j, 1, :, :] for j in example_ids]\n", 239 | "\n", 240 | "# initialize your figure\n", 241 | "fig = plt.figure(figsize=(8, 8)) \n", 242 | "\n", 243 | "# loop through the randomly selected images and plot with labels\n", 244 | "for i, image in enumerate(examples):\n", 245 | " ax = fig.add_subplot(4, 4, i+1)\n", 246 | " norm = simple_norm(image, 'log', max_percent=99.75)\n", 247 | "\n", 248 | " ax.imshow(image, aspect='equal', cmap='binary_r', norm=norm)\n", 249 | " ax.set_title('Merger='+str(bool(hdu[1].data[example_ids[i]][0])))\n", 250 | " \n", 251 | " ax.axis('off')\n", 252 | " \n", 253 | "plt.show()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "id": "408afadf", 259 | "metadata": {}, 260 | "source": [ 261 | "## 2. Divide data into training, validation, and testing sets" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "id": "857ab807", 267 | "metadata": {}, 268 | "source": [ 269 | "To divide the data set into training, validation, and testing data we will use Scikit-Learn's [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) function. \n", 270 | "\n", 271 | "We will denote the input images as `X` and their corresponding labels (i.e. the integer indicating whether or not they are a merger) as `y`, following the convention used by `sklearn`." 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "40cef745", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "X = hdu[0].data\n", 282 | "y = hdu[1].data" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "id": "c578392b", 288 | "metadata": {}, 289 | "source": [ 290 | "Following the authors, we will split the data into 70:10:20 ratio of train:validate:test\n" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "id": "65eef99f", 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "# as above, set the random seed to randomly split the images in a repeatable way. Feel free to try different values!\n", 301 | "random_state = 42\n", 302 | "\n", 303 | "X = np.asarray(X).astype('float32')\n", 304 | "y = np.asarray(y).astype('float32')\n", 305 | "\n", 306 | "# First split off 30% of the data for validation+testing\n", 307 | "X_train, X_split, y_train, y_split = train_test_split(X, y, test_size=0.3, random_state=random_state, shuffle=True)\n", 308 | "\n", 309 | "# Then divide this subset into training and testing sets\n", 310 | "X_valid, X_test, y_valid, y_test = train_test_split(X_split, y_split, test_size=0.666, random_state=random_state, shuffle=True)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "id": "8a769e47", 316 | "metadata": {}, 317 | "source": [ 318 | "Next, reshape the image array as follows: (number_of_images, image_width, image_length, 3).\n", 319 | "This is referred to as a \"channels last\" approach, where the final axis denotes the number of \"colors\" or \"channels\".\n", 320 | "The three-filter images have three channels, similar to RGB images like `jpg` and `png` image formats.\n", 321 | "CNN's will work with an arbitrary number of channels." 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "19f0cf02", 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "imsize = np.shape(X_train)[2]\n", 332 | "\n", 333 | "X_train = X_train.reshape(-1, imsize, imsize, 3)\n", 334 | "X_valid = X_valid.reshape(-1, imsize, imsize, 3)\n", 335 | "X_test = X_test.reshape(-1, imsize, imsize, 3)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "id": "fe1a5163", 341 | "metadata": {}, 342 | "source": [ 343 | "### 3. Build a CNN in `Keras`\n", 344 | "\n", 345 | "Here, we will build the model described in Section 3 of [Ciprijanovic et al. 2020](https://doi.org/10.1016/j.ascom.2020.100390).\n", 346 | "\n", 347 | "Further details about `Conv2D`, `MaxPooling2D`, `BatchNormalization`, `Dropout`, and Dense layers can be found in the [Keras Layers Documentation](https://keras.io/api/layers/). \n", 348 | "Further details about the sigmoid and softmax activation function can be found in the [Keras Activation Function Documentation](https://keras.io/api/layers/activations/)." 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "id": "580ca93f", 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "# ------------------------------------------------------------------------------\n", 359 | "# generate the model architecture\n", 360 | "# Written for Keras 2\n", 361 | "# ------------------------------------------------------------------------------\n", 362 | "\n", 363 | "# Define architecture for model\n", 364 | "data_shape = np.shape(X)\n", 365 | "input_shape = (imsize, imsize, 3)\n", 366 | "\n", 367 | "x_in = Input(shape=input_shape)\n", 368 | "c0 = Convolution2D(8, (5, 5), activation='relu', strides=(1, 1), padding='same')(x_in)\n", 369 | "b0 = BatchNormalization()(c0)\n", 370 | "d0 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b0)\n", 371 | "e0 = Dropout(0.5)(d0)\n", 372 | "\n", 373 | "c1 = Convolution2D(16, (3, 3), activation='relu', strides=(1, 1), padding='same')(e0)\n", 374 | "b1 = BatchNormalization()(c1)\n", 375 | "d1 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b1)\n", 376 | "e1 = Dropout(0.5)(d1)\n", 377 | "\n", 378 | "c2 = Convolution2D(32, (3, 3), activation='relu', strides=(1, 1), padding='same')(e1)\n", 379 | "b2 = BatchNormalization()(c2)\n", 380 | "d2 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(b2)\n", 381 | "e2 = Dropout(0.5)(d2)\n", 382 | "\n", 383 | "f = Flatten()(e2)\n", 384 | "z0 = Dense(64, activation='softmax', kernel_regularizer=l2(0.0001))(f)\n", 385 | "z1 = Dense(32, activation='softmax', kernel_regularizer=l2(0.0001))(z0)\n", 386 | "y_out = Dense(1, activation='sigmoid')(z1)\n", 387 | "\n", 388 | "cnn = Model(inputs=x_in, outputs=y_out)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "id": "2454f58f", 394 | "metadata": {}, 395 | "source": [ 396 | "### 4. Compile the CNN\n", 397 | "\n", 398 | "Next, we compile the model.\n", 399 | "As in [Ciprijanovic et al. 2020](https://doi.org/10.1016/j.ascom.2020.100390), we select the Adam opmimizer and the binary cross entropy loss function (as this is a binary classification problem).\n", 400 | "\n", 401 | "You can learn more about [optimizers](https://keras.io/api/optimizers/) and more about [loss functions for regression tasks](https://keras.io/api/losses/) in the [Keras documentation](https://keras.io/)." 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "id": "b90b61ba", 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "# Compile Model\n", 412 | "optimizer = 'adam'\n", 413 | "fit_metrics = ['accuracy']\n", 414 | "loss = 'binary_crossentropy'\n", 415 | "cnn.compile(loss=loss, optimizer=optimizer, metrics=fit_metrics)\n", 416 | "cnn.summary()" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "id": "81fca68b", 422 | "metadata": {}, 423 | "source": [ 424 | "### 5. Train the CNN to perform a classification task\n", 425 | "\n", 426 | "We will start with training for 20 epochs, but this almost certainly won't be long enough to get great results. We set the \"batch size\" of the network (i.e., the number of samples to be propagated through the network, see the keras documentation [here](https://keras.io/api/models/model_training_apis/)) to 128. Once you've run your model and evaluated the fit, you can come back here and run the next cell again for 100 epochs or longer.\n", 427 | "This step will likely take many minutes. The training step is typically the computational bottleneck for using CNNs.\n", 428 | "However, once a CNN is trained, it can effectively be \"packaged up\" for future use on the original or other machines.\n", 429 | "In other words, it doesn't have to be retrained every time one wants to use it!\n", 430 | "\n", 431 | "You can learn more about `model.fit` [here](https://keras.rstudio.com/reference/fit.html)." 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "id": "b7b0bf55", 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "nb_epoch = 20\n", 442 | "batch_size = 128\n", 443 | "shuffle = True\n", 444 | "\n", 445 | "# Train\n", 446 | "history = cnn.fit(X_train, y_train, \n", 447 | " batch_size=batch_size, \n", 448 | " epochs=nb_epoch, \n", 449 | " validation_data=(X_valid, y_valid),\n", 450 | " shuffle=shuffle,\n", 451 | " verbose=False)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "id": "7f92fad3", 457 | "metadata": {}, 458 | "source": [ 459 | "### 6. Visualize CNN performance\n", 460 | "\n", 461 | "To visualize the performance of the CNN, we plot the evolution of the accuracy and loss as a function of training epochs, for the training set and for the validation set. " 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "ae194501", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "# plotting from history\n", 472 | "\n", 473 | "loss = history.history['loss']\n", 474 | "val_loss = history.history['val_loss']\n", 475 | "acc = history.history['accuracy']\n", 476 | "val_acc = history.history['val_accuracy']\n", 477 | "\n", 478 | "epochs = list(range(len(loss)))\n", 479 | "\n", 480 | "figsize = (6, 4)\n", 481 | "fig, axis1 = plt.subplots(figsize=figsize)\n", 482 | "plot1_lacc = axis1.plot(epochs, acc, 'navy', label='accuracy')\n", 483 | "plot1_val_lacc = axis1.plot(epochs, val_acc, 'deepskyblue', label=\"validation accuracy\")\n", 484 | "\n", 485 | "plot1_loss = axis1.plot(epochs, loss, 'red', label='loss')\n", 486 | "plot1_val_loss = axis1.plot(epochs, val_loss, 'lightsalmon', label=\"validation loss\")\n", 487 | "\n", 488 | "\n", 489 | "plots = plot1_loss + plot1_val_loss\n", 490 | "labs = [plot.get_label() for plot in plots]\n", 491 | "axis1.set_xlabel('Epoch')\n", 492 | "axis1.set_ylabel('Loss/Accuracy')\n", 493 | "plt.title(\"Loss/Accuracy History (Pristine Images)\")\n", 494 | "plt.tight_layout()\n", 495 | "axis1.legend(loc='lower right')\n", 496 | "plt.show()" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "id": "7cab9df0", 502 | "metadata": {}, 503 | "source": [ 504 | "Observe how the loss for the validation set is higher than for the training set (and conversely, the accuracy for the validation set is lower than for the training set), suggesting that this model is suffering from [overfitting](https://www.tensorflow.org/tutorials/keras/overfit_and_underfit). Revisit [the original paper](https://ui.adsabs.harvard.edu/abs/2020A%26C....3200390C/abstract) and notice the strategies they employ to improve the validation accuracy. Observe [their Figure 2](https://www.sciencedirect.com/science/article/pii/S2213133720300445) for an example of what the results of a properly-trained network look like!\n" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "id": "d1f1309e", 510 | "metadata": {}, 511 | "source": [ 512 | "### 7. Predict mergers!" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "id": "0b9eed13", 518 | "metadata": {}, 519 | "source": [ 520 | "Apply the CNN to predict mergers in the \"test\" set, not used for training or validating the CNN." 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "id": "76b5d921", 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "test_predictions = cnn.predict(X_test)" 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "id": "a4fc6fbe", 536 | "metadata": {}, 537 | "source": [ 538 | "Below, we use a confusion matrix to evaluate the model performance on the test data. See the documentation from [sklearn on confusion matrices](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html) for more information." 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "id": "52003db9", 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "def plot_confusion_matrix(cnn, input_data, input_labels):\n", 549 | " \n", 550 | " # Compute merger predictions for the test dataset\n", 551 | " predictions = cnn.predict(input_data)\n", 552 | "\n", 553 | " # Convert to binary classification \n", 554 | " predictions = (predictions > 0.5).astype('int32') \n", 555 | " \n", 556 | " # Compute the confusion matrix by comparing the test labels (ds.test_labels) with the test predictions\n", 557 | " cm = metrics.confusion_matrix(input_labels, predictions, labels=[0, 1])\n", 558 | " cm = cm.astype('float')\n", 559 | "\n", 560 | " # Normalize the confusion matrix results. \n", 561 | " cm_norm = cm / cm.sum(axis=1)[:, np.newaxis]\n", 562 | " \n", 563 | " # Plotting\n", 564 | " fig = plt.figure()\n", 565 | " ax = fig.add_subplot(111)\n", 566 | "\n", 567 | " ax.matshow(cm_norm, cmap='binary_r')\n", 568 | "\n", 569 | " plt.title('Confusion matrix', y=1.08)\n", 570 | " \n", 571 | " ax.set_xticks([0, 1])\n", 572 | " ax.set_xticklabels(['Merger', 'No Merger'])\n", 573 | " \n", 574 | " ax.set_yticks([0, 1])\n", 575 | " ax.set_yticklabels(['Merger', 'No Merger'])\n", 576 | "\n", 577 | " plt.xlabel('Predicted')\n", 578 | " plt.ylabel('True')\n", 579 | "\n", 580 | " fmt = '.2f'\n", 581 | " thresh = cm_norm.max() / 2.\n", 582 | " for i in range(cm_norm.shape[0]):\n", 583 | " for j in range(cm_norm.shape[1]):\n", 584 | " ax.text(j, i, format(cm_norm[i, j], fmt), \n", 585 | " ha=\"center\", va=\"center\", \n", 586 | " color=\"white\" if cm_norm[i, j] < thresh else \"black\")\n", 587 | " plt.show()" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "id": "19e4335d", 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "plot_confusion_matrix(cnn, X_test, y_test)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "id": "0591d8de", 603 | "metadata": {}, 604 | "source": [ 605 | "## FAQ\n", 606 | "\n", 607 | "- **How do I interpret theses results?** The confusion matrix shows the model predicts a large fraction of false positive (roughly 25%) and false negative (roughly 36%) merger events. The published models from [Ciprijanovic et al. 2020](https://doi.org/10.1016/j.ascom.2020.100390) perform much better. We note that in this notebook we are training for only a subset of the optimal number of epochs for space and time considerations, but you are welcome to agument these restricitons, and as always check out [the DeepMerge code](https://github.com/deepskies/deepmerge-public) for more information!\n", 608 | "\n", 609 | "\n", 610 | "- **Can I improve the model by changing it?** We only trained for 20 epochs, which is many fewer than the published model. Go back to Section 4 (\"Train the CNN to perform a classification task\") and increase the number of epochs to 100 (or more!) and train again. Does your model perform better? Your results may look better/worse/different from the published results due to the stochastic nature of training. \n", 611 | "\n", 612 | "\n", 613 | "- **Can I try a different model? I think the results could be improved.** Yes! You can try adding layers, swapping out the max pooling, changing the activation functions, swapping out the loss function, or trying a different optimizer or learning rate. Experiment and see what model changes give the best results. You should be aware: when you start training again, you pick up where your model left off. If you want to \"reset\" your model to epoch 0 and random weights, you should run the cells to make and compile the model again.\n", 614 | "\n", 615 | "\n", 616 | "- **I want to test my model on my training data!** No. You will convince yourself that your results are much better than they actually are. Always keep your training, validation, and testing sets completely separate!\n" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "id": "4aee2cf0", 622 | "metadata": {}, 623 | "source": [ 624 | "### Extensions/Exercises\n", 625 | "\n", 626 | "- **Effect of noise?** Try re-training the network with \"noisy\" data (i.e., modify the `version` in Section 1 to \"noisy\" and download the associated data product). Do the results change? If so, how and why? What are the pros and cons of using noisy vs. pristine data to train a ML model? \n", 627 | "\n", 628 | "\n", 629 | "\n", 630 | "- **Effect of wavelength?** The [DEEPMERGE HLSP](https://archive.stsci.edu/doi/resolve/resolve.html?doi=10.17909/t9-vqk6-pc80) includes mock galaxy images in 2 filters only (only HST data). If you train the network with this data (hint: this will require downloading it from the website, or modifying the download cells to point to the correct URL; and also modifying the shapes of the training, validation and test data, as well as the network inputs), how do the results change? \n", 631 | "\n", 632 | "\n", 633 | "\n", 634 | "- **Early stopping?** The DeepMerge team employed \"early stopping\" to minimize overfitting. Try implementing it in the network here! The Keras library for [early stopping](https://keras.io/api/callbacks/early_stopping/) functions will be useful. For example, you can recompile the model, train for many more epochs, and include a `callback`, in `cnn.train` e.g.,\n", 635 | "\n", 636 | " `callback = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=50)`\n", 637 | " \n", 638 | "\n", 639 | "*Don't forget, [the DeepMerge team provides code](https://github.com/deepskies/deepmerge-public) for building their production-level model and verifying its results, please check them out for more extensions and ideas!*\n" 640 | ] 641 | }, 642 | { 643 | "cell_type": "markdown", 644 | "id": "3b2c2acc", 645 | "metadata": {}, 646 | "source": [ 647 | "## About this Notebook\n", 648 | "\n", 649 | "**Author:** \n", 650 | "Claire Murray, Assistant Astronomer, cmurray1@stsci.edu \n", 651 | "\n", 652 | "**Additional Contributors:** \n", 653 | "Yotam Cohen, STScI Staff Scientist, ycohen@stsci.edu\n", 654 | "\n", 655 | "**Info:** \n", 656 | "This notebook is based on the code repository for the paper \"DeepMerge: Classifying High-redshift Merging Galaxies with Deep Neural Networks\", A. Ćiprijanović, G.F. Snyder, B. Nord, J.E.G. Peek, Astronomy & Computing, Volume 32, July 2020, and the notebook \"CNN_for_cluster_masses\" by Michelle Ntampaka, Assistant Astronomer, mntampaka@stsci.edu.\n", 657 | "\n", 658 | "**Updated On:** 2022-5-25" 659 | ] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "id": "3d60053a", 664 | "metadata": {}, 665 | "source": [ 666 | "## Citations\n", 667 | "\n", 668 | "If you use this data set, `astropy`, or `keras` for published research, please cite the\n", 669 | "authors. Follow these links for more information:\n", 670 | "\n", 671 | "* [Citing the data set](https://www.sciencedirect.com/science/article/pii/S2213133720300445#fn3)\n", 672 | "* [Citing `astropy`](https://www.astropy.org/acknowledging.html)\n", 673 | "* [Citing `keras`](https://keras.io/getting_started/faq/#how-should-i-cite-keras)\n", 674 | "\n" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "id": "d10703b3", 680 | "metadata": {}, 681 | "source": [ 682 | "[Top of Page](#top)\n", 683 | "\"Space " 684 | ] 685 | } 686 | ], 687 | "metadata": { 688 | "kernelspec": { 689 | "display_name": "Python 3.8.10 ('zoobot')", 690 | "language": "python", 691 | "name": "python3" 692 | }, 693 | "language_info": { 694 | "codemirror_mode": { 695 | "name": "ipython", 696 | "version": 3 697 | }, 698 | "file_extension": ".py", 699 | "mimetype": "text/x-python", 700 | "name": "python", 701 | "nbconvert_exporter": "python", 702 | "pygments_lexer": "ipython3", 703 | "version": "3.8.10" 704 | }, 705 | "vscode": { 706 | "interpreter": { 707 | "hash": "f17685d2e70c07ccb24ff33fa38795c5cafdc43f943119294282af8db3ae350a" 708 | } 709 | } 710 | }, 711 | "nbformat": 4, 712 | "nbformat_minor": 5 713 | } 714 | -------------------------------------------------------------------------------- /monday_ml_tutorial/requirements.txt: -------------------------------------------------------------------------------- 1 | astropy>=5.0.2 2 | keras>=2.8.0 3 | matplotlib>=3.5.1 4 | numpy>=1.22.3 5 | sklearn>=0.0 6 | tensorflow>=2.8.0 7 | protobuf==3.18 8 | 9 | # adding 10 | dvc 11 | wandb 12 | fsspec==2022.8.2 13 | typing-extensions==4.4.0 # astropy incompat -------------------------------------------------------------------------------- /monday_ml_tutorial/trim_fits.py: -------------------------------------------------------------------------------- 1 | from astropy.io import fits 2 | import numpy as np 3 | 4 | if __name__ == '__main__': 5 | 6 | # just a little dev script to make the data smaller by taking the first N images/labels 7 | # not generally advised, just makes the download quicker for a demo 8 | 9 | fits_loc = 'deepmerge/latest_data_untrimmed.fits' 10 | max_examples = 3000 11 | 12 | hdu = fits.open(fits_loc) 13 | random_indices = np.random.choice(np.arange(len(hdu[0].data)), size=max_examples, replace=False) 14 | hdu[0].data = hdu[0].data[random_indices] 15 | hdu[1].data = hdu[1].data[random_indices] 16 | 17 | hdu.writeto('deepmerge/latest_data.fits', overwrite=True) 18 | --------------------------------------------------------------------------------