├── README.md ├── Week1 ├── READMe.txt ├── W1.1_Introduction.pptx └── W1.2_Introduction.pptx ├── Week12 ├── GNN.png └── WOA7015_Wk12.ipynb ├── Week2 ├── CoinFlip.png ├── EstimatePiFromCircleSquare.png ├── Exercise3_Multivariate_Image.png ├── READMe.txt ├── W2_FoundationI.pptx └── WOA7015_Wk2.ipynb ├── Week3 ├── READMe.txt ├── W3.pptx ├── WOA7015_Wk3.ipynb └── XOR.jpg ├── Week4 ├── MnistExamples.png ├── README ├── W4.pptx └── WOA7015_Wk4.ipynb ├── Week5 ├── CV.png ├── CV_test.png ├── W5.pptx ├── WOA7015_Wk5.ipynb ├── iris.PNG └── train_test.png ├── Week6 ├── W6.pptx └── WOA7015_Wk6.ipynb ├── Week7 └── W7.pptx ├── Week8 └── W8.pptx └── Week9 ├── Adj_matrix.PNG ├── Message_passing.PNG ├── W9.pptx ├── WOA7015_Wk9.ipynb └── WOA7015_Wk9_Extra.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Advanced_ML 2 | 3 | The teaching materials developed here is inspired from 4 | 1. Kevin Murphy - https://probml.github.io/pml-book/book1.html 5 | 6 | -------------------------------------------------------------------------------- /Week1/READMe.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Week1/W1.1_Introduction.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week1/W1.1_Introduction.pptx -------------------------------------------------------------------------------- /Week1/W1.2_Introduction.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week1/W1.2_Introduction.pptx -------------------------------------------------------------------------------- /Week12/GNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week12/GNN.png -------------------------------------------------------------------------------- /Week2/CoinFlip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week2/CoinFlip.png -------------------------------------------------------------------------------- /Week2/EstimatePiFromCircleSquare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week2/EstimatePiFromCircleSquare.png -------------------------------------------------------------------------------- /Week2/Exercise3_Multivariate_Image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week2/Exercise3_Multivariate_Image.png -------------------------------------------------------------------------------- /Week2/READMe.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Week2/W2_FoundationI.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week2/W2_FoundationI.pptx -------------------------------------------------------------------------------- /Week2/WOA7015_Wk2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [], 8 | "authorship_tag": "ABX9TyMPAj70b1tfVLCW+dBMWfV9", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "HLjtuMaImRla" 34 | }, 35 | "source": [ 36 | "# Welcome to WOA7015 Advance Machine Learning Lab - Week 2 \n", 37 | "This code is generated for the purpose of WOA7015 module.\n", 38 | "The code is available in github https://github.com/shiernee/Advanced_ML \n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "Jxdp_95lYsoV" 45 | }, 46 | "source": [ 47 | "# The Gaussian Distribution\n", 48 | "\n", 49 | "The p.d.f of random variable $Z$ with a gaussian / normal distribution is shown below\n", 50 | "\n", 51 | "$$ p(z) = \\frac{1}{\\sqrt{2\\pi}} e^{-z^2 / 2}. $$\n", 52 | "\n", 53 | "It is defined for all real values $z$, from $-\\infty$ to $\\infty$.\n", 54 | "\n", 55 | "The distribution looks like this:" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "colab": { 62 | "background_save": true 63 | }, 64 | "id": "MPLuZCPQZ9cK" 65 | }, 66 | "source": [ 67 | "# import symbulate https://dlsun.github.io/symbulate/index.html \n", 68 | "\n", 69 | "!pip install -q symbulate\n", 70 | "from symbulate import *" 71 | ], 72 | "execution_count": null, 73 | "outputs": [] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "metadata": { 78 | "colab": { 79 | "base_uri": "https://localhost:8080/", 80 | "height": 257 81 | }, 82 | "id": "QBS_72HxYozq", 83 | "outputId": "f571f724-7a9b-418e-b1f8-843cd5ce2c77" 84 | }, 85 | "source": [ 86 | "Normal().plot()" 87 | ], 88 | "execution_count": null, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "image/png": "\n", 93 | "text/plain": [ 94 | "
" 95 | ] 96 | }, 97 | "metadata": {}, 98 | "output_type": "display_data" 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": { 105 | "id": "4jjeZ6B4ZNfI" 106 | }, 107 | "source": [ 108 | "### Expected Value\n", 109 | "\n", 110 | "The expected value of a standard normal random variable, $E[Z]$, is...\n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "colab": { 117 | "background_save": true 118 | }, 119 | "id": "mt-8KUGKZRct", 120 | "outputId": "34924efb-6deb-49a2-ca0c-4f0db325cc53" 121 | }, 122 | "source": [ 123 | "Normal().mean()" 124 | ], 125 | "execution_count": null, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "0.0" 131 | ] 132 | }, 133 | "execution_count": null, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": { 142 | "id": "IR_V8vj5ZS5s" 143 | }, 144 | "source": [ 145 | "### Variance\n", 146 | "\n", 147 | "The variance of a standard normal random variable, $\\text{Var}[Z]$, is..." 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "metadata": { 153 | "id": "hcrji_ZGZU-u" 154 | }, 155 | "source": [ 156 | "Normal().var()" 157 | ], 158 | "execution_count": null, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": { 164 | "id": "K0RWPxPoZYT1" 165 | }, 166 | "source": [ 167 | "## The (General) Normal Distribution\n", 168 | "\n", 169 | "The standard normal distribution is centered at 0 with a variance of 1. In general, we can\n", 170 | "- scale the bell shape to be as wide as we want, \n", 171 | "- shift the bell shape to be centered wherever we want.\n", 172 | "\n", 173 | "If $Z$ is standard normal, then \n", 174 | "$$ X = \\mu + \\sigma Z $$\n", 175 | "is $\\text{Normal}(\\mu, \\sigma)$. The parameter $\\mu$ is the expected value, and the parameter $\\sigma$ is the standard deviation. (So $\\sigma^2$ is the variance.)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "id": "-vWADOjsbZ4i" 182 | }, 183 | "source": [ 184 | "## Exercise 1\n", 185 | "Generate a normal distribution with \n", 186 | "1. mean=1, stdev=0.25\n", 187 | "2. mean=1, stdev=0.5\n", 188 | "3. mean=1, stdev=0.75\n", 189 | "4. mean=3, stdev=0.25\n", 190 | "5. mean=3, stdev=0.5\n", 191 | "6. mean=3, stdev=0.75\n", 192 | "\n", 193 | "in the same plot with different colors with legends." 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "metadata": { 199 | "id": "R6kAL_mAekJp" 200 | }, 201 | "source": [ 202 | "# Your code here" 203 | ], 204 | "execution_count": null, 205 | "outputs": [] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": { 210 | "cellView": "form", 211 | "id": "3Hyw2KPXfobM" 212 | }, 213 | "source": [ 214 | "#@title ## Exercise 1 Solution - Try yourself first.\n", 215 | "\n", 216 | "legend_list = []\n", 217 | "for m in [1, 3]:\n", 218 | " for sd in [0.25, 0.5, 0.75]:\n", 219 | " Normal(mean=m, sd=sd).plot()\n", 220 | " legend_list.append([\"mean=%d sd=%0.2f\" %(m, sd)])\n", 221 | "\n", 222 | "\n", 223 | "import matplotlib.pyplot as plt\n", 224 | "plt.legend(legend_list)" 225 | ], 226 | "execution_count": null, 227 | "outputs": [] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "70GimY8YZDsW" 233 | }, 234 | "source": [ 235 | "# Probability\n", 236 | "\n", 237 | "To calculate probabilities, we integrate the p.d.f. over the relevant region. For example,\n", 238 | "\n", 239 | "$$ P(Z \\leq 1) = \\int_{-\\infty}^1 \\frac{1}{\\sqrt{2\\pi}} e^{-z^2 / 2}\\,dz. $$\n", 240 | "\n", 241 | "Unlike other continuous distributions we have studied, the p.d.f. $p(z)$ has no elementary antiderivative. That means that you will not be able to evaluate this integral by paper and pencil, using techniques you learned in calculus. It has to be evaluated numerically. Fortunately, you can do this easily in Symbulate. \n", 242 | "\n", 243 | "For example, $P(Z \\leq 1)$ is just the c.d.f. evaluated at $1$. The c.d.f. of the standard normal distribution is often represented by $\\Phi(z)$. So we need to calculate $\\Phi(1)$." 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "metadata": { 249 | "id": "VdEJLbBhZJFx" 250 | }, 251 | "source": [ 252 | "Normal().cdf(1)" 253 | ], 254 | "execution_count": null, 255 | "outputs": [] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "id": "SM7-6Y98ZK3f" 261 | }, 262 | "source": [ 263 | "## Exercise 2:\n", 264 | "How would you calculate $P(-2 < Z < 2)$?" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "metadata": { 270 | "id": "BJCxH08IZOBx" 271 | }, 272 | "source": [ 273 | "# YOUR CODE HERE\n" 274 | ], 275 | "execution_count": null, 276 | "outputs": [] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "metadata": { 281 | "cellView": "form", 282 | "id": "7DJQ7MKMgLYs" 283 | }, 284 | "source": [ 285 | "#@title Solution 2 - Try yourself first \n", 286 | "# P(-2 < Z < 2) = P(Z < 2) - P(Z < -2)\n", 287 | "\n", 288 | "Normal().cdf([2]) - Normal().cdf([-2]) \n" 289 | ], 290 | "execution_count": null, 291 | "outputs": [] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": { 296 | "id": "qplX9C0R2xiO" 297 | }, 298 | "source": [ 299 | "# Monte Carlo Approximation \n", 300 | "\n", 301 | "## Example 1: Coin Flip Example\n", 302 | "\n", 303 | "The probability of head for a fair coin is 1/2. Monte-Carlo method to simulate the coin-flipping iteratively 5000 times to find out why the probability of a head or tail is always 1/2. \n", 304 | "\n", 305 | "" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "VDZ8iQzRDetH" 312 | }, 313 | "source": [ 314 | "# import require libraries\n", 315 | "import random\n", 316 | "import numpy as np\n", 317 | "import matplotlib.pyplot as plt\n" 318 | ], 319 | "execution_count": null, 320 | "outputs": [] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "metadata": { 325 | "id": "E51dVvVuDlCv" 326 | }, 327 | "source": [ 328 | "# coin flip function:\n", 329 | "# 0 --> Head\n", 330 | "# 1 --> Tail\n", 331 | "\n", 332 | "def coin_flip():\n", 333 | " return random.randint(0, 1)\n", 334 | "\n", 335 | "# check the output of coin_flip\n", 336 | "for i in range(10):\n", 337 | " print('iteration' + str(i) + '--> ' + str(coin_flip()))" 338 | ], 339 | "execution_count": null, 340 | "outputs": [] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "metadata": { 345 | "id": "FNTszy01EA7C" 346 | }, 347 | "source": [ 348 | "# Monte Carlo Simulation\n", 349 | "list1= []\n", 350 | "\n", 351 | "def monte_carlo(n):\n", 352 | " results = 0\n", 353 | " plt.axhline(y=0.5, color='r', linestyle='-')\n", 354 | "\n", 355 | " for i in range(n):\n", 356 | " flip_result = coin_flip()\n", 357 | " results = results + flip_result\n", 358 | "\n", 359 | " # calculate probabibility valuue\n", 360 | " prob_value = results / (i+1)\n", 361 | "\n", 362 | " # append probability to list1\n", 363 | " list1.append(prob_value)\n", 364 | "\n", 365 | " # plot results \n", 366 | " plt.xlabel('iteration')\n", 367 | " plt.ylabel('probability')\n", 368 | " plt.plot(list1)\n", 369 | "\n", 370 | " return results / n" 371 | ], 372 | "execution_count": null, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "metadata": { 378 | "id": "OIgkm927Eiml" 379 | }, 380 | "source": [ 381 | "# call monte carlo functioin\n", 382 | "answer = monte_carlo(10000)\n", 383 | "print('final value of probability: ', answer)" 384 | ], 385 | "execution_count": null, 386 | "outputs": [] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "id": "uq4nxGfCI5U9" 392 | }, 393 | "source": [ 394 | "# Example 2: Estimating Pi from Circle and Square\n", 395 | "\n", 396 | "To know about the history of pi - read https://www.scientificamerican.com/article/what-is-pi-and-how-did-it-originate/
\n", 397 | "\n", 398 | " To estimate the value of Pi, we can use the area of circle and square. \n", 399 | "$$ \\frac{Area \\ Circle}{Area \\ Square} = \\frac{\\pi*r^2}{2r * 2r} $$
\n", 400 | "$$ \\frac{Area \\ Circle}{Area \\ Square} = \\frac{\\pi}{4} $$
\n", 401 | "\n", 402 | "$\\pi$ value can be estimate using the following formula\n", 403 | "$$ \\pi = 4* \\frac{Area \\ Circle}{Area \\ Square} $$\n", 404 | "\n", 405 | "\n", 406 | " \n", 407 | "\n", 408 | "Assuming r = 0.5\n", 409 | "\n", 410 | "length_of_field = 2r = 1.0\n", 411 | "\n" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "metadata": { 417 | "id": "IDsUuJvXNsHl" 418 | }, 419 | "source": [ 420 | "import turtle\n", 421 | "from random import random \n", 422 | "import matplotlib.pyplot as plt\n", 423 | "import math" 424 | ], 425 | "execution_count": null, 426 | "outputs": [] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "metadata": { 431 | "id": "YC4A6swHNwUK" 432 | }, 433 | "source": [ 434 | "# simulate raindrop \n", 435 | "# return x and y coordinates of raindrop\n", 436 | "\n", 437 | "def rain_drop(length_of_field=1):\n", 438 | " \"\"\"\n", 439 | " Simulate a random rain drop\n", 440 | " \"\"\"\n", 441 | " return [(.5 - random()) * length_of_field, (.5 - random()) * length_of_field]\n" 442 | ], 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "metadata": { 449 | "id": "FSjfm-qKOBo7" 450 | }, 451 | "source": [ 452 | "# check if raindrop fall in circle by using circle formula \n", 453 | "\n", 454 | "def is_point_in_circle(point, length_of_field=1):\n", 455 | " \"\"\"\n", 456 | " Return True if point is in inscribed circle\n", 457 | " Use circle formula --> x^2 + y^2 <= r^2\n", 458 | " \"\"\"\n", 459 | " return (point[0]) ** 2 + (point[1]) ** 2 <= (length_of_field / 2) ** 2" 460 | ], 461 | "execution_count": null, 462 | "outputs": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "metadata": { 467 | "id": "Siq4wlDBOCf4" 468 | }, 469 | "source": [ 470 | "def plot_rain_drops(drops_in_circle, drops_out_of_circle, length_of_field=1, format='pdf'):\n", 471 | " \"\"\" Function to draw rain drops \"\"\"\n", 472 | " number_of_drops_in_circle = len(drops_in_circle)\n", 473 | " number_of_drops_out_of_circle = len(drops_out_of_circle)\n", 474 | " number_of_drops = number_of_drops_in_circle + number_of_drops_out_of_circle\n", 475 | " plt.figure()\n", 476 | " plt.xlim(-length_of_field / 2, length_of_field / 2)\n", 477 | " plt.ylim(-length_of_field / 2, length_of_field / 2)\n", 478 | " plt.scatter([e[0] for e in drops_in_circle], [e[1] for e in drops_in_circle], color='blue', label=\"Drops in circle\")\n", 479 | " plt.scatter([e[0] for e in drops_out_of_circle], [e[1] for e in drops_out_of_circle], color='black', label=\"Drops out of circle\")\n", 480 | " plt.legend(loc=\"center\")\n", 481 | " plt.title(\"%s drops: %s landed in circle, estimating $\\pi$ as %.4f.\" % (number_of_drops, number_of_drops_in_circle, 4 * number_of_drops_in_circle / number_of_drops))\n", 482 | " plt.savefig(\"%s_drops.%s\" % (number_of_drops, format))\n" 483 | ], 484 | "execution_count": null, 485 | "outputs": [] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "metadata": { 490 | "id": "6J5qnEUKOFQS" 491 | }, 492 | "source": [ 493 | "# simulate raindrop \n", 494 | "# return total number of raindrop in circle and in square\n", 495 | "\n", 496 | "def rain(number_of_drops=1000, length_of_field=1, plot=True, format='pdf', dynamic=False):\n", 497 | " \"\"\"\n", 498 | " Function to make rain drops.\n", 499 | " \"\"\"\n", 500 | " number_of_drops_in_circle = 0\n", 501 | " drops_in_circle = []\n", 502 | " drops_out_of_circle = []\n", 503 | " pi_estimate = []\n", 504 | " for k in range(number_of_drops):\n", 505 | " d = (rain_drop(length_of_field))\n", 506 | " if is_point_in_circle(d, length_of_field):\n", 507 | " drops_in_circle.append(d)\n", 508 | " number_of_drops_in_circle += 1\n", 509 | " else:\n", 510 | " drops_out_of_circle.append(d)\n", 511 | " if dynamic: # The dynamic option if set to True will plot every new drop (this can be used to create animations of the simulation)\n", 512 | " print(\"Plotting drop number: %s\" % (k + 1))\n", 513 | " plot_rain_drops(drops_in_circle, drops_out_of_circle, length_of_field, format)\n", 514 | " pi_estimate.append(4 * number_of_drops_in_circle / (k + 1)) # This updates the list with the newest estimate for pi.\n", 515 | " # Plot the pi estimates\n", 516 | " plt.figure()\n", 517 | " plt.scatter(range(1, number_of_drops + 1), pi_estimate)\n", 518 | " max_x = plt.xlim()[1]\n", 519 | " plt.hlines(math.pi, 0, max_x, color='black')\n", 520 | " plt.xlim(0, max_x)\n", 521 | " plt.title(\"$\\pi$ estimate against number of rain drops\")\n", 522 | " plt.xlabel(\"Number of rain drops\")\n", 523 | " plt.ylabel(\"$\\pi$\")\n", 524 | " # plt.savefig(\"Pi_estimate_for_%s_drops_thrown.pdf\" % number_of_drops)\n", 525 | "\n", 526 | " if plot and not dynamic:\n", 527 | " # If the plot option is passed and matplotlib is installed this plots\n", 528 | " # the final set of drops\n", 529 | " plot_rain_drops(drops_in_circle, drops_out_of_circle, length_of_field, format)\n", 530 | "\n", 531 | " return [number_of_drops_in_circle, number_of_drops]\n" 532 | ], 533 | "execution_count": null, 534 | "outputs": [] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "metadata": { 539 | "colab": { 540 | "base_uri": "https://localhost:8080/", 541 | "height": 632 542 | }, 543 | "id": "yLLV4UstOHQ7", 544 | "outputId": "04e1adfe-25fa-4527-d301-b3657fc56027" 545 | }, 546 | "source": [ 547 | "# call the function \n", 548 | "number_of_drops = 500\n", 549 | "r = rain(number_of_drops, plot=True, format='png', dynamic=False)\n", 550 | "\n", 551 | "print(\"----------------------\")\n", 552 | "print(\"%s drops\" % number_of_drops)\n", 553 | "print(\"pi estimated as: %s \" % (4 * r[0] / r[1]))\n", 554 | "print(\"----------------------\")" 555 | ], 556 | "execution_count": null, 557 | "outputs": [ 558 | { 559 | "output_type": "stream", 560 | "name": "stdout", 561 | "text": [ 562 | "----------------------\n", 563 | "500 drops\n", 564 | "pi estimated as: 3.128 \n", 565 | "----------------------\n" 566 | ] 567 | }, 568 | { 569 | "output_type": "display_data", 570 | "data": { 571 | "text/plain": [ 572 | "
" 573 | ], 574 | "image/png": "\n" 575 | }, 576 | "metadata": { 577 | "needs_background": "light" 578 | } 579 | }, 580 | { 581 | "output_type": "display_data", 582 | "data": { 583 | "text/plain": [ 584 | "
" 585 | ], 586 | "image/png": "\n" 587 | }, 588 | "metadata": { 589 | "needs_background": "light" 590 | } 591 | } 592 | ] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": { 597 | "id": "Si-mUBcUSuJr" 598 | }, 599 | "source": [ 600 | "## Now try increasing number_of_drops and check the value of $\\pi$. \n", 601 | "\n", 602 | "## At what value of number_of_drops does the $\\pi$ value approaches 3.14? Write down your answer below. " 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "metadata": { 608 | "id": "Fu1QGY8gTHxa" 609 | }, 610 | "source": [ 611 | "# write your answer here. \n" 612 | ], 613 | "execution_count": null, 614 | "outputs": [] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": { 619 | "id": "gcHiPdooiubi" 620 | }, 621 | "source": [ 622 | "# ***Let's go back to power point - slide 22***" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "metadata": { 628 | "id": "0-_f8CZnVmYO" 629 | }, 630 | "source": [ 631 | "# Multivariate Gaussian Distribution\n", 632 | "\n", 633 | "For two continuous random variables, plot type density uses the simulated $ (x,y)$ to estimate the joint probability density function and plot it.\n", 634 | "\n", 635 | "### Example. Assume
\n", 636 | "mean of X = 1, mean of Y = 2
\n", 637 | "variance of X = 2, variance of Y = 4
\n", 638 | "covariance of xy and yx = 1" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "metadata": { 644 | "id": "lY6YpCNPVpsk" 645 | }, 646 | "source": [ 647 | "mu = [1, 2]\n", 648 | "Sigma = [[2, 1],\n", 649 | " [1, 4]]\n", 650 | "\n", 651 | "X, Y = RV(MultivariateNormal(mean = mu, cov = Sigma))\n", 652 | "Z = X + Y\n" 653 | ], 654 | "execution_count": null, 655 | "outputs": [] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "metadata": { 660 | "colab": { 661 | "background_save": true 662 | }, 663 | "id": "9FVnrgThbEyw", 664 | "outputId": "77a76ae7-3984-45c2-e022-cdfda14d3eae" 665 | }, 666 | "source": [ 667 | "# understand each output \n", 668 | "\n", 669 | "x = X.sim(10000)\n", 670 | "y = Y.sim(10000)\n", 671 | "z = Z.sim(10000)\n", 672 | "print('X mean:', x.mean())\n", 673 | "print('Y mean:', y.mean())\n", 674 | "print('Z mean:', z.mean())\n", 675 | "print('X variance:', x.sd()**2)\n", 676 | "print('Y variance:', y.sd()**2)\n", 677 | "print('Z variance:', z.sd()**2)\n" 678 | ], 679 | "execution_count": null, 680 | "outputs": [ 681 | { 682 | "name": "stdout", 683 | "output_type": "stream", 684 | "text": [ 685 | "X mean: 1.0258520182513167\n", 686 | "Y mean: 1.9758308752410725\n", 687 | "Z mean: 2.984547830463785\n", 688 | "X variance: 1.9987250660153855\n", 689 | "Y variance: 4.012078198155444\n", 690 | "Z variance: 8.089690367918754\n" 691 | ] 692 | } 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "metadata": { 698 | "colab": { 699 | "background_save": true 700 | }, 701 | "id": "BP6E_Ytdd7Km", 702 | "outputId": "60487e35-a564-46c6-df2e-70e0e2d5056b" 703 | }, 704 | "source": [ 705 | "(X & Y).sim(10000).plot(type=\"density\")" 706 | ], 707 | "execution_count": null, 708 | "outputs": [ 709 | { 710 | "data": { 711 | "image/png": "\n", 712 | "text/plain": [ 713 | "
" 714 | ] 715 | }, 716 | "metadata": {}, 717 | "output_type": "display_data" 718 | } 719 | ] 720 | }, 721 | { 722 | "cell_type": "markdown", 723 | "metadata": { 724 | "id": "VPiOeOD1a-FF" 725 | }, 726 | "source": [ 727 | "# Exercise 3 \n", 728 | "Generate the Multivariate Gaussian as shown below given variance of X = 2, variance of Y = 4\n", 729 | "\n", 730 | " " 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "metadata": { 736 | "colab": { 737 | "background_save": true 738 | }, 739 | "id": "0fYiOs5_lVBl" 740 | }, 741 | "source": [ 742 | "# your code here" 743 | ], 744 | "execution_count": null, 745 | "outputs": [] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "metadata": { 750 | "cellView": "code", 751 | "id": "kF_99p4Jg4AP" 752 | }, 753 | "source": [ 754 | "#@title Solution - Try yourself first\n", 755 | "mu = [4, 2]\n", 756 | "Sigma = [[2, -1],\n", 757 | " [-1, 4]]\n", 758 | "\n", 759 | "X, Y = RV(MultivariateNormal(mean = mu, cov = Sigma))\n", 760 | "(X & Y).sim(10000).plot(type=\"density\")" 761 | ], 762 | "execution_count": null, 763 | "outputs": [] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": { 768 | "id": "9lbejY3TVZsJ" 769 | }, 770 | "source": [ 771 | "# ***Let's go back to power point - slide 36***" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "metadata": { 777 | "id": "pQXygBZFK9MJ" 778 | }, 779 | "source": [ 780 | "# Regularization \n", 781 | "\n", 782 | "Here we examine how regularizer in Ridge regression help in reducing overfitting. " 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "metadata": { 788 | "id": "p0abOD4LK8oO" 789 | }, 790 | "source": [ 791 | "import numpy as np\n", 792 | "import matplotlib.pyplot as plt\n", 793 | "\n", 794 | "from sklearn.preprocessing import PolynomialFeatures\n", 795 | "from sklearn.linear_model import Ridge\n", 796 | "from sklearn.preprocessing import MinMaxScaler \n", 797 | "from sklearn.metrics import mean_squared_error as mse\n" 798 | ], 799 | "execution_count": null, 800 | "outputs": [] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "1HxAoiAGL3VY" 806 | }, 807 | "source": [ 808 | "# generate 1d regression data \n", 809 | "def make_1dregression_data(n=21):\n", 810 | " np.random.seed(0)\n", 811 | " xtrain = np.linspace(0.0, 20, n)\n", 812 | " xtest = np.arange(0.0, 20, 0.1)\n", 813 | " sigma2 = 4\n", 814 | " w = np.array([-1.5, 1/9.])\n", 815 | " fun = lambda x: w[0]*x + w[1]*np.square(x)\n", 816 | " ytrain = fun(xtrain) + np.random.normal(0, 1, xtrain.shape) * \\\n", 817 | " np.sqrt(sigma2)\n", 818 | " ytest= fun(xtest) + np.random.normal(0, 1, xtest.shape) * \\\n", 819 | " np.sqrt(sigma2)\n", 820 | " return xtrain, ytrain, xtest, ytest\n" 821 | ], 822 | "execution_count": null, 823 | "outputs": [] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "metadata": { 828 | "id": "oKlktL3rM3jI" 829 | }, 830 | "source": [ 831 | "import numpy as np\n", 832 | "n=21\n", 833 | "xtrain = np.linspace(0.0, 20, n)\n", 834 | "w = np.array([-1.5, 1/9.])\n", 835 | "fun = lambda x: w[0]*x + w[1]*np.square(x)\n", 836 | "\n", 837 | "ytrain = fun(xtrain)" 838 | ], 839 | "execution_count": null, 840 | "outputs": [] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "metadata": { 845 | "id": "2-2X7p5E4JZu" 846 | }, 847 | "source": [], 848 | "execution_count": null, 849 | "outputs": [] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "metadata": { 854 | "id": "Kd-943vgL-Gp" 855 | }, 856 | "source": [ 857 | "# split data into train and test\n", 858 | "xtrain, ytrain, xtest, ytest = make_1dregression_data(n=21)\n", 859 | "\n", 860 | "#Rescaling data\n", 861 | "scaler = MinMaxScaler(feature_range=(-1, 1))\n", 862 | "Xtrain = scaler.fit_transform(xtrain.reshape(-1, 1))\n", 863 | "Xtest = scaler.transform(xtest.reshape(-1, 1))\n" 864 | ], 865 | "execution_count": null, 866 | "outputs": [] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "metadata": { 871 | "id": "A_qzH6A8MCdJ" 872 | }, 873 | "source": [ 874 | "# fit Ridge model with different regularizer strength\n", 875 | "deg = 14\n", 876 | "alphas = np.logspace(-10, 1.3, 10) # Regularization strength\n", 877 | "nalphas = len(alphas)\n", 878 | "mse_train = np.empty(nalphas)\n", 879 | "mse_test = np.empty(nalphas)\n", 880 | "ytest_pred_stored = dict()\n", 881 | "\n", 882 | "\n", 883 | "for i, alpha in enumerate(alphas):\n", 884 | " model = Ridge(alpha=alpha, fit_intercept=False)\n", 885 | " poly_features = PolynomialFeatures(degree=deg, include_bias=False) # create 14 features which is used as X\n", 886 | " Xtrain_poly = poly_features.fit_transform(Xtrain)\n", 887 | " model.fit(Xtrain_poly, ytrain)\n", 888 | " ytrain_pred = model.predict(Xtrain_poly)\n", 889 | " Xtest_poly = poly_features.transform(Xtest)\n", 890 | " ytest_pred = model.predict(Xtest_poly)\n", 891 | " mse_train[i] = mse(ytrain_pred, ytrain) \n", 892 | " mse_test[i] = mse(ytest_pred, ytest)\n", 893 | " ytest_pred_stored[alpha] = ytest_pred\n", 894 | " \n" 895 | ], 896 | "execution_count": null, 897 | "outputs": [] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "metadata": { 902 | "id": "B8Tbv5RrMmRe" 903 | }, 904 | "source": [ 905 | "# Plot MSE vs degree\n", 906 | "fig, ax = plt.subplots()\n", 907 | "mask = [True]*nalphas\n", 908 | "ax.plot(alphas[mask], mse_test[mask], color = 'r', marker = 'x',label='test')\n", 909 | "ax.plot(alphas[mask], mse_train[mask], color='b', marker = 's', label='train')\n", 910 | "ax.set_xscale('log')\n", 911 | "ax.legend(loc='upper right', shadow=True)\n", 912 | "plt.xlabel('L2 regularizer')\n", 913 | "plt.ylabel('mse')\n", 914 | "plt.show()\n" 915 | ], 916 | "execution_count": null, 917 | "outputs": [] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "metadata": { 922 | "id": "Yu_yqywDMn2U" 923 | }, 924 | "source": [ 925 | "# Plot fitted functions\n", 926 | "chosen_alphas = alphas[[0,5,8]]\n", 927 | "for i, alpha in enumerate(alphas):\n", 928 | " fig, ax = plt.subplots()\n", 929 | " ax.scatter(xtrain, ytrain)\n", 930 | " ax.plot(xtest, ytest_pred_stored[alpha])\n", 931 | " plt.title('L2 regularizer {:0.5e}'.format(alpha))\n", 932 | " plt.show()" 933 | ], 934 | "execution_count": null, 935 | "outputs": [] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "metadata": { 940 | "id": "9ms-23NROC0Y" 941 | }, 942 | "source": [ 943 | "# Exercise 4: \n", 944 | "How do you choose what regularizer strength is optimal?? Explain your answer at the following cell. " 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "metadata": { 950 | "id": "MxwId3RjOP2g" 951 | }, 952 | "source": [ 953 | "# your answer here\n" 954 | ], 955 | "execution_count": null, 956 | "outputs": [] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "metadata": { 961 | "id": "7wHhZ-mooapM", 962 | "cellView": "form" 963 | }, 964 | "source": [ 965 | "#@title Solution \n", 966 | "\n", 967 | "Use cross validation - explain the process" 968 | ], 969 | "execution_count": null, 970 | "outputs": [] 971 | }, 972 | { 973 | "cell_type": "markdown", 974 | "source": [ 975 | "# Submission Instructions\n", 976 | "\n", 977 | "Once you are finished, follow these steps:\n", 978 | "\n", 979 | "Restart the kernel and re-run this notebook from beginning to end by going to Kernel > Restart Kernel and Run All Cells. If this process stops halfway through, that means there was an error. Correct the error and repeat Step 1 until the notebook runs from beginning to end. Double check that there is a number next to each code cell and that these numbers are in order. Then, submit your lab as follows:\n", 980 | "\n", 981 | "Go to File > Print > Save as PDF. Double check that the entire notebook, from beginning to end, is in this PDF file. Make sure Solution for Exercise 4 are in for marks. Upload the PDF to Spectrum." 982 | ], 983 | "metadata": { 984 | "id": "lN7_59o5ZkgZ" 985 | } 986 | }, 987 | { 988 | "cell_type": "markdown", 989 | "metadata": { 990 | "id": "FcqpGzfGyWxL" 991 | }, 992 | "source": [ 993 | "# Acknowledgement\n", 994 | "\n", 995 | "The works are inspired from \n", 996 | "1. Normal Distribtion - https://colab.research.google.com/github/dlsun/Stat350F19/blob/master/Normal_Distribution.ipynb#scrollTo=4K2s06RQFP_1 \n", 997 | "2. Coin Flip Example - https://pub.towardsai.net/monte-carlo-simulation-an-in-depth-tutorial-with-python-bcf6eb7856c8 \n", 998 | "3. Estimating $\\pi$ from circle and square = https://www.youtube.com/watch?v=VJTFfIqO4TU " 999 | ] 1000 | } 1001 | ] 1002 | } -------------------------------------------------------------------------------- /Week3/READMe.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Week3/W3.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week3/W3.pptx -------------------------------------------------------------------------------- /Week3/WOA7015_Wk3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "WOA7015_Wk3.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyNIgFUm0FAZOu775EIoFBzm", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "OzY5ZsYGoDdu" 35 | }, 36 | "source": [ 37 | "# Welcome to WOA7015 Advance Machine Learning Lab - Week 3\n", 38 | "This code is generated for the purpose of WOA7015 module.\n", 39 | "The code is available in github https://github.com/shiernee/Advanced_ML \n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "nVerYxHW-ZrQ" 46 | }, 47 | "source": [ 48 | "# The effect of imbalanced data on AUROC \n", 49 | "The following code evaluates the effect of imbalanced data on the AUROC of TPR-FPR curve. \n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "c87yzg0goBrP" 56 | }, 57 | "source": [ 58 | "# roc curve and auc on an imbalanced dataset\n", 59 | "import numpy as np\n", 60 | "from sklearn.datasets import make_classification\n", 61 | "from sklearn.linear_model import LogisticRegression\n", 62 | "from sklearn.model_selection import train_test_split\n", 63 | "from sklearn.metrics import roc_curve\n", 64 | "from sklearn.metrics import roc_auc_score\n", 65 | "import matplotlib.pyplot as plt\n", 66 | "from imblearn.under_sampling import RandomUnderSampler\n" 67 | ], 68 | "execution_count": 36, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "colab": { 75 | "base_uri": "https://localhost:8080/" 76 | }, 77 | "id": "hcpJntDPEq2J", 78 | "outputId": "6e6cbb8c-8ee1-43a6-91bc-64bdfad9412f" 79 | }, 80 | "source": [ 81 | "# generate 2 class dataset \n", 82 | "X, y = make_classification(n_samples=1000, n_classes=2, random_state=1000)\n", 83 | "\n", 84 | "print(X)\n", 85 | "print('-----------')\n", 86 | "print(y)\n" 87 | ], 88 | "execution_count": 37, 89 | "outputs": [ 90 | { 91 | "output_type": "stream", 92 | "name": "stdout", 93 | "text": [ 94 | "[[-0.32584935 0.21897754 0.62061895 ... 2.84071377 -0.02582733\n", 95 | " -0.40885762]\n", 96 | " [-1.12624124 -0.86026727 -0.89264356 ... -0.92962064 0.59483549\n", 97 | " 1.24052468]\n", 98 | " [-0.48993428 -0.7453348 -1.43801838 ... -1.67525801 -0.09994425\n", 99 | " -0.46569289]\n", 100 | " ...\n", 101 | " [ 0.47406074 -1.9209351 0.41681779 ... 1.04574815 1.092832\n", 102 | " -0.01541749]\n", 103 | " [-0.62731673 -0.94336697 -1.50694171 ... -0.85092941 0.99046917\n", 104 | " 2.19583454]\n", 105 | " [ 0.88990126 0.81857103 -2.12551556 ... 1.00271323 -0.88101446\n", 106 | " -0.81149645]]\n", 107 | "-----------\n", 108 | "[1 0 0 0 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 0 0 0 1 1 1 1 1 0 1 1 0 0 0 1 0 0 0\n", 109 | " 0 0 0 1 1 0 1 0 1 1 0 0 0 1 0 0 1 0 0 1 1 0 1 0 0 0 0 1 0 0 1 1 1 0 0 0 0\n", 110 | " 0 0 0 1 0 1 0 1 0 1 0 1 0 0 0 1 1 1 1 0 1 0 0 0 0 1 1 0 1 1 1 1 0 0 0 0 1\n", 111 | " 0 1 1 0 0 1 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1 1 0 0 1 1 1 0 1 0 0 0 0 1 0 1\n", 112 | " 0 0 1 1 0 0 1 1 1 0 1 1 1 0 1 0 0 0 0 0 0 1 0 1 1 1 1 1 0 1 0 0 0 0 1 0 0\n", 113 | " 1 0 1 1 0 1 1 1 1 0 1 0 0 0 0 1 1 1 0 1 1 0 1 1 0 1 0 1 1 1 1 1 0 0 0 1 0\n", 114 | " 1 1 1 1 1 1 0 1 0 0 0 0 1 1 1 1 0 0 0 1 0 1 1 0 0 1 1 1 0 0 1 0 0 0 1 0 1\n", 115 | " 1 1 1 1 0 0 0 1 1 0 0 0 1 1 0 1 1 0 1 1 1 1 0 0 0 1 1 0 1 0 0 1 0 1 1 1 0\n", 116 | " 1 1 0 0 1 0 0 0 1 0 0 1 1 0 1 0 1 1 0 1 0 0 0 1 1 1 0 0 0 0 1 0 0 0 1 1 0\n", 117 | " 1 1 1 0 0 0 1 0 0 0 0 1 1 1 0 0 0 1 0 1 0 0 0 1 1 0 1 0 0 1 0 1 1 0 1 0 0\n", 118 | " 1 1 0 0 1 0 0 1 1 1 0 0 1 0 0 1 1 0 1 0 1 1 0 1 1 0 0 1 0 1 0 1 0 1 1 1 0\n", 119 | " 0 0 0 0 1 1 1 1 1 1 1 0 1 0 0 0 0 0 1 1 0 1 1 0 0 1 0 1 0 0 0 1 0 1 1 0 0\n", 120 | " 0 1 0 1 0 0 1 0 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 1 0 0\n", 121 | " 1 1 1 1 1 1 0 0 0 0 0 0 0 0 1 0 0 1 0 1 1 1 1 1 1 1 1 0 0 1 1 0 1 0 0 0 0\n", 122 | " 1 0 0 0 1 0 1 1 0 1 1 1 0 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 0 1 0 1 0 1\n", 123 | " 0 0 1 0 0 0 0 0 1 0 1 1 1 1 1 1 0 0 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0\n", 124 | " 0 0 1 0 0 0 1 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 0 0 1 1 1 1 0 1 0 0 1 1 0 0 0\n", 125 | " 0 1 1 1 0 1 0 0 1 0 1 0 1 1 0 1 1 1 0 0 1 1 1 1 0 1 0 0 0 1 1 1 0 1 1 1 0\n", 126 | " 0 0 1 0 0 0 0 0 1 0 1 0 0 1 1 1 1 0 1 1 0 1 0 1 0 1 0 0 1 1 1 0 1 0 0 0 1\n", 127 | " 0 0 0 1 0 1 0 1 0 0 0 1 0 0 1 1 1 1 1 1 0 1 1 0 0 1 0 0 0 1 0 0 0 1 1 0 0\n", 128 | " 0 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 1 0 0 0 1 0 1 1 0 1 1 0 1 1 0 0 1 0 0 0 1\n", 129 | " 1 1 0 0 0 1 0 1 0 0 1 1 1 1 1 0 1 0 1 1 0 1 1 1 1 1 1 0 0 0 0 1 0 0 1 0 1\n", 130 | " 1 1 1 1 0 0 0 1 0 1 1 0 1 1 1 0 1 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 0 1 1 0\n", 131 | " 0 0 0 1 1 0 1 1 1 0 0 1 1 1 0 1 0 1 0 1 1 1 1 0 0 1 1 0 0 1 0 0 0 1 1 0 1\n", 132 | " 1 1 1 1 1 0 0 0 1 1 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1 1 0 1 0 1 1 1 0 0 1 1 0\n", 133 | " 1 0 1 0 1 1 0 0 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 0 0 0 1 0 0 0 1 1 0 0 0 0 0\n", 134 | " 0 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 0 0 1 1 0 1 0 0 1 0 0 1 1 0 1 1 1 0\n", 135 | " 0]\n" 136 | ] 137 | } 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "metadata": { 143 | "colab": { 144 | "base_uri": "https://localhost:8080/" 145 | }, 146 | "id": "wLo12OKVE__J", 147 | "outputId": "a6dbfbe9-fecd-4d53-b281-c7193c79caa2" 148 | }, 149 | "source": [ 150 | "# split into train/test sets\n", 151 | "trainX, testX, trainy, testy = train_test_split(X, y, test_size=0.5, random_state=1000)\n", 152 | "\n", 153 | "print('trainy - class0: ', len(trainy)-trainy.sum())\n", 154 | "print('trainy - class1: ', trainy.sum())\n", 155 | "print('----------------------')\n", 156 | "print('testy - class0: ', len(testy)-testy.sum())\n", 157 | "print('testy - class1: ', testy.sum())\n", 158 | "print('============================')\n", 159 | "\n", 160 | "# make testing dataset balance\n", 161 | "undersample = RandomUnderSampler(sampling_strategy='majority')\n", 162 | "testX, testy = undersample.fit_resample(testX, testy)\n", 163 | "\n", 164 | "print('Balanced Testing date')\n", 165 | "print('testy - class0: ', len(testy)-testy.sum())\n", 166 | "print('testy - class1: ', testy.sum())" 167 | ], 168 | "execution_count": 38, 169 | "outputs": [ 170 | { 171 | "output_type": "stream", 172 | "name": "stdout", 173 | "text": [ 174 | "trainy - class0: 253\n", 175 | "trainy - class1: 247\n", 176 | "----------------------\n", 177 | "testy - class0: 249\n", 178 | "testy - class1: 251\n", 179 | "============================\n", 180 | "Balanced Testing date\n", 181 | "testy - class0: 249\n", 182 | "testy - class1: 249\n" 183 | ] 184 | }, 185 | { 186 | "output_type": "stream", 187 | "name": "stderr", 188 | "text": [ 189 | "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", 190 | " warnings.warn(msg, category=FutureWarning)\n" 191 | ] 192 | } 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "metadata": { 198 | "id": "Y3QmySaLE7Nm" 199 | }, 200 | "source": [ 201 | "# fit a model with training data\n", 202 | "model = LogisticRegression(solver='lbfgs')\n", 203 | "model.fit(trainX, trainy)\n" 204 | ], 205 | "execution_count": null, 206 | "outputs": [] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "metadata": { 211 | "id": "nOu_783uGpZd" 212 | }, 213 | "source": [ 214 | "# repeat with different skewness \n", 215 | "roc_list = []\n", 216 | "lr_acc = []\n", 217 | "k=1\n", 218 | "for i in range(0, 10):\n", 219 | " pos_ind = np.where(testy==1)[0]\n", 220 | " n = int(i/10 * len(pos_ind))\n", 221 | " tmp_testX, tmp_testy = np.copy(testX), np.copy(testy)\n", 222 | " tmp_testX = np.delete(tmp_testX, pos_ind[:n], axis=0)\n", 223 | " tmp_testy = np.delete(tmp_testy, pos_ind[:n], axis=0)\n", 224 | " print('nth %d:positive: %d negative: %d' \n", 225 | " % (i, tmp_testy.sum(), tmp_testy.shape[0] - tmp_testy.sum()))\n", 226 | " print('---------------------------------------------')\n", 227 | " \n", 228 | " # predict probabilities\n", 229 | " lr_probs = model.predict_proba(tmp_testX)\n", 230 | " # keep probabilities for the positive outcome only\n", 231 | " lr_probs = lr_probs[:, 1]\n", 232 | " # calculate scores\n", 233 | " lr_auc = roc_auc_score(tmp_testy, lr_probs)\n", 234 | "\n", 235 | " # summarize scores\n", 236 | " # print('iteration %d: Logistic: ROC AUC=%.3f' % (k, lr_auc))\n", 237 | " k += 1\n", 238 | " # calculate roc curves\n", 239 | " lr_fpr, lr_tpr, _ = roc_curve(tmp_testy, lr_probs)\n", 240 | " roc_list.append(lr_auc)\n", 241 | "\n", 242 | "plt.plot(np.arange(0, len(roc_list)), roc_list)\n", 243 | "plt.xlabel('skewness ratio')\n", 244 | "plt.ylabel('AUROC')\n", 245 | "plt.title('decreasing positive sample')\n" 246 | ], 247 | "execution_count": null, 248 | "outputs": [] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": { 253 | "id": "njFP2GHpC1VE" 254 | }, 255 | "source": [ 256 | "# Exercise 1 (2%):\n", 257 | "Does the AUROC (TPR vs FPR) affected by imbalanced class?\n", 258 | "\n", 259 | "\n" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "metadata": { 265 | "id": "vOIpcKliC56h" 266 | }, 267 | "source": [ 268 | "# Your answer here\n" 269 | ], 270 | "execution_count": null, 271 | "outputs": [] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": { 276 | "id": "KK4Cxp0q75PM" 277 | }, 278 | "source": [ 279 | "# The effect of imbalanced data on AUROC of PR curve and F1 score\n", 280 | "The following code evaluates the effect of imbalanced data on the AUROC of Precision-Recall and F1 value. \n" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "id": "oYxjJuD_8ewJ" 287 | }, 288 | "source": [ 289 | "# roc curve and auc on an imbalanced dataset\n", 290 | "import numpy as np\n", 291 | "from sklearn.datasets import make_classification\n", 292 | "from sklearn.linear_model import LogisticRegression\n", 293 | "from sklearn.model_selection import train_test_split\n", 294 | "from sklearn.metrics import auc, f1_score\n", 295 | "from sklearn.metrics import precision_recall_curve\n", 296 | "import matplotlib.pyplot as plt\n" 297 | ], 298 | "execution_count": 23, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "id": "nr_uY_mPHLTF" 305 | }, 306 | "source": [ 307 | "# generate 2 class dataset \n", 308 | "X, y = make_classification(n_samples=1000, n_classes=2, random_state=1000)\n", 309 | "\n", 310 | "print(X)\n", 311 | "print('-----------')\n", 312 | "print(y)\n" 313 | ], 314 | "execution_count": null, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "metadata": { 320 | "id": "bbmv6pFhHWyQ" 321 | }, 322 | "source": [ 323 | "# split into train/test sets\n", 324 | "trainX, testX, trainy, testy = train_test_split(X, y, test_size=0.5, random_state=1000)\n", 325 | "\n", 326 | "print('trainy - class0: ', len(trainy)-trainy.sum())\n", 327 | "print('trainy - class1: ', trainy.sum())\n", 328 | "print('----------------------')\n", 329 | "print('testy - class0: ', len(testy)-testy.sum())\n", 330 | "print('testy - class1: ', testy.sum())\n", 331 | "print('============================')\n", 332 | "\n", 333 | "# make testing dataset balance\n", 334 | "undersample = RandomUnderSampler(sampling_strategy='majority')\n", 335 | "testX, testy = undersample.fit_resample(testX, testy)\n", 336 | "\n", 337 | "print('Balanced Testing date')\n", 338 | "print('testy - class0: ', len(testy)-testy.sum())\n", 339 | "print('testy - class1: ', testy.sum())\n" 340 | ], 341 | "execution_count": null, 342 | "outputs": [] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "metadata": { 347 | "id": "i1ST-YEcHh5I" 348 | }, 349 | "source": [ 350 | "# fit a model\n", 351 | "model = LogisticRegression(solver='lbfgs')\n", 352 | "model.fit(trainX, trainy)" 353 | ], 354 | "execution_count": null, 355 | "outputs": [] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "metadata": { 360 | "id": "scPocodDHRp-" 361 | }, 362 | "source": [ 363 | "# repeat with different skewness \n", 364 | "roc_list = []\n", 365 | "f1_list = []\n", 366 | "\n", 367 | "k=1\n", 368 | "for i in range(0, 10):\n", 369 | " pos_ind = np.where(testy==1)[0]\n", 370 | " n = int(i/10 * len(pos_ind))\n", 371 | " tmp_testX, tmp_testy = np.copy(testX), np.copy(testy)\n", 372 | " tmp_testX = np.delete(tmp_testX, pos_ind[:n], axis=0)\n", 373 | " tmp_testy = np.delete(tmp_testy, pos_ind[:n], axis=0)\n", 374 | " print('nth %d:positive: %d negative: %d' \n", 375 | " % (i, tmp_testy.sum(), tmp_testy.shape[0] - tmp_testy.sum()))\n", 376 | " print('---------------------------------------------')\n", 377 | " \n", 378 | "\n", 379 | " # predict probabilities\n", 380 | " lr_probs = model.predict_proba(tmp_testX)\n", 381 | " # keep probabilities for the positive outcome only\n", 382 | " lr_probs = lr_probs[:, 1]\n", 383 | " # predict class values\n", 384 | " yhat = model.predict(tmp_testX)\n", 385 | " # calculate precision and recall for each threshold\n", 386 | " lr_precision, lr_recall, _ = precision_recall_curve(tmp_testy, lr_probs)\n", 387 | " # calculate scores\n", 388 | " lr_f1, lr_auc = f1_score(tmp_testy, yhat), auc(lr_recall, lr_precision)\n", 389 | " # summarize scores\n", 390 | " # print('iteration%d Logistic: f1=%.3f auc=%.3f' % (k, lr_f1, lr_auc))\n", 391 | " k += 1\n", 392 | " roc_list.append(lr_auc)\n", 393 | " f1_list.append(lr_f1)\n", 394 | "\n", 395 | "plt.plot(np.arange(0, len(roc_list)), roc_list)\n", 396 | "plt.xlabel('skewness ratio')\n", 397 | "plt.ylabel('AUC of PR curve')\n", 398 | "plt.title('decreasing positive sample')\n", 399 | "\n", 400 | "plt.figure()\n", 401 | "plt.plot(np.arange(0, len(roc_list)), f1_list)\n", 402 | "plt.xlabel('skewness ratio')\n", 403 | "plt.ylabel('F1')\n", 404 | "plt.title('decreasing positive sample')\n" 405 | ], 406 | "execution_count": null, 407 | "outputs": [] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": { 412 | "id": "2DP3LOZY7v6M" 413 | }, 414 | "source": [ 415 | "# Exercise 2 (4%):\n", 416 | "Does the AUROC (Precision vs Recall), F1 score affected by imbalanced class?" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "iPXGu2E00-c5" 423 | }, 424 | "source": [ 425 | "# Your answer here" 426 | ], 427 | "execution_count": null, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "metadata": { 433 | "id": "_h1qUoy4DSc2" 434 | }, 435 | "source": [ 436 | "# ***Let's go back to power point - slide 13***" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "id": "SK3GJxwgzEjP" 443 | }, 444 | "source": [ 445 | "# Convex function\n", 446 | "\n", 447 | "This is the code to generate the graph in slide 38" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "metadata": { 453 | "id": "3v0il6m3zOQb" 454 | }, 455 | "source": [ 456 | "import numpy as np\n", 457 | "import matplotlib.pyplot as plt\n", 458 | "import imageio\n", 459 | "\n", 460 | "x = np.arange(-2, 2, 0.01)\n", 461 | "\n", 462 | "# choose one function to try\n", 463 | "f = lambda x: 0.5 * x ** 2 # Convex\n", 464 | "# f = lambda x: np.cos(np.pi * x) # Nonconvex\n", 465 | "# f = lambda x: -0.5 * x ** 4 # Nonconvex\n", 466 | "\n", 467 | "filenames=[]\n", 468 | "for lamda in np.arange(0, 1, 0.02):\n", 469 | " # LHS\n", 470 | " tmp_x = lamda*x[0] + (1-lamda)*x[-1]\n", 471 | "\n", 472 | " # RHS\n", 473 | " x_line, y_line = np.array([x[0], x[-1]]), np.array([lamda*f(x[0]), (1-lamda)*f(x[-1])])\n", 474 | "\n", 475 | " # compute LHS and RHS\n", 476 | " LHS = f(tmp_x)\n", 477 | " RHS = lamda*f(x[0]) + (1-lamda)*f(x[-1])\n", 478 | " if LHS > RHS:\n", 479 | " print('At lamda %0.3f, it is concave' % lamda)\n", 480 | " print('lhs %.5f rhs %.5f' % (LHS, RHS))\n", 481 | "\n", 482 | " plt.figure()\n", 483 | " # original graph\n", 484 | " plt.plot(x, f(x), label='f(x)')\n", 485 | " # plot RHS\n", 486 | " plt.plot(x_line, y_line, label='%0.3f' % lamda)\n", 487 | " # plot LHS\n", 488 | " plt.scatter(tmp_x, f(tmp_x))\n", 489 | " #title, legennd\n", 490 | " plt.title('lhs %.3f rhs %.3f' % (LHS, RHS))\n", 491 | " plt.legend()\n", 492 | " plt.savefig('lamda %0.3f.png' % lamda)\n", 493 | " # plt.close()\n", 494 | " filenames.append('lamda %0.3f.png' % lamda)\n", 495 | "\n", 496 | "# Build GIF\n", 497 | "with imageio.get_writer('mygif.gif', mode='I') as writer:\n", 498 | " for filename in filenames:\n", 499 | " image = imageio.imread(filename)\n", 500 | " writer.append_data(image)" 501 | ], 502 | "execution_count": null, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": { 508 | "id": "l9Wuu9o0pGmg" 509 | }, 510 | "source": [ 511 | "# Understand how learning rate affects your SGD optimization\n", 512 | "\n", 513 | "We will train a neural network for a pretty simple task, i.e. calculating the exclusive-or (XOR) of two input. \n", 514 | "\n", 515 | "
\n", 516 | "\n", 517 | "\n" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "metadata": { 523 | "id": "Qr2IABWKoGHk" 524 | }, 525 | "source": [ 526 | "import random\n", 527 | "import numpy as np" 528 | ], 529 | "execution_count": null, 530 | "outputs": [] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "colab": { 536 | "base_uri": "https://localhost:8080/" 537 | }, 538 | "id": "3PpJZLhyndlE", 539 | "outputId": "b780d629-bcb3-46dd-82f4-b50699ec8d0a" 540 | }, 541 | "source": [ 542 | "# generate a function for XOR\n", 543 | "x1 = random.randint(0, 1)\n", 544 | "x2 = random.randint(0, 1)\n", 545 | "yy = 0 if (x1 == x2) else 1\n", 546 | "\n", 547 | "print('x1:', x1)\n", 548 | "print('x2:',x2)\n", 549 | "print('yy:',yy)" 550 | ], 551 | "execution_count": 52, 552 | "outputs": [ 553 | { 554 | "output_type": "stream", 555 | "name": "stdout", 556 | "text": [ 557 | "1\n", 558 | "1\n", 559 | "0\n" 560 | ] 561 | } 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "metadata": { 567 | "colab": { 568 | "base_uri": "https://localhost:8080/" 569 | }, 570 | "id": "sucvJe_fntPU", 571 | "outputId": "3dff8ccc-9361-430c-f2a8-5fe510e90b2f" 572 | }, 573 | "source": [ 574 | "x1 = random.randint(0, 1)\n", 575 | "x2 = random.randint(0, 1)\n", 576 | "yy = 0 if (x1 == x2) else 1\n", 577 | "\n", 578 | "# centered at zero\n", 579 | "x1 = 2. * (x1 - 0.5)\n", 580 | "x2 = 2. * (x2 - 0.5)\n", 581 | "yy = 2. * (yy - 0.5)\n", 582 | "\n", 583 | "print('x1:', x1)\n", 584 | "print('x2:',x2)\n", 585 | "print('yy:',yy)" 586 | ], 587 | "execution_count": 54, 588 | "outputs": [ 589 | { 590 | "output_type": "stream", 591 | "name": "stdout", 592 | "text": [ 593 | "x1: -1.0\n", 594 | "x2: 1.0\n", 595 | "yy: 1.0\n" 596 | ] 597 | } 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/" 605 | }, 606 | "id": "aEbi2kK4n9Ci", 607 | "outputId": "75c1ede3-5545-40bb-80bc-01fecf04192f" 608 | }, 609 | "source": [ 610 | "x1 = random.randint(0, 1)\n", 611 | "x2 = random.randint(0, 1)\n", 612 | "yy = 0 if (x1 == x2) else 1\n", 613 | "\n", 614 | "# centered at zero\n", 615 | "x1 = 2. * (x1 - 0.5)\n", 616 | "x2 = 2. * (x2 - 0.5)\n", 617 | "yy = 2. * (yy - 0.5)\n", 618 | "\n", 619 | "# add noise\n", 620 | "x1 += 0.1 * random.random()\n", 621 | "x2 += 0.1 * random.random()\n", 622 | "yy += 0.1 * random.random()\n", 623 | "\n", 624 | "print('x1:', x1)\n", 625 | "print('x2:',x2)\n", 626 | "print('yy:',yy)" 627 | ], 628 | "execution_count": 56, 629 | "outputs": [ 630 | { 631 | "output_type": "stream", 632 | "name": "stdout", 633 | "text": [ 634 | "x1: -0.9574273128896689\n", 635 | "x2: -0.9030175180760335\n", 636 | "yy: -0.9475814083656239\n" 637 | ] 638 | } 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "metadata": { 644 | "id": "6RiENlDbpGS3" 645 | }, 646 | "source": [ 647 | "# make it into function \n", 648 | "def make_data():\n", 649 | " x1 = random.randint(0, 1)\n", 650 | " x2 = random.randint(0, 1)\n", 651 | " yy = 0 if (x1 == x2) else 1\n", 652 | " \n", 653 | " # centered at zero\n", 654 | " x1 = 2. * (x1 - 0.5)\n", 655 | " x2 = 2. * (x2 - 0.5)\n", 656 | " yy = 2. * (yy - 0.5)\n", 657 | " \n", 658 | " # add noise\n", 659 | " x1 += 0.1 * random.random()\n", 660 | " x2 += 0.1 * random.random()\n", 661 | " yy += 0.1 * random.random()\n", 662 | " \n", 663 | " return [x1, x2, ], yy\n", 664 | " " 665 | ], 666 | "execution_count": null, 667 | "outputs": [] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "metadata": { 672 | "id": "IRS81-dioL4n" 673 | }, 674 | "source": [ 675 | "# create batch samples\n", 676 | "batch_size = 10\n", 677 | "def make_batch():\n", 678 | " data = [make_data() for ii in range(batch_size)]\n", 679 | " labels = [label for xx, label in data]\n", 680 | " data = [xx for xx, label in data]\n", 681 | " return np.array(data, dtype='float32'), np.array(labels, dtype='float32')\n", 682 | " \n", 683 | "print(make_batch())\n" 684 | ], 685 | "execution_count": null, 686 | "outputs": [] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "metadata": { 691 | "id": "XB7kASWVoZJi" 692 | }, 693 | "source": [ 694 | "# generate 500 train and 50 test data \n", 695 | "train_data = [make_batch() for ii in range(500)]\n", 696 | "test_data = [make_batch() for ii in range(50)]\n" 697 | ], 698 | "execution_count": 106, 699 | "outputs": [] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "metadata": { 704 | "id": "huxK2x7WpUGw" 705 | }, 706 | "source": [ 707 | "# import torch libraries\n", 708 | "import torch\n", 709 | "import torch.nn as nn\n", 710 | "import torch.nn.functional as F\n", 711 | "import torch.optim as optim\n", 712 | "from torch.autograd import Variable\n", 713 | " " 714 | ], 715 | "execution_count": 107, 716 | "outputs": [] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "metadata": { 721 | "id": "BMj_0PO0ojry" 722 | }, 723 | "source": [ 724 | "## Define our neural network class\n", 725 | "torch.manual_seed(42)\n", 726 | " \n", 727 | "class NN(nn.Module):\n", 728 | " def __init__(self):\n", 729 | " super(NN, self).__init__()\n", 730 | " \n", 731 | " self.dense1 = nn.Linear(2, 2)\n", 732 | " self.dense2 = nn.Linear(2, 1)\n", 733 | " \n", 734 | " def forward(self, x):\n", 735 | " x = F.tanh(self.dense1(x))\n", 736 | " x = self.dense2(x)\n", 737 | " return torch.squeeze(x)\n", 738 | " \n" 739 | ], 740 | "execution_count": 112, 741 | "outputs": [] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "metadata": { 746 | "id": "ZqLyWCLOtrJS" 747 | }, 748 | "source": [ 749 | "# initialize our network\n", 750 | "model = NN()\n", 751 | "\n", 752 | "## optimizer = stochastic gradient descent\n", 753 | "optimizer = optim.SGD(model.parameters(), lr)" 754 | ], 755 | "execution_count": 127, 756 | "outputs": [] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "metadata": { 761 | "id": "Furjt7ZppdxA" 762 | }, 763 | "source": [ 764 | "## train and test functions\n", 765 | " \n", 766 | "def train(epoch):\n", 767 | " model.train()\n", 768 | " for batch_idx, (data, target) in enumerate(train_data):\n", 769 | " data, target = Variable(torch.from_numpy(data)), Variable(torch.from_numpy(target))\n", 770 | " optimizer.zero_grad()\n", 771 | " output = model(data)\n", 772 | " loss = F.mse_loss(output, target)\n", 773 | " loss.backward()\n", 774 | " optimizer.step()\n", 775 | " if batch_idx % 100 == 0:\n", 776 | " print('Train Epoch: {} {}\\tLoss: {:.4f}'.format(epoch, batch_idx * len(data), loss.item()))\n", 777 | " \n", 778 | "def test():\n", 779 | " model.eval()\n", 780 | " test_loss = 0\n", 781 | " correct = 0\n", 782 | " for data, target in test_data:\n", 783 | " data, target = Variable(torch.from_numpy(data), volatile=True), Variable(torch.from_numpy(target))\n", 784 | " output = model(data)\n", 785 | " test_loss += F.mse_loss(output, target)\n", 786 | " correct += (np.around(output.data.numpy()) == np.around(target.data.numpy())).sum()\n", 787 | " \n", 788 | " test_loss /= len(test_data)\n", 789 | " test_loss = test_loss.item()\n", 790 | " \n", 791 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", 792 | " test_loss, correct, batch_size * len(test_data), 100. * correct / (batch_size * len(test_data))) )\n", 793 | " " 794 | ], 795 | "execution_count": 128, 796 | "outputs": [] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "metadata": { 801 | "id": "wZ4c_vUzphKy", 802 | "colab": { 803 | "base_uri": "https://localhost:8080/" 804 | }, 805 | "outputId": "60d93a88-ce66-41e1-f823-1de85b60c26e" 806 | }, 807 | "source": [ 808 | "## run experiment \n", 809 | "nepochs = 1000\n", 810 | "lr = 0.001\n", 811 | "\n", 812 | "print('lr=', lr)\n", 813 | "for epoch in range(1, nepochs + 1):\n", 814 | " train(epoch)\n", 815 | " print('---------------------------------------------')\n", 816 | " test()\n", 817 | " \n", 818 | " # everytime rerun this cell, please re initialize your network, and re run the train test function " 819 | ], 820 | "execution_count": null, 821 | "outputs": [ 822 | { 823 | "output_type": "stream", 824 | "name": "stdout", 825 | "text": [ 826 | "Train Epoch: 497 1000\tLoss: 0.4051\n", 827 | "Train Epoch: 497 2000\tLoss: 0.3741\n", 828 | "Train Epoch: 497 3000\tLoss: 0.6958\n", 829 | "Train Epoch: 497 4000\tLoss: 0.5083\n", 830 | "---------------------------------------------\n", 831 | "\n", 832 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 833 | "\n", 834 | "Train Epoch: 498 0\tLoss: 0.2817\n", 835 | "Train Epoch: 498 1000\tLoss: 0.4051\n", 836 | "Train Epoch: 498 2000\tLoss: 0.3741\n", 837 | "Train Epoch: 498 3000\tLoss: 0.6958\n", 838 | "Train Epoch: 498 4000\tLoss: 0.5083\n", 839 | "---------------------------------------------\n", 840 | "\n", 841 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 842 | "\n", 843 | "Train Epoch: 499 0\tLoss: 0.2817\n", 844 | "Train Epoch: 499 1000\tLoss: 0.4051\n", 845 | "Train Epoch: 499 2000\tLoss: 0.3741\n", 846 | "Train Epoch: 499 3000\tLoss: 0.6958\n", 847 | "Train Epoch: 499 4000\tLoss: 0.5083\n", 848 | "---------------------------------------------\n", 849 | "\n", 850 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 851 | "\n", 852 | "Train Epoch: 500 0\tLoss: 0.2817\n", 853 | "Train Epoch: 500 1000\tLoss: 0.4051\n", 854 | "Train Epoch: 500 2000\tLoss: 0.3741\n", 855 | "Train Epoch: 500 3000\tLoss: 0.6958\n", 856 | "Train Epoch: 500 4000\tLoss: 0.5083\n", 857 | "---------------------------------------------\n", 858 | "\n", 859 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 860 | "\n", 861 | "Train Epoch: 501 0\tLoss: 0.2817\n", 862 | "Train Epoch: 501 1000\tLoss: 0.4051\n", 863 | "Train Epoch: 501 2000\tLoss: 0.3741\n", 864 | "Train Epoch: 501 3000\tLoss: 0.6958\n", 865 | "Train Epoch: 501 4000\tLoss: 0.5083\n", 866 | "---------------------------------------------\n", 867 | "\n", 868 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 869 | "\n", 870 | "Train Epoch: 502 0\tLoss: 0.2817\n", 871 | "Train Epoch: 502 1000\tLoss: 0.4051\n", 872 | "Train Epoch: 502 2000\tLoss: 0.3741\n", 873 | "Train Epoch: 502 3000\tLoss: 0.6958\n", 874 | "Train Epoch: 502 4000\tLoss: 0.5083\n", 875 | "---------------------------------------------\n", 876 | "\n", 877 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 878 | "\n", 879 | "Train Epoch: 503 0\tLoss: 0.2817\n", 880 | "Train Epoch: 503 1000\tLoss: 0.4051\n", 881 | "Train Epoch: 503 2000\tLoss: 0.3741\n", 882 | "Train Epoch: 503 3000\tLoss: 0.6958\n", 883 | "Train Epoch: 503 4000\tLoss: 0.5083\n", 884 | "---------------------------------------------\n", 885 | "\n", 886 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 887 | "\n", 888 | "Train Epoch: 504 0\tLoss: 0.2817\n", 889 | "Train Epoch: 504 1000\tLoss: 0.4051\n", 890 | "Train Epoch: 504 2000\tLoss: 0.3741\n", 891 | "Train Epoch: 504 3000\tLoss: 0.6958\n", 892 | "Train Epoch: 504 4000\tLoss: 0.5083\n", 893 | "---------------------------------------------\n", 894 | "\n", 895 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 896 | "\n", 897 | "Train Epoch: 505 0\tLoss: 0.2817\n", 898 | "Train Epoch: 505 1000\tLoss: 0.4051\n", 899 | "Train Epoch: 505 2000\tLoss: 0.3741\n", 900 | "Train Epoch: 505 3000\tLoss: 0.6958\n", 901 | "Train Epoch: 505 4000\tLoss: 0.5083\n", 902 | "---------------------------------------------\n", 903 | "\n", 904 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 905 | "\n", 906 | "Train Epoch: 506 0\tLoss: 0.2817\n", 907 | "Train Epoch: 506 1000\tLoss: 0.4051\n", 908 | "Train Epoch: 506 2000\tLoss: 0.3741\n", 909 | "Train Epoch: 506 3000\tLoss: 0.6958\n", 910 | "Train Epoch: 506 4000\tLoss: 0.5083\n", 911 | "---------------------------------------------\n", 912 | "\n", 913 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 914 | "\n", 915 | "Train Epoch: 507 0\tLoss: 0.2817\n", 916 | "Train Epoch: 507 1000\tLoss: 0.4051\n", 917 | "Train Epoch: 507 2000\tLoss: 0.3741\n", 918 | "Train Epoch: 507 3000\tLoss: 0.6958\n", 919 | "Train Epoch: 507 4000\tLoss: 0.5083\n", 920 | "---------------------------------------------\n", 921 | "\n", 922 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 923 | "\n", 924 | "Train Epoch: 508 0\tLoss: 0.2817\n", 925 | "Train Epoch: 508 1000\tLoss: 0.4051\n", 926 | "Train Epoch: 508 2000\tLoss: 0.3741\n", 927 | "Train Epoch: 508 3000\tLoss: 0.6958\n", 928 | "Train Epoch: 508 4000\tLoss: 0.5083\n", 929 | "---------------------------------------------\n", 930 | "\n", 931 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 932 | "\n", 933 | "Train Epoch: 509 0\tLoss: 0.2817\n", 934 | "Train Epoch: 509 1000\tLoss: 0.4051\n", 935 | "Train Epoch: 509 2000\tLoss: 0.3741\n", 936 | "Train Epoch: 509 3000\tLoss: 0.6958\n", 937 | "Train Epoch: 509 4000\tLoss: 0.5083\n", 938 | "---------------------------------------------\n", 939 | "\n", 940 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 941 | "\n", 942 | "Train Epoch: 510 0\tLoss: 0.2817\n", 943 | "Train Epoch: 510 1000\tLoss: 0.4051\n", 944 | "Train Epoch: 510 2000\tLoss: 0.3741\n", 945 | "Train Epoch: 510 3000\tLoss: 0.6958\n", 946 | "Train Epoch: 510 4000\tLoss: 0.5083\n", 947 | "---------------------------------------------\n", 948 | "\n", 949 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 950 | "\n", 951 | "Train Epoch: 511 0\tLoss: 0.2817\n", 952 | "Train Epoch: 511 1000\tLoss: 0.4051\n", 953 | "Train Epoch: 511 2000\tLoss: 0.3741\n", 954 | "Train Epoch: 511 3000\tLoss: 0.6958\n", 955 | "Train Epoch: 511 4000\tLoss: 0.5083\n", 956 | "---------------------------------------------\n", 957 | "\n", 958 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 959 | "\n", 960 | "Train Epoch: 512 0\tLoss: 0.2817\n", 961 | "Train Epoch: 512 1000\tLoss: 0.4051\n", 962 | "Train Epoch: 512 2000\tLoss: 0.3741\n", 963 | "Train Epoch: 512 3000\tLoss: 0.6958\n", 964 | "Train Epoch: 512 4000\tLoss: 0.5083\n", 965 | "---------------------------------------------\n", 966 | "\n", 967 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 968 | "\n", 969 | "Train Epoch: 513 0\tLoss: 0.2817\n", 970 | "Train Epoch: 513 1000\tLoss: 0.4051\n", 971 | "Train Epoch: 513 2000\tLoss: 0.3741\n", 972 | "Train Epoch: 513 3000\tLoss: 0.6958\n", 973 | "Train Epoch: 513 4000\tLoss: 0.5083\n", 974 | "---------------------------------------------\n", 975 | "\n", 976 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 977 | "\n", 978 | "Train Epoch: 514 0\tLoss: 0.2817\n", 979 | "Train Epoch: 514 1000\tLoss: 0.4051\n", 980 | "Train Epoch: 514 2000\tLoss: 0.3741\n", 981 | "Train Epoch: 514 3000\tLoss: 0.6958\n", 982 | "Train Epoch: 514 4000\tLoss: 0.5083\n", 983 | "---------------------------------------------\n", 984 | "\n", 985 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 986 | "\n", 987 | "Train Epoch: 515 0\tLoss: 0.2817\n", 988 | "Train Epoch: 515 1000\tLoss: 0.4051\n", 989 | "Train Epoch: 515 2000\tLoss: 0.3741\n", 990 | "Train Epoch: 515 3000\tLoss: 0.6958\n", 991 | "Train Epoch: 515 4000\tLoss: 0.5083\n", 992 | "---------------------------------------------\n", 993 | "\n", 994 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 995 | "\n", 996 | "Train Epoch: 516 0\tLoss: 0.2817\n", 997 | "Train Epoch: 516 1000\tLoss: 0.4051\n", 998 | "Train Epoch: 516 2000\tLoss: 0.3741\n", 999 | "Train Epoch: 516 3000\tLoss: 0.6958\n", 1000 | "Train Epoch: 516 4000\tLoss: 0.5083\n", 1001 | "---------------------------------------------\n", 1002 | "\n", 1003 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1004 | "\n", 1005 | "Train Epoch: 517 0\tLoss: 0.2817\n", 1006 | "Train Epoch: 517 1000\tLoss: 0.4051\n", 1007 | "Train Epoch: 517 2000\tLoss: 0.3741\n", 1008 | "Train Epoch: 517 3000\tLoss: 0.6958\n", 1009 | "Train Epoch: 517 4000\tLoss: 0.5083\n", 1010 | "---------------------------------------------\n", 1011 | "\n", 1012 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1013 | "\n", 1014 | "Train Epoch: 518 0\tLoss: 0.2817\n", 1015 | "Train Epoch: 518 1000\tLoss: 0.4051\n", 1016 | "Train Epoch: 518 2000\tLoss: 0.3741\n", 1017 | "Train Epoch: 518 3000\tLoss: 0.6958\n", 1018 | "Train Epoch: 518 4000\tLoss: 0.5083\n", 1019 | "---------------------------------------------\n", 1020 | "\n", 1021 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1022 | "\n", 1023 | "Train Epoch: 519 0\tLoss: 0.2817\n", 1024 | "Train Epoch: 519 1000\tLoss: 0.4051\n", 1025 | "Train Epoch: 519 2000\tLoss: 0.3741\n", 1026 | "Train Epoch: 519 3000\tLoss: 0.6958\n", 1027 | "Train Epoch: 519 4000\tLoss: 0.5083\n", 1028 | "---------------------------------------------\n", 1029 | "\n", 1030 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1031 | "\n", 1032 | "Train Epoch: 520 0\tLoss: 0.2817\n", 1033 | "Train Epoch: 520 1000\tLoss: 0.4051\n", 1034 | "Train Epoch: 520 2000\tLoss: 0.3741\n", 1035 | "Train Epoch: 520 3000\tLoss: 0.6958\n", 1036 | "Train Epoch: 520 4000\tLoss: 0.5083\n", 1037 | "---------------------------------------------\n", 1038 | "\n", 1039 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1040 | "\n", 1041 | "Train Epoch: 521 0\tLoss: 0.2817\n", 1042 | "Train Epoch: 521 1000\tLoss: 0.4051\n", 1043 | "Train Epoch: 521 2000\tLoss: 0.3741\n", 1044 | "Train Epoch: 521 3000\tLoss: 0.6958\n", 1045 | "Train Epoch: 521 4000\tLoss: 0.5083\n", 1046 | "---------------------------------------------\n", 1047 | "\n", 1048 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1049 | "\n", 1050 | "Train Epoch: 522 0\tLoss: 0.2817\n", 1051 | "Train Epoch: 522 1000\tLoss: 0.4051\n", 1052 | "Train Epoch: 522 2000\tLoss: 0.3741\n", 1053 | "Train Epoch: 522 3000\tLoss: 0.6958\n", 1054 | "Train Epoch: 522 4000\tLoss: 0.5083\n", 1055 | "---------------------------------------------\n", 1056 | "\n", 1057 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1058 | "\n", 1059 | "Train Epoch: 523 0\tLoss: 0.2817\n", 1060 | "Train Epoch: 523 1000\tLoss: 0.4051\n", 1061 | "Train Epoch: 523 2000\tLoss: 0.3741\n", 1062 | "Train Epoch: 523 3000\tLoss: 0.6958\n", 1063 | "Train Epoch: 523 4000\tLoss: 0.5083\n", 1064 | "---------------------------------------------\n", 1065 | "\n", 1066 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1067 | "\n", 1068 | "Train Epoch: 524 0\tLoss: 0.2817\n", 1069 | "Train Epoch: 524 1000\tLoss: 0.4051\n", 1070 | "Train Epoch: 524 2000\tLoss: 0.3741\n", 1071 | "Train Epoch: 524 3000\tLoss: 0.6958\n", 1072 | "Train Epoch: 524 4000\tLoss: 0.5083\n", 1073 | "---------------------------------------------\n", 1074 | "\n", 1075 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1076 | "\n", 1077 | "Train Epoch: 525 0\tLoss: 0.2817\n", 1078 | "Train Epoch: 525 1000\tLoss: 0.4051\n", 1079 | "Train Epoch: 525 2000\tLoss: 0.3741\n", 1080 | "Train Epoch: 525 3000\tLoss: 0.6958\n", 1081 | "Train Epoch: 525 4000\tLoss: 0.5083\n", 1082 | "---------------------------------------------\n", 1083 | "\n", 1084 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1085 | "\n", 1086 | "Train Epoch: 526 0\tLoss: 0.2817\n", 1087 | "Train Epoch: 526 1000\tLoss: 0.4051\n", 1088 | "Train Epoch: 526 2000\tLoss: 0.3741\n", 1089 | "Train Epoch: 526 3000\tLoss: 0.6958\n", 1090 | "Train Epoch: 526 4000\tLoss: 0.5083\n", 1091 | "---------------------------------------------\n", 1092 | "\n", 1093 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1094 | "\n", 1095 | "Train Epoch: 527 0\tLoss: 0.2817\n", 1096 | "Train Epoch: 527 1000\tLoss: 0.4051\n", 1097 | "Train Epoch: 527 2000\tLoss: 0.3741\n", 1098 | "Train Epoch: 527 3000\tLoss: 0.6958\n", 1099 | "Train Epoch: 527 4000\tLoss: 0.5083\n", 1100 | "---------------------------------------------\n", 1101 | "\n", 1102 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1103 | "\n", 1104 | "Train Epoch: 528 0\tLoss: 0.2817\n", 1105 | "Train Epoch: 528 1000\tLoss: 0.4051\n", 1106 | "Train Epoch: 528 2000\tLoss: 0.3741\n", 1107 | "Train Epoch: 528 3000\tLoss: 0.6958\n", 1108 | "Train Epoch: 528 4000\tLoss: 0.5083\n", 1109 | "---------------------------------------------\n", 1110 | "\n", 1111 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1112 | "\n", 1113 | "Train Epoch: 529 0\tLoss: 0.2817\n", 1114 | "Train Epoch: 529 1000\tLoss: 0.4051\n", 1115 | "Train Epoch: 529 2000\tLoss: 0.3741\n", 1116 | "Train Epoch: 529 3000\tLoss: 0.6958\n", 1117 | "Train Epoch: 529 4000\tLoss: 0.5083\n", 1118 | "---------------------------------------------\n", 1119 | "\n", 1120 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1121 | "\n", 1122 | "Train Epoch: 530 0\tLoss: 0.2817\n", 1123 | "Train Epoch: 530 1000\tLoss: 0.4051\n", 1124 | "Train Epoch: 530 2000\tLoss: 0.3741\n", 1125 | "Train Epoch: 530 3000\tLoss: 0.6958\n", 1126 | "Train Epoch: 530 4000\tLoss: 0.5083\n", 1127 | "---------------------------------------------\n", 1128 | "\n", 1129 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1130 | "\n", 1131 | "Train Epoch: 531 0\tLoss: 0.2817\n", 1132 | "Train Epoch: 531 1000\tLoss: 0.4051\n", 1133 | "Train Epoch: 531 2000\tLoss: 0.3741\n", 1134 | "Train Epoch: 531 3000\tLoss: 0.6958\n", 1135 | "Train Epoch: 531 4000\tLoss: 0.5083\n", 1136 | "---------------------------------------------\n", 1137 | "\n", 1138 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1139 | "\n", 1140 | "Train Epoch: 532 0\tLoss: 0.2817\n", 1141 | "Train Epoch: 532 1000\tLoss: 0.4051\n", 1142 | "Train Epoch: 532 2000\tLoss: 0.3741\n", 1143 | "Train Epoch: 532 3000\tLoss: 0.6958\n", 1144 | "Train Epoch: 532 4000\tLoss: 0.5083\n", 1145 | "---------------------------------------------\n", 1146 | "\n", 1147 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1148 | "\n", 1149 | "Train Epoch: 533 0\tLoss: 0.2817\n", 1150 | "Train Epoch: 533 1000\tLoss: 0.4051\n", 1151 | "Train Epoch: 533 2000\tLoss: 0.3741\n", 1152 | "Train Epoch: 533 3000\tLoss: 0.6958\n", 1153 | "Train Epoch: 533 4000\tLoss: 0.5083\n", 1154 | "---------------------------------------------\n", 1155 | "\n", 1156 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1157 | "\n", 1158 | "Train Epoch: 534 0\tLoss: 0.2817\n", 1159 | "Train Epoch: 534 1000\tLoss: 0.4051\n", 1160 | "Train Epoch: 534 2000\tLoss: 0.3741\n", 1161 | "Train Epoch: 534 3000\tLoss: 0.6958\n", 1162 | "Train Epoch: 534 4000\tLoss: 0.5083\n", 1163 | "---------------------------------------------\n", 1164 | "\n", 1165 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1166 | "\n", 1167 | "Train Epoch: 535 0\tLoss: 0.2817\n", 1168 | "Train Epoch: 535 1000\tLoss: 0.4051\n", 1169 | "Train Epoch: 535 2000\tLoss: 0.3741\n", 1170 | "Train Epoch: 535 3000\tLoss: 0.6958\n", 1171 | "Train Epoch: 535 4000\tLoss: 0.5083\n", 1172 | "---------------------------------------------\n", 1173 | "\n", 1174 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1175 | "\n", 1176 | "Train Epoch: 536 0\tLoss: 0.2817\n", 1177 | "Train Epoch: 536 1000\tLoss: 0.4051\n", 1178 | "Train Epoch: 536 2000\tLoss: 0.3741\n", 1179 | "Train Epoch: 536 3000\tLoss: 0.6958\n", 1180 | "Train Epoch: 536 4000\tLoss: 0.5083\n", 1181 | "---------------------------------------------\n", 1182 | "\n", 1183 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1184 | "\n", 1185 | "Train Epoch: 537 0\tLoss: 0.2817\n", 1186 | "Train Epoch: 537 1000\tLoss: 0.4051\n", 1187 | "Train Epoch: 537 2000\tLoss: 0.3741\n", 1188 | "Train Epoch: 537 3000\tLoss: 0.6958\n", 1189 | "Train Epoch: 537 4000\tLoss: 0.5083\n", 1190 | "---------------------------------------------\n", 1191 | "\n", 1192 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1193 | "\n", 1194 | "Train Epoch: 538 0\tLoss: 0.2817\n", 1195 | "Train Epoch: 538 1000\tLoss: 0.4051\n", 1196 | "Train Epoch: 538 2000\tLoss: 0.3741\n", 1197 | "Train Epoch: 538 3000\tLoss: 0.6958\n", 1198 | "Train Epoch: 538 4000\tLoss: 0.5083\n", 1199 | "---------------------------------------------\n", 1200 | "\n", 1201 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1202 | "\n", 1203 | "Train Epoch: 539 0\tLoss: 0.2817\n", 1204 | "Train Epoch: 539 1000\tLoss: 0.4051\n", 1205 | "Train Epoch: 539 2000\tLoss: 0.3741\n", 1206 | "Train Epoch: 539 3000\tLoss: 0.6958\n", 1207 | "Train Epoch: 539 4000\tLoss: 0.5083\n", 1208 | "---------------------------------------------\n", 1209 | "\n", 1210 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1211 | "\n", 1212 | "Train Epoch: 540 0\tLoss: 0.2817\n", 1213 | "Train Epoch: 540 1000\tLoss: 0.4051\n", 1214 | "Train Epoch: 540 2000\tLoss: 0.3741\n", 1215 | "Train Epoch: 540 3000\tLoss: 0.6958\n", 1216 | "Train Epoch: 540 4000\tLoss: 0.5083\n", 1217 | "---------------------------------------------\n", 1218 | "\n", 1219 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1220 | "\n", 1221 | "Train Epoch: 541 0\tLoss: 0.2817\n", 1222 | "Train Epoch: 541 1000\tLoss: 0.4051\n", 1223 | "Train Epoch: 541 2000\tLoss: 0.3741\n", 1224 | "Train Epoch: 541 3000\tLoss: 0.6958\n", 1225 | "Train Epoch: 541 4000\tLoss: 0.5083\n", 1226 | "---------------------------------------------\n", 1227 | "\n", 1228 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1229 | "\n", 1230 | "Train Epoch: 542 0\tLoss: 0.2817\n", 1231 | "Train Epoch: 542 1000\tLoss: 0.4051\n", 1232 | "Train Epoch: 542 2000\tLoss: 0.3741\n", 1233 | "Train Epoch: 542 3000\tLoss: 0.6958\n", 1234 | "Train Epoch: 542 4000\tLoss: 0.5083\n", 1235 | "---------------------------------------------\n", 1236 | "\n", 1237 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1238 | "\n", 1239 | "Train Epoch: 543 0\tLoss: 0.2817\n", 1240 | "Train Epoch: 543 1000\tLoss: 0.4051\n", 1241 | "Train Epoch: 543 2000\tLoss: 0.3741\n", 1242 | "Train Epoch: 543 3000\tLoss: 0.6958\n", 1243 | "Train Epoch: 543 4000\tLoss: 0.5083\n", 1244 | "---------------------------------------------\n", 1245 | "\n", 1246 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1247 | "\n", 1248 | "Train Epoch: 544 0\tLoss: 0.2817\n", 1249 | "Train Epoch: 544 1000\tLoss: 0.4051\n", 1250 | "Train Epoch: 544 2000\tLoss: 0.3741\n", 1251 | "Train Epoch: 544 3000\tLoss: 0.6958\n", 1252 | "Train Epoch: 544 4000\tLoss: 0.5083\n", 1253 | "---------------------------------------------\n", 1254 | "\n", 1255 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1256 | "\n", 1257 | "Train Epoch: 545 0\tLoss: 0.2817\n", 1258 | "Train Epoch: 545 1000\tLoss: 0.4051\n", 1259 | "Train Epoch: 545 2000\tLoss: 0.3741\n", 1260 | "Train Epoch: 545 3000\tLoss: 0.6958\n", 1261 | "Train Epoch: 545 4000\tLoss: 0.5083\n", 1262 | "---------------------------------------------\n", 1263 | "\n", 1264 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1265 | "\n", 1266 | "Train Epoch: 546 0\tLoss: 0.2817\n", 1267 | "Train Epoch: 546 1000\tLoss: 0.4051\n", 1268 | "Train Epoch: 546 2000\tLoss: 0.3741\n", 1269 | "Train Epoch: 546 3000\tLoss: 0.6958\n", 1270 | "Train Epoch: 546 4000\tLoss: 0.5083\n", 1271 | "---------------------------------------------\n", 1272 | "\n", 1273 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1274 | "\n", 1275 | "Train Epoch: 547 0\tLoss: 0.2817\n", 1276 | "Train Epoch: 547 1000\tLoss: 0.4051\n", 1277 | "Train Epoch: 547 2000\tLoss: 0.3741\n", 1278 | "Train Epoch: 547 3000\tLoss: 0.6958\n", 1279 | "Train Epoch: 547 4000\tLoss: 0.5083\n", 1280 | "---------------------------------------------\n", 1281 | "\n", 1282 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1283 | "\n", 1284 | "Train Epoch: 548 0\tLoss: 0.2817\n", 1285 | "Train Epoch: 548 1000\tLoss: 0.4051\n", 1286 | "Train Epoch: 548 2000\tLoss: 0.3741\n", 1287 | "Train Epoch: 548 3000\tLoss: 0.6958\n", 1288 | "Train Epoch: 548 4000\tLoss: 0.5083\n", 1289 | "---------------------------------------------\n", 1290 | "\n", 1291 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1292 | "\n", 1293 | "Train Epoch: 549 0\tLoss: 0.2817\n", 1294 | "Train Epoch: 549 1000\tLoss: 0.4051\n", 1295 | "Train Epoch: 549 2000\tLoss: 0.3741\n", 1296 | "Train Epoch: 549 3000\tLoss: 0.6958\n", 1297 | "Train Epoch: 549 4000\tLoss: 0.5083\n", 1298 | "---------------------------------------------\n", 1299 | "\n", 1300 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1301 | "\n", 1302 | "Train Epoch: 550 0\tLoss: 0.2817\n", 1303 | "Train Epoch: 550 1000\tLoss: 0.4051\n", 1304 | "Train Epoch: 550 2000\tLoss: 0.3741\n", 1305 | "Train Epoch: 550 3000\tLoss: 0.6958\n", 1306 | "Train Epoch: 550 4000\tLoss: 0.5083\n", 1307 | "---------------------------------------------\n", 1308 | "\n", 1309 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1310 | "\n", 1311 | "Train Epoch: 551 0\tLoss: 0.2817\n", 1312 | "Train Epoch: 551 1000\tLoss: 0.4051\n", 1313 | "Train Epoch: 551 2000\tLoss: 0.3741\n", 1314 | "Train Epoch: 551 3000\tLoss: 0.6958\n", 1315 | "Train Epoch: 551 4000\tLoss: 0.5083\n", 1316 | "---------------------------------------------\n", 1317 | "\n", 1318 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1319 | "\n", 1320 | "Train Epoch: 552 0\tLoss: 0.2817\n", 1321 | "Train Epoch: 552 1000\tLoss: 0.4051\n", 1322 | "Train Epoch: 552 2000\tLoss: 0.3741\n", 1323 | "Train Epoch: 552 3000\tLoss: 0.6958\n", 1324 | "Train Epoch: 552 4000\tLoss: 0.5083\n", 1325 | "---------------------------------------------\n", 1326 | "\n", 1327 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1328 | "\n", 1329 | "Train Epoch: 553 0\tLoss: 0.2817\n", 1330 | "Train Epoch: 553 1000\tLoss: 0.4051\n", 1331 | "Train Epoch: 553 2000\tLoss: 0.3741\n", 1332 | "Train Epoch: 553 3000\tLoss: 0.6958\n", 1333 | "Train Epoch: 553 4000\tLoss: 0.5083\n", 1334 | "---------------------------------------------\n", 1335 | "\n", 1336 | "Test set: Average loss: 0.5481, Accuracy: 232/500 (46.40%)\n", 1337 | "\n", 1338 | "Train Epoch: 554 0\tLoss: 0.2817\n", 1339 | "Train Epoch: 554 1000\tLoss: 0.4051\n", 1340 | "Train Epoch: 554 2000\tLoss: 0.3741\n", 1341 | "Train Epoch: 554 3000\tLoss: 0.6958\n" 1342 | ] 1343 | } 1344 | ] 1345 | }, 1346 | { 1347 | "cell_type": "markdown", 1348 | "metadata": { 1349 | "id": "wM3QoS2buZe7" 1350 | }, 1351 | "source": [ 1352 | "## Exercise 3 (6%) \n", 1353 | "For this experiment, try the following learning rate (lr=0.0001, 0.001, 0.01, 0.1). What do you observed?

\n", 1354 | "For example, at lr=0.001, test acc reach 100% at epoch xx... At lr=0.001, test acc reach 100% at epoch xx. As lr increases / decreases, what happen?\n" 1355 | ] 1356 | }, 1357 | { 1358 | "cell_type": "markdown", 1359 | "metadata": { 1360 | "id": "yH5tn9G1ugD1" 1361 | }, 1362 | "source": [ 1363 | "### Your answer here\n" 1364 | ] 1365 | }, 1366 | { 1367 | "cell_type": "markdown", 1368 | "metadata": { 1369 | "id": "P1d47bLhDMlw" 1370 | }, 1371 | "source": [ 1372 | "# Submission Instructions\n", 1373 | "Once you are finished, follow these steps:\n", 1374 | "\n", 1375 | "Restart the kernel and re-run this notebook from beginning to end by going to Kernel > Restart Kernel and Run All Cells.\n", 1376 | "If this process stops halfway through, that means there was an error. Correct the error and repeat Step 1 until the notebook runs from beginning to end.\n", 1377 | "Double check that there is a number next to each code cell and that these numbers are in order.\n", 1378 | "Then, submit your lab as follows:\n", 1379 | "\n", 1380 | "Go to File > Print > Save as PDF.\n", 1381 | "Double check that the entire notebook, from beginning to end, is in this PDF file. Make sure Solution for Exercise 5 are in for marks. \n", 1382 | "Upload the PDF to Spectrum. " 1383 | ] 1384 | }, 1385 | { 1386 | "cell_type": "markdown", 1387 | "metadata": { 1388 | "id": "4FBd4KLZwyVB" 1389 | }, 1390 | "source": [ 1391 | "# Acknowledgement\n", 1392 | "\n", 1393 | "Some of the works are inspired from \n", 1394 | "1. Effect of learning rate on AI model = https://www.commonlounge.com/discussion/5076b2cfb2364594ba608fca3ac606bb" 1395 | ] 1396 | } 1397 | ] 1398 | } -------------------------------------------------------------------------------- /Week3/XOR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week3/XOR.jpg -------------------------------------------------------------------------------- /Week4/MnistExamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week4/MnistExamples.png -------------------------------------------------------------------------------- /Week4/README: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Week4/W4.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week4/W4.pptx -------------------------------------------------------------------------------- /Week4/WOA7015_Wk4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "WOA7015_Wk4.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyMTgQdYvrjN0mSEo1r267hp", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "Cw5hkuTyx3VO" 35 | }, 36 | "source": [ 37 | "# Welcome to WOA7015 Advance Machine Learning Lab - Week 4\n", 38 | "This code is generated for the purpose of WOA7015 module.\n", 39 | "The code is available in github https://github.com/shiernee/Advanced_ML \n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "INxzH5Bqw5dq" 46 | }, 47 | "source": [ 48 | "## 1.0 Effect of weight and bias to sigmoid function\n", 49 | "This is the code to generate the figure in slide 6" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "2HctFBauH47o" 56 | }, 57 | "source": [ 58 | "#### 1.1 Effect of weight on sigmoid function" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "ah7KAg-fsGWY" 65 | }, 66 | "source": [ 67 | "import matplotlib.pyplot as plt\n", 68 | "import numpy as np\n", 69 | "import imageio\n", 70 | "\n", 71 | "# create sigmoid function\n", 72 | "f = lambda x, w, b: 1/(1 + np.exp(-(w*x + b)))\n", 73 | "\n", 74 | "x = np.arange(-10, 10, 0.01).reshape([-1, 1])\n", 75 | "\n", 76 | "# effect of weight on sigmoid function\n", 77 | "filenames = []\n", 78 | "for i in np.arange(1, 5, 0.1):\n", 79 | " w = np.ones([1, 1]) * i * 0.5\n", 80 | " b = np.ones([1, 1]) * 0\n", 81 | "\n", 82 | " plt.plot(x, f(x, w, b))\n", 83 | " plt.title('w = %0.1f' % i)\n", 84 | " plt.grid()\n", 85 | " plt.savefig('w %0.1f.png' % i)\n", 86 | " plt.close()\n", 87 | " filenames.append('w %0.1f.png' % i)\n", 88 | "\n", 89 | "# Build GIF\n", 90 | "with imageio.get_writer('w_mygif.gif', mode='I') as writer:\n", 91 | " for filename in filenames:\n", 92 | " image = imageio.imread(filename)\n", 93 | " writer.append_data(image)\n", 94 | "\n" 95 | ], 96 | "execution_count": 1, 97 | "outputs": [] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "id": "JYbm63A9H91J" 103 | }, 104 | "source": [ 105 | "#### 1.1 Effect of bias on sigmoid function" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "id": "s1QreLFgHyIi" 112 | }, 113 | "source": [ 114 | "import matplotlib.pyplot as plt\n", 115 | "import numpy as np\n", 116 | "import imageio\n", 117 | "\n", 118 | "# create sigmoid function\n", 119 | "f = lambda x, w, b: 1/(1 + np.exp(-(w*x + b)))\n", 120 | "\n", 121 | "x = np.arange(-10, 10, 0.01).reshape([-1, 1])\n", 122 | "\n", 123 | "# effect of bias on sigmoid function\n", 124 | "filenames = []\n", 125 | "for i in np.arange(1, 5, 0.1):\n", 126 | " w = np.ones([1, 1])\n", 127 | " b = np.ones([1, 1])* i\n", 128 | "\n", 129 | " plt.plot(x, f(x, w, b))\n", 130 | " plt.title('b = %0.1f' % i)\n", 131 | " plt.grid()\n", 132 | " plt.savefig('b %0.1f.png' % i)\n", 133 | " plt.close()\n", 134 | " filenames.append('b %0.1f.png' % i)\n", 135 | "\n", 136 | "# Build GIF\n", 137 | "with imageio.get_writer('b_mygif.gif', mode='I') as writer:\n", 138 | " for filename in filenames:\n", 139 | " image = imageio.imread(filename)\n", 140 | " writer.append_data(image)\n" 141 | ], 142 | "execution_count": 5, 143 | "outputs": [] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "id": "iXAZiEqSEbDH" 149 | }, 150 | "source": [ 151 | "# 2.0 Logistic Regression\n", 152 | "\n", 153 | "In this section, we will learn how to create train a Logistic Regression Model using pytorch. We will use MNIST image, as shown below.

\n", 154 | "\n", 155 | "PyTorch (https://pytorch.org/) is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. \n", 156 | "\n", 157 | "\n", 158 | "
\n", 159 | "\n", 160 | "\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "5NXavLptEdGe" 167 | }, 168 | "source": [ 169 | "# 2.1 import library\n", 170 | "import torch\n", 171 | "import torch.nn as nn\n", 172 | "import torchvision\n", 173 | "import torchvision.transforms as transforms\n" 174 | ], 175 | "execution_count": 6, 176 | "outputs": [] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "metadata": { 181 | "id": "crKCbRxdFD86" 182 | }, 183 | "source": [ 184 | "#2.2 Set the Hyper-parameters \n", 185 | "input_size = 28 * 28 # 784\n", 186 | "num_classes = 10\n", 187 | "num_epochs = 5\n", 188 | "batch_size = 100\n", 189 | "learning_rate = 0.001\n" 190 | ], 191 | "execution_count": 7, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "cZXRWEdzFt_Z" 198 | }, 199 | "source": [ 200 | "#2.3 Data loader\n", 201 | "# MNIST dataset (images and labels)\n", 202 | "train_dataset = torchvision.datasets.MNIST(root='../../data', \n", 203 | " train=True, \n", 204 | " transform=transforms.ToTensor())\n", 205 | "\n", 206 | "test_dataset = torchvision.datasets.MNIST(root='../../data', \n", 207 | " train=False, \n", 208 | " transform=transforms.ToTensor())\n", 209 | "\n", 210 | "# Data loader (input pipeline)\n", 211 | "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n", 212 | " batch_size=batch_size, \n", 213 | " shuffle=True)\n", 214 | "\n", 215 | "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, \n", 216 | " batch_size=batch_size, \n", 217 | " shuffle=False)\n" 218 | ], 219 | "execution_count": 11, 220 | "outputs": [] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "metadata": { 225 | "colab": { 226 | "base_uri": "https://localhost:8080/", 227 | "height": 587 228 | }, 229 | "id": "PCMGIa6ZLxWX", 230 | "outputId": "8cbe4714-bd53-4744-cacc-b3b2099965de" 231 | }, 232 | "source": [ 233 | "# 2.3.1 Check data \n", 234 | "print(train_dataset)\n", 235 | "print('----------------')\n", 236 | "print(test_dataset)\n", 237 | "print()\n", 238 | "\n", 239 | "import matplotlib.pyplot as plt\n", 240 | "print('training data shape: ', train_dataset.data.shape)\n", 241 | "n = np.random.randint(0, 60000)\n", 242 | "plt.imshow(train_dataset.data[n])\n", 243 | "plt.title(f'n = %d label = %d' % (n, train_dataset.train_labels[n].numpy()))\n" 244 | ], 245 | "execution_count": 34, 246 | "outputs": [ 247 | { 248 | "output_type": "stream", 249 | "name": "stdout", 250 | "text": [ 251 | "Dataset MNIST\n", 252 | " Number of datapoints: 60000\n", 253 | " Root location: ../../data\n", 254 | " Split: Train\n", 255 | " StandardTransform\n", 256 | "Transform: ToTensor()\n", 257 | "----------------\n", 258 | "Dataset MNIST\n", 259 | " Number of datapoints: 10000\n", 260 | " Root location: ../../data\n", 261 | " Split: Test\n", 262 | " StandardTransform\n", 263 | "Transform: ToTensor()\n", 264 | "\n", 265 | "training data shape: torch.Size([60000, 28, 28])\n" 266 | ] 267 | }, 268 | { 269 | "output_type": "stream", 270 | "name": "stderr", 271 | "text": [ 272 | "/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets\n", 273 | " warnings.warn(\"train_labels has been renamed targets\")\n" 274 | ] 275 | }, 276 | { 277 | "output_type": "execute_result", 278 | "data": { 279 | "text/plain": [ 280 | "Text(0.5, 1.0, 'n = 47414 label = 5')" 281 | ] 282 | }, 283 | "metadata": {}, 284 | "execution_count": 34 285 | }, 286 | { 287 | "output_type": "display_data", 288 | "data": { 289 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAR2ElEQVR4nO3dfZRcdX3H8fcHsiQQHiMYI0l4MmjTahFWohUsGsvB9PQAUlGqNiqegJpWWyhS7CnYauGoQClUbIA00YMIFSgctSrmaMGHE7OhIYABQml4iJuEh0AS0JDdfPvH3MTJOvOb2XnO/j6vc+bs7P3e397vTvLZe+femfkpIjCzsW+PbjdgZp3hsJtlwmE3y4TDbpYJh90sEw67WSYcdmuYpDWS3lnnuiHpNQ1up+Gx9hsOe4+TtKT4zz6u+H66pC0jbiHpvApjF44MiqT5kgYkbZW0KLHdvy/G1hXm3Y2kH0n6ddlj+HC3e2o3h72HSXo/0Fe+LCKeiIh9d9yA1wPbgVtHjD0BOKrCj/0l8DlgYWK7RwHvAQab+w163vyyx/K13W6m3Rz2USoOXc+XtFLSC5JuljShDds5ALgYuKDGqn8O3B0Ra8rGjgOuBv5i5MoRcVtE/CfwbOJn/ivwaeDlUfR7vKSfSXpe0qCkayTtNWK1OZIek/SMpC9K2qNs/EckrZK0UdL3JB1W77atPg57Y84ETgGOAN4AfKjSSpJOKP7zV7udkNjGPwHXAuuqrSBJlMK+eETpryj9AVhZ/6+082e+B9gaEd8Z5dDhYrsHA28BZgMfH7HO6UA/cCxwKvCRYpunAhcB7wYOAe4Bbqqz3y8nHt9av/+lxR+en0g6qa7fcncWEb6N4gasAT5Q9v0XgK+0eBv9wApgHHA4EMC4CuudCGwB9i1bNg14FDig+D6A11QY+zlg0Yhl+wGrgcPLftd31ngsKtaBTwG3l30fwCll338cWFLc/y/g7LLaHsBLwGGp36HJx3hW8fuOB+YCm4Gjuv3/q50379kbU763fQnYt1U/uDi0/TLwyYgYqrH6XODWiNhStuyfgX+IiBca2PwlwNei7ClBvSQdLelbktZJ2kTpyOTgEas9WXb/ceDVxf3DgKt27JGB5wABh462j3pFxNKI2BwRWyNiMfATYE67ttcLHPY2knRihTPn5bcTKwzbn9Ke/WZJ64BlxfKnyteXtDelk2gjD+FnA18sQrfjj9LPJP1ZHS3PBv6ybOw04BZJn65j7LXAQ8CMiNif0mG5Rqwzrez+dEonC6H0R+CciDiw7LZ3RPy01kYlfSXx+D5YR987RIV+x5Rx3W5gLIuIexj9Xv8FfrPHg1JAfg4cBzxdtvx0YCPwwxHjj2bXP+KDwJ8A98HOk3fjgD2BPYuTi0PFUcRsdj37vwz4a0qH2bXsB2wCtkh6HfCxEf0C/I2kpZQek08CVxTLvwL8o6QVEfFgcXLy5Ij4j1objYhzgXPr6G8nSQdSOoz/b2AIeC/wtqKnMcth7zFRekK582lC2Zn+9SMO6+dSOuSOEeM3lH9fOofHMxHxq2LR31E6y7/DB4DPApdExLMjxg4DG0c8TajmfGABpasH/wPcDLxjxDp3AMuBA4BFwA1Fz7dL2hf4RnEW/gXgLqBm2BvUR+mcxesonVh8CDgtIh5p0/Z6gkb8XzGzMcrP2c0y4bCbZcJhN8uEw26WiY6ejd9L42MCEzu5SbOs/JoXeTm2Vny9QFNhl3QKcBWla7bXR8RlqfUnMJFZmt3MJs0sYWksqVpr+DBe0p6U3h31LmAmcJakmY3+PDNrr2aesx8PPBoRj0XEy8A3KL2Tycx6UDNhP5Rd39jwFBXeuCBpXvHJKAPb2NrE5sysGW0/Gx8RCyKiPyL6+xjf7s2ZWRXNhH0tu76LaWqxzMx6UDNhXwbMkHRE8fFD7wPubE1bZtZqDV96i4ghSfOB71G69LYwIkbz/mEz66CmrrNH6XPKRvtZZWbWBX65rFkmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZaKjUzbbGKSKswPvtMf46rMArfvIscmx58y/I1k/98D0nCR//NbqUw8O/d/jybFjkffsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmfJ3dkja/783J+tb3P5esXz7zm1VrJ074aUM97TAcTQ3PTlNhl7QG2AwMA0MR0d+Kpsys9VqxZ397RDzTgp9jZm3k5+xmmWg27AF8X9JySfMqrSBpnqQBSQPb2Nrk5sysUc0exp8QEWslvRK4S9JDEXF3+QoRsQBYALC/JvmUilmXNLVnj4i1xdcNwO3A8a1oysxar+GwS5ooab8d94GTgQda1ZiZtVYzh/GTgdtVej/zOODrEfHdlnRlHfPolenr6Lec9i/J+jF7te+lGsu2pp/1XfbknGQ9Nm9pZTu7vYb/pSLiMeD3W9iLmbWRL72ZZcJhN8uEw26WCYfdLBMOu1km/BbXDnj63Lck6xM2pi8x7X/bvcl6bHt51D3tsM/hm5L1Zi+tffiJk6rW7ln52uTYmZeuT9aH1jzRQEf58p7dLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEr7N3wPMztyfrq8+4Nll/w5Hzk/Wplzb+kczTP57+KOgZn/lYsv6qe9JTNh/4vVVVa0c/vyw5dihZtdHynt0sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4Svs+8GLv7wjcn6v1/zhqq17Zs3J8cODa5L1mfMT9drGW5qtLWS9+xmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSZ8nX03cMbEjcn6on32rl6scZ3d8lFzzy5poaQNkh4oWzZJ0l2SVhdfD2pvm2bWrHoO4xcBp4xYdiGwJCJmAEuK782sh9UMe0TcDYz87KJTgcXF/cXAaS3uy8xarNHn7JMjYrC4vw6YXG1FSfOAeQAT2KfBzZlZs5o+Gx8RAVSdmTAiFkREf0T09zG+2c2ZWYMaDft6SVMAiq8bWteSmbVDo2G/E5hb3J8L3NGadsysXWo+Z5d0E3AScLCkp4CLgcuAWySdDTwOnNnOJnd307+b/tx4zuhMH5a3mmGPiLOqlGa3uBczayO/XNYsEw67WSYcdrNMOOxmmXDYzTLht7h2wIS77kvWL1jXn6x/4VUDyfraf3tF1dre3zwyObaWF1+d3h9M/GX6smLfS1VfXMk+ty1tqCdrjPfsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmfJ29A2Lby8n6Q2dMT9av/vbIjwDc1b1vSkzp/Kbk0LbbXv1DjFj+pfTYi+bNS9b7frC8kZay5T27WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJlSZ06Yz9NSlmyR9K22ob576lau3pWcPJsZ+bfWuy/vxwesqurz+RvpD/ugOrzx9y1dQfJMc+OZR+r/z5f/jeZH3o8SeT9bFoaSxhUzynSjXv2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTPg6u3XNS++elaz/6Oprk/X+S+cn66+85qej7ml319R1dkkLJW2Q9EDZskskrZW0orjNaWXDZtZ69RzGLwJOqbD8yog4prh9p7VtmVmr1Qx7RNwNpD8Xycx6XjMn6OZLWlkc5h9UbSVJ8yQNSBrYxtYmNmdmzWg07NcCRwHHAIPA5dVWjIgFEdEfEf19jG9wc2bWrIbCHhHrI2I4IrYD1wHHt7YtM2u1hsIuaUrZt6cDD1Rb18x6Q83PjZd0E3AScLCkp4CLgZMkHQMEsAY4p4092hg1fuO2psa/NLlzrxEZC2qGPSLOqrD4hjb0YmZt5JfLmmXCYTfLhMNulgmH3SwTDrtZJjxls7WVxld/1eRhlz7SwU7Me3azTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBO+zm5NUd9eyfoj1/1u1dq3p13f1Lb7tlT8xGSrwnt2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTvs5uSeOmHpqsP3TpK5P11e9o/Fr6kl+lZxCavujRZH244S2PTd6zm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZqGfK5mnAV4HJlKZoXhARV0maBNwMHE5p2uYzI2Jj+1ptrz0mTkzWV3/29VVrR9/wbHLs8KrVDfXUCYPn/UGyfv38q5L14/bas5Xt7OLCyz+arB+y/mdt2/ZYVM+efQg4LyJmAm8GPiFpJnAhsCQiZgBLiu/NrEfVDHtEDEbEvcX9zcAq4FDgVGBxsdpi4LR2NWlmzRvVc3ZJhwNvBJYCkyNisCito3SYb2Y9qu6wS9oXuBX4VERsKq9FRFB6Pl9p3DxJA5IGtrG1qWbNrHF1hV1SH6Wg3xgRtxWL10uaUtSnABsqjY2IBRHRHxH9faTf2GBm7VMz7JIE3ACsiogrykp3AnOL+3OBO1rfnpm1Sj1vcX0r8EHgfkkrimUXAZcBt0g6G3gcOLM9LXbG9hdfTNan/rD6GyYv+NatybHz7zsrvfGfH5AsH7g6/WbNtXOq1//0jcuTY2885IvJ+kF77J2s17Jh+KWqtdnXXZAcO/36gWS94vNGq6pm2CPix0C1D+ie3dp2zKxd/Ao6s0w47GaZcNjNMuGwm2XCYTfLhMNulgmVXunaGftrUszS2Lta9+IZs5L1j37+9mT9g/uta2U7HfXhJ05K1p/56JSqteEHH25tM8bSWMKmeK7ipXLv2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTPg6ewfUmvZ41d9OTdbfftyDyfqCaXdXrV2wrj859rYVxybrv3PF5mR9eFV62mS2e+LkTvJ1djNz2M1y4bCbZcJhN8uEw26WCYfdLBMOu1kmfJ3dbAzxdXYzc9jNcuGwm2XCYTfLhMNulgmH3SwTDrtZJmqGXdI0ST+U9AtJD0r6ZLH8EklrJa0obnPa366ZNarm/OzAEHBeRNwraT9guaS7itqVEfGl9rVnZq1SM+wRMQgMFvc3S1oFpD96xcx6zqies0s6HHgjsLRYNF/SSkkLJR1UZcw8SQOSBraxtalmzaxxdYdd0r7ArcCnImITcC1wFHAMpT3/5ZXGRcSCiOiPiP4+xregZTNrRF1hl9RHKeg3RsRtABGxPiKGI2I7cB1wfPvaNLNm1XM2XsANwKqIuKJsefn0nKcDD7S+PTNrlXrOxr8V+CBwv6QVxbKLgLMkHQMEsAY4py0dmllL1HM2/sdApffHfqf17ZhZu/gVdGaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTHZ2yWdLTwONliw4GnulYA6PTq731al/g3hrVyt4Oi4hDKhU6Gvbf2rg0EBH9XWsgoVd769W+wL01qlO9+TDeLBMOu1kmuh32BV3efkqv9tarfYF7a1RHeuvqc3Yz65xu79nNrEMcdrNMdCXskk6R9LCkRyVd2I0eqpG0RtL9xTTUA13uZaGkDZIeKFs2SdJdklYXXyvOsdel3npiGu/ENONdfey6Pf15x5+zS9oTeAT4I+ApYBlwVkT8oqONVCFpDdAfEV1/AYaktwFbgK9GxO8Vy74APBcRlxV/KA+KiE/3SG+XAFu6PY13MVvRlPJpxoHTgA/Rxccu0deZdOBx68ae/Xjg0Yh4LCJeBr4BnNqFPnpeRNwNPDdi8anA4uL+Ykr/WTquSm89ISIGI+Le4v5mYMc041197BJ9dUQ3wn4o8GTZ90/RW/O9B/B9Scslzet2MxVMjojB4v46YHI3m6mg5jTenTRimvGeeewamf68WT5B99tOiIhjgXcBnygOV3tSlJ6D9dK107qm8e6UCtOM79TNx67R6c+b1Y2wrwWmlX0/tVjWEyJibfF1A3A7vTcV9fodM+gWXzd0uZ+demka70rTjNMDj103pz/vRtiXATMkHSFpL+B9wJ1d6OO3SJpYnDhB0kTgZHpvKuo7gbnF/bnAHV3sZRe9Mo13tWnG6fJj1/XpzyOi4zdgDqUz8v8LfKYbPVTp60jgvuL2YLd7A26idFi3jdK5jbOBVwBLgNXAD4BJPdTb14D7gZWUgjWlS72dQOkQfSWworjN6fZjl+irI4+bXy5rlgmfoDPLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMvH/dOr4dtbwai4AAAAASUVORK5CYII=\n", 290 | "text/plain": [ 291 | "
" 292 | ] 293 | }, 294 | "metadata": { 295 | "needs_background": "light" 296 | } 297 | } 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "metadata": { 303 | "id": "N6ZE4roJF0HA" 304 | }, 305 | "source": [ 306 | "#2.4 Logistic regression model\n", 307 | "model = nn.Linear(input_size, num_classes)\n" 308 | ], 309 | "execution_count": 35, 310 | "outputs": [] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "metadata": { 315 | "id": "XXCxm64GH27K" 316 | }, 317 | "source": [ 318 | "#2.5 Cross Entropy Loss \n", 319 | "# nn.CrossEntropyLoss() computes softmax internally\n", 320 | "criterion = nn.CrossEntropyLoss() \n" 321 | ], 322 | "execution_count": 36, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "metadata": { 328 | "id": "pm2Q-PDcH34b" 329 | }, 330 | "source": [ 331 | "#2.6 Optimizer Stochastic Gradient Descent \n", 332 | "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) \n" 333 | ], 334 | "execution_count": 37, 335 | "outputs": [] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "metadata": { 340 | "id": "4TO7Egs7F2o9" 341 | }, 342 | "source": [ 343 | "#2.7 Train the model\n", 344 | "total_step = len(train_loader)\n", 345 | "for epoch in range(num_epochs):\n", 346 | " for i, (images, labels) in enumerate(train_loader):\n", 347 | " # Reshape images to (batch_size, input_size)\n", 348 | " images = images.reshape(-1, input_size)\n", 349 | " \n", 350 | " # Forward pass\n", 351 | " outputs = model(images)\n", 352 | " loss = criterion(outputs, labels)\n", 353 | "\n", 354 | " # Backward and optimize\n", 355 | " optimizer.zero_grad()\n", 356 | " loss.backward()\n", 357 | " optimizer.step()\n", 358 | " \n", 359 | " if (i+1) % 100 == 0:\n", 360 | " print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' \n", 361 | " .format(epoch+1, num_epochs, i+1, total_step, loss.item()))" 362 | ], 363 | "execution_count": null, 364 | "outputs": [] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "metadata": { 369 | "id": "fGO_mdMZF4iS", 370 | "colab": { 371 | "base_uri": "https://localhost:8080/" 372 | }, 373 | "outputId": "7b99b627-0111-4f1a-9103-b7762f59baa8" 374 | }, 375 | "source": [ 376 | "#2.8 Test the model\n", 377 | "# In test phase, we don't need to compute gradients (for memory efficiency)\n", 378 | "with torch.no_grad():\n", 379 | " correct = 0\n", 380 | " total = 0\n", 381 | " for images, labels in test_loader:\n", 382 | " images = images.reshape(-1, input_size)\n", 383 | " outputs = model(images)\n", 384 | " _, predicted = torch.max(outputs.data, 1)\n", 385 | " total += labels.size(0)\n", 386 | " correct += (predicted == labels).sum()\n", 387 | "\n", 388 | " print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))\n" 389 | ], 390 | "execution_count": 39, 391 | "outputs": [ 392 | { 393 | "output_type": "stream", 394 | "name": "stdout", 395 | "text": [ 396 | "Accuracy of the model on the 10000 test images: 82.5 %\n" 397 | ] 398 | } 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "metadata": { 404 | "id": "EBEBeGYsF5t-" 405 | }, 406 | "source": [ 407 | "#2.9 Save the model checkpoint\n", 408 | "torch.save(model.state_dict(), 'model.ckpt')" 409 | ], 410 | "execution_count": 40, 411 | "outputs": [] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": { 416 | "id": "j6nIkwr2YM9l" 417 | }, 418 | "source": [ 419 | "## Exercise 1 (10%): Create custom loss function\n", 420 | "In this section, you will need to create our own Cross Entropy loss function and compare with Pytorch's Cross Entropy loss. The objective of this exercise is to enable you to design your own loss in the future. \n", 421 | "\n", 422 | "Follow the steps below:\n", 423 | "1. Import libraries - copy section 2.1\n", 424 | "2. Set hyperparameter - copy section 2.2\n", 425 | "3. Data loader - copy section 2.3\n", 426 | "4. Initialize Logistic Regression - copy section 2.4\n", 427 | "5. Create custom_CrossEntropyLoss class - copy the following code. Your task is to ***code the log_softmax equation in the log_softmax function.*** \n", 428 | "\n", 429 | "```\n", 430 | "# Custom Loss - Cross Entropy Loss\n", 431 | "class custom_CrossEntropyLoss(nn.Module):\n", 432 | " def __init__(self, weight=None, size_average=True):\n", 433 | " super(custom_CrossEntropyLoss, self).__init__()\n", 434 | " \n", 435 | " def forward(self, inputs, targets, smooth=1): \n", 436 | " num_examples = targets.shape[0]\n", 437 | " batch_size = inputs.shape[0]\n", 438 | " softmax_outputs = self.log_softmax(inputs)\n", 439 | " outputs = softmax_outputs[range(batch_size), targets] \n", 440 | " return -torch.sum(outputs)/num_examples\n", 441 | "\n", 442 | " @staticmethod\n", 443 | " def log_softmax(x):\n", 444 | " return ### put the log_softmax function here ### \n", 445 | "```\n", 446 | "\n", 447 | "6. Initialize custom_CrossEntropyLoss loss as criterion - copy section 2.5. Replace *nn.CrossEntropyLoss* with *custom_CrossEntropyLoss*\n", 448 | "7. Train the model, evaluate it on your testing data. Save your model. \n", 449 | "8. Compare the loss computed from torch and our custom loss. \n", 450 | "\n", 451 | "\n", 452 | " 5% will be given if step 1 - 4 are done correctly
\n", 453 | " 3% will be given if step 5-7 is done correctly
\n", 454 | " 2% will be given if your custom loss and pytorch loss is near zero. " 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "metadata": { 460 | "id": "7azur5ciavP2" 461 | }, 462 | "source": [ 463 | "# your code here" 464 | ], 465 | "execution_count": null, 466 | "outputs": [] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "metadata": { 471 | "id": "2d7IdSoak0p2" 472 | }, 473 | "source": [ 474 | "# Submission Instructions\n", 475 | "Once you are finished, follow these steps:\n", 476 | "\n", 477 | "Restart the kernel and re-run this notebook from beginning to end by going to Kernel > Restart Kernel and Run All Cells.\n", 478 | "If this process stops halfway through, that means there was an error. Correct the error and repeat Step 1 until the notebook runs from beginning to end.\n", 479 | "Double check that there is a number next to each code cell and that these numbers are in order.\n", 480 | "Then, submit your lab as follows:\n", 481 | "\n", 482 | "Go to File > Print > Save as PDF.\n", 483 | "Double check that the entire notebook, from beginning to end, is in this PDF file. Make sure Solution for Exercise 5 are in for marks. \n", 484 | "Upload the PDF to Spectrum. " 485 | ] 486 | } 487 | ] 488 | } -------------------------------------------------------------------------------- /Week5/CV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week5/CV.png -------------------------------------------------------------------------------- /Week5/CV_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week5/CV_test.png -------------------------------------------------------------------------------- /Week5/W5.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week5/W5.pptx -------------------------------------------------------------------------------- /Week5/iris.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week5/iris.PNG -------------------------------------------------------------------------------- /Week5/train_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week5/train_test.png -------------------------------------------------------------------------------- /Week6/W6.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week6/W6.pptx -------------------------------------------------------------------------------- /Week7/W7.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week7/W7.pptx -------------------------------------------------------------------------------- /Week8/W8.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week8/W8.pptx -------------------------------------------------------------------------------- /Week9/Adj_matrix.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week9/Adj_matrix.PNG -------------------------------------------------------------------------------- /Week9/Message_passing.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week9/Message_passing.PNG -------------------------------------------------------------------------------- /Week9/W9.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiernee/Advanced_ML/09ae2103b9ccb9edbe80461ba6dbac148830d1ce/Week9/W9.pptx -------------------------------------------------------------------------------- /Week9/WOA7015_Wk9_Extra.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "WOA7015_Wk9_Extra.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyMuy2xu5jnoqDQLIYuEVjiC", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "# Welcome to WOA7015 Advance Machine Learning Lab - Week 9_extra\n", 34 | "This code is generated for the purpose of WOA7015 module.\n", 35 | "The code is available in github https://github.com/shiernee/Advanced_ML \n" 36 | ], 37 | "metadata": { 38 | "id": "NB-ogNj0uNhs" 39 | } 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "colab": { 46 | "base_uri": "https://localhost:8080/" 47 | }, 48 | "id": "jy2YuQ-suK7m", 49 | "outputId": "6b840de3-ef3c-4c76-f537-29c146f02be1" 50 | }, 51 | "outputs": [ 52 | { 53 | "output_type": "stream", 54 | "name": "stdout", 55 | "text": [ 56 | "[]\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "# Create an empty graph with no nodes and no edges.\n", 62 | "import networkx as nx\n", 63 | "G = nx.Graph()\n", 64 | "print(list(G.nodes)) # examine element in nodes" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "source": [ 70 | "## Add nodes" 71 | ], 72 | "metadata": { 73 | "id": "zM4dl6K8wMhj" 74 | } 75 | }, 76 | { 77 | "cell_type": "code", 78 | "source": [ 79 | "G.add_node(1) # add one node\n", 80 | "print(list(G.nodes)) \n", 81 | "print()\n", 82 | "\n", 83 | "G.add_nodes_from([2, 3]) # add nodes from any iterable container, such as a list\n", 84 | "print(list(G.nodes))\n", 85 | "print()\n", 86 | "\n", 87 | "# add nodes along with node attributes if your container yields 2-tuples of the form (node, node_attribute_dict)\n", 88 | "G.add_nodes_from([\n", 89 | " (4, {\"color\": \"red\"}),\n", 90 | " (5, {\"color\": \"green\"}),\n", 91 | "])\n", 92 | "print(list(G.nodes)) \n", 93 | "print()\n", 94 | "\n", 95 | "# Nodes from one graph can be incorporated into another:\n", 96 | "H = nx.path_graph(10)\n", 97 | "print('H nodes: ', list(H.nodes))\n", 98 | "G.add_nodes_from(H)\n", 99 | "print('G nodes: ', list(G.nodes))\n", 100 | "print()\n", 101 | "\n", 102 | "G.add_node(H)\n", 103 | "print('G nodes: ', list(G.nodes))\n" 104 | ], 105 | "metadata": { 106 | "colab": { 107 | "base_uri": "https://localhost:8080/" 108 | }, 109 | "id": "fbLKZrdfujay", 110 | "outputId": "b0cd5a63-e817-4d0f-f545-092103200071" 111 | }, 112 | "execution_count": null, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "name": "stdout", 117 | "text": [ 118 | "[1]\n", 119 | "\n", 120 | "[1, 2, 3]\n", 121 | "\n", 122 | "[1, 2, 3, 4, 5]\n", 123 | "\n", 124 | "H nodes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", 125 | "G nodes: [1, 2, 3, 4, 5, 0, 6, 7, 8, 9]\n", 126 | "\n", 127 | "G nodes: [1, 2, 3, 4, 5, 0, 6, 7, 8, 9, ]\n" 128 | ] 129 | } 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "source": [ 135 | "# Add edges" 136 | ], 137 | "metadata": { 138 | "id": "_0Fx_PkVwObo" 139 | } 140 | }, 141 | { 142 | "cell_type": "code", 143 | "source": [ 144 | "G.add_edge(1, 2)\n", 145 | "print(list(G.edges))\n", 146 | "print()\n", 147 | "\n", 148 | "e = (2, 3)\n", 149 | "G.add_edge(*e) # unpack edge tuple*\n", 150 | "print(list(G.edges))\n", 151 | "print()\n", 152 | "\n", 153 | "G.add_edges_from([(1, 2), (1, 3)])\n", 154 | "print(list(G.edges))\n", 155 | "print()\n", 156 | "\n", 157 | "G.add_edges_from(H.edges)\n", 158 | "\n", 159 | "print(list(G.adj[1])) # or list(G.neighbors(1))" 160 | ], 161 | "metadata": { 162 | "colab": { 163 | "base_uri": "https://localhost:8080/" 164 | }, 165 | "id": "p08Fo4XLwPb9", 166 | "outputId": "087921c5-bd2d-40d3-eb39-a200687d36d9" 167 | }, 168 | "execution_count": null, 169 | "outputs": [ 170 | { 171 | "output_type": "stream", 172 | "name": "stdout", 173 | "text": [ 174 | "[(1, 2), (2, 3)]\n", 175 | "\n", 176 | "[(1, 2), (2, 3)]\n", 177 | "\n", 178 | "[1, 2, 3, 4, 5, 0, 6, 7, 8, 9, ]\n" 179 | ] 180 | } 181 | ] 182 | } 183 | ] 184 | } --------------------------------------------------------------------------------