├── .gitignore ├── README.md ├── cryptography ├── he-ap14.ipynb ├── homomorphic-encryption.ipynb ├── lwe-reg05.ipynb ├── public-key-encryption.ipynb └── rlwe-lpr10.ipynb ├── jupyter-lab.sh ├── jupyter-notebook.sh ├── swift ├── swift-intro.ipynb └── swift-packages.ipynb ├── word2vec ├── word2vec-intro.ipynb ├── word2vec-intuition.ipynb └── word2vec-sgns-tensorflow.ipynb └── x-from-scratch ├── decision-trees-from-scratch.ipynb ├── gradient-descent-from-scratch.ipynb ├── k-means-clustering-from-scratch.ipynb ├── k-nn-from-scratch.ipynb ├── linear-regression-from-scratch.ipynb ├── logistic-regression-from-scratch.ipynb ├── multiple-regression-from-scratch.ipynb ├── naive-bayes-from-scratch.ipynb └── neural-networks-from-scratch.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | data/ 3 | checkpoints/ 4 | 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | 14 | # Thumbnails 15 | ._* 16 | 17 | # Files that might appear in the root of a volume 18 | .DocumentRevisions-V100 19 | .fseventsd 20 | .Spotlight-V100 21 | .TemporaryItems 22 | .Trashes 23 | .VolumeIcon.icns 24 | .com.apple.timemachine.donotpresent 25 | 26 | # Directories potentially created on remote AFP share 27 | .AppleDB 28 | .AppleDesktop 29 | Network Trash Folder 30 | Temporary Items 31 | .apdisk 32 | 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | .hypothesis/ 82 | .pytest_cache/ 83 | 84 | # Translations 85 | *.mo 86 | *.pot 87 | 88 | # Django stuff: 89 | *.log 90 | local_settings.py 91 | db.sqlite3 92 | 93 | # Flask stuff: 94 | instance/ 95 | .webassets-cache 96 | 97 | # Scrapy stuff: 98 | .scrapy 99 | 100 | # Sphinx documentation 101 | docs/_build/ 102 | 103 | # PyBuilder 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | .python-version 115 | 116 | # celery beat schedule file 117 | celerybeat-schedule 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lab 2 | 3 | Personal lab to play around with algorithms and data. 4 | 5 | **NOTE:** In order to make the implementations as understandable as possible I sometimes write more expressive code which could result in poor performance or disapproval of purists. I strongly believe that readability for such educational endeavors is more important than high-performance or idiomatic code. 6 | 7 | ## Implementations 8 | 9 | ### X from scratch 10 | 11 | From scratch implementations of various algorithms and models in pure Python. 12 | 13 | | Notebook | nbviewer | Google Colab | Blog post | 14 | | --------------------------------------------- | :----------------------------------: | :-------------------------------: | :------------------------------: | 15 | | [Gradient Descent][gradient-descent-nb] | [Link][gradient-descent-nbviewer] | [Link][gradient-descent-colab] | [Link][gradient-descent-post] | 16 | | [k-NN][k-nn-nb] | [Link][k-nn-nbviewer] | [Link][k-nn-colab] | [Link][k-nn-post] | 17 | | [Naive Bayes][naive-bayes-nb] | [Link][naive-bayes-nbviewer] | [Link][naive-bayes-colab] | [Link][naive-bayes-post] | 18 | | [Linear Regression][linear-regression-nb] | [Link][linear-regression-nbviewer] | [Link][linear-regression-colab] | [Link][linear-regression-post] | 19 | | [Multiple Regression][multiple-regression-nb] | [Link][multiple-regression-nbviewer] | [Link][multiple-regression-colab] | [Link][multiple-regression-post] | 20 | | [Logistic Regression][logistic-regression-nb] | [Link][logistic-regression-nbviewer] | [Link][logistic-regression-colab] | [Link][logistic-regression-post] | 21 | | [Decision Trees][decision-trees-nb] | [Link][decision-trees-nbviewer] | [Link][decision-trees-colab] | [Link][decision-trees-post] | 22 | | [Neural Networks][neural-networks-nb] | [Link][neural-networks-nbviewer] | [Link][neural-networks-colab] | Coming soon | 23 | | [k-means Clustering][k-means-clustering-nb] | [Link][k-means-clustering-nbviewer] | [Link][k-means-clustering-colab] | Coming soon | 24 | 25 | [gradient-descent-nb]: ./x-from-scratch/gradient-descent-from-scratch.ipynb 26 | [gradient-descent-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/gradient-descent-from-scratch.ipynb 27 | [gradient-descent-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/gradient-descent-from-scratch.ipynb 28 | [gradient-descent-post]: https://philippmuens.com/gradient-descent-from-scratch/ 29 | [k-nn-nb]: ./x-from-scratch/k-nn-from-scratch.ipynb 30 | [k-nn-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/k-nn-from-scratch.ipynb 31 | [k-nn-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/k-nn-from-scratch.ipynb 32 | [k-nn-post]: https://philippmuens.com/k-nearest-neighbors-from-scratch/ 33 | [naive-bayes-nb]: ./x-from-scratch/naive-bayes-from-scratch.ipynb 34 | [naive-bayes-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/naive-bayes-from-scratch.ipynb 35 | [naive-bayes-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/naive-bayes-from-scratch.ipynb 36 | [naive-bayes-post]: https://philippmuens.com/naive-bayes-from-scratch/ 37 | [linear-regression-nb]: ./x-from-scratch/linear-regression-from-scratch.ipynb 38 | [linear-regression-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/linear-regression-from-scratch.ipynb 39 | [linear-regression-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/linear-regression-from-scratch.ipynb 40 | [linear-regression-post]: https://philippmuens.com/linear-and-multiple-regression-from-scratch/ 41 | [multiple-regression-nb]: ./x-from-scratch/multiple-regression-from-scratch.ipynb 42 | [multiple-regression-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/multiple-regression-from-scratch.ipynb 43 | [multiple-regression-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/multiple-regression-from-scratch.ipynb 44 | [multiple-regression-post]: https://philippmuens.com/linear-and-multiple-regression-from-scratch/ 45 | [logistic-regression-nb]: ./x-from-scratch/logistic-regression-from-scratch.ipynb 46 | [logistic-regression-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/logistic-regression-from-scratch.ipynb 47 | [logistic-regression-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/logistic-regression-from-scratch.ipynb 48 | [logistic-regression-post]: https://philippmuens.com/logistic-regression-from-scratch/ 49 | [decision-trees-nb]: ./x-from-scratch/decision-trees-from-scratch.ipynb 50 | [decision-trees-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/decision-trees-from-scratch.ipynb 51 | [decision-trees-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/decision-trees-from-scratch.ipynb 52 | [decision-trees-post]: https://philippmuens.com/decision-trees-from-scratch/ 53 | [neural-networks-nb]: ./x-from-scratch/neural-networks-from-scratch.ipynb 54 | [neural-networks-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/neural-networks-from-scratch.ipynb 55 | [neural-networks-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/neural-networks-from-scratch.ipynb 56 | [k-means-clustering-nb]: ./x-from-scratch/k-means-clustering-from-scratch.ipynb 57 | [k-means-clustering-nbviewer]: https://nbviewer.jupyter.org/github/pmuens/lab/blob/master/x-from-scratch/k-means-clustering-from-scratch.ipynb 58 | [k-means-clustering-colab]: https://colab.research.google.com/github/pmuens/lab/blob/master/x-from-scratch/k-means-clustering-from-scratch.ipynb 59 | 60 | ## Running it 61 | 62 | **NOTE:** You can pass an optional port number as the first CLI argument (i.e. `./jupyter-lab 3000`). 63 | 64 | ### Jupyter Lab 65 | 66 | ```sh 67 | ./jupyter-lab.sh 68 | ``` 69 | 70 | ### Jupyter Notebook 71 | 72 | ```sh 73 | ./jupyter-notebook.sh 74 | ``` 75 | -------------------------------------------------------------------------------- /cryptography/he-ap14.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# [AP14](https://web.eecs.umich.edu/~cpeikert/pubs/polyboot.pdf) Homomorphic Encryption Scheme\n", 8 | "\n", 9 | "- [Faster Bootstrapping with Polynomial Error](https://web.eecs.umich.edu/~cpeikert/pubs/polyboot.pdf)\n", 10 | "- [Homomorphic Encryption from Learning with Errors:Conceptually-Simpler, Asymptotically-Faster, Attribute-Based](https://eprint.iacr.org/2013/340.pdf)\n", 11 | "- [Fully Homomorphic Encryptionfor Machine Learning](https://www.di.ens.fr/~minelli/docs/phd-thesis.pdf)\n", 12 | "- [Building a Fully Homomorphic Encryption Scheme in Python](https://courses.csail.mit.edu/6.857/2019/project/15-Hedglin-Phillips-Reilley.pdf)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from math import log2, ceil, inf\n", 22 | "import numpy as np\n", 23 | "from numpy.testing import assert_array_equal" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# NOTE: Uncomment to simplfy debugging\n", 33 | "# np.random.seed(1)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Utility Functions" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def generate_gadget_matrix(n, l, modulus):\n", 50 | " \"\"\"\n", 51 | " Generates the gadget matrix `G` which is a block-diagonal matrix of powers of 2.\n", 52 | " \"\"\"\n", 53 | " # NOTE: In the paper the range is `l - 1` but Pythons `range` function already excludes the last entry\n", 54 | " g = np.array([1 << i for i in range(l)])\n", 55 | " I = np.eye(n)\n", 56 | " G = np.kron(I, g).astype(int)\n", 57 | " return G\n", 58 | "\n", 59 | "def test():\n", 60 | " q = 65536\n", 61 | " n = 3\n", 62 | " l = ceil(log2(q))\n", 63 | " result = generate_gadget_matrix(n, l, modulus=q)\n", 64 | " expected = np.array([\n", 65 | " [\n", 66 | " 1, 2, 4, 8, 16, 32, 64, 128, 256,\n", 67 | " 512, 1024, 2048, 4096, 8192, 16384, 32768, 0, 0,\n", 68 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 69 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 70 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 71 | " 0, 0, 0\n", 72 | " ],\n", 73 | " [\n", 74 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 75 | " 0, 0, 0, 0, 0, 0, 0, 1, 2,\n", 76 | " 4, 8, 16, 32, 64, 128, 256, 512, 1024,\n", 77 | " 2048, 4096, 8192, 16384, 32768, 0, 0, 0, 0,\n", 78 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 79 | " 0, 0, 0\n", 80 | " ],\n", 81 | " [\n", 82 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 83 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 84 | " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 85 | " 0, 0, 0, 0, 0, 1, 2, 4, 8,\n", 86 | " 16, 32, 64, 128, 256, 512, 1024, 2048, 4096,\n", 87 | " 8192, 16384, 32768\n", 88 | " ]\n", 89 | " ])\n", 90 | " assert result.shape == (n, n * l)\n", 91 | " assert_array_equal(result, expected)\n", 92 | "\n", 93 | "test()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "def num_to_bin_vector(number, width):\n", 103 | " \"\"\"\n", 104 | " Translates a number to a fixed-width binary vector\n", 105 | " \"\"\"\n", 106 | " return np.array([(int(number) >> i & 1) for i in range(width)])\n", 107 | "\n", 108 | "def test():\n", 109 | " # Integer\n", 110 | " number = 64\n", 111 | " width = 8\n", 112 | " result = num_to_bin_vector(number, width)\n", 113 | " assert_array_equal(result, np.array([0, 0, 0, 0, 0, 0, 1, 0]))\n", 114 | " # Float\n", 115 | " number = 1024.0\n", 116 | " width = 11\n", 117 | " result = num_to_bin_vector(number, width)\n", 118 | " assert_array_equal(result, np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))\n", 119 | "\n", 120 | "test()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def bit_decomp(matrix, padding):\n", 130 | " \"\"\"\n", 131 | " Decomposes a given matrix into its corresponding binary representation\n", 132 | " NOTE: The binary numbers are returned in a columnar fashion starting with the LSB\n", 133 | " ASIDE: This funciton is called G^-1 or G_inv in the literature\n", 134 | " \"\"\"\n", 135 | " result_matrix = []\n", 136 | " for column in matrix.T:\n", 137 | " interim_matrix = []\n", 138 | " for value in column:\n", 139 | " interim_matrix.append(num_to_bin_vector(value, padding))\n", 140 | " result_matrix.append(interim_matrix)\n", 141 | " return np.array(result_matrix).reshape(matrix.shape[1], padding * matrix.shape[0]).T\n", 142 | "\n", 143 | "def test():\n", 144 | " # Non-square matrix\n", 145 | " result = bit_decomp(np.array([\n", 146 | " [64, 32, 16],\n", 147 | " [8, 4, 2]]), padding=8)\n", 148 | " expected = np.array([\n", 149 | " [0, 0, 0],\n", 150 | " [0, 0, 0],\n", 151 | " [0, 0, 0],\n", 152 | " [0, 0, 0],\n", 153 | " [0, 0, 1],\n", 154 | " [0, 1, 0],\n", 155 | " [1, 0, 0],\n", 156 | " [0, 0, 0],\n", 157 | " #########\n", 158 | " [0, 0, 0],\n", 159 | " [0, 0, 1],\n", 160 | " [0, 1, 0],\n", 161 | " [1, 0, 0],\n", 162 | " [0, 0, 0],\n", 163 | " [0, 0, 0],\n", 164 | " [0, 0, 0],\n", 165 | " [0, 0, 0],\n", 166 | " ])\n", 167 | " assert_array_equal(result, expected)\n", 168 | " # Square matrix\n", 169 | " result = bit_decomp(np.array([\n", 170 | " [64, 32],\n", 171 | " [16, 8]]), padding=8)\n", 172 | " expected = np.array([\n", 173 | " [0, 0],\n", 174 | " [0, 0],\n", 175 | " [0, 0],\n", 176 | " [0, 0],\n", 177 | " [0, 0],\n", 178 | " [0, 1],\n", 179 | " [1, 0],\n", 180 | " [0, 0],\n", 181 | " ######\n", 182 | " [0, 0],\n", 183 | " [0, 0],\n", 184 | " [0, 0],\n", 185 | " [0, 1],\n", 186 | " [1, 0],\n", 187 | " [0, 0],\n", 188 | " [0, 0],\n", 189 | " [0, 0],\n", 190 | " ])\n", 191 | " assert_array_equal(result, expected)\n", 192 | "\n", 193 | "test()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "## Security Parameters" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 6, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "q: 65536\n", 213 | "l: 16\n", 214 | "n: 3\n", 215 | "m: 48\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "# NOTE: One might need to increase `q` and `n` such that the growing noise when adding / multiplying ciphertexts\n", 221 | "# won't cause the decryption to fail.\n", 222 | "q = pow(2, 16)\n", 223 | "n = 3\n", 224 | "l = ceil(log2(q))\n", 225 | "m = n * l\n", 226 | "\n", 227 | "print(f'q: {q}')\n", 228 | "print(f'l: {l}')\n", 229 | "print(f'n: {n}')\n", 230 | "print(f'm: {m}')" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## Secret Key" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 7, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "[25708 58288]\n", 250 | "(2,)\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "s = np.random.choice(q, n - 1) % q\n", 256 | "\n", 257 | "print(s)\n", 258 | "print(s.shape)\n", 259 | "\n", 260 | "assert s.shape == (n - 1,)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 8, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | "[25708 58288 1]\n", 273 | "(3,)\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "sk = np.append(s, [1])\n", 279 | "\n", 280 | "print(sk)\n", 281 | "print(sk.shape)\n", 282 | "\n", 283 | "assert sk.shape == (n,)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "## Public Key" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 9, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "[65534 65535 0 0 0 65535 65534 0 0 65534 65534 0\n", 303 | " 1 65535 2 0 0 65535 65535 1 0 1 65535 1\n", 304 | " 0 0 0 1 1 0 0 0 2 65533 0 2\n", 305 | " 0 0 1 1 1 65535 65535 0 1 1 0 0]\n", 306 | "(48,)\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "e = np.rint(np.random.normal(0.0, 1.0, m)).astype(int) % q\n", 312 | "\n", 313 | "print(e)\n", 314 | "print(e.shape)\n", 315 | "\n", 316 | "assert e.shape == (m,)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 10, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "name": "stdout", 326 | "output_type": "stream", 327 | "text": [ 328 | "[[64215 45153 49341 55156 21866 8970 5513 64479 37170 33768 35233 26892\n", 329 | " 21533 12667 14561 32478 55939 6468 21045 17047 30137 49223 39669 57276\n", 330 | " 8985 34606 728 48747 34413 6141 43548 32825 42637 62786 22535 34439\n", 331 | " 46608 12739 23711 15682 51586 39983 57274 63859 63842 11879 11902 22578]\n", 332 | " [26696 4082 42398 46255 55206 35564 38096 46248 5960 37039 26188 27157\n", 333 | " 14038 15815 53570 36947 35058 5331 49211 1027 53443 31511 45654 5497\n", 334 | " 21867 2332 42120 13762 42124 27466 10084 20085 21037 51447 44144 36451\n", 335 | " 13804 318 16071 39140 46208 58346 24600 54678 17659 5866 4189 57839]]\n", 336 | "(2, 48)\n" 337 | ] 338 | } 339 | ], 340 | "source": [ 341 | "A = np.random.choice(q, (n - 1, m)) % q\n", 342 | "\n", 343 | "print(A)\n", 344 | "print(A.shape)\n", 345 | "\n", 346 | "assert A.shape == (n - 1, m)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 11, 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "name": "stdout", 356 | "output_type": "stream", 357 | "text": [ 358 | "[22578 55627 6748 39488 58584 30327 22730 35220 41624 60206 44074 36480\n", 359 | " 17757 55731 17998 4024 6052 41663 56299 32453 23836 57029 61563 57729\n", 360 | " 11548 5800 18592 7301 36413 21404 29840 2940 48494 28325 48116 10758\n", 361 | " 25856 65252 52453 59545 25305 28339 25335 420 31465 3669 35864 65128]\n", 362 | "(48,)\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "b = (s.dot(A) + e) % q\n", 368 | "\n", 369 | "print(b)\n", 370 | "print(b.shape)\n", 371 | "\n", 372 | "assert b.shape == (m,)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 12, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "[[-64215 -45153 -49341 -55156 -21866 -8970 -5513 -64479 -37170 -33768\n", 385 | " -35233 -26892 -21533 -12667 -14561 -32478 -55939 -6468 -21045 -17047\n", 386 | " -30137 -49223 -39669 -57276 -8985 -34606 -728 -48747 -34413 -6141\n", 387 | " -43548 -32825 -42637 -62786 -22535 -34439 -46608 -12739 -23711 -15682\n", 388 | " -51586 -39983 -57274 -63859 -63842 -11879 -11902 -22578]\n", 389 | " [-26696 -4082 -42398 -46255 -55206 -35564 -38096 -46248 -5960 -37039\n", 390 | " -26188 -27157 -14038 -15815 -53570 -36947 -35058 -5331 -49211 -1027\n", 391 | " -53443 -31511 -45654 -5497 -21867 -2332 -42120 -13762 -42124 -27466\n", 392 | " -10084 -20085 -21037 -51447 -44144 -36451 -13804 -318 -16071 -39140\n", 393 | " -46208 -58346 -24600 -54678 -17659 -5866 -4189 -57839]\n", 394 | " [ 22578 55627 6748 39488 58584 30327 22730 35220 41624 60206\n", 395 | " 44074 36480 17757 55731 17998 4024 6052 41663 56299 32453\n", 396 | " 23836 57029 61563 57729 11548 5800 18592 7301 36413 21404\n", 397 | " 29840 2940 48494 28325 48116 10758 25856 65252 52453 59545\n", 398 | " 25305 28339 25335 420 31465 3669 35864 65128]]\n", 399 | "(3, 48)\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "pk = np.row_stack((-A, b))\n", 405 | "\n", 406 | "print(pk)\n", 407 | "print(pk.shape)\n", 408 | "\n", 409 | "assert pk.shape == (n, m)\n", 410 | "assert_array_equal(pk[:n - 1], -A)\n", 411 | "assert_array_equal(pk[n - 1], b)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 13, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "assert_array_equal(sk.dot(pk) % q, e)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "## Encryption" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 14, 433 | "metadata": {}, 434 | "outputs": [ 435 | { 436 | "name": "stdout", 437 | "output_type": "stream", 438 | "text": [ 439 | "42\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "mu = 42\n", 445 | "\n", 446 | "print(mu)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 15, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "name": "stdout", 456 | "output_type": "stream", 457 | "text": [ 458 | "[[ 1 2 4 8 16 32 64 128 256 512 1024 2048\n", 459 | " 4096 8192 16384 32768 0 0 0 0 0 0 0 0\n", 460 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 461 | " 0 0 0 0 0 0 0 0 0 0 0 0]\n", 462 | " [ 0 0 0 0 0 0 0 0 0 0 0 0\n", 463 | " 0 0 0 0 1 2 4 8 16 32 64 128\n", 464 | " 256 512 1024 2048 4096 8192 16384 32768 0 0 0 0\n", 465 | " 0 0 0 0 0 0 0 0 0 0 0 0]\n", 466 | " [ 0 0 0 0 0 0 0 0 0 0 0 0\n", 467 | " 0 0 0 0 0 0 0 0 0 0 0 0\n", 468 | " 0 0 0 0 0 0 0 0 1 2 4 8\n", 469 | " 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768]]\n", 470 | "(3, 48)\n" 471 | ] 472 | } 473 | ], 474 | "source": [ 475 | "G = generate_gadget_matrix(n, l, q)\n", 476 | "\n", 477 | "print(G)\n", 478 | "print(G.shape)\n", 479 | "\n", 480 | "assert G.shape == (n, m)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 16, 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "[[1 1 0 ... 0 0 1]\n", 493 | " [1 0 0 ... 0 0 0]\n", 494 | " [1 0 1 ... 1 1 0]\n", 495 | " ...\n", 496 | " [0 1 1 ... 0 1 1]\n", 497 | " [1 0 1 ... 0 0 1]\n", 498 | " [0 0 1 ... 0 1 1]]\n", 499 | "(48, 48)\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "R = np.random.choice(2, (m, m)) % q\n", 505 | "\n", 506 | "print(R)\n", 507 | "print(R.shape)\n", 508 | "\n", 509 | "assert R.shape == (m, m)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 17, 515 | "metadata": {}, 516 | "outputs": [ 517 | { 518 | "name": "stdout", 519 | "output_type": "stream", 520 | "text": [ 521 | "[[ 5710 58646 63027 60887 20432 15002 29460 49609 25884 60422 10572 59629\n", 522 | " 2726 36781 54479 22299 7826 56710 12694 16278 35086 60355 5784 36685\n", 523 | " 41232 37847 58590 14799 800 2681 54920 4542 1748 53758 2872 42193\n", 524 | " 49190 33652 19464 29554 26236 44235 29130 33487 20935 39222 38041 64979]\n", 525 | " [ 8674 40445 4531 807 15622 21739 10207 25420 27748 1709 9135 20603\n", 526 | " 57437 10352 65079 55336 58944 47881 62907 13295 1305 34742 20109 27045\n", 527 | " 48469 34416 58804 62710 55975 7 24309 56710 25977 7707 51959 49569\n", 528 | " 21407 36200 51714 61164 29839 49131 19509 28358 41080 44484 45471 57315]\n", 529 | " [59241 49465 14669 47931 27041 37756 60226 61172 61287 34691 3742 33395\n", 530 | " 29446 45569 53722 41238 24894 22539 9891 51816 47081 2717 51189 63250\n", 531 | " 9165 43079 45717 13710 50867 5923 55087 11707 16517 36251 54839 62323\n", 532 | " 41032 53458 12223 21163 36451 53199 23391 35467 43404 17015 61736 18928]]\n", 533 | "(3, 48)\n" 534 | ] 535 | } 536 | ], 537 | "source": [ 538 | "C = ((pk @ R) + (mu * G)) % q\n", 539 | "\n", 540 | "print(C)\n", 541 | "print(C.shape)\n", 542 | "\n", 543 | "assert C.shape == (n, m)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "## Decryption" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 18, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "[31153 62321 59105 52671 39809 14084 28162 56320 47095 28667 57342 49151\n", 563 | " 32766 65533 65534 65530 23254 46531 27515 55040 44545 23553 47109 28670\n", 564 | " 57341 49147 32765 2 3 65535 65535 3 37 83 167 335\n", 565 | " 672 1346 2687 5379 10755 21507 43015 20479 40960 16383 32772 4]\n", 566 | "(48,)\n" 567 | ] 568 | } 569 | ], 570 | "source": [ 571 | "msg = sk.dot(C) % q\n", 572 | "\n", 573 | "print(msg)\n", 574 | "print(msg.shape)\n", 575 | "\n", 576 | "assert msg.shape == (m,)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 19, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "name": "stdout", 586 | "output_type": "stream", 587 | "text": [ 588 | "[25708 51416 37296 9056 18112 36224 6912 13824 27648 55296 45056 24576\n", 589 | " 49152 32768 0 0 58288 51040 36544 7552 15104 30208 60416 55296\n", 590 | " 45056 24576 49152 32768 0 0 0 0 1 2 4 8\n", 591 | " 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768]\n", 592 | "(48,)\n" 593 | ] 594 | } 595 | ], 596 | "source": [ 597 | "sg = sk.dot(G) % q\n", 598 | "\n", 599 | "print(sg)\n", 600 | "print(sg.shape)\n", 601 | "\n", 602 | "assert sg.shape == (m,)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 20, 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "name": "stdout", 612 | "output_type": "stream", 613 | "text": [ 614 | "[ 1 1 1 5 2 0 4 4 1 0 1 1 0 1 0 0 0 0 0 7 2 0 0 0\n", 615 | " 1 1 0 0 0 0 0 0 37 41 41 41 42 42 41 42 42 42 42 9 10 1 2 0]\n", 616 | "(48,)\n" 617 | ] 618 | } 619 | ], 620 | "source": [ 621 | "# We might run into \"divide by zero\" RuntimeWarnings here\n", 622 | "with np.errstate(divide='ignore',invalid='ignore'):\n", 623 | " r = (msg // sg)\n", 624 | "\n", 625 | "print(r)\n", 626 | "print(r.shape)\n", 627 | "\n", 628 | "assert r.shape == (m,)" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 21, 634 | "metadata": {}, 635 | "outputs": [ 636 | { 637 | "name": "stdout", 638 | "output_type": "stream", 639 | "text": [ 640 | "The value is: 42\n" 641 | ] 642 | } 643 | ], 644 | "source": [ 645 | "res = 0\n", 646 | "dist = inf\n", 647 | "\n", 648 | "for val in np.unique(r):\n", 649 | " d = (msg - (val * sg)) % q\n", 650 | " d = np.minimum(d, q - d) % q\n", 651 | " d = int(np.linalg.norm(d)) % q\n", 652 | " if d < dist:\n", 653 | " res = val\n", 654 | " dist = d\n", 655 | "\n", 656 | "print(f'The value is: {res}')\n", 657 | "\n", 658 | "assert res == mu" 659 | ] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": {}, 664 | "source": [ 665 | "## Homomorphic Addition / Multiplication" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": 22, 671 | "metadata": {}, 672 | "outputs": [], 673 | "source": [ 674 | "def encrypt(mu):\n", 675 | " \"\"\"\n", 676 | " The encryption logic from above in a single function\n", 677 | " \"\"\"\n", 678 | " return ((pk @ R) + (mu * G)) % q" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 23, 684 | "metadata": {}, 685 | "outputs": [], 686 | "source": [ 687 | "def decrypt(C):\n", 688 | " \"\"\"\n", 689 | " The decryption logic from above in a single function\n", 690 | " \"\"\"\n", 691 | " msg = sk.dot(C) % q\n", 692 | " sg = sk.dot(G) % q\n", 693 | " with np.errstate(divide='ignore',invalid='ignore'):\n", 694 | " r = (msg // sg)\n", 695 | " \n", 696 | " res = 0\n", 697 | " dist = inf\n", 698 | "\n", 699 | " for val in np.unique(r):\n", 700 | " d = (msg - (val * sg)) % q\n", 701 | " d = np.minimum(d, q - d) % q\n", 702 | " d = int(np.linalg.norm(d)) % q\n", 703 | " if d < dist:\n", 704 | " res = val\n", 705 | " dist = d\n", 706 | "\n", 707 | " return res" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 24, 713 | "metadata": {}, 714 | "outputs": [ 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "100\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "res = decrypt((encrypt(42) + encrypt(28) + encrypt(30)) % q)\n", 725 | "\n", 726 | "print(res)\n", 727 | "\n", 728 | "assert res == 42 + 28 + 30" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": 25, 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "name": "stdout", 738 | "output_type": "stream", 739 | "text": [ 740 | "18\n" 741 | ] 742 | } 743 | ], 744 | "source": [ 745 | "res = decrypt(((encrypt(2) + encrypt(4)) % q) @ bit_decomp(encrypt(3), l) % q)\n", 746 | "\n", 747 | "print(res)\n", 748 | "\n", 749 | "assert res == (2 + 4) * 3" 750 | ] 751 | } 752 | ], 753 | "metadata": { 754 | "kernelspec": { 755 | "display_name": "Python 3", 756 | "language": "python", 757 | "name": "python3" 758 | }, 759 | "language_info": { 760 | "codemirror_mode": { 761 | "name": "ipython", 762 | "version": 3 763 | }, 764 | "file_extension": ".py", 765 | "mimetype": "text/x-python", 766 | "name": "python", 767 | "nbconvert_exporter": "python", 768 | "pygments_lexer": "ipython3", 769 | "version": "3.6.9" 770 | } 771 | }, 772 | "nbformat": 4, 773 | "nbformat_minor": 4 774 | } 775 | -------------------------------------------------------------------------------- /cryptography/homomorphic-encryption.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# [Homomorphic Encryption](https://en.wikipedia.org/wiki/Homomorphic_encryption)\n", 8 | "\n", 9 | "Playing around with (fully) homomorphic encryption schemes." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "from math import gcd\n", 20 | "from random import randint, randrange\n", 21 | "from typing import List, Tuple, NamedTuple\n", 22 | "from numpy import ndarray" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm\n", 32 | "def egcd(a: int, b: int) -> Tuple[int, int, int]:\n", 33 | " if a == 0:\n", 34 | " return (b, 0, 1)\n", 35 | " else:\n", 36 | " g, y, x = egcd(b % a, a)\n", 37 | " return (g, x - (b // a) * y, y)\n", 38 | "\n", 39 | "def modinv(a: int, m: int) -> int:\n", 40 | " g, x, y = egcd(a, m)\n", 41 | " if g != 1:\n", 42 | " raise Exception('modular inverse does not exist')\n", 43 | " else:\n", 44 | " return x % m\n", 45 | "\n", 46 | "assert modinv(17, 3120) == 2753\n", 47 | "assert egcd(1071, 462) == (21, -3, 7)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## [El Gamal](https://en.wikipedia.org/wiki/ElGamal_encryption)\n", 55 | "\n", 56 | "El Gamal can be used to perform encrypted multiplications." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "class PublicKey(NamedTuple):\n", 66 | " p: int\n", 67 | " a: int\n", 68 | " b: int\n", 69 | "\n", 70 | "class SecretKey(NamedTuple):\n", 71 | " d: int\n", 72 | "\n", 73 | "class Ciphertext(NamedTuple):\n", 74 | " r: int\n", 75 | " t: int" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def keygen(p: int) -> Tuple[PublicKey, SecretKey]:\n", 85 | " # a: 1 < a < p - 1\n", 86 | " a: int = randint(1, p - 1)\n", 87 | " # d: 2 <= d <= p - 2\n", 88 | " d: int = randint(2, p - 2)\n", 89 | " b: int = (a ** d) % p\n", 90 | " pk: PublicKey = PublicKey(p, a, b)\n", 91 | " sk: SecretKey = SecretKey(d)\n", 92 | " return (pk, sk)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "def encrypt(message: int, pk: PublicKey) -> Ciphertext:\n", 102 | " k: int = randint(0, 100) \n", 103 | " r: int = (pk.a ** k) % pk.p\n", 104 | " t: int = ((pk.b ** k) * message) % pk.p\n", 105 | " return Ciphertext(r, t)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:\n", 115 | " # NOTE: This implementation of https://en.wikipedia.org/wiki/Modular_multiplicative_inverse is expensive\n", 116 | " # TODO: One can use the `modinv` function from above but I'll leave this code here as another way to compute it\n", 117 | " return ((c.r ** sk.d) ** (pk.p - 2) * c.t) % pk.p" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 7, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:\n", 127 | " r: int = a.r * b.r\n", 128 | " t: int = a.t * b.t\n", 129 | " return Ciphertext(r, t)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 8, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "--- Message Encryption / Decryption ---\n", 142 | "Message (Plaintext): 42\n", 143 | "Message (Ciphertext): Ciphertext(r=41, t=32)\n", 144 | "Message (Decrypted): 42\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "pk, sk = keygen(47)\n", 150 | "\n", 151 | "print('--- Message Encryption / Decryption ---')\n", 152 | "plaintext: int = 42\n", 153 | "print(f'Message (Plaintext): {plaintext}')\n", 154 | " \n", 155 | "ciphertext: Ciphertext = encrypt(plaintext, pk)\n", 156 | "print(f'Message (Ciphertext): {ciphertext}')\n", 157 | "\n", 158 | "decrypted: int = decrypt(ciphertext, pk, sk)\n", 159 | "print(f'Message (Decrypted): {decrypted}')\n", 160 | "\n", 161 | "assert plaintext == decrypted" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "--- Encrypted Multiplication ---\n", 174 | "Numbers (Plaintext): 6, 5\n", 175 | "Result (Plaintext): 30\n", 176 | "Numbers (Ciphertext): Ciphertext(r=12, t=18), Ciphertext(r=8, t=38)\n", 177 | "Result (Ciphertext): Ciphertext(r=96, t=684)\n", 178 | "Result (Decrypted): 30\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "pk, sk = keygen(47)\n", 184 | "\n", 185 | "print('--- Encrypted Multiplication ---')\n", 186 | "a: int = 6\n", 187 | "b: int = 5\n", 188 | "print(f'Numbers (Plaintext): {a}, {b}')\n", 189 | "print(f'Result (Plaintext): {a * b}')\n", 190 | "\n", 191 | "enc_a: Ciphertext = encrypt(a, pk)\n", 192 | "enc_b: Ciphertext = encrypt(b, pk)\n", 193 | "print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')\n", 194 | "\n", 195 | "result: Ciphertext = mult(enc_a, enc_b)\n", 196 | "print(f'Result (Ciphertext): {result}')\n", 197 | "decrypted: int = decrypt(result, pk, sk)\n", 198 | "print(f'Result (Decrypted): {decrypted}')\n", 199 | "\n", 200 | "assert a * b == decrypted" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem\n", 208 | "\n", 209 | "RSA can be used to perform encrypted multiplications." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "class PublicKey(NamedTuple):\n", 219 | " e: int\n", 220 | " n: int\n", 221 | "\n", 222 | "class SecretKey(NamedTuple):\n", 223 | " d: int\n", 224 | " n: int\n", 225 | "\n", 226 | "class Ciphertext(NamedTuple):\n", 227 | " m: int" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 11, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:\n", 237 | " n: int = p * q\n", 238 | " phi: int = (p - 1) * (q - 1)\n", 239 | " # e must be greater than 1 and smaller than phi\n", 240 | " # furthermore gcd(phi, e) must be 1\n", 241 | " e: int = 2\n", 242 | " while gcd(phi, e) != 1:\n", 243 | " e += 1\n", 244 | " d: int = modinv(e, phi)\n", 245 | " pk: PublicKey = PublicKey(e, n)\n", 246 | " sk: SecretKey = SecretKey(d, n)\n", 247 | " return (pk, sk)\n", 248 | " \n", 249 | "assert keygen(61, 53)[0] == PublicKey(7, 3233)\n", 250 | "assert keygen(61, 53)[1] == SecretKey(1783, 3233)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 12, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "def encrypt(message: int, pk: PublicKey) -> Ciphertext:\n", 260 | " return Ciphertext(message ** pk.e % pk.n)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 13, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "def decrypt(c: Ciphertext, sk: SecretKey) -> int:\n", 270 | " return c.m ** sk.d % sk.n" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 14, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def mult(a: Ciphertext, b: Ciphertext) -> Ciphertext:\n", 280 | " return Ciphertext(m=(a.m * b.m))" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 15, 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "--- Message Encryption / Decryption ---\n", 293 | "Message (Plaintext): 42\n", 294 | "Message (Ciphertext): Ciphertext(m=240)\n", 295 | "Message (Decrypted): 42\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "pk, sk = keygen(61, 53)\n", 301 | "\n", 302 | "print('--- Message Encryption / Decryption ---')\n", 303 | "plaintext: int = 42\n", 304 | "print(f'Message (Plaintext): {plaintext}')\n", 305 | " \n", 306 | "ciphertext: Ciphertext = encrypt(plaintext, pk)\n", 307 | "print(f'Message (Ciphertext): {ciphertext}')\n", 308 | "\n", 309 | "decrypted: int = decrypt(ciphertext, sk)\n", 310 | "print(f'Message (Decrypted): {decrypted}')\n", 311 | "\n", 312 | "assert plaintext == decrypted" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 16, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "--- Encrypted Multiplication ---\n", 325 | "Numbers (Plaintext): 6, 5\n", 326 | "Result (Plaintext): 30\n", 327 | "Numbers (Ciphertext): Ciphertext(m=1898), Ciphertext(m=533)\n", 328 | "Result (Ciphertext): Ciphertext(m=1011634)\n", 329 | "Result (Decrypted): 30\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "pk, sk = keygen(61, 53)\n", 335 | "\n", 336 | "print('--- Encrypted Multiplication ---')\n", 337 | "a: int = 6\n", 338 | "b: int = 5\n", 339 | "print(f'Numbers (Plaintext): {a}, {b}')\n", 340 | "print(f'Result (Plaintext): {a * b}')\n", 341 | "\n", 342 | "enc_a: Ciphertext = encrypt(a, pk)\n", 343 | "enc_b: Ciphertext = encrypt(b, pk)\n", 344 | "print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')\n", 345 | "\n", 346 | "result: Ciphertext = mult(enc_a, enc_b)\n", 347 | "print(f'Result (Ciphertext): {result}')\n", 348 | "decrypted: int = decrypt(result, sk)\n", 349 | "print(f'Result (Decrypted): {decrypted}')\n", 350 | "\n", 351 | "assert a * b == decrypted" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## [Paillier](https://en.wikipedia.org/wiki/Paillier_cryptosystem) Cryptosystem\n", 359 | "\n", 360 | "Paillier can be used to perform encrypted additions." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 17, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "class PublicKey(NamedTuple):\n", 370 | " n: int\n", 371 | " g: int\n", 372 | "\n", 373 | "class SecretKey(NamedTuple):\n", 374 | " la: int\n", 375 | " mu: int\n", 376 | "\n", 377 | "class Ciphertext(NamedTuple):\n", 378 | " m: int" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 18, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "def keygen(p: int, q: int) -> Tuple[PublicKey, SecretKey]:\n", 388 | " assert p.bit_length() == q.bit_length()\n", 389 | " n: int = p * q\n", 390 | " g: int = n + 1\n", 391 | " la: int = (p - 1) * (q - 1)\n", 392 | " mu: int = modinv(la, n)\n", 393 | " pk: PublicKey = PublicKey(n, g)\n", 394 | " sk: SecretKey = SecretKey(la, mu)\n", 395 | " return (pk, sk)\n", 396 | "\n", 397 | "assert keygen(61, 53)[0] == PublicKey(3233, 3234)\n", 398 | "assert keygen(61, 53)[1] == SecretKey(3120, 2718)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 19, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "def encrypt(message: int, pk: PublicKey) -> Ciphertext:\n", 408 | " r: int = 0\n", 409 | " while gcd(r, pk.n) != 1:\n", 410 | " r: int = randrange(0, pk.n + 1)\n", 411 | " m: int = ((pk.g ** message % pk.n ** 2) * (r ** pk.n % pk.n ** 2)) % pk.n ** 2\n", 412 | " return Ciphertext(m)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 20, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "def decrypt(c: Ciphertext, pk: PublicKey, sk: SecretKey) -> int:\n", 422 | " return ((((c.m ** sk.la) % pk.n ** 2) - 1) // pk.n) * sk.mu % pk.n" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 21, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "def add(a: Ciphertext, b: Ciphertext, pk: PublicKey) -> Ciphertext:\n", 432 | " return Ciphertext(a.m * b.m % (pk.n ** 2))" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 22, 438 | "metadata": {}, 439 | "outputs": [ 440 | { 441 | "name": "stdout", 442 | "output_type": "stream", 443 | "text": [ 444 | "--- Message Encryption / Decryption ---\n", 445 | "Message (Plaintext): 42\n", 446 | "Message (Ciphertext): Ciphertext(m=1548548)\n", 447 | "Message (Decrypted): 42\n" 448 | ] 449 | } 450 | ], 451 | "source": [ 452 | "pk, sk = keygen(61, 53)\n", 453 | "\n", 454 | "print('--- Message Encryption / Decryption ---')\n", 455 | "plaintext: int = 42\n", 456 | "print(f'Message (Plaintext): {plaintext}')\n", 457 | " \n", 458 | "ciphertext: Ciphertext = encrypt(plaintext, pk)\n", 459 | "print(f'Message (Ciphertext): {ciphertext}')\n", 460 | "\n", 461 | "decrypted: int = decrypt(ciphertext, pk, sk)\n", 462 | "print(f'Message (Decrypted): {decrypted}')\n", 463 | "\n", 464 | "assert plaintext == decrypted" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 23, 470 | "metadata": {}, 471 | "outputs": [ 472 | { 473 | "name": "stdout", 474 | "output_type": "stream", 475 | "text": [ 476 | "--- Encrypted Addition (with encrypted values) ---\n", 477 | "Numbers (Plaintext): 6, 5\n", 478 | "Result (Plaintext): 11\n", 479 | "Numbers (Ciphertext): Ciphertext(m=934918), Ciphertext(m=4097088)\n", 480 | "Result (Ciphertext): Ciphertext(m=1421243)\n", 481 | "Result (Decrypted): 11\n", 482 | "--- Encrypted Addition (with encrypted and plaintext value) ---\n", 483 | "Numbers (Plaintext): 13, 13\n", 484 | "Result (Plaintext): 33\n", 485 | "Number (Ciphertext): Ciphertext(m=10310040)\n", 486 | "Result (Ciphertext): Ciphertext(m=10436127)\n", 487 | "Result (Decrypted): 33\n", 488 | "--- Encrypted Multiplication (with encrypted and plaintext value) ---\n", 489 | "Numbers (Plaintext): 12, 12\n", 490 | "Result (Plaintext): 24\n", 491 | "Number (Ciphertext): Ciphertext(m=4026668)\n", 492 | "Result (Ciphertext): Ciphertext(m=3272659)\n", 493 | "Result (Decrypted): 24\n" 494 | ] 495 | } 496 | ], 497 | "source": [ 498 | "pk, sk = keygen(61, 53)\n", 499 | "\n", 500 | "print('--- Encrypted Addition (with encrypted values) ---')\n", 501 | "a: int = 6\n", 502 | "b: int = 5\n", 503 | "print(f'Numbers (Plaintext): {a}, {b}')\n", 504 | "print(f'Result (Plaintext): {a + b}')\n", 505 | "\n", 506 | "enc_a: Ciphertext = encrypt(a, pk)\n", 507 | "enc_b: Ciphertext = encrypt(b, pk)\n", 508 | "print(f'Numbers (Ciphertext): {enc_a}, {enc_b}')\n", 509 | "\n", 510 | "result: Ciphertext = add(enc_a, enc_b, pk)\n", 511 | "print(f'Result (Ciphertext): {result}')\n", 512 | "decrypted: int = decrypt(result, pk, sk)\n", 513 | "print(f'Result (Decrypted): {decrypted}')\n", 514 | "\n", 515 | "assert a + b == decrypted\n", 516 | "\n", 517 | "print('--- Encrypted Addition (with encrypted and plaintext value) ---')\n", 518 | "a: int = 20\n", 519 | "b: int = 13\n", 520 | "print(f'Numbers (Plaintext): {b}, {b}')\n", 521 | "print(f'Result (Plaintext): {a + b}')\n", 522 | "\n", 523 | "enc_a: Ciphertext = encrypt(a, pk)\n", 524 | "print(f'Number (Ciphertext): {enc_a}')\n", 525 | "\n", 526 | "# `pk.n + 1` == `g`\n", 527 | "result: Ciphertext = Ciphertext(enc_a.m * (pk.n + 1) ** b % (pk.n ** 2))\n", 528 | "print(f'Result (Ciphertext): {result}')\n", 529 | "decrypted: int = decrypt(result, pk, sk)\n", 530 | "print(f'Result (Decrypted): {decrypted}')\n", 531 | "\n", 532 | "assert a + b == decrypted\n", 533 | "\n", 534 | "print('--- Encrypted Multiplication (with encrypted and plaintext value) ---')\n", 535 | "a: int = 2\n", 536 | "b: int = 12\n", 537 | "print(f'Numbers (Plaintext): {b}, {b}')\n", 538 | "print(f'Result (Plaintext): {a * b}')\n", 539 | "\n", 540 | "enc_a: Ciphertext = encrypt(a, pk)\n", 541 | "print(f'Number (Ciphertext): {enc_a}')\n", 542 | "\n", 543 | "result: Ciphertext = Ciphertext(enc_a.m ** b % (pk.n ** 2))\n", 544 | "print(f'Result (Ciphertext): {result}')\n", 545 | "decrypted: int = decrypt(result, pk, sk)\n", 546 | "print(f'Result (Decrypted): {decrypted}')\n", 547 | "\n", 548 | "assert a * b == decrypted" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": {}, 554 | "source": [ 555 | "## [Efficient Homomorphic Encryption on Integer Vectors and Its Applications](https://www.rle.mit.edu/sia/wp-content/uploads/2015/04/2014-zhou-wornell-ita.pdf)\n", 556 | "\n", 557 | "**NOTE:** The code written here was produced by following the blog post [\"Building Safe A.I.\"](http://iamtrask.github.io/2017/03/17/safe-ai/) by Andrew Trask." 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "metadata": {}, 563 | "source": [ 564 | "### Terminology\n", 565 | "\n", 566 | "- **S**: Matrix which represents the secret / private key\n", 567 | "- **M**: Public Key (also used to perform Math operations)\n", 568 | "- **c**: Vector which contains the encrypted data\n", 569 | "- **x**: Plaintext (some papers use the variable **m** instead)\n", 570 | "- ***w***: (Weighting) Scalar used to control signal / noise ratio of **x**\n", 571 | "- **e**: Random noise (e.g. noise added to the data before encrypting it via the public key) which makes the decryption difficult\n", 572 | "\n", 573 | "Homomorphic Encryption has 4 kind of operations we care about:\n", 574 | "\n", 575 | "1. Public / private keypair generation\n", 576 | "1. One-way encryption\n", 577 | "1. Decryption\n", 578 | "1. Math operations\n", 579 | "\n", 580 | "$$\n", 581 | "\\textit{S}c = \\textit{w}x + e\n", 582 | "$$\n", 583 | "\n", 584 | "$$\n", 585 | "x = \\lceil \\frac{Sc}{\\textit{w}} \\rfloor\n", 586 | "$$" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 24, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "def generate_key(w: int, m: int, n: int) -> ndarray:\n", 596 | " S: ndarray = (np.random.rand(m, n) * w / (2 ** 16))\n", 597 | " return S\n", 598 | "\n", 599 | "def encrypt(x: ndarray, S: ndarray, m: int, n: int, w: int) -> ndarray:\n", 600 | " assert len(x) == len(S)\n", 601 | " e: ndarray = (np.random.rand(m))\n", 602 | " c: ndarray = np.linalg.inv(S).dot((w * x) + e)\n", 603 | " return c\n", 604 | "\n", 605 | "def decrypt(c: ndarray, S: ndarray, w) -> ndarray:\n", 606 | " return (S.dot(c) / w).astype('int')\n", 607 | "\n", 608 | "def switch_key(c: ndarray, S: ndarray, m: int, n: int, T) -> (ndarray, ndarray):\n", 609 | " l: int = int(np.ceil(np.log2(np.max(np.abs(c)))))\n", 610 | " c_star: ndarray = get_c_star(c, m, l)\n", 611 | " S_star: ndarray = get_S_star(S, m, n, l)\n", 612 | " n_prime = n + 1\n", 613 | " S_prime = np.concatenate((np.eye(m), T.T), 0).T\n", 614 | " A: ndarray = (np.random.rand(n_prime - m, n * l) * 10).astype('int')\n", 615 | " E: ndarray = (1 * np.random.rand(S_star.shape[0], S_star.shape[1])).astype('int')\n", 616 | " M: ndarray = np.concatenate(((S_star - T.dot(A) + E), A), 0)\n", 617 | " c_prime: ndarray = M.dot(c_star)\n", 618 | " return c_prime, S_prime\n", 619 | "\n", 620 | "def get_c_star(c: ndarray, m: int, l: int) -> ndarray:\n", 621 | " c_star: ndarray = np.zeros(l * m, dtype='int')\n", 622 | " for i in range(m):\n", 623 | " b: ndarray = np.array(list(np.binary_repr(np.abs(c[i]))), dtype='int')\n", 624 | " if (c[i] < 0):\n", 625 | " b *= -1\n", 626 | " c_star[(i * l) + (l - len(b)): (i + 1) * l] += b\n", 627 | " return c_star\n", 628 | "\n", 629 | "def get_S_star(S: ndarray, m: int, n: int, l: int) -> ndarray:\n", 630 | " S_star: List = list()\n", 631 | " for i in range(l):\n", 632 | " S_star.append(S * 2 ** (l - i - 1))\n", 633 | " S_star: ndarray = np.array(S_star).transpose(1, 2, 0).reshape(m, n * l)\n", 634 | " return S_star\n", 635 | "\n", 636 | "def get_T(n: int) -> ndarray:\n", 637 | " n_prime = n + 1\n", 638 | " T: ndarray = (10 * np.random.rand(n, n_prime - n)).astype('int')\n", 639 | " return T\n", 640 | "\n", 641 | "def encrypt_via_switch(x: ndarray, w: int, m: int, n: int, T: ndarray) -> (ndarray, ndarray):\n", 642 | " c, S = switch_key(x * w, np.eye(m), m, n, T)\n", 643 | " return (c, S)" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 25, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "data": { 653 | "text/plain": [ 654 | "array([[3.23752234e-05, 8.71385929e-05, 1.31983522e-04, 2.07309940e-05],\n", 655 | " [1.86839421e-04, 6.44340965e-05, 1.80036139e-04, 1.37908927e-04],\n", 656 | " [3.59239655e-05, 1.13418365e-04, 2.11182783e-04, 3.64218508e-05],\n", 657 | " [1.56050588e-04, 2.07357406e-04, 1.54697291e-04, 4.90344554e-05]])" 658 | ] 659 | }, 660 | "execution_count": 25, 661 | "metadata": {}, 662 | "output_type": "execute_result" 663 | } 664 | ], 665 | "source": [ 666 | "x: ndarray = np.array([0, 1, 2, 5])\n", 667 | " \n", 668 | "m: int = len(x)\n", 669 | "n: int = m\n", 670 | "w: int = 16\n", 671 | "\n", 672 | "S: ndarray = generate_key(w, m, n)\n", 673 | "S" 674 | ] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": {}, 679 | "source": [ 680 | "### Basic addition / multiplication" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 26, 686 | "metadata": {}, 687 | "outputs": [ 688 | { 689 | "data": { 690 | "text/plain": [ 691 | "array([-20682572.06481609, 15771016.05581861, -10814309.73560343,\n", 692 | " 34891263.61949188])" 693 | ] 694 | }, 695 | "execution_count": 26, 696 | "metadata": {}, 697 | "output_type": "execute_result" 698 | } 699 | ], 700 | "source": [ 701 | "c: ndarray = encrypt(x, S, m, n, w)\n", 702 | "c" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 27, 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/plain": [ 713 | "array([0, 1, 2, 5])" 714 | ] 715 | }, 716 | "execution_count": 27, 717 | "metadata": {}, 718 | "output_type": "execute_result" 719 | } 720 | ], 721 | "source": [ 722 | "decrypt(c, S, w)" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 28, 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "text/plain": [ 733 | "array([ 0, 2, 4, 10])" 734 | ] 735 | }, 736 | "execution_count": 28, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "decrypt(c + c, S, w)" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 29, 748 | "metadata": {}, 749 | "outputs": [ 750 | { 751 | "data": { 752 | "text/plain": [ 753 | "array([ 0, 10, 20, 50])" 754 | ] 755 | }, 756 | "execution_count": 29, 757 | "metadata": {}, 758 | "output_type": "execute_result" 759 | } 760 | ], 761 | "source": [ 762 | "decrypt(c * 10, S, w)" 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": {}, 768 | "source": [ 769 | "### Key-switching addition / multiplication" 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 30, 775 | "metadata": {}, 776 | "outputs": [], 777 | "source": [ 778 | "T: ndarray = get_T(n)" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 31, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "c, S = encrypt_via_switch(x, w, m, n, T)" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": 32, 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "array([0, 1, 2, 5])" 799 | ] 800 | }, 801 | "execution_count": 32, 802 | "metadata": {}, 803 | "output_type": "execute_result" 804 | } 805 | ], 806 | "source": [ 807 | "decrypt(c, S, w)" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": 33, 813 | "metadata": {}, 814 | "outputs": [ 815 | { 816 | "data": { 817 | "text/plain": [ 818 | "array([ 0, 2, 4, 10])" 819 | ] 820 | }, 821 | "execution_count": 33, 822 | "metadata": {}, 823 | "output_type": "execute_result" 824 | } 825 | ], 826 | "source": [ 827 | "decrypt(c + c, S, w)" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": 34, 833 | "metadata": {}, 834 | "outputs": [ 835 | { 836 | "data": { 837 | "text/plain": [ 838 | "array([ 0, 10, 20, 50])" 839 | ] 840 | }, 841 | "execution_count": 34, 842 | "metadata": {}, 843 | "output_type": "execute_result" 844 | } 845 | ], 846 | "source": [ 847 | "decrypt(c * 10, S, w)" 848 | ] 849 | } 850 | ], 851 | "metadata": { 852 | "kernelspec": { 853 | "display_name": "Python 3", 854 | "language": "python", 855 | "name": "python3" 856 | }, 857 | "language_info": { 858 | "codemirror_mode": { 859 | "name": "ipython", 860 | "version": 3 861 | }, 862 | "file_extension": ".py", 863 | "mimetype": "text/x-python", 864 | "name": "python", 865 | "nbconvert_exporter": "python", 866 | "pygments_lexer": "ipython3", 867 | "version": "3.6.9" 868 | } 869 | }, 870 | "nbformat": 4, 871 | "nbformat_minor": 4 872 | } 873 | -------------------------------------------------------------------------------- /cryptography/lwe-reg05.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# [Regev's LWE](https://cims.nyu.edu/~regev/papers/qcrypto.pdf) Public Key Cryptosystem" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from numpy.testing import assert_array_equal" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Modulus\n", 27 | "q = 65536\n", 28 | "# Lattice dimension\n", 29 | "n = 3" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Secret Key" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "array([ 1105, 19041, 47494])" 48 | ] 49 | }, 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "# Our LWE secret which is used below to construct the Public Key\n", 57 | "sk = np.random.choice(q, n)\n", 58 | "sk" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "array([64431, 46495, 18042, 1])" 70 | ] 71 | }, 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "output_type": "execute_result" 75 | } 76 | ], 77 | "source": [ 78 | "# With `v` (a variant of our Secret Key `sk`) we can basically recover the error we add when creating our Public Key `pk`\n", 79 | "v = np.append(-sk % q, [1])\n", 80 | "v" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Public Key" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "array([ 5851, 12895, 7497])" 99 | ] 100 | }, 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "# This is the error we're introducing to make it hard to recover information from our Public Key `pk`\n", 108 | "e = np.random.choice(q // 4, n)\n", 109 | "e" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "array([[43390, 454, 35574],\n", 121 | " [35489, 32362, 33863],\n", 122 | " [27010, 23733, 10572]])" 123 | ] 124 | }, 125 | "execution_count": 6, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "# The matrix `A` is our System of Linear Equations\n", 132 | "A = np.random.choice(q, (n, n))\n", 133 | "A" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 7, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "array([ 6787, 43428, 33736])" 145 | ] 146 | }, 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "# Here we're computing results for our System of Linear equations (`A`) via our `sk` vector\n", 154 | "# Note that we add the error term we've defined above\n", 155 | "b = (A.dot(sk) + e) % q\n", 156 | "b" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 8, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# The public key is the \"plain\" matrix `A` in combination with an evaluation of such matrix via our `sk`\n", 166 | "# We essentially append a \"solution\" column to our System of Linear Equations (which is our matrix `A`)\n", 167 | "pk = np.column_stack((A, b))\n", 168 | "\n", 169 | "# Here we test that we can use `v` (a modified version of our Secret Key `sk`) to recover the error `e`\n", 170 | "# Recovering the error from the Public Key makes it possible to retain encrypted values\n", 171 | "assert_array_equal(pk.dot(v) % q, e)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "## Encryption" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 9, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "32768" 190 | ] 191 | }, 192 | "execution_count": 9, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "# Our message is a bit (0 or 1) rather than a number or a string of text\n", 199 | "# `mu` is either `0` or `q // 2` which makes it possible to determine whether the bit was `0` or `1` once\n", 200 | "# the error `e` was removed (when decrypting later on)\n", 201 | "m = 1\n", 202 | "mu = m * q // 2\n", 203 | "mu" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 10, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "text/plain": [ 214 | "array([1, 0, 1])" 215 | ] 216 | }, 217 | "execution_count": 10, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "# The `x` acts as a mask which determines which parts of the `pk` / `A` matrix we're about to evaluate\n", 224 | "x = np.random.choice(2, n)\n", 225 | "x" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 11, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "array([ 0, 0, 0, 32768])" 237 | ] 238 | }, 239 | "execution_count": 11, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "# Using `emb` we're embedding the `mu` value into the \"result column\" when we're evaluating the `pk` / `A` matrix\n", 246 | "emb = np.append(np.full(n, 0), mu)\n", 247 | "emb" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 12, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "array([ 4864, 24187, 46146, 7755])" 259 | ] 260 | }, 261 | "execution_count": 12, 262 | "metadata": {}, 263 | "output_type": "execute_result" 264 | } 265 | ], 266 | "source": [ 267 | "# Using our mask `x` we can evaluate our `pk` / `A` matrix and add the `emb` vector to it such that the `mu` value is\n", 268 | "# embedded into the result which is our ciphertext\n", 269 | "c = (x.dot(pk) + emb) % q\n", 270 | "c" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "## Decryption" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 13, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "data": { 287 | "text/plain": [ 288 | "46116" 289 | ] 290 | }, 291 | "execution_count": 13, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "# Decryption is as simple as using our `sk` variant `v` to remove the noise from the ciphertext\n", 298 | "p = c.dot(v) % q\n", 299 | "p" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 14, 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "The message is 1\n", 312 | "\n", 313 | "abs((p - (q // 2)) % q): 13348\n", 314 | "(p % q):\t\t 46116\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "# We can now check if the value we're getting is closer to `q // 2` or `0`\n", 320 | "# Closer to `q // 2` --> 1 was embedded\n", 321 | "# Closer to `0` --> 0 was embedded\n", 322 | "if abs((p - (q // 2)) % q) < (p % q):\n", 323 | " print('The message is 1')\n", 324 | "else:\n", 325 | " print('The message is 0')\n", 326 | "\n", 327 | "print()\n", 328 | "print(f'abs((p - (q // 2)) % q): {abs((p - q // 2) % q)}')\n", 329 | "print(f'(p % q):\\t\\t {(p % q)}')" 330 | ] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.6.9" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 4 354 | } 355 | -------------------------------------------------------------------------------- /cryptography/public-key-encryption.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Public Key Encryption" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from math import sqrt\n", 17 | "from typing import Tuple" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## [Diffie-Hellman](https://en.wikipedia.org/wiki/Diffie–Hellman_key_exchange) Key Exchange" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# Public, shared information\n", 34 | "p: int = 23 # A prime number\n", 35 | "g: int = 5 # A base number" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# Private, non-shared information\n", 45 | "a: int = 4 # Alices secret exponent\n", 46 | "b: int = 3 # Bobs secret exponent" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def alice_enc(p: int, g: int) -> int:\n", 56 | " return (g ** a) % p" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 5, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def bob_enc(p: int, g: int) -> int:\n", 66 | " return (g ** b) % p" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "j: int = alice_enc(p, g)\n", 76 | "k: int = bob_enc(p, g)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 7, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "def alice_dec(k: int) -> int:\n", 86 | " return (k ** a) % p\n", 87 | "\n", 88 | "assert alice_dec(k) == (k ** a % p) == (g ** (b * a) % p)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 8, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "def bob_dec(j: int) -> int:\n", 98 | " return (j ** b) % p\n", 99 | "\n", 100 | "assert bob_dec(j) == (j ** b % p) == (g ** (a * b) % p)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 9, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Alices number: 18\n", 113 | "Bobs number: 18\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "print(f'Alices number: {alice_dec(k)}')\n", 119 | "print(f'Bobs number: {bob_dec(j)}')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## [RSA](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) Cryptosystem" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 10, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# Taken from: https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm\n", 136 | "def egcd(a: int, b: int) -> Tuple[int, int, int]:\n", 137 | " if a == 0:\n", 138 | " return (b, 0, 1)\n", 139 | " else:\n", 140 | " g, y, x = egcd(b % a, a)\n", 141 | " return (g, x - (b // a) * y, y)\n", 142 | "\n", 143 | "def modinv(a: int, m: int) -> int:\n", 144 | " g, x, y = egcd(a, m)\n", 145 | " if g != 1:\n", 146 | " raise Exception('modular inverse does not exist')\n", 147 | " else:\n", 148 | " return x % m\n", 149 | "\n", 150 | "assert modinv(17, 3120) == 2753\n", 151 | "assert egcd(1071, 462) == (21, -3, 7)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 11, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# Private, non-shared information\n", 161 | "p: int = 61\n", 162 | "q: int = 53\n", 163 | "n: int = p * q\n", 164 | "phi_n: int = (p - 1) * (q - 1)\n", 165 | "\n", 166 | "# NOTE: We start with a \"high\" guess for e here so that we can control\n", 167 | "# how \"large\" e should be\n", 168 | "e: int = 12\n", 169 | "while egcd(e, phi_n)[0] != 1:\n", 170 | " e += 1\n", 171 | "\n", 172 | "d: int = modinv(e, phi_n)\n", 173 | "\n", 174 | "secret_key: Tuple[int, int] = (d, n)\n", 175 | "\n", 176 | "# Public, shared information\n", 177 | "public_key: Tuple[int, int] = (e, n)\n", 178 | " \n", 179 | "assert secret_key == (2753, 3233)\n", 180 | "assert public_key == (17, 3233)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 12, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "Message (Plaintext): 42\n", 193 | "Message (Encrypted): 2557\n", 194 | "Message (Decrypted): 42\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "e: int = public_key[0]\n", 200 | "n: int = public_key[1]\n", 201 | "\n", 202 | "plaintext: int = 42\n", 203 | "print(f'Message (Plaintext): {plaintext}')\n", 204 | "\n", 205 | "ciphertext: int = (plaintext ** e) % n\n", 206 | "print(f'Message (Encrypted): {ciphertext}')\n", 207 | "\n", 208 | "d: int = secret_key[0]\n", 209 | "n: int = secret_key[1]\n", 210 | "decrypted: int = (ciphertext ** d) % n\n", 211 | " \n", 212 | "print(f'Message (Decrypted): {decrypted}')" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.6.9" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 4 237 | } 238 | -------------------------------------------------------------------------------- /cryptography/rlwe-lpr10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# [LPR's RLWE](https://eprint.iacr.org/2012/230.pdf) Public Key Cryptosystem\n", 8 | "\n", 9 | "- [On Ideal Lattices and Learning with Errors Over Rings](https://eprint.iacr.org/2012/230.pdf)\n", 10 | "- [Learning With Errors and Ring Learning With Errors](https://medium.com/asecuritysite-when-bob-met-alice/learning-with-errors-and-ring-learning-with-errors-23516a502406)\n", 11 | "- [A Homomorphic Encryption Illustrated Primer](https://blog.n1analytics.com/homomorphic-encryption-illustrated-primer/)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "from numpy.testing import assert_array_equal" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# NOTE: Uncomment to simplfy debugging\n", 31 | "# np.random.seed(1)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Utility Functions" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def polymod(poly, poly_mod, coeff_mod):\n", 48 | " \"\"\"\n", 49 | " Computes the remainder after a polynomial division\n", 50 | " Args:\n", 51 | " poly: Polynomial\n", 52 | " poly_mod: Polynomial modulus\n", 53 | " coeff_mod: Coefficient modulus\n", 54 | " Returns:\n", 55 | " The coefficients of the remainder when `poly` is divided by `poly_mod`\n", 56 | " \"\"\"\n", 57 | " return np.poly1d(np.floor(np.polydiv(poly, poly_mod)[1]) % coeff_mod)\n", 58 | "\n", 59 | "def test():\n", 60 | " coeff_mod = 10\n", 61 | " # x^16 + 1\n", 62 | " poly_mod = np.poly1d([1] + (15 * [0]) + [1])\n", 63 | " # 2x^14\n", 64 | " a = np.poly1d([2] + (14 * [0]))\n", 65 | " # x^4\n", 66 | " b = np.poly1d([1] + (4 * [0]))\n", 67 | " # 2x^14 * x^4 = 2x^18\n", 68 | " result_mul = np.polymul(a, b)\n", 69 | " assert_array_equal(result_mul, np.poly1d([2] + (18 * [0])))\n", 70 | " # 2x^18 % x^16 + 1 = -2x^2\n", 71 | " result_mod = polymod(result_mul, poly_mod, coeff_mod)\n", 72 | " assert_array_equal(result_mod, np.poly1d([8, 0, 0]))\n", 73 | "\n", 74 | "test()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def addition(poly_mod, coeff_mod):\n", 84 | " \"\"\"\n", 85 | " Creates a function which performs polynomial addition and auto-applys polynomial- and coefficient modulus\n", 86 | " Args:\n", 87 | " poly_mod: Polynomial modulus\n", 88 | " coeff_mod: Coefficient modulus\n", 89 | " Returns:\n", 90 | " A function which takes polynomials `a` and `b` and adds them together\n", 91 | " \"\"\"\n", 92 | " return lambda a, b: np.poly1d(polymod(np.polyadd(a, b), poly_mod, coeff_mod))\n", 93 | "\n", 94 | "def test():\n", 95 | " coeff_mod = 8\n", 96 | " # x^4 + 1\n", 97 | " poly_mod = np.poly1d([1] + (3 * [0]) + [1])\n", 98 | " a = np.poly1d([1, 2, 3, 4])\n", 99 | " b = np.poly1d([1, 2, 3, 4])\n", 100 | " add = addition(poly_mod, coeff_mod)\n", 101 | " result = add(a, b)\n", 102 | " assert_array_equal(result, np.poly1d([2, 4, 6, 0]))\n", 103 | "\n", 104 | "test()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def multiplication(poly_mod, coeff_mod):\n", 114 | " \"\"\"\n", 115 | " Creates a function which performs polynomial multiplication and auto-applys polynomial- and coefficient modulus\n", 116 | " Args:\n", 117 | " poly_mod: Polynomial modulus\n", 118 | " coeff_mod: Coefficient modulus\n", 119 | " Returns:\n", 120 | " A function which takes polynomials `a` and `b` and multiplies them\n", 121 | " \"\"\"\n", 122 | " return lambda a, b: np.poly1d(polymod(np.polymul(a, b), poly_mod, coeff_mod))\n", 123 | "\n", 124 | "def test():\n", 125 | " coeff_mod = 8\n", 126 | " # x^4 + 1\n", 127 | " poly_mod = np.poly1d([1] + (3 * [0]) + [1])\n", 128 | " a = np.poly1d([1, 2, 3, 4])\n", 129 | " b = np.poly1d([1, 2, 3, 4])\n", 130 | " mul = multiplication(poly_mod, coeff_mod)\n", 131 | " result = mul(a, b)\n", 132 | " assert_array_equal(result, np.poly1d([4, 0, 4, 6]))\n", 133 | "\n", 134 | "test()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Security Parameters" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "n: 4\n", 154 | "t: 7\n", 155 | "d: 16\n", 156 | "delta: 124\n", 157 | "c_q: 874\n", 158 | "p_q: \n", 159 | " 16\n", 160 | "1 x + 1\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "n = 4\n", 166 | "t = 7\n", 167 | "# Highest coefficient power used\n", 168 | "d = 2 ** n\n", 169 | "# Coefficient modulus\n", 170 | "c_q = 874\n", 171 | "delta = c_q // t\n", 172 | "# Polynomial modulus\n", 173 | "p_q = np.poly1d([1] + ([0] * (d - 1)) + [1])\n", 174 | "\n", 175 | "print(f'n: {n}')\n", 176 | "print(f't: {t}')\n", 177 | "print(f'd: {d}')\n", 178 | "print(f'delta: {delta}')\n", 179 | "print(f'c_q: {c_q}')\n", 180 | "print(f'p_q: \\n{p_q}')" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 7, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "assert c_q == delta * t + (c_q % t)\n", 190 | "assert p_q.order == d" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 8, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "# Creating our polynomial addition and multiplication functions via our security parameters\n", 200 | "add = addition(p_q, c_q)\n", 201 | "mul = multiplication(p_q, c_q)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "## Secret Key" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 9, 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | " 15 14 13 12 10 8 6 5 2\n", 221 | "1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1 x\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "sk = np.poly1d(np.random.randint(0, 2, d))\n", 227 | "\n", 228 | "print(sk)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## Public Key" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 10, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | " 15 14 13 12 11 10 9 8\n", 248 | "29 x + 195 x + 692 x + 407 x + 193 x + 237 x + 316 x + 76 x\n", 249 | " 7 6 5 4 3 2\n", 250 | " + 589 x + 292 x + 640 x + 81 x + 201 x + 494 x + 440 x + 248\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "a = np.poly1d(np.random.randint(0, c_q, d) % c_q)\n", 256 | "\n", 257 | "print(a)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 11, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | " 14 13 12 10 8 7 6 5 4\n", 270 | "1 x + 873 x + 871 x + 1 x + 1 x + 873 x + 2 x + 872 x + 873 x\n", 271 | " 3\n", 272 | " + 869 x + 2 x + 3\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "e = np.poly1d(np.random.normal(0, 2, d).astype(int) % c_q)\n", 278 | "\n", 279 | "print(e)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 12, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | " 15 14 13 12 11 10 9 8\n", 292 | "513 x + 298 x + 458 x + 720 x + 308 x + 823 x + 247 x + 114 x\n", 293 | " 7 6 5 4 3 2\n", 294 | " + 842 x + 179 x + 862 x + 222 x + 805 x + 630 x + 330 x + 464\n" 295 | ] 296 | } 297 | ], 298 | "source": [ 299 | "pk_0 = add(-mul(a, sk), e)\n", 300 | "\n", 301 | "print(pk_0)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 13, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | " 15 14 13 12 11 10 9 8\n", 314 | "29 x + 195 x + 692 x + 407 x + 193 x + 237 x + 316 x + 76 x\n", 315 | " 7 6 5 4 3 2\n", 316 | " + 589 x + 292 x + 640 x + 81 x + 201 x + 494 x + 440 x + 248\n" 317 | ] 318 | } 319 | ], 320 | "source": [ 321 | "pk_1 = a\n", 322 | "\n", 323 | "print(pk_1)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 14, 329 | "metadata": {}, 330 | "outputs": [ 331 | { 332 | "name": "stdout", 333 | "output_type": "stream", 334 | "text": [ 335 | "pk_0:\n", 336 | "\n", 337 | " 15 14 13 12 11 10 9 8\n", 338 | "513 x + 298 x + 458 x + 720 x + 308 x + 823 x + 247 x + 114 x\n", 339 | " 7 6 5 4 3 2\n", 340 | " + 842 x + 179 x + 862 x + 222 x + 805 x + 630 x + 330 x + 464\n", 341 | "\n", 342 | "pk_1:\n", 343 | "\n", 344 | " 15 14 13 12 11 10 9 8\n", 345 | "29 x + 195 x + 692 x + 407 x + 193 x + 237 x + 316 x + 76 x\n", 346 | " 7 6 5 4 3 2\n", 347 | " + 589 x + 292 x + 640 x + 81 x + 201 x + 494 x + 440 x + 248\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "pk = (pk_0, pk_1)\n", 353 | "\n", 354 | "print('pk_0:\\n')\n", 355 | "print(pk[0])\n", 356 | "print()\n", 357 | "print('pk_1:\\n')\n", 358 | "print(pk[1])" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 15, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "# We should be able to extract the error `e` from the public key via the secret key\n", 368 | "# NOTE: Doing so will make it possible to identify the noise when decrypting later on\n", 369 | "def test():\n", 370 | " extr_e = add(mul(pk[1], sk), pk[0]) \n", 371 | " assert_array_equal(extr_e, e)\n", 372 | "\n", 373 | "test()" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "## Encryption" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 16, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | " 2\n", 393 | "2 x + 5\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "# 2x^2 + 5\n", 399 | "m = np.poly1d((np.array([0] * (d - 3) + [2] + [0] + [5])) % t)\n", 400 | "\n", 401 | "print(m)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 17, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | " 15 14 11 9 8 5 4 3\n", 414 | "1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1 x + 1\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "u = np.poly1d(np.random.randint(0, 2, d))\n", 420 | "\n", 421 | "print(u)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 18, 427 | "metadata": {}, 428 | "outputs": [ 429 | { 430 | "name": "stdout", 431 | "output_type": "stream", 432 | "text": [ 433 | " 15 12 11 10 9 7 5\n", 434 | "1 x + 2 x + 1 x + 1 x + 873 x + 873 x + 868 x + 871 x + 3\n" 435 | ] 436 | } 437 | ], 438 | "source": [ 439 | "e_1 = np.poly1d(np.random.normal(0, 2, d).astype(int) % c_q)\n", 440 | "\n", 441 | "print(e_1)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 19, 447 | "metadata": {}, 448 | "outputs": [ 449 | { 450 | "name": "stdout", 451 | "output_type": "stream", 452 | "text": [ 453 | " 15 14 12 10 9 8 7 5\n", 454 | "4 x + 873 x + 873 x + 1 x + 869 x + 2 x + 873 x + 1 x + 3 x + 872\n" 455 | ] 456 | } 457 | ], 458 | "source": [ 459 | "e_2 = np.poly1d(np.random.normal(0, 2, d).astype(int) % c_q)\n", 460 | "\n", 461 | "print(e_2)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 20, 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | " 15 14 13 12 11 10 8 7\n", 474 | "32 x + 851 x + 797 x + 778 x + 417 x + 378 x + 426 x + 606 x\n", 475 | " 6 5 4 3 2\n", 476 | " + 798 x + 132 x + 809 x + 751 x + 166 x + 372 x + 319\n" 477 | ] 478 | } 479 | ], 480 | "source": [ 481 | "c_0 = add(add(mul(pk[0], u), e_1), mul(delta, m))\n", 482 | "\n", 483 | "print(c_0)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 21, 489 | "metadata": {}, 490 | "outputs": [ 491 | { 492 | "name": "stdout", 493 | "output_type": "stream", 494 | "text": [ 495 | " 15 14 13 12 11 10 9 8\n", 496 | "772 x + 544 x + 564 x + 348 x + 120 x + 316 x + 513 x + 848 x\n", 497 | " 7 6 5 4 3 2\n", 498 | " + 341 x + 556 x + 480 x + 640 x + 746 x + 776 x + 392 x + 211\n" 499 | ] 500 | } 501 | ], 502 | "source": [ 503 | "c_1 = add(mul(pk[1], u), e_2)\n", 504 | "\n", 505 | "print(c_1)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": {}, 511 | "source": [ 512 | "## Decryption" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 22, 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "name": "stdout", 522 | "output_type": "stream", 523 | "text": [ 524 | " 2\n", 525 | "2 x + 5\n" 526 | ] 527 | } 528 | ], 529 | "source": [ 530 | "m_prime = np.poly1d(np.round(add(mul(c_1, sk), c_0) * t / c_q) % t)\n", 531 | "\n", 532 | "print(m_prime)\n", 533 | "\n", 534 | "assert_array_equal(m_prime, m)" 535 | ] 536 | } 537 | ], 538 | "metadata": { 539 | "kernelspec": { 540 | "display_name": "Python 3", 541 | "language": "python", 542 | "name": "python3" 543 | }, 544 | "language_info": { 545 | "codemirror_mode": { 546 | "name": "ipython", 547 | "version": 3 548 | }, 549 | "file_extension": ".py", 550 | "mimetype": "text/x-python", 551 | "name": "python", 552 | "nbconvert_exporter": "python", 553 | "pygments_lexer": "ipython3", 554 | "version": "3.6.9" 555 | } 556 | }, 557 | "nbformat": 4, 558 | "nbformat_minor": 4 559 | } 560 | -------------------------------------------------------------------------------- /jupyter-lab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PORT=${1:-8888} 4 | 5 | docker run -it -p "$PORT":8888 --rm --name jupyter-lab \ 6 | -v "$PWD":/home/jovyan/work \ 7 | --ipc=host \ 8 | pmuens/jupyter-lab:latest jupyter lab \ 9 | --ip=0.0.0.0 \ 10 | --no-browser \ 11 | --allow-root \ 12 | --NotebookApp.token=\ 13 | --notebook-dir=/home/jovyan/work 14 | -------------------------------------------------------------------------------- /jupyter-notebook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | PORT=${1:-8888} 4 | 5 | docker run -it -p "$PORT":8888 --rm --name jupyter-notebook \ 6 | -v "$PWD":/home/jovyan/work \ 7 | --ipc=host \ 8 | pmuens/jupyter-lab:latest jupyter notebook \ 9 | --ip=0.0.0.0 \ 10 | --no-browser \ 11 | --allow-root \ 12 | --NotebookApp.token=\ 13 | --notebook-dir=/home/jovyan/work 14 | -------------------------------------------------------------------------------- /swift/swift-intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Swift Intro" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Python-like syntax" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "Hello World\r\n", 27 | "3\r\n", 28 | "10\r\n", 29 | "35\r\n", 30 | "52\r\n", 31 | "92\r\n", 32 | "88\r\n", 33 | "👋 I'm Mitsy and I have 9 remaining lives 😸\r\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "let greeting = \"Hello World\"\n", 39 | "print(greeting)\n", 40 | "// Hello World\n", 41 | "\n", 42 | "let num1 = 1\n", 43 | "let num2 = 2\n", 44 | "print(num1 + num2)\n", 45 | "// 3\n", 46 | "\n", 47 | "let scores = [10, 35, 52, 92, 88]\n", 48 | "for score in scores {\n", 49 | " print(score)\n", 50 | "}\n", 51 | "// 10\n", 52 | "// 35\n", 53 | "// 52\n", 54 | "// 92\n", 55 | "// 88\n", 56 | "\n", 57 | "class Cat {\n", 58 | " var name: String\n", 59 | " var livesRemaining: Int = 9\n", 60 | " \n", 61 | " init(name: String) {\n", 62 | " self.name = name\n", 63 | " }\n", 64 | " \n", 65 | " func describe() -> String {\n", 66 | " return \"👋 I'm \\(self.name) and I have \\(self.livesRemaining) remaining lives 😸\"\n", 67 | " }\n", 68 | "}\n", 69 | "let mitsy = Cat(name: \"Mitsy\")\n", 70 | "print(mitsy.describe())\n", 71 | "// 👋 I'm Mitsy and I have 9 remaining lives 😸" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Static typing" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "ename": "", 88 | "evalue": "", 89 | "output_type": "error", 90 | "traceback": [ 91 | "error: :17:28: error: cannot convert value of type '[String]' to expected argument type '[Int]'\nlet resultString = sum(xs: stringNumbers)\n ^\n\n:17:28: note: arguments to generic parameter 'Element' ('String' and 'Int') are expected to be equal\nlet resultString = sum(xs: stringNumbers)\n ^\n\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "func sum(xs: [Int]) -> Int {\n", 97 | " var result: Int = 0\n", 98 | " for x: Int in xs {\n", 99 | " result = result + x\n", 100 | " }\n", 101 | " return result\n", 102 | "}\n", 103 | "\n", 104 | "// Using correct types\n", 105 | "let intNumbers: [Int] = [1, 2, 3, 4, 5]\n", 106 | "let resultInt = sum(xs: intNumbers)\n", 107 | "print(resultInt)\n", 108 | "// 15\n", 109 | "\n", 110 | "// Using incorrect types\n", 111 | "let stringNumbers: [String] = [\"one\", \"two\", \"three\"]\n", 112 | "let resultString = sum(xs: stringNumbers)\n", 113 | "print(resultString)\n", 114 | "// error: cannot convert value of type '[String]' to expected argument type '[Int]'" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Hackable" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Protocols" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Offers methods to aid with matrix-matrix multiplications.\r\n", 141 | "Offers methods to aid with matrix-vector multiplications.\r\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "// One needs to implement `help` when using the `Debugging` Protocol\n", 147 | "protocol Debugging {\n", 148 | " func help() -> String\n", 149 | "}\n", 150 | "\n", 151 | "// Implementing `Debugging` for MatrixMultiply\n", 152 | "class MatrixMultiply: Debugging {\n", 153 | " func help() -> String {\n", 154 | " return \"Offers methods to aid with matrix-matrix multiplications.\"\n", 155 | " }\n", 156 | " \n", 157 | " func multiply() {\n", 158 | " // ...\n", 159 | " }\n", 160 | "}\n", 161 | "var matMult = MatrixMultiply()\n", 162 | "print(matMult.help())\n", 163 | "// Offers methods to aid with matrix-matrix multiplications.\n", 164 | "\n", 165 | "// Implementing `Debugging` for VectorMultiply\n", 166 | "class VectorMultiply: Debugging {\n", 167 | " func help() -> String {\n", 168 | " return \"Offers methods to aid with matrix-vector multiplications.\"\n", 169 | " }\n", 170 | "}\n", 171 | "var vecMult = VectorMultiply()\n", 172 | "print(vecMult.help())\n", 173 | "// Offers methods to aid with matrix-vector multiplications." 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### Extensions" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 4, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "🎱\r\n", 193 | "💯\r\n", 194 | "42\r\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "// Makes it possible to emojify an existing type\n", 200 | "protocol Emojifier {\n", 201 | " func emojify() -> String\n", 202 | "}\n", 203 | "\n", 204 | "// Here we're extending Swifts core `Int` type\n", 205 | "extension Int: Emojifier {\n", 206 | " func emojify() -> String {\n", 207 | " if self == 8 {\n", 208 | " return \"🎱\"\n", 209 | " } else if self == 100 {\n", 210 | " return \"💯\"\n", 211 | " }\n", 212 | " return String(self)\n", 213 | " }\n", 214 | "}\n", 215 | "\n", 216 | "\n", 217 | "print(8.emojify())\n", 218 | "// 🎱\n", 219 | "print(100.emojify())\n", 220 | "// 💯\n", 221 | "print(42.emojify())\n", 222 | "// 42" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "## Value semantics" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 5, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "0x7f4f80dbe538\r\n", 242 | "0x7f4f80dbe538\r\n", 243 | "[1, 2, 3, 4, 5]\r\n", 244 | "[1, 2, 3, 4, 5, 6]\r\n", 245 | "0x7f4f80dbe538\r\n", 246 | "0x1c49570\r\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "// As seen on: https://marcosantadev.com/copy-write-swift-value-types/\n", 252 | "\n", 253 | "import Foundation\n", 254 | "\n", 255 | "// Prints the memory address of the given object\n", 256 | "func address(of object: UnsafeRawPointer) -> String {\n", 257 | " let addr = Int(bitPattern: object)\n", 258 | " return String(format: \"%p\", addr)\n", 259 | "}\n", 260 | "\n", 261 | "var list1 = [1, 2, 3, 4, 5]\n", 262 | "print(address(of: list1))\n", 263 | "// 0x7f4f80dbe538\n", 264 | "\n", 265 | "var list2 = list1\n", 266 | "print(address(of: list2))\n", 267 | "// 0x7f4f80dbe538 <-- Both lists share the same address\n", 268 | "\n", 269 | "list2.append(6) // <-- Mutating `list2`\n", 270 | "\n", 271 | "print(list1)\n", 272 | "// [1, 2, 3, 4, 5]\n", 273 | "\n", 274 | "print(list2)\n", 275 | "// [1, 2, 3, 4, 5, 6]\n", 276 | "\n", 277 | "print(address(of: list1))\n", 278 | "// 0x7f4f80dbe538\n", 279 | "print(address(of: list2))\n", 280 | "// 0x1c49570 <-- `list2` has a different address" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Swift", 287 | "language": "swift", 288 | "name": "swift" 289 | }, 290 | "language_info": { 291 | "file_extension": ".swift", 292 | "mimetype": "text/x-swift", 293 | "name": "swift", 294 | "version": "" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 4 299 | } 300 | -------------------------------------------------------------------------------- /swift/swift-packages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Swift Packages" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "scrolled": true 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Installing packages:\n", 22 | "\t.package(url: \"https://github.com/saeta/Just\", from: \"0.7.2\")\n", 23 | "\t\tJust\n", 24 | "With SwiftPM flags: []\n", 25 | "Working in: /tmp/tmpbbdxrwew/swift-install\n", 26 | "Fetching https://github.com/saeta/Just\n", 27 | "Cloning https://github.com/saeta/Just\n", 28 | "Resolving https://github.com/saeta/Just at 0.7.3\n", 29 | "[1/4] Compiling Just Just.swift\n", 30 | "[2/5] Merging module Just\n", 31 | "[3/6] Wrapping AST for Just for debugging\n", 32 | "[4/6] Compiling jupyterInstalledPackages jupyterInstalledPackages.swift\n", 33 | "[5/7] Merging module jupyterInstalledPackages\n", 34 | "[6/7] Wrapping AST for jupyterInstalledPackages for debugging\n", 35 | "[7/7] Linking libjupyterInstalledPackages.so\n", 36 | "Initializing Swift...\n", 37 | "Installation complete!\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "%install '.package(url: \"https://github.com/saeta/Just\", from: \"0.7.2\")' Just" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "GET http://example.com 200\n" 54 | ] 55 | }, 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "import Just\n", 63 | "\n", 64 | "Just.get(\"http://example.com\")" 65 | ] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": "Swift", 71 | "language": "swift", 72 | "name": "swift" 73 | }, 74 | "language_info": { 75 | "file_extension": ".swift", 76 | "mimetype": "text/x-swift", 77 | "name": "swift", 78 | "version": "" 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 4 83 | } 84 | -------------------------------------------------------------------------------- /x-from-scratch/decision-trees-from-scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Decision Trees from scratch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import csv\n", 17 | "from pathlib import Path\n", 18 | "from copy import deepcopy\n", 19 | "from typing import List, Tuple, Dict, NamedTuple, Any\n", 20 | "from collections import Counter, defaultdict" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# Ensure that we have a `data` directory we use to store downloaded data\n", 30 | "!mkdir -p data\n", 31 | "data_dir: Path = Path('data')" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "--2020-02-23 10:52:54-- https://raw.githubusercontent.com/husnainfareed/Simple-Naive-Bayes-Weather-Prediction/c75b2fa747956ee9b5f9da7b2fc2865be04c618c/new_dataset.csv\n", 44 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.0.133, 151.101.64.133, ...\n", 45 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.\n", 46 | "HTTP request sent, awaiting response... 200 OK\n", 47 | "Length: 373 [text/plain]\n", 48 | "Saving to: ‘data/golf.csv’\n", 49 | "\n", 50 | "golf.csv 100%[===================>] 373 --.-KB/s in 0s \n", 51 | "\n", 52 | "2020-02-23 10:52:55 (11.5 MB/s) - ‘data/golf.csv’ saved [373/373]\n", 53 | "\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "# Downloading the \"Golf\" data set\n", 59 | "!wget -O \"data/golf.csv\" -nc -P data https://raw.githubusercontent.com/husnainfareed/Simple-Naive-Bayes-Weather-Prediction/c75b2fa747956ee9b5f9da7b2fc2865be04c618c/new_dataset.csv" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "Outlook,Temp,Humidity,Windy,Play\n", 72 | "Rainy,Hot,High,f,no\n", 73 | "Rainy,Hot,High,t,no\n", 74 | "Overcast,Hot,High,f,yes\n", 75 | "Sunny,Mild,High,f,yes\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "!head -n 5 data/golf.csv" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# Create the Python path pointing to the `golf.csv` file\n", 90 | "golf_data_path: Path = data_dir / 'golf.csv'" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 6, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# Evey entry in our data set is represented as a `DataPoint`\n", 100 | "class DataPoint(NamedTuple):\n", 101 | " outlook: str\n", 102 | " temp: str\n", 103 | " humidity: str\n", 104 | " windy: bool\n", 105 | " play: bool" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 7, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Open the file, iterate over every row, create a `DataPoint` and append it to a list\n", 115 | "data_points: List[DataPoint] = []\n", 116 | "\n", 117 | "with open(golf_data_path) as csv_file:\n", 118 | " reader = csv.reader(csv_file, delimiter=',')\n", 119 | " next(reader, None)\n", 120 | " for row in reader:\n", 121 | " outlook: str = row[0].lower()\n", 122 | " temp: str = row[1].lower()\n", 123 | " humidty: str = row[2].lower()\n", 124 | " windy: bool = True if row[3].lower() == 't' else False\n", 125 | " play: bool = True if row[4].lower() == 'yes' else False\n", 126 | " data_point: DataPoint = DataPoint(outlook, temp, humidty, windy, play)\n", 127 | " data_points.append(data_point)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 8, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "[DataPoint(outlook='rainy', temp='hot', humidity='high', windy=False, play=False),\n", 139 | " DataPoint(outlook='rainy', temp='hot', humidity='high', windy=True, play=False),\n", 140 | " DataPoint(outlook='overcast', temp='hot', humidity='high', windy=False, play=True),\n", 141 | " DataPoint(outlook='sunny', temp='mild', humidity='high', windy=False, play=True),\n", 142 | " DataPoint(outlook='sunny', temp='cool', humidity='normal', windy=False, play=True)]" 143 | ] 144 | }, 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "data_points[:5]" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 9, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# Calculate the Gini impurity for a list of values\n", 161 | "# See: https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity\n", 162 | "def gini(data: List[Any]) -> float:\n", 163 | " counter: Counter = Counter(data)\n", 164 | " classes: List[Any] = list(counter.keys())\n", 165 | " num_items: int = len(data)\n", 166 | " result: float = 0\n", 167 | " item: Any\n", 168 | " for item in classes:\n", 169 | " p_i: float = counter[item] / num_items\n", 170 | " result += p_i * (1 - p_i)\n", 171 | " return result\n", 172 | "\n", 173 | "assert gini(['one', 'one']) == 0\n", 174 | "assert gini(['one', 'two']) == 0.5\n", 175 | "assert gini(['one', 'two', 'one', 'two']) == 0.5\n", 176 | "assert 0.8 < gini(['one', 'two', 'three', 'four', 'five']) < 0.81" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 10, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "# Helper function to filter down a list of data points by a `feature` and its `value`\n", 186 | "def filter_by_feature(data_points: List[DataPoint], *args) -> List[DataPoint]:\n", 187 | " result: List[DataPoint] = deepcopy(data_points)\n", 188 | " for arg in args:\n", 189 | " feature: str = arg[0]\n", 190 | " value: Any = arg[1]\n", 191 | " result = [data_point for data_point in result if getattr(data_point, feature) == value]\n", 192 | " return result\n", 193 | "\n", 194 | "assert len(filter_by_feature(data_points, ('outlook', 'sunny'))) == 5\n", 195 | "assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'))) == 3\n", 196 | "assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'), ('humidity', 'high'))) == 2" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 11, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# Helper function to extract the values the `feature` in question can assume\n", 206 | "def feature_values(data_points: List[DataPoint], feature: str) -> List[Any]:\n", 207 | " return list(set([getattr(dp, feature) for dp in data_points]))\n", 208 | "\n", 209 | "assert feature_values(data_points, 'outlook').sort() == ['sunny', 'overcast', 'rainy'].sort()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 12, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "# Calculate the weighted sum of the Gini impurities for the `feature` in question\n", 219 | "def gini_for_feature(data_points: List[DataPoint], feature: str, label: str = 'play') -> float:\n", 220 | " total: int = len(data_points)\n", 221 | " # Distinct values the `feature` in question can assume\n", 222 | " dist_values: List[Any] = feature_values(data_points, feature)\n", 223 | " # Calculate all the Gini impurities for every possible value a `feature` can assume\n", 224 | " ginis: Dict[str, float] = defaultdict(float)\n", 225 | " ratios: Dict[str, float] = defaultdict(float)\n", 226 | " for value in dist_values:\n", 227 | " filtered: List[DataPoint] = filter_by_feature(data_points, (feature, value))\n", 228 | " labels: List[Any] = [getattr(dp, label) for dp in filtered]\n", 229 | " ginis[value] = gini(labels)\n", 230 | " # We use the ratio when we compute the weighted sum later on\n", 231 | " ratios[value] = len(labels) / total\n", 232 | " # Calculate the weighted sum of the `feature` in question\n", 233 | " weighted_sum: float = sum([ratios[key] * value for key, value in ginis.items()])\n", 234 | " return weighted_sum\n", 235 | "\n", 236 | "assert 0.34 < gini_for_feature(data_points, 'outlook') < 0.35\n", 237 | "assert 0.44 < gini_for_feature(data_points, 'temp') < 0.45\n", 238 | "assert 0.36 < gini_for_feature(data_points, 'humidity') < 0.37\n", 239 | "assert 0.42 < gini_for_feature(data_points, 'windy') < 0.43" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 13, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "# NOTE: We can't use type hinting here due to cyclic dependencies\n", 249 | "\n", 250 | "# A `Node` has a `value` and optional out `Edge`s\n", 251 | "class Node:\n", 252 | " def __init__(self, value):\n", 253 | " self._value = value\n", 254 | " self._edges = []\n", 255 | "\n", 256 | " def __repr__(self):\n", 257 | " if len(self._edges):\n", 258 | " return f'{self._value} --> {self._edges}'\n", 259 | " else:\n", 260 | " return f'{self._value}'\n", 261 | " \n", 262 | " @property\n", 263 | " def value(self):\n", 264 | " return self._value\n", 265 | "\n", 266 | " def add_edge(self, edge):\n", 267 | " self._edges.append(edge)\n", 268 | " \n", 269 | " def find_edge(self, value):\n", 270 | " return next(edge for edge in self._edges if edge.value == value)\n", 271 | "\n", 272 | "# An `Edge` has a value and points to a `Node`\n", 273 | "class Edge:\n", 274 | " def __init__(self, value):\n", 275 | " self._value = value\n", 276 | " self._node = None\n", 277 | "\n", 278 | " def __repr__(self):\n", 279 | " return f'{self._value} --> {self._node}'\n", 280 | " \n", 281 | " @property\n", 282 | " def value(self):\n", 283 | " return self._value\n", 284 | " \n", 285 | " @property\n", 286 | " def node(self):\n", 287 | " return self._node\n", 288 | " \n", 289 | " @node.setter\n", 290 | " def node(self, node):\n", 291 | " self._node = node" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 14, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "# Recursively build a tree via the CART algorithm based on our list of data points\n", 301 | "def build_tree(data_points: List[DataPoint], features: List[str], label: str = 'play') -> Node:\n", 302 | " # Ensure that the `features` list doesn't include the `label`\n", 303 | " features.remove(label) if label in features else None\n", 304 | "\n", 305 | " # Compute the weighted Gini impurity for each `feature` given that we'd split the tree at the `feature` in question\n", 306 | " weighted_sums: Dict[str, float] = defaultdict(float)\n", 307 | " for feature in features:\n", 308 | " weighted_sums[feature] = gini_for_feature(data_points, feature)\n", 309 | "\n", 310 | " # If all the weighted Gini impurities are 0.0 we create a final `Node` (leaf) with the given `label`\n", 311 | " weighted_sum_vals: List[float] = list(weighted_sums.values())\n", 312 | " if (float(0) in weighted_sum_vals and len(set(weighted_sum_vals)) == 1):\n", 313 | " label = getattr(data_points[0], 'play')\n", 314 | " return Node(label) \n", 315 | " \n", 316 | " # The `Node` with the most minimal weighted Gini impurity is the one we should use for splitting\n", 317 | " min_feature = min(weighted_sums, key=weighted_sums.get)\n", 318 | " node: Node = Node(min_feature)\n", 319 | " \n", 320 | " # Remove the `feature` we've processed from the list of `features` which still need to be processed\n", 321 | " reduced_features: List[str] = deepcopy(features)\n", 322 | " reduced_features.remove(min_feature)\n", 323 | "\n", 324 | " # Next up we build the `Edge`s which are the values our `min_feature` can assume\n", 325 | " for value in feature_values(data_points, min_feature):\n", 326 | " # Create a new `Edge` which contains a potential `value` of our `min_feature`\n", 327 | " edge: Edge = Edge(value)\n", 328 | " # Add the `Edge` to our `Node`\n", 329 | " node.add_edge(edge)\n", 330 | " # Filter down the data points we'll use next since we've just processed the set which includes our `min_feature`\n", 331 | " reduced_data_points: List[DataPoint] = filter_by_feature(data_points, (min_feature, value))\n", 332 | " # This `Edge` points to the new `Node` (subtree) we'll create through recursion\n", 333 | " edge.node = build_tree(reduced_data_points, reduced_features)\n", 334 | "\n", 335 | " # Return the `Node` (our `min_feature`)\n", 336 | " return node" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 15, 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "text/plain": [ 347 | "outlook --> [overcast --> True, sunny --> windy --> [False --> True, True --> False], rainy --> humidity --> [normal --> True, high --> False]]" 348 | ] 349 | }, 350 | "execution_count": 15, 351 | "metadata": {}, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "# Create a new tree based on the loaded data points\n", 357 | "features: List[str] = list(DataPoint._fields)\n", 358 | "\n", 359 | "tree: Node = build_tree(data_points, features)\n", 360 | "tree" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 16, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# Traverse the tree based on the query trying to find a leaf with the prediction\n", 370 | "def predict(tree: Node, query: List[Tuple[str, str]]) -> Any:\n", 371 | " node: Node = deepcopy(tree)\n", 372 | " for item in query:\n", 373 | " feature: str = item[0]\n", 374 | " value: Any = item[1]\n", 375 | " if node.value != feature:\n", 376 | " continue\n", 377 | " edge: Edge = node.find_edge(value)\n", 378 | " if not edge:\n", 379 | " raise Exception(f'Edge with value \"{value}\" not found on Node \"{node}\"')\n", 380 | " node: Node = edge.node\n", 381 | " return node\n", 382 | "\n", 383 | "assert predict(tree, [('outlook', 'overcast')]) != True\n", 384 | "assert predict(tree, [('outlook', 'sunny'), ('windy', False)]) != True\n", 385 | "assert predict(tree, [('outlook', 'sunny'), ('windy', True)]) != False\n", 386 | "assert predict(tree, [('outlook', 'rainy'), ('humidity', 'high')]) != False\n", 387 | "assert predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')]) != True\n", 388 | "assert predict(tree, [('outlook', 'rainy'), ('windy', True), ('humidity', 'normal')]) != True" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 17, 394 | "metadata": {}, 395 | "outputs": [ 396 | { 397 | "data": { 398 | "text/plain": [ 399 | "True" 400 | ] 401 | }, 402 | "execution_count": 17, 403 | "metadata": {}, 404 | "output_type": "execute_result" 405 | } 406 | ], 407 | "source": [ 408 | "predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')])" 409 | ] 410 | } 411 | ], 412 | "metadata": { 413 | "kernelspec": { 414 | "display_name": "Python 3", 415 | "language": "python", 416 | "name": "python3" 417 | }, 418 | "language_info": { 419 | "codemirror_mode": { 420 | "name": "ipython", 421 | "version": 3 422 | }, 423 | "file_extension": ".py", 424 | "mimetype": "text/x-python", 425 | "name": "python", 426 | "nbconvert_exporter": "python", 427 | "pygments_lexer": "ipython3", 428 | "version": "3.6.9" 429 | } 430 | }, 431 | "nbformat": 4, 432 | "nbformat_minor": 4 433 | } 434 | -------------------------------------------------------------------------------- /x-from-scratch/k-means-clustering-from-scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# k-means Clustering from scratch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (3.1.3)\n", 20 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (0.10.0)\n", 21 | "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.8.1)\n", 22 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.1.0)\n", 23 | "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.18.1)\n", 24 | "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.4.6)\n", 25 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib) (1.14.0)\n", 26 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib) (45.2.0)\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "!pip3 install matplotlib" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import csv\n", 41 | "from random import sample\n", 42 | "from statistics import mean\n", 43 | "from math import sqrt, inf\n", 44 | "from pathlib import Path\n", 45 | "from collections import defaultdict\n", 46 | "from typing import List, Dict, Tuple\n", 47 | "from matplotlib import pyplot as plt" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Ensure that we have a `data` directory we use to store downloaded data\n", 57 | "!mkdir -p data\n", 58 | "data_dir: Path = Path('data')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "File ‘data/iris.data’ already there; not retrieving.\n", 71 | "\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "# Downloading the Iris data set\n", 77 | "!wget -nc -P data https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "6.9,3.1,5.1,2.3,Iris-virginica\n", 90 | "5.8,2.7,5.1,1.9,Iris-virginica\n", 91 | "6.8,3.2,5.9,2.3,Iris-virginica\n", 92 | "6.7,3.3,5.7,2.5,Iris-virginica\n", 93 | "6.7,3.0,5.2,2.3,Iris-virginica\n", 94 | "6.3,2.5,5.0,1.9,Iris-virginica\n", 95 | "6.5,3.0,5.2,2.0,Iris-virginica\n", 96 | "6.2,3.4,5.4,2.3,Iris-virginica\n", 97 | "5.9,3.0,5.1,1.8,Iris-virginica\n", 98 | "\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "# The structure of the Iris data set is as follows:\n", 104 | "# Sepal Length, Sepal Width, Petal Length, Petal Width, Class\n", 105 | "!tail data/iris.data" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Create the Python path pointing to the `iris.data` file\n", 115 | "data_path: Path = data_dir / 'iris.data'" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# The list in which we store the \"petal length\" and \"sepal width\" as vectors (a vector is a list of floats)\n", 125 | "data_points: List[List[float]] = []\n", 126 | "\n", 127 | "# Indexes according to the data set description\n", 128 | "petal_length_idx: int = 2\n", 129 | "sepal_width_idx: int = 1\n", 130 | "\n", 131 | "# Read the `iris.data` file and parse it line-by-line\n", 132 | "with open(data_path) as csv_file:\n", 133 | " reader = csv.reader(csv_file, delimiter=',')\n", 134 | " for row in reader:\n", 135 | " # Check if the given row is a valid iris data point\n", 136 | " if len(row) == 5:\n", 137 | " label: str = row[-1]\n", 138 | " x1: float = float(row[petal_length_idx])\n", 139 | " x2: float = float(row[sepal_width_idx])\n", 140 | " data_points.append([x1, x2])" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "150" 152 | ] 153 | }, 154 | "execution_count": 8, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "len(data_points)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 9, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "[[1.4, 3.5], [1.4, 3.0], [1.3, 3.2], [1.5, 3.1], [1.4, 3.6]]" 172 | ] 173 | }, 174 | "execution_count": 9, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "data_points[:5]" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 10, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEHCAYAAACjh0HiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dfbQddX3v8fcnh1RPgZJSIsUDITYit14DBE4FGq+iXqpWFkYeigq9Uruktip4vdXVKFe9aIuaSnVpfYigF4SLVZ4udlktV0QBAT2JSJQIikUxtRJJgaCphuR7/9hzkpN99sPsPTN7Zs98Xmtl5ZzZ8/CbvWf298zM9/v7KSIwM7PmWlB2A8zMrFwOBGZmDedAYGbWcA4EZmYN50BgZtZwDgRmZg23V9EbkDQBzACbIuKkttfOBtYAm5JJH4qIi3ut74ADDoilS5cW0FIzs/pat27dzyJicafXCg8EwHnARuA3urz+DxHxurQrW7p0KTMzM7k0zMysKST9sNtrhd4aknQw8GKg51/5ZmZWnqKfEbwfeDOws8c8p0q6S9JVkg7pNIOkcyTNSJrZvHlzIQ01M2uqwgKBpJOAByNiXY/ZPgcsjYgjgBuASzvNFBFrI2I6IqYXL+54i8vMzIZU5BXBSuBkSfcDnwaeJ+nyuTNExEMR8cvk14uBYwpsj5mZdVBYIIiI1RFxcEQsBV4G3BgRZ82dR9JBc349mdZDZTMzG6FRZA3tQdIFwExEXA+cK+lk4HFgC3D2qNtjZtZ0GrduqKenp6PK6aPXfXMTa754D//68DaevGiSN73gcFatmCq7WWbWcJLWRcR0p9dGfkVQZ9d9cxOrr9nAtu07ANj08DZWX7MBwMHAzCrLXUzkaM0X79kVBGZt276DNV+8p6QWmZn150CQo399eNtA083MqsCBIEdPXjQ50HQzsypwIMjRm15wOJMLJ/aYNrlwgje94PCSWmRm1p8fFudo9oGws4bMbJw4EORs1Yopf/Gb2VjxrSEzs4ZzIDAzazgHAjOzhnMgMDNrOAcCM7OGcyAwM2s4BwIzs4ZzIDAzazgHAjOzhnNl8ZA8AI2Z1YUDwRA8AI2Z1YlvDQ3BA9CYWZ04EAzBA9CYWZ04EAzBA9CYWZ04EAzBA9CYWZ34YfEQPACNmdWJA8GQPACNmdWFA0HOXF9gZuPGgSBHri8ws3Hkh8U5cn2BmY0jB4Icub7AzMaRA0GOXF9gZuPIgSBHri8ws3Hkh8U5cn2BmY0jB4Kcub7AzMZN4YFA0gQwA2yKiJPaXnsCcBlwDPAQcEZE3F90m3pxHYCZNc0onhGcB2zs8tqfAv8eEU8F/g54zwja09VsHcCmh7cR7K4DuO6bm8pslplZoQoNBJIOBl4MXNxllpcAlyY/XwU8X5KKbFMvrgMwsyYq+org/cCbgZ1dXp8CHgCIiMeBR4Dfap9J0jmSZiTNbN68uai2ug7AzBqpsEAg6STgwYhYl3VdEbE2IqYjYnrx4sU5tK4z1wGYWRMVeUWwEjhZ0v3Ap4HnSbq8bZ5NwCEAkvYC9qP10LgUrgMwsyYqLBBExOqIODgilgIvA26MiLPaZrseeGXy82nJPFFUm/pZtWKKC09ZztSiSQRMLZrkwlOWO2vIzGpt5HUEki4AZiLieuAS4FOSvg9soRUwSuU6ADNrmpEEgoi4Cbgp+fltc6b/B3D6KNqQRRG1Ba5XMLOqcGVxH0WMMeBxC8ysStzpXB9F1Ba4XsHMqsSBoI8iagtcr2BmVeJA0EcRtQWuVzCzKnEg6KOI2gLXK5hZlfhhcR9FjDHgcQvMrEpUYv3WUKanp2NmZqbsZpiZjRVJ6yJiutNrviLIWaf6AJj/13+naZ2uCFxvYFZdWc7PKp3bviLIUXt9AMDCBQLB9h273+eFE4KA7Tt3T5tcODGvO4tO6+s0n5mNXpbzs4xzu9cVgR8W56hTfcD2nbFHEIBWUJgbBKBzHYHrDcyqK8v5WbVz24EgR1nrANqXd72BWXVlOT+rdm47EOQoax1A+/KuNzCrriznZ9XObQeCHHWqD1i4QK1nAnOnTaj17GCOTnUErjcwq64s52fVzm1nDeWoW31A2mntD4lcb2BWXVnOz6qd284aMjNrANcRjNCZH7+NW+/bsuv3lcv254pXH19ii8zMevMzghy1BwGAW+/bwpkfv62kFpmZ9edAkKP2INBvuplZFTgQmJk1nAOBmVnDORDkaOWy/QeabmZWBQ4EObri1cfP+9J31pCZVZ3TR3PmL30zGze+IjAza7jGXBHkPQhElQaVMLPh+DxuaUQgaB8EYtPD21h9zQaAoT70vNdnZqPn83i3RtwaynsQiKoNKmFmg/N5vFsjAkHeg0BUbVAJMxucz+PdGhEI8h4EomqDSpjZ4Hwe79aIQJD3IBBVG1TCzAbn83i3RjwsznsQiKoNKmFmg/N5vFuqgWkkTQGHMidwRMRXC2xXVx6YxsxscJkGppH0HuAM4G5g9hF7AD0DgaQnJvM8IdnOVRHx9rZ5zgbWAJuSSR+KiIv7tWnUzr9uA1fe8QA7IpiQePmxhwDMm/auVcs75iWD/+ows+rqe0Ug6R7giIj45UArlgTsHRGPSVoI3AKcFxG3z5nnbGA6Il6Xdr2jviI4/7oNXH77j1LNu3LZ/qz/0SN7pKQtXCAQbN+x+32eXDjBhacsdzAws5HpdUWQ5mHxD4CFg240Wh5Lfl2Y/BuvAZJp/dWf1q33bZmXl7x9Z+wRBKC5ucpmVk1dbw1J+iCtL+5fAHdK+hKw66ogIs7tt3JJE8A64KnA30fEHR1mO1XSs4F7gf8eEfO+eSWdA5wDsGTJkn6bzdWOFM9QhtHEXGUzq6Zezwhm77+sA65vey3Vt2NE7ACOkrQIuFbSMyLi23Nm+RxwZUT8UtKfAZcCz+uwnrXAWmjdGkqz7bxMSIUEgybmKptZNXW9NRQRl0bEpcCi2Z/nTPvNQTYSEQ8DXwZe2Db9oTnPHi4Gjhms+cWbfTCcxspl+8/LS164QCyc0B7TmpqrbGbVlOYZwSs7TDu730KSFidXAkiaBE4Evts2z0Fzfj0Z2JiiPSP1rlXLOeu4JUyo9WU+IXHWcUs6Trvi1cdz4SnLmVo0iYCpRZOsOf1I1px25B7T/KDYzKqka9aQpJcDrwCeBdw856V9gZ0R8fyeK5aOoHWrZ4JWwPlMRFwg6QJgJiKul3QhrQDwOLAF+POI+G7XleI6AjOzYQxbR/A14CfAAcD75kzfCtzVb6MRcRewosP0t835eTWwut+6qqhTbcG7Vi3nzI/fxq33bdk138pl+/OUxfvMm3f60P1T1Ra4v3TrZFyPi7TtHtf9G1epKourpApXBN1qCw7c99f46dZfpVrHAsHOOW99p9qC9v7Su81nzTKux0Xado/r/lXdUHUEkrZKerTbv+KaW33dagvSBgHYMwhA59oC95dunYzrcZG23eO6f+Os662hiNgXQNI7ad0i+hQg4EzgoG7LNcGoagvcX7p1Mq7HRdp2j+v+jbM0WUMnR8SHI2JrRDwaER8BXlJ0w6psNlsob+21Be4v3ToZ1+MibbvHdf/GWZpA8HNJZ0qakLRA0pnAz4tuWJV1qy04cN9fS72OBW2xpFNtgftLt07G9bhI2+5x3b9xliYQvAL4I+Cnyb/Tk2mN1a224I63nsjKZfvvMe/KZft3nPeiPzqqb23BqhVT8+oS/MDMxvW4SNvucd2/ceasITOzBhiqjkDSmyPivXM6n9tDmk7n6ixrnrPzpG0U6nacdavfSaNu70WeehWUzXb34D+/27TnOW96eBurr9kAkOrAyrq8WRp1O87a63d2ROz6vV8wqNt7kbdezwgekKT2DufmdDzXWFnznJ0nbaNQt+OsW/1OmjFD6vZe5K3XFcHFwO9IWkeru4lbgdsiYutIWlZhWfOcnSdto1C346xb/U6aup66vRd569UN9TRwMPDXtAakORf4vqRvSfrwiNpXSVnznJ0nbaNQt+OsW/1Omrqeur0XeeuZPhoRv4iIm4APAH8H/D2wN23jCjRN1jxn50nbKNTtOOtWv5NmzJC6vRd565U19Arg94GjaF0RfAO4A3hWRPzbaJpXTbMPl4bNQMi6vFkadTvOZh8ID5M1VLf3Im+9xiPYCtwDfBT4akTcO8qGdeM6AjOzwQ07HsEi4EhaVwXvkHQ4rc7nbqP10PjG3FtaAZ1yjWd+uGXo8QS6rdN/iVgWWfr1h2x/GZc5pkCZ+11nqSuLJR1Iq3uJNwBPiYiJPosUosgrgk79oLePGzBrYoHYMeeFbv2lu291y1uWfv0XLhAItu/of+zmve2sx32m/Z4QBGxPcc7W1bDjERwh6TWSLpP0fVrPCJ4FfBA4tpimlqtTrnGnIADsEQSge06y85ctb1n69d++M/YIAt2WLWLbWY/7TPu9I/YIAnm0p0563Rr638AtwD8B50fE/CG5aiZrTnGn5Z2/bHnL2q//IOvMe9tZjvsy97vuetURHB0R50bElU0IApA9p7jT8s5ftrxl7dd/kHXmve0sx32Z+113abqhboxOucbt4wbMmmh7oVtOsvOXLW9Z+vVfuECt++V9li1i21mP+0z7PaHW85Ec21MnvW4NNU63XOMsWUPOX7a8pT2mus2XZtmitp3luC9zv+vO4xGYmTXAsOMRfI4O4xDMioiTc2hb7Zx40U1878HdI3ke9qS9ee1zD/NfIja0LPn4ReT8590e8F/vZetVWfycXgtGxFcKaVEfVb4iaA8C3TQtf9mGlyUfv4ic/7zb0ym/P2utg3U2VB1BRHyl17/imju+0gQBcP6ypZclH7+InP+829Mpvz9rrYMNru/DYkmHARcCTweeODs9In6nwHbVnvOXLY0s+fhF5PwX0Z60fM4UJ0366CeBjwCPA88FLgMuL7JRTeD8ZUsjSz5+ETn/RbQnLZ8zxUkTCCYj4ku0nif8MCLeAby42GaNp8OetHeq+Zy/bGllyccvIuc/7/Z0yu/PWutgg0sTCH4paQHwPUmvk/RSYJ+C2zWWbnjjCfOCwWFP2pv3n3EUU4smETC1aNIPvSy1VSumuPCU5UMdP2mXHWQbebdnzWlHsub0I/ecdvqRrDntSJ8zI9S3jkDS7wEbaXVL/U5gP+C9EXF78c2br8pZQ2ZmVTXseAQARMQ3kpUsAM5NO3i9pCcCXwWekGznqoh4e9s8T6D1zOEY4CHgjIi4P83685Alx/qzMz/i1vu27Jpn5bL9ueLVx2fajtVT2jEuuo20lfcYGYP01Z923kH2p935120Yetluy2d5L4o4N6v+HZDmimCa1gPjfZNJjwCvioh1fZYTsHdEPCZpIa2eTM+beyUh6S+AIyLiNZJeBrw0Is7otd68rgiy5FiLzpV2nYKBxyNotkHGuDjruCXzvgDzHiNjkL76Tz1miqvXbeo7nkH7dnvtT7vzr9vA5bfP79MyzbK9ll8A7Jzz+6jGTOikKt8BQ9URzPEJ4C8iYmlELAVeSysw9BQtjyW/Lkz+tR8tLwEuTX6+Cnh+EkAKlyXHulvonHuFMOh2rJ4GGePiyjseyLR8mjEyBumr/8o7Hkg1nkGnIACd9yftPGmW7TXfzrbfRzVmQifj8B2QJhDsiIibZ3+JiFtopZL2JWlC0p3Ag8ANEXFH2yxTwAPJeh+ndbXxWx3Wc46kGUkzmzdvTrPpvoro2zzLdqyeBvmcd3S4Os/7+MvankGkWb7bPGm3PUgbRzFmwiDrq9J3QJpA8BVJH5N0gqTnSPowcJOkoyUd3WvBiNgREUcBBwPPlPSMYRoZEWsjYjoiphcvXjzMKuYpom/zLNuxehrkc57ocDGc9/GXtT2DSLN8t3nSbnuQNo5izIRB1lel74A0geBI4GnA24F3AL8LrADeB/xtmo1ExMPAl4EXtr20CTgEQNJetDKSHkqzzqyy5Fh3O/RWLtt/6O1YPQ0yxsXLjz0k0/JpxsgYpK/+lx97SKrxDNq3O6vT/qSdJ82yveZr/2Ib1ZgJnYzDd0CarKHnDrNiSYuB7RHxsKRJ4ETgPW2zXQ+8ErgNOA24MUbUL3bWvs3TZg15PIJmG2SMi04PR/MeI2PQvvo7rTPL/rSbnWfYrKFuy2d5L/I+N8fhOyBN1tCBwN8AT46IF0l6OnB8RFzSZ7kjaD0InqAVoD8TERdIugCYiYjrkxTTT9G6wtgCvCwiftBrva4jMDMbXKY6AlqD2H8SeGvy+73APwA9A0FE3EXrC759+tvm/PwfwOkp2mBmZgVJEwgOiIjPSFoNreweSTv6LVR3VS8QsXTK+hy7bTdtcVXdB3ip2vlVtfbkLc2toZuAU2mlfx4t6TjgPRHRc+CaolTh1lBVCkQsm7I+x27bPXrJfh1rUdqLq+o+wEvVzq+qtWdYWQvK3kjroe4ySbfS6hLi9Tm2b+yMQ4GI9VfW59htu52CAMwvmqr7AC9VO7+q1p4ipMkaWp8MW3k4rczJeyJie+Etq7BxKBCx/sr6HAddf3vRVN0HeKna+VW19hSh6xWBpN+T9Nuwq+r3GOCvgfdJmp8w3yDjUCBi/ZX1OQ66/vaiqboP8FK186tq7SlCr1tDHwN+BSDp2cC7ad0WegRYW3zTqmscCkSsv7I+x27b7VSQCPOLpuo+wEvVzq+qtacIvW4NTUTE7E3LM4C1EXE1cHXSf1BjjUOBiPVX1ufYa7tpsoYGKQorY/+yqtr5VbX2FKFr1pCkbwNHJemi3wXOiYivzr4WEUP1G5RVFbKGzMzGzbAFZVfS6nDuZ8A24OZkZU+ldXvIzNoMMujLKLbTaVqWQW0g+0AyaYxiG1BufUCVahN61hEkNQMHAf8cET9Ppj0N2Cci1o+miXvyFYFVVcf8/gJy+bPUEbQP2LJrettgN93amHUgmTRGsQ0otz6gjG0PXUcQEbdHxLWzQSCZdm9ZQcCsyjrm9xeQy5+ljqBTEID5g910a2PWgWTSGMU2oNz6gKrVJqQpKDOzFAbJK8+Sgz6q/PVO28k6kEwao9gGlFsfULXaBAcCs5wMkleeJQd9VPnrnbaTdSCZNEaxDSi3PqBqtQkOBGY56ZjfX0Auf5Y6gm4nfPvYMt3amHUgmTRGsQ0otz6garUJaXofNbMUBh30ZRTb6TQtS9ZQ1oFk0hjFNqDc+oCq1Sb07X20apw1ZGY2uKwD05hZB3mPHTAu/e2f+fHbUg3TWqU8eevNVwRmQ0ib6z6q2oIsBslpbw8Cs9qDQV368K+TrOMRmFmbtLnuo6otyGKQnPZuYya0T69anrz15kBgNoS0ue6jqi3Iooic9qrlyVtvDgRmQ0ib6z6q2oIsishpr1qevPXmQGA2hLS57qOqLchikJz2bmMmtE+vWp689eZAYDaEd61azlnHLdl1BTAhdewUbdWKKS48ZTlTiyYRMLVokjWnH8ma047cY1qZD1E7tbFbe6549fHzvvQ7ZQ0Nsk4rn7OGzMwawHUE1nijymkvYjtNzMdv4j53M4r3woHAaq89p33Tw9tYfc0GgFxPqCK2M6q2V0kT97mbUb0XfkZgtTeqnPYittPEfPwm7nM3o3ovHAis9kaV0+58/Hw0cZ+7GdV74UBgtTeqnHbn4+ejifvczajeCwcCq71R5bQXsZ0m5uM3cZ+7GdV74YfFVnuj6vu9iO1Urd/6UWjiPnczqvfCdQRmZg1QSh2BpEOAy4ADgQDWRsQH2uY5Afi/wL8kk66JiAuKapNV07jmjJc5zkDe71m39aUdc6FM43r8VElhVwSSDgIOioj1kvYF1gGrIuLuOfOcAPxlRJyUdr2+IqiXce23vsxxBvJ+z7qt7+gl+3XsdrpTVxplGdfjpwyljEcQET+JiPXJz1uBjYA/GdvDuOaMlznOQN7vWbf1dRt7oNtYDGUY1+OnakaSNSRpKbACuKPDy8dL+pakf5L0n7ssf46kGUkzmzdvLrClNmrjmjNe5jgDeb9ngy7XbSyGMozr8VM1hQcCSfsAVwNviIhH215eDxwaEUcCHwSu67SOiFgbEdMRMb148eJiG2wjNa4542WOM5D3ezboct3GYijDuB4/VVNoIJC0kFYQuCIirml/PSIejYjHkp8/DyyUdECRbbJqGdec8TLHGcj7Peu2vm5jD3Qbi6EM43r8VE2RWUMCLgE2RsRFXeb5beCnERGSnkkrMD1UVJusesY1Z7xbuztNq3q9Qq/1VT1raFyPn6opMmvoWcDNwAZgZzL5LcASgIj4qKTXAX8OPA5sA94YEV/rtV5nDZmZDa6UOoKIuAXoeTMxIj4EfKioNlj9jar/f/BfnVZf7mLCxtao+v9/01Xfgmilh+a1HbMqcadzNrZG1f//9h2xKwjktR2zKnEgsLE1yv7/896OWZU4ENjYGmX//3lvx6xKHAhsbI2q//+FE2r1I5TjdsyqxA+LbWyNsv//vLdjViUej8DMrAFKqSOoE/d3Xl1V+2xGVdfg48/y5EDQRxG56paPqn02o6pr8PFnefPD4j7c33l1Ve2zGVVdg48/y5sDQR/u77y6qvbZjLKuwcef5cmBoA/3d15dVftsRlnX4OPP8uRA0If7O6+uqn02o6pr8PFnefPD4j7c33l1Ve2zGWVdg48/y5PrCMzMGqBXHYFvDZmZNZxvDZnlaFTFX3XbjpXLgcAsJ6Mq/qrbdqx8vjVklpNRFX/VbTtWPgcCs5yMqvirbtux8jkQmOVkVMVfdduOlc+BwCwnoyr+qtt2rHx+WGyWk1EVf9VtO1Y+F5SZmTWAC8rMzKwrBwIzs4ZzIDAzazgHAjOzhnMgMDNrOAcCM7OGcyAwM2s4BwIzs4YrrLJY0iHAZcCBQABrI+IDbfMI+ADwh8AvgLMjYn1RbbLmcr/6Zt0V2cXE48D/iIj1kvYF1km6ISLunjPPi4DDkn/HAh9J/jfLjfvVN+utsFtDEfGT2b/uI2IrsBFoP+teAlwWLbcDiyQdVFSbrJncr75ZbyN5RiBpKbACuKPtpSnggTm//5j5wQJJ50iakTSzefPmopppNeV+9c16KzwQSNoHuBp4Q0Q8Osw6ImJtRExHxPTixYvzbaDVnvvVN+ut0EAgaSGtIHBFRFzTYZZNwCFzfj84mWaWG/erb9ZbYYEgyQi6BNgYERd1me164L+p5TjgkYj4SVFtsmZatWKKC09ZztSiSQRMLZrkwlOW+0GxWaLIrKGVwB8DGyTdmUx7C7AEICI+CnyeVuro92mlj/5Jge2xBlu1Yspf/GZdFBYIIuIWQH3mCeC1RbXBzMz6c2WxmVnDORCYmTWcA4GZWcM5EJiZNZwDgZlZw6mVuDM+JG0GfphhFQcAP8upOWWr075AvfanTvsC9dqfpu7LoRHRsWuGsQsEWUmaiYjpstuRhzrtC9Rrf+q0L1Cv/fG+zOdbQ2ZmDedAYGbWcE0MBGvLbkCO6rQvUK/9qdO+QL32x/vSpnHPCMzMbE9NvCIwM7M5HAjMzBquMYFA0ickPSjp22W3JStJh0j6sqS7JX1H0nllt2lYkp4o6euSvpXsy/8qu01ZSZqQ9E1J/1h2W7KSdL+kDZLulDRTdnuykrRI0lWSvitpo6Tjy27TMCQdnnwms/8elfSGodfXlGcEkp4NPAZcFhHPKLs9WUg6CDgoItZL2hdYB6yKiLtLbtrAkgGM9o6Ix5IR7W4BzouI20tu2tAkvRGYBn4jIk4quz1ZSLofmI6IWhRgSboUuDkiLpb0a8CvR8TDZbcrC0kTtEZ2PDYihiq2bcwVQUR8FdhSdjvyEBE/iYj1yc9bgY3AWI66Ei2PJb8uTP6N7V8nkg4GXgxcXHZbbE+S9gOeTWvkRCLiV+MeBBLPB+4bNghAgwJBXUlaCqwA7ii3JcNLbqXcCTwI3BARY7svwPuBNwM7y25ITgL4Z0nrJJ1TdmMyegqwGfhkcuvuYkl7l92oHLwMuDLLChwIxpikfYCrgTdExKNlt2dYEbEjIo4CDgaeKWksb91JOgl4MCLWld2WHD0rIo4GXgS8NrnFOq72Ao4GPhIRK4CfA39VbpOySW5vnQx8Nst6HAjGVHI//Wrgioi4puz25CG5TP8y8MKy2zKklcDJyX31TwPPk3R5uU3KJiI2Jf8/CFwLPLPcFmXyY+DHc644r6IVGMbZi4D1EfHTLCtxIBhDyQPWS4CNEXFR2e3JQtJiSYuSnyeBE4Hvltuq4UTE6og4OCKW0rpcvzEiziq5WUOTtHeSjEByC+UPgLHNuouIfwMekHR4Mun5wNglWLR5ORlvC0GBg9dXjaQrgROAAyT9GHh7RFxSbquGthL4Y2BDcm8d4C0R8fkS2zSsg4BLk8yHBcBnImLs0y5r4kDg2tbfHewF/J+I+EK5Tcrs9cAVyS2VHwB/UnJ7hpYE5xOBP8u8rqakj5qZWWe+NWRm1nAOBGZmDedAYGbWcA4EZmYN50BgZtZwDgRWO5J2JD0yflvSZyX9ep/535JyvfdLOiDt9LxIWiXp6XN+v0lSLQZft2pwILA62hYRRyW9zP4KeE2f+VMFghKtAp7edy6zITkQWN3dDDwVQNJZydgHd0r6WNLZ3buByWTaFcl81yWdrH1n2I7WkqrcTyTb+6aklyTTz5Z0jaQvSPqepPfOWeZPJd2bLPNxSR+S9Pu0+pJZk7RxWTL76cl890r6LxneH7PmVBZb80jai1ZfLF+Q9LvAGcDKiNgu6cPAmRHxV5Jel3R6N+tVEbEl6fLiG5KujoiHBtz8W2l1MfGqpAuNr0v6f8lrR9HqMfaXwD2SPgjsAP4nrb5vtgI3At+KiK9Juh74x4i4KtkvgL0i4pmS/hB4O/BfB2yf2S4OBFZHk3O63riZVr9M5wDH0PpiB5ik1e11J+dKemny8yHAYcCggeAPaHVA95fJ708EliQ/fykiHgGQdDdwKHAA8JWI2JJM/yzwtB7rn+1ocB2wdMC2me3BgcDqaFvbX/izHfVdGhGrey0o6QRaf10fHxG/kHQTrS/xQQk4NSLuaVv/sbSuBGbtYLjzcHYdwy5vtoufEVhTfAk4TdKTACTtL+nQ5LXtSbfeAPsB/54Egf8EHDfk9r4IvD4JQEgkNfsAAADASURBVEha0Wf+bwDPkfSbyS2tU+e8thXYd8h2mPXlQGCNkIznfD6t0bbuAm6g1fMpwFrgruRh8ReAvSRtBN4NpB07+S5JP07+XQS8k9awm3dJ+k7ye6/2bQL+Bvg6cCtwP/BI8vKngTclD52XdV6D2fDc+6hZRUjaJyIeS64IrgU+ERHXlt0uqz9fEZhVxzuSh9zfBv4FuK7k9lhD+IrAzKzhfEVgZtZwDgRmZg3nQGBm1nAOBGZmDedAYGbWcP8fAThEasSNtkIAAAAASUVORK5CYII=\n", 191 | "text/plain": [ 192 | "
" 193 | ] 194 | }, 195 | "metadata": { 196 | "needs_background": "light" 197 | }, 198 | "output_type": "display_data" 199 | } 200 | ], 201 | "source": [ 202 | "# Plot the `data_points`\n", 203 | "plt.scatter([item[0] for item in data_points], [item[1] for item in data_points])\n", 204 | "plt.xlabel('Petal Length')\n", 205 | "plt.ylabel('Sepal Width');" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 11, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "# Function to compute the Euclidean distance\n", 215 | "# See: https://en.wikipedia.org/wiki/Euclidean_distance\n", 216 | "def distance(a: List[float], b: List[float]) -> float:\n", 217 | " assert len(a) == len(b)\n", 218 | " return sqrt(sum((a_i - b_i) ** 2 for a_i, b_i in zip(a, b)))\n", 219 | "\n", 220 | "assert distance([1, 2], [1, 2]) == 0\n", 221 | "assert distance([1, 2, 3, 4], [5, 6, 7, 8]) == 8" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 12, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# Function which computes the element-wise average of a list of vectors (a vector is a list of floats)\n", 231 | "def vector_mean(xs: List[List[float]]) -> List[float]:\n", 232 | " # Check that all arrays have the same number of dimensions\n", 233 | " for prev, curr in zip(xs, xs[1:]):\n", 234 | " assert len(prev) == len(curr)\n", 235 | " num_items: int = len(xs)\n", 236 | " # Figure out how many dimensions we have to support\n", 237 | " num_dims: int = len(xs[0])\n", 238 | " # Dynamically create a list which contains lists for each dimension\n", 239 | " # to simplify the mean calculation later on\n", 240 | " dim_values: List[List[float]] = [[] for _ in range(num_dims)]\n", 241 | " for x in xs:\n", 242 | " for dim, val in enumerate(x):\n", 243 | " dim_values[dim].append(val)\n", 244 | " # Calculate the mean across the dimensions\n", 245 | " return [mean(item) for item in dim_values]\n", 246 | "\n", 247 | "assert vector_mean([[1], [2], [3]]) == [2]\n", 248 | "assert vector_mean([[1, 2], [3, 4], [5, 6]]) == [3, 4]\n", 249 | "assert vector_mean([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) == [4, 5, 6]" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 13, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "class KMeans:\n", 259 | " def __init__(self, k: int) -> None:\n", 260 | " self._k: int = k\n", 261 | " self._centroids: Dict[int, List[float]] = defaultdict(list)\n", 262 | " self._clusters: Dict[int, List[List[float]]] = defaultdict(list)\n", 263 | "\n", 264 | " def train(self, data_points: List[List[float]]) -> None:\n", 265 | " # Pick `k` random samples from the `data_points` and use them as the initial centroids\n", 266 | " centroids: List[List[float]] = sample(data_points, self._k)\n", 267 | " # Initialize the `_centroids` lookup dict with such centroids\n", 268 | " for i, centroid in enumerate(centroids):\n", 269 | " self._centroids[i] = centroid\n", 270 | " # Star the training process\n", 271 | " while True:\n", 272 | " # Starting a new round, removing all previous `cluster` associations (if any)\n", 273 | " self._clusters.clear() \n", 274 | " # Iterate over all items in the `data_points` and compute their distances to all `centroids`\n", 275 | " item: List[float]\n", 276 | " for item in data_points:\n", 277 | " smallest_distance: float = inf\n", 278 | " closest_centroid_idx: int = None\n", 279 | " # Identify the closest `centroid`\n", 280 | " centroid_idx: int\n", 281 | " centroid: List[float]\n", 282 | " for centroid_idx, centroid in self._centroids.items():\n", 283 | " current_distance: float = distance(item, centroid)\n", 284 | " if current_distance < smallest_distance:\n", 285 | " smallest_distance: float = current_distance\n", 286 | " closest_centroid_idx: int = centroid_idx\n", 287 | " # Append the current `item` to the `Cluster` whith the nearest `centroid`\n", 288 | " self._clusters[closest_centroid_idx].append(item)\n", 289 | " # The `vector_mean` of all items in the `cluster` should be the `cluster`s new centroid\n", 290 | " old_centroid: List[float]\n", 291 | " centroids_to_update: List[Tuple[int, List[float]]] = []\n", 292 | " for old_centroid_idx, old_centroid in self._centroids.items():\n", 293 | " items: List[List[float]] = self._clusters[old_centroid_idx]\n", 294 | " new_centroid: List[float] = vector_mean(items)\n", 295 | " if new_centroid != old_centroid:\n", 296 | " centroids_to_update.append((old_centroid_idx, new_centroid))\n", 297 | " # Update centroids if they changed\n", 298 | " if len(centroids_to_update):\n", 299 | " idx: int\n", 300 | " centroid: List[float]\n", 301 | " for idx, centroid in centroids_to_update:\n", 302 | " self._centroids[idx] = centroid\n", 303 | " # If nothing changed, we're done\n", 304 | " else:\n", 305 | " break\n", 306 | " \n", 307 | " @property\n", 308 | " def centroids(self) -> Dict[int, List[float]]:\n", 309 | " return self._centroids\n", 310 | " \n", 311 | " @property\n", 312 | " def clusters(self) -> Dict[int, List[List[float]]]:\n", 313 | " return self._clusters" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 14, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "The clusters centroids are: [[1.4941176470588236, 3.4], [4.925252525252525, 2.875757575757576]]\n", 326 | "The number of elements in each cluster are: [51, 99]\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "# Create a new KMeans instance and train it\n", 332 | "km: KMeans = KMeans(2)\n", 333 | "km.train(data_points)\n", 334 | "\n", 335 | "print(f'The clusters centroids are: {list(km.centroids.values())}')\n", 336 | "print(f'The number of elements in each cluster are: {[len(items) for items in km.clusters.values()]}')" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 15, 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEHCAYAAACjh0HiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzdd5hcZfXA8e+502dLeggJJaFKDSX03kR6UYqAFKWJ/AAVUEBBUIqoFEGaIE2aNIEIQhQiTZAE6SCChJqQns3uzk675/fHnezu1J2dndnZcj7Pkyc7d+6975nZnfvOvfe87xFVxRhjzPDl1DsAY4wx9WUdgTHGDHPWERhjzDBnHYExxgxz1hEYY8wwZx2BMcYMc/5aNyAiPmAW8Lmq7pvz3LHAr4DPM4uuVdWbS+1v7NixOnny5BpEaowxQ9fs2bMXquq4Qs/VvCMATgfeBZqLPH+fqp5a7s4mT57MrFmzqhKYMcYMFyLycbHnanppSERWAfYBSn7LN8YYUz+1vkdwFXA24JZY5+si8oaIPCAiqxZaQUROFJFZIjJrwYIFNQnUGGOGq5p1BCKyLzBfVWeXWO0xYLKqbgzMAG4vtJKq3qSq01R12rhxBS9xGWOMqVAtzwi2A/YXkTnAvcCuIvLH7iuo6iJVjWce3gxsXsN4jDHGFFCzjkBVz1HVVVR1MnA48LSqHtV9HRFZudvD/fFuKhtjjOlH/ZE1lEVELgJmqeqjwGkisj+QAhYDx/Z3PNWWSqZY9MUSRoxrJhwN1TscY4zpkQy2aainTZumAzV99OFrHuf28+8jnUzjqrLXd3blu1cci8/vq3doxphhTkRmq+q0Qs/1+xnBUDXzvhe45Zy7ibfHO5f99Q9P4w/4OPk3x9YvMGOM6YFNMVElf/zFg1mdAEC8PcH0G2eQTCTrFJUxxvTMOoIqWfT54oLL3bTS3hLr52iMMaZ81hFUyTrT1ii4vGFElKbRjf0cjTHGlM86gio5/rKjCEVDiHQtC0WDnPSbo3Ece5uNMQOXHaGqZO3N1uCq53/OlntvxqgJI1lv63W44MGz2P3IHesdmjHGlGTpo8YYMwyUSh+1MwJjjBnmrCMwxphhzjoCY4wZ5qwjMMaYYc46AmOMGeasIzDGmGHOOgJjjBnmrCMwxphhzqahrtAbz77DnRfez2fvf8Gam07hmJ8dytqbFZ5vyBhjBjLrCCrwz8dmcfE3ryTengBg0ReLee3pN7l8xvmsv826dY7OGGN6xy4N9ZKqct3pt3Z2At4yr/bAjWfeUcfIjDGmMtYR9FIynmT+pwsLPvfha3P6NxhjjKkC6wh6yR/0E4oECz43cvyIfo7GGGP6zjqCXnIchwNP24tQNLszCEVDHH7OQXWKyhhjKmc3iytwzIWHEW+LM/2mv+HzOagqh/3oQPY5Yfd6h2aMMb1m9Qj6INbWweK5Sxi3yhiC4cKXi4wxZiAoVY/Azgj6INIQZtJaK3c+XvLlUv58zRO89cJ7rPaVSRx8xj6suu6kOkZojDE9s46gSubNmc8p035ER1sHyXiKt55/jxl3PsvF089h6s4b1Ds8Y4wpym4WV8kt59xF29I2kvEUAG7aJd4e54oTb2CwXX4zxgwv1hFUyewZb+C6+Qf8+Z8sZPni1jpEZIwx5bGOoEoamiMFl4uQl2pqjDEDiXUEVXLg/+WPLQgE/Wx7wJaEIqE6RWWMMT2zjqBKDjxtb3b95vYEQgEaRkQIRYKsv806fP+mk+odmjHGlGTjCKps4eeL+OitT5kweZyljhpjBgwbR9ALcz/6kr/d+SytS9vYet/N2WSXDRGRsrcfO2kMYyeNqWGExhhTXTXvCETEB8wCPlfVfXOeCwF3AJsDi4DDVHVOrWMq5pn7XuDX374ON5UmlUzz+O//xuZfncr59/8Qx7GraMaYoak/jm6nA+8Wee47wBJVXQu4EvhlP8RTUKw1xm++cz2JWIJUMg1AR1uc2U+9zouPvFKvsIwxpuZq2hGIyCrAPsDNRVY5ALg98/MDwG7Sm+swVfT6zHfw+fPfjo62OE/f/VwdIjLGmP5R6zOCq4CzAbfI85OATwFUNQUsA/IusIvIiSIyS0RmLViwoCaB+gK+os8FQjYOwBgzdNWsIxCRfYH5qjq7r/tS1ZtUdZqqThs3blwVoss3decNCt4UDjeE2PO4XWrSpjHGDAS1PCPYDthfROYA9wK7isgfc9b5HFgVQET8wAi8m8b9LhgKcOGfzybSGCbSGCYYCRIMB9j35K+y2W4b1SMkY4zpF/0yjkBEdgbOLJA19D1gI1U9WUQOBw5W1UNL7avW4wjal8d48ZFXaG+JsflXN86aZtoYYwarATWOQEQuAmap6qPALcCdIvIBsBg4vL/jyRVtirD7UTtmLUsmkrzy19dYPHcpG2y7DlM2Wr3P7cz/dCGznnydcDTI1vtNI9pUeK4iY4ypNRtZ3IPP3v+CH+x0PvH2BOl0GhS22ndzzr37dHy+4jeYS7n7kgf54y8exOdzEEdQV7nw4bPZbPeNqxy9McZ4Sp0R2CipHvzs679m6fwW2pfHiLcniMcSvPyXV/nrLU9XtL/3/vVf7r7kIZIdSTra4sSWd9DRFudnB/+KWFtHlaM3xpieWUdQwtz/fcm8/32ZV1gm3h5n+o0zKtrnU7fPJNGRzFsuIsx68vWK9mmMMX1hHUEJyUQKcQqPb0vG8w/m5Uh0JNECBWy0D/s0xpi+sI6ghFXWWZnGkQ15y4ORILsesX1F+9zpkG0IN4TzlqeSKaZ9dWpF+zTGmL6wjqAEx3E4567TCTeECIQCAIQbw6y67kQOOn2fivY5bc9N2Hq/zQk3eMVqfH6HUCTIKVcdR/OYpqrFbowx5bKsoTIsmruEp257hvmfLmSTnTdku4O2xB+oPPNWVXntmbd44c//ItIYZvdv7cTq661SxYiNMSbbgBpHMBiNWXkU3zzn4LLWXTxvCXPe/owJk8cxcc0JALQta+P92f9j5Lhmpmy0OiLCBtuui+NzCEdDrLruxKL7U1X+++r/iC3vYN0t1yIctbKXxgwUml4IqffBNwnx9258kWoakm8AaQhMRSRQmyDLYB1BlbiuyzWn3sKTtz5DMBwgGU+y4fZfYcMd1uPeSx/GH/STTrlMWmsC+538VW46+87OMQSNoxq4ePo5eQPVPnv/C87d+2KWzG/BcQQ37XLqNd9hz2Nt7iNj6knVRVt+DrH7QUKgCTS4GTLyd4jT2PP2iVfRJacA8cwSH4y8GgltV9O4i7FLQ1Xy8G//wi3n3kO8Pd65zBdwUBfcdNfkq47PQV2X3Ld9xLhm7v3sxs5LTq7rctSUU1j42aKsdUPRIFc99wvW2nRKTV+PMaY4t+0uWH45EOu2NAih3XBGXV1yW3Vb0QU7gLZlPyERZOzfEF9tJta0AWX94KGrH8/qBADSSTerEwCvUyjU9ybjSWbPeKPz8VvPv0fr0ra8dZMdSR69/smqxW2MqUD7rWR3AgAJiP8dddtLb9vxFAUPAupCx/RqRdgr1hFUSevStp5XKkFdZfni1s7Hyxe3FpwW23WVpV8u61Nbxpg+cpcXf05zO4jc55cChcYMxVF3cV+iqph1BFWy2e4b4xQafFZmvbVUMs3GO63f+XiD7dYlGU/lrRduCLHN/gXP7owx/SW0HQUPn77x4IwuvW1wa6DAPGUSRYL1uUdgHUGVHH/pkURHRAkEvWv8js8hGAkyavwIghEvG0DEu8Y/YY3xhLpl/4QbQhz4f3sxftWxnctGjhvBEece1DneALxtJ641gd2O3KGfXpUxphBp/AFIE7CieqEPiCDNPy94Jp+1bWB9CH8N6D7jcASCW0Jwq9oE3AO7WVxFC79YzMNXP87bL77HqutO4hs/2Jexk0bz2A0zeGn6bMZOGsVBp+3D2puvwVO3zeSZe58n2hRh35P2YMu9Nyv4BzR7xus8et2TLF/cyo6HbMPXvr2rpZAaMwBoej7afgckZoF/ChI9DgmsU9626kL8SbT9fiCNRA6E8H549blqo9TNYusIqqx1WSv//ttbrLPFmqy0Wm3u/htjTG/ZgLJ+cuImP+SjNz7pfNw0upG75lxHpNGKzhhjBi67R1Al5+59cVYnAF7mz9FrnlqniIwxpjzWEVTJK0++VnD50gUtLFtcItXMGGPqzDqCailxq+Xz97/ovziMMaaXrCOoEp+/+Fu51mY2HYQxZuCyjqBKvnXBoQWXb7LrhgSDwYLPGWPMQGAdQZUced7X+c6lR+DPDCgTR9jjmJ341d8uqHNkxhhTmo0jMMaYYWDYzz76xYfzOHevi/la6HD2azqKq797E7HWHiaG6sGTtz3DEaudzJ6Bwzh23dN44c//qlK0xpj+oBrHbbkE98tNceetj7v4aDT1Qb3Dqoshf0bQsng5x657Gq1L2lDXe62BUIB1t1iTK5/9eUUxTL9pBjf84PasaadDkSDn3ft9ttnPJoQzZjBwF58AiZfoKg4jIA3I2CcQ30r1DK0mhvUZwV//8AyJ9kRnJwDe3P8f/Psj3p/9Ya/3p6rc9tN782oPxGMJ/nDu3X2O1xhTe5r6KKcTAFCv0lj78PscD/mO4INX/0c8lshbLiJ8/PZnvd5foiORVTeguy8+nNfr/Rlj6iD1ARSsEZyA5Fv9Hk69DfmOYK1NpxCK5KdvqiqrrTep1/sLhoM0jmoo+NyEKUPvdNKYIcm/Bmh+vQ8IQmD9AsuHtiHfEXzt27sSjASzpngOhPxM2Wh11pm2Zq/3JyIc87ND86aCDkWDfOeSI/ocrzGm9sS/JgSnATlTuksAiR5Zl5jqach3BM1jmvjtixezya4bdhaL2e2oHbnsyZ/0WECimP2+uycnXXEMo1ceBQITpoznrFtPZdsDtqhy9MaYWpFR10HkGyARwIHAFsjoexHfhHqH1u+GfNZQd6pa8cG/P/dpjOlfw+Fz3Od6BCIyCVi9+/qq+mx1wus/lf6ibzr7Th67/kkSHUnGrTqGH978Xea8/Sl3XPAn2lvaaR7bxMlXHMsaG63G7Rf8if+88gET1liJb/30G6yyzsrcceH9vDrjDUaMb+bQMw9g58O2HfJ/dMYMJsP989jjGYGI/BI4DHgHSGcWq6ru38N2YeBZvItwfuABVb0gZ51jgV8Bn2cWXauqN5fab3+PLP7pAZfx0mOzy1rXH/STTqZZ8Z4GwwEcn0OiI4mbdgGvPvGhZx3At84/pGYxG2NMrr6eERwIrKuq8R7XzBYHdlXVVhEJAM+LyBOq+lLOevep6oCs3tKyeHnZnQBAKpGdhZDoSOat09EW597L/szBZ+xDQ3O0zzEaY0xflXOz+H9AoYTbktSzIuE+kPk3qG5IvD7z7Zrs1x/08fHbn9Zk38YY01tFzwhE5Bq8A3c78JqI/J1uw/BU9bSedi4iPmA2sBbwO1V9ucBqXxeRHYH3ge+rat4RUkROBE4EWG211XpqtmpWW2+Vmuw3lUgxdtLomuzbGGN6q9QZwSy8g/ijwM+BFzOPZ2ee65GqplV1E2AVYEsR2TBnlceAyaq6MTADuL3Ifm5S1WmqOm3cuHHlNF0Vq6+3CmMmjip7fceX/XYGQv68gjWBkJ+NdliP8av13+swxphSinYEqnq7qt4OjFzxc7dl5R8dvX0tBZ4BvpazfFG3ew83A5v3Lvzau/7Vyxm36piuBQJ7fnsX1tpkctZ6m+y6Id+57EgijWHCDSGC4QB7HrsLZ99+KiPGNhFuCBEI+dnia5vy0/t/2L8vwhhjSigna+hVVd0sZ9m/VXXTHrYbByRVdamIRICngF+q6vRu66ysqnMzPx8E/EhVty6133rVI/jy4/l88eGXbLDtugTD3pQVS+YvZc5bn7L2ZlNoHNkIQCKeZOFnixi10ggijREA0uk08z9eSOOoBppGNfZ77MYYU1HWkIh8EzgCmCIij3Z7qglYXEa7KwO3Z+4TOMCfVHW6iFwEzFLVR4HTRGR/IJXZ57HlvKB6WGn18ay0+vjOx/9++k2u+/5tLPhkIausO5HTrjuB2PIYlx31WxbPW0q0Kczxlx/FVntvzlUn38hbz79H44gGjjr/G+x25A78/Y/P8cy9zxNpirDPiXuwxZ6bFGz3tWfe4tHrnqR1aRs7fH1r9jx2586OyAxPqgmIPYp2TAeJItHDkdCO9Q6rLBp/GW2/C3QZhPZEogfjZZrnrJf6BG27HVL/geDGSPToYTnit78UPSMQkdWBKcClwI+7PbUceEO14IxNNTcQKpQ98YenueL46ws+92udCcCZsrO3QMjLlRo5vpmOtjgdbd5VsXBDiINO35tv/yJ7rqJ7LnuYu37xYOeU16FoiNW+MpGrXriYYKjXiVxmCFBNoYuPhtTboJniShKByLdwms+sb3A9cFtvgtbfASuKQkXAPwUZcx8iXXP+aOJ1dMkxoAm874gBkLC3nn+tOkQ+NFRUj0BVP1bVmaq6jar+o9u/V+vVCQwU155acsxbtgL97NL5LZ2dAHhjCx64YjoLPlvUuWzZwhbuvPD+rLoH8fY4n/7nC5655/mK4jZDQPzvkHqnqxMA7+f229H03PrF1QN1l0Drb+nqBPB+Tn0Escey1225ALQdrxMASIK2oi2X9VO0w0/RjkBElotIS7F//RnkQNLW0l5woNivdSa/1plMZSFTWdj5uFx+vy9r3MJbz79HIJR/5a6jLc4LD1tZzOFKO57OHCRziC9TaGWASrwKUuiSZgzteKrzkWoCUu8VWE8hUSj73FRD0XsEqtoEICI/B+YCd+Jd6DgS7/r/sBQI1+aSjDhCU7c6B42jGih02c5xhBHjm2sSgxkEnFF4H9vck3IHZGQdAiqT00zh8aQO+LqPqfHjjT0tMJGBFK4DYvqunJHF+6vqdaq6XFVbVPV64IBaBzZQBYMBVllnYt7yM2VnzpSdeZ2xvM7Yzsfl8gf9bLbHxp2PN9z+K0Sbo+TOhRUIB9jv5K9WGr4Z5CR6CIW/v/khtF1/h1O+wOYgTXjfJbsLItGue2MiDkQOJK9OAGGIWr2PWimnI2gTkSNFxCcijogcCbTVOrCB7IpnL6R5THYa6NhVRhNpjmQtc3wOX9l67axl/qCfoy88lHBDiGhzhGhThDETR3H5jPMJBLvONnw+H7986qeMW3UskcYwDc1RQpEgp1x1HOts3vuCOmZoEP+aMOJi7waxNHrfkp1xyOjbkIKXXgYGEQcZfSs4K4NEvdgJQ9O5SGDj7HWbz4XgVkAo03kEIbwb0vjdeoQ+LJQzjmAycDWwHd653QvAGao6p8axFTQQsoZWePXvb/DuP99n0902Yv1t1gXgX4/PZub9/2SD7dZln+P3AODzD+cy874XmThlJXY6bFscx6GjPc47L/6HUDTEeluvjeMU7pNd1+U/r3xIe0s762+zTufYBDO8qcYy190jENjE+yY9CKgqJN8AbYXAVMQpPq5GUx9D+mPwr4X48s/CTe+UyhoaVoVpqqmtpZ2n73qOj9/9nLU3m8LOh21LKJJ7OluYqvLmc+/y4qOvEGkIs9uROxS83GRMX2nqUzT2CGgLEtoJgoO3FobrtsLy30BiFvgmQ/PZOP5Vy95ek++hHX8BTSORvZDARrULdgCqqCMQkbNV9fJuk89lKWfSuVoYCB3B5x/M5bRtzyMRS9DRFifcGKZpVAPXvnwpoyeUnn1DVbn82Gt5/qGXibfHcXw+fAEfp17zbfb69m799ArMcODGHodlP8YrI5IEohDaFhl57aA5g1jBTX0MC/ci7yb5iGtxIj3fM3Nbb8yMYUhkloQgegRO84+qHeqAVdE4AuDdzP8rJp/L/TdsXXHCDSxf1No5FqCjtYPFc5dy41l39rjt7Blv8PxDL9PRFkcV0qk0iViCa0+9hZbFy2sduhkm1G2HZecAHXidAEA7xF+A+Iw6RlahJd8jP1MKWNbzvF2a+gRar8V7L9zMvxi034Um36lunINUqcI0n4qIZCaZMxnJRJK3nn8vL7UznUrzz0df6XH7f/zphazBZCv4Aj5mP/UGuxw+gDM/zOCR/Jc3tiDvXD6Gxh5FwnvWI6rKpf9b5Ik4burT0peI4jOLPJFAO/6GBNbvY3CDX6mO4GZgDRGZjTcF9QvAP1V1WH9tdRyn6DVWn8/X4/b+oB9xBHWzP6GC4A/0vL0x5Skx3mUAZxcVV2Culk49jO0RP/lpq+CNvbCpWqD0FBPT8OoIXIw3uuM04AMReV1Eruun+AYcn9/Hlntvis+ffdAOhPzsdtQOPW6/x9GFJ41zXZdpXys88ZwxvRbcgoIfb4kgkW/0ezh9Figy2bE04fh7mIwutAeFOxEfEt67r5ENCSXvGKlqu6rOxEsfvRL4HdBATl2B4eb7N53MhCnjiTSFCYYDRBrDTNl4db5zSc8DXtbfeh0OO/sAguEAoUiQSGOYUDTE+Q+cSaQhfxZGYyohEkRGXe+NM5AoEAZCEDkcgtvWO7zeG3VjZkxBdz4YdWuPm4pvHIy4BG9cQgSIeD83/Rjxr16DYAefUllDRwDbApvgnRG8AryMd3loXr9FmGMgZA2BV2Pg1Rlv8Pl/5zFl49XYeMf1e5WWN2/OfF554t+EG8Jse8A0GkbY8HlTfeq2eRPV6XIIbj+oD3yu60Lsfki8AP41oeEkHKf8L0/qLoaOvwNpCO2K+Mb3uM1QUlE9AuBG4D/ADcCzqvp+LYIbaGJtHbz02Gxiy2NsuvtGrDxlJdpbY9x7ycN8+ckCtj94K3Y4eGtEBJ/fS/3MvUyU6+N3P+PNZ99lxLhmttpnM4KhABMmj2e/7w6yG3ZmwND0XIg/533DDe1SYmCWAD68j7p3AUA1CfFnwV3gDUYLfKV3bac+hMQr3rxHoV2KjmhWd4l3o1ZdCO2E+Mb2qp28VyICga+AuOCbVPL6vqbnea9Rwt5B32nEuyfgB3Uob1KF4aPUGYEPmIp3VrAtsC7e5HP/xDsreLq/guyulmcEbz3/LufucykAbtpFXZdt9p/Gsw+8lHVzd/TKo2gYEWHh54txUy7iCFM2Wo1fzjg/6/KO67pccfz1zLzvRcC7v+AP+fn13y9gykaD95uZqS+39YZMOqQPbzIqRUZeh+TMNaSJ2eiSEwAFTXv/hw+AxNPe1NWa9lYM7YyMvBLvI1+cqqIt50JsOiBeVhIBZPQdeZ2JN4bhR5kY1esMmn+CEz2sotesGkeXnAiJ17zXIT5wRiGj784rWOPVPbgG78DvAC5EjoH22zLL+h7PYFSVkcUishJwCHAGMEVV65LiUquOIJlIcuiEE2hdWt40SiKSlUIaCAXY96Q9OOWq4zqXPX33c1x50o156aITpoznjg+uHbQjPE39aPJNdNGReDnx3UgUGfci4kS99TSBzt/OqwSWveKKPXVbFoGmH+M0fLN027Hp6LLzyK4pADgTkXHPdP49a3ohumAX8mcQDSFj/4L4Vyv9Igtwl/8W2n6fs08fBLfAGX1HV4zF3p+CKo9nMKpoQJmIbCwiJ4vIHSLyAd49gu2Ba4CtahNq/bw+8x3vGmSZcjvQZDzJjDv/kbVs+o0zCo4ZWDp/GXPe+qSyQM2wpu0P0TU6tjuBxHNdDxOv4I0oztsD+Rk0MYjdU0bb95LXCQDo0uwaAvGnKJyumUY7nuixnYJiD5DfsaQhMQt1W7vF+DCF359C0mjH45XFM8SUukdwG/A88ATwE1Ud0keuZDy/2ExvpZPZH7xEkX2K45CID+sib6ZiCbyRsTmUTGnHFY/LPRj2Zv1i60hO20kKxogLWqDOQFmKfT6F7BHHRd6fgvoSz9BSahzBZqp6mqreM9Q7AYCpO2+QdyDvDZ/fYev9Ns9attsR2xOK5t9ICwT9rLXJ5IrbMsOXhPfCS3/MlcyuRxDcsuseQPYeCiwLQWTfnhsP71ek7QAENui2u12KtBNEwrv33E4hoT0p+L3VvwbidBXkkfDXisRYSB/iGWLs1nlGtCnCGTedRCgS7BzhG24IsfoG+UPXA6EAjSMbOg/y4YYQI1cayUm/PiZrvX1O3IM1Np5MuDGc2c5PKBrinLtO7zHTyJiCgttBePdMPvyKjKAQNJ2HOF2VvsRpgBG/wBs/kDmAShQCUzPbrviCEvUKyEe/3WPTEj0MAutlxiWAN6I3jIz8DSJdB2nxrwaNJ2XaXpGhE4boYRVP5yBNp4NvQre2wyCNyIjLs1cMbgfhPfLfn+DOOfFEMvFsgLFpqPN89t+5zLhjJm1L29h6vy3YbPeN+N8bH3Pzj//Igs8WM22PjTnu4m+SSqSZcec/+OSdz1h78zXZ5ZvbFRwQlk6l+edjs3j1b28wZuVR7HHMzoxftW9pdGZ4U/Xq92rHDHCiSOQAxL9W4XVTH6Oxh8BtQcK7egdKd6G3LD0XCW4J4a8iZU61oJqC+DNo/AVwxiLRryO+wpVrNfk2GpsOpJHw3kiwbyPnVePQ8TiaeA18k5HogYiTP9tvsfdHk++gsceqFs9gY/UI+tGzD7zIX//wDJvuthGH/HB/ANqWtfGfWf9j1PhmSxs1vabucki+Cc4Y8K/Tq2wzVYXU2+C2QmDjzqyiguumPssUglkzLyWzcDyjwb9uVeLR1AeQng+B9Tsv9ZQbjylPpfUIHqP4LE+o6v7VCa93BmpH0NHRwUEjjyOVyL4JfMCpX+OJm/9OIBQgnUozcc0JXPz4uYydOLrInozp4rbeAq1XZSaKS4FvNWTU78s6MGpqDrrkeHAXAo53z6D5fJzo17PX0w506RneFNUS9G78hvdCRlySdcmnJvE0fd8bl5D6b2awVwKix0D6g5x4voaMuDQvHlO+SjuCnUrtVFX/Uer5WhmoHcHhk05g0dylPa7n+BzWnLo61826vMd1zfCm8efRJd8jO2XTB/51cMY+Unpbdb1cfnce2d/nwsiYu5HAhp1L3GUXQOwhstMzw9B4Ik7jqVWMZ1dw5+bE4+Bdy+9+Y9tHfjZQfjymdyoaR6Cq/yj1r3bhDk7ldALgjVj+5N3P+ey/c2sckRnstO028vP205D6CE39r/TGyVdBW8g/qU+g7Xd3taFugU4AoAPa/ljleJYViMclf2hvNhwAACAASURBVLxDmvwiNPnxmOrpMWtIRNYWkQdE5B0R+d+Kf/0R3FDlC/hoWTSsyzqYcriLCy8XP7i5I4Zzt11G4RROF9ILuz1OUTRHX3NG2ZeMp4cvQkXj6YXceEzVlJM+eitwPd5fzC7AHYB1zTnEKf+P3HWVNafaTWPTg/CuQKjAE2kvjbOU4KaZgV25Ilm58yJB8K9TYD2BYM5VhJLx9JAWGty094PceorHVE05HUFEVf+Odz/hY1X9GbBPbcMafE769dEFlwfCAYIRLzVPBELRIKdcdSyhSKEPlDFdJHo0+MbRdfAVvHmBzkOk9PTL4oyGxlPIHlwVBv9qEMnO85DmizJ59yvGtvhBGpDm88qM59wy4/lefjzOxMz+VhyKQiAjved6iMdUTzm34OMi4gD/FZFTgc+BYnPeDltfP2NfRk8YyRUn3EBHWxxfwMdhZ+3PoWcdwGM3zOCl6bMZO2kUB5++D+tvs269wzWDgDjNMOYRtP0eiD8Dznik4RgkuFlZ2zuN30UDG6Ptd3mXbsJ7IdFv5B20JbiJ107brd6cQYGpSMNxeeMDahUPqQ+9ttOfQ3A7pOEocJflxHMs4ptY3htneq3HcQQisgXwLjAS+DkwArhcVV+qfXj5qp01tGT+MtR1GT0hf2BKd+3LY7QuaWXMpNGdtYkXfL6I9176LxvtuB4jx40ouq3ruiz8bBENI6JWgGaYUrfVu3nrrNQ53bPrLoXUh+BfF6doPYHMDV13HkgT4jRltnW9fHynCcc/uWvd9CLA9apy9TIe1Q5wF4EzrrPGgGoS3Pkgo7rNbFp+POVyU1+C+wX4N8BxeldT2XVTkHoTnJVw/BMzMSq4X4KECg46607dZd603M5KNZsRuDfx1EqlhWkAUNVXMjtxgNPKLV4v3teOZ/HO+/zAA6p6Qc46Ibx7DpsDi4DDVHVOOfvvq8/e/4JLjriaOW9/AgiT1l6Zc+8+nSkbZk9J29Ee58oTb+C5B1/G8QmhSIjjf3UkfzjnHpZ+2XXDbtLaE7jlnavyCtg/99DL/PaU3xNbHsN1Xbbebxpn3nIK0aZy50Mxg5lqB7rsJ9DxV7y58MNo448hdpd38Mpwg9vByFtwnOyrtW5sBiy/wBuAhYuGdvGulS//JSsya1xphOZLoe0GLx8fUN/qyMgrkED22acXz0+h44mueJrOhdT70P7HTH0DQRtOBiLQdjWQBnXRyNchuBUs/3l2PIFp0JoTz6jbcYIb9fj+uO5SWHgwuJ9llghu+DCckReV9f66y34FsZtZkY3kOuOh+UJYfimkv/RiDEz13oucsQ7qLkaXngmJl733whkNIy5FQtUt5amJ19BlZ/YYTz2Vc0YwDe+G8YqCocuAb6vq7B62E6BBVVvFG7/+PHB69zMJETkF2FhVTxaRw4GDVLVkpYhqnBHEY3GOnHwKLQuXd04nLQINIxu4a871WQfpiw79DS9Pn02io+fZSdfbem1+++IlnY/fffm/nLXbhcTbu1LzAiE/m+yyIZc8btc7hwN3yekQf5rs9EyvmEye0P44o37d+VCTb6CLjiJ7bn0/+amVRfYrzci4mVnVy9ylZ2TKNXaPZ0X1su43c4N4qZ2pnGVpstM9i8Xjh/Fv4Dilv2u683fKjC3I0fgjnMbvlN429nCm+E1PfOCbhIx9Cu/7bKbIzqKDvA4wK/4IMvYRpIKzmkI0/SW6cE/Q9px4JiJjZ3TG0x8qGkfQzR+AU1R1sqpOBr6H1zGUpJ4VE4UHMv9y//oPAG7P/PwAsJv0Q7WWF/78ColYIqumgCqkEqnOamIASxcs46XHyusEAN596b9Zj//0q0dIxLLzs5PxFK/PfJv5ny7EDG3qLvbqBefl6Bf58hX/S/ZarTcX2LbU9OU5+9UkdEzvFs8S6PhbkX3mZvQkCrSVID/nv1g8KWi/vchzHjc1p3AnAN7ZTU+WX9nzOgCkvUteiZe7hfcupD8iP/4k2n5nmfvtmcbuB81tI+2l4naPp87K6QjSqtpZ8UJVn6f0X2MnEfGJyGvAfGCGqua+8knAp5n9pvDONsYU2M+JIjJLRGYtWLCgnKZLmv/JwoIH9462OPM/6dr/4rlLCQQrH9I+98MvKXTC5Q/6WfjZoor3awaJ9MKSdXULbJDz8BNKzPJShhia/qLb/hb0Mp4+Svc0yOyD4s+VM2agp7EL2TuEdLdOJ/0FXVlJ3aUg9XEv9tuD1CcUruOQE0+dldMR/ENEbhSRnUVkJxG5DpgpIpuJSMl0AVVNq+omwCrAliKyYan1S+znJlWdpqrTxo0rfROsHF/Zci2C4fwPRKQxzFe2XLvz8cS1JvSqalnuWIKNd1q/c0rr7lKJFKuvv0ovIjaDkn81yi+SQiaFs5vglpSX2Fdsf1EkMDUnnn6cZDK4S+nnQyXGBTiFZzTN4l+j/FjUhUC3exaBDYqMswhn3vfqkOAWFKyPkBtPnZXTEUwF1gEuAH4GrAdsCvwG+HXxzbqo6lLgGeBrOU99DqwKIN5sUiPwbhrX1NSdN2CNjVfvzO8HCIYDTFp7ZbbYq2tq2nA0xFHnH0Io2pXz7ziC4yv8th18RvbwikPO3J9wYxinWwcRbghxyFn7W/bQMCAShsbTyD4QOBQ9uDeelb19w7dBGsj+mIYpPEI3QPZgryD4VofQzjnx/F+BeMJ01SdYIZRZ1r2tcGbbMuJxxuNEShd9cZyRxTuLET8vuS0AzRcXeSKI9350izG0IxLo+pInvpUz4ym6vxd+cJq8ugvVEtkXfGMLxLNDVjz1VrNpqEVkHJBU1aUiEgGeAn6pqtO7rfM9YKNuN4sPVtVDS+23Wumj8VicP/3qUZ667RlcV9n9qB05/McHEmnM773/cf8/ufeyh1k8bykb7bAex/78cB644lGeuPlp3LSLL+Djmz8+kGMuPDxv23lz5nPb+ffx77+/yYhxTRx65gHsduQOVrh+GNGOv6KtN4C7AAJbIE2nox3PQNu13iUQGQFNZ+NEv5G/beoztPW3kHgBnFFIw/GofyosOwNS/wEcCO4II34JsTuh/SHAhcj+SMNJXoGakvFMQ5rOgPR8tPXqzLTP6yKNp4MEvWXJN72brY3fA9+amXheBGdkJp5NYNnp2fGMvArHKT3IbAW35VJovweIg7MSjLgYJ7RDedvGZ3k3jN3PAT9EDoKG70P79dDxJEgIIod74x1yZi5Vdb06zO13eL+H0K5I46k9pt72lrpL0dbf9RhPrfWpHoGIrARcAkxU1b1EZH1gG1W9pYftNsa7EezD+wrxJ1W9SEQuAmap6qOZFNM78c4wFgOHq2rJC4sDdfZRY4wZyPo0jgCviP2twIp8x/eB+4CSHYGqvoF3gM9dfn63nzuAQ8qIYcB58rZnuP38+1j4xWImrjmBE355FNsdWL1ri6b2NP4s2nKplz3ijIGGU5DoEf1ytua2PwStV3uDjHyrQOPZ3hz8rb/DmwTOgdDXYMQVeWMLNPUJ2nKR961cAhDeHxpOhtYrvW+dqFc3uOls79tu7AHQDghugTSfX7Sa2UBS6P1xIl+tYzwPe3UYBkg81VbOGcErqrqFiPxbVTfNLHstcxO43w2EM4LpN83ghh/cRry9KxsgFAly3r3fZ5v9bGKswUDj/0SXnER2jn4EGk/FaTyhpm277fdByyVkT+lcJB8/uBPO6N93PlR3Gbrgq5kpnVfciF5xLT9FV+aRj655/Veki4o3Z8/YvyK+8dV7QVVW+P0Je4Ow6lBs3m3/E7RcPGDiqVRfxxG0icgYMukGIrI1XprnsKSq3PaTe7M6AYB4LMEt595Vp6hMb2nrFWR3AgAxaLveq8tbq3ZVvW+WefP6F2kz8Q9ctytObX/Qmw4hKxspgXew755+mu62vHNr0Ox6BAXtvLP3rw689+dq8t+fDnT5b+oUT6HfVwe6vKxcmUGhnEtDPwAeBdYUkReAcUD+Xa1hItGRYPmS1oLPzf3wy36OxlQsNafwck16c+f78oazVEkC3CW92yT9KTiZDJPU2+R3YL1sP/lWH7avtWTxugfpT/s3FMCLp0giY/qzwssHoXLmGno1U7ZyXbzzzP+oFkzAHRaC4SCNoxpoWZg/5dKEKSvVISJTEf9kSL6ev1z84BSfQLDvgt40y1rkYFeIb9Wun/3rAzOovDMIejn0haw4C/jHP7Ifz5xZYVuVCIAzqvDB11ePsTcBbw6igvFM6v9waqTopSER2UJEJkDnqN/NgYuB34jIsK28LiIc87NDs8YWgFdn4DuXHFGnqExvSeMZeDnw3UWg4eSapvWJCDSeTv4goyJtBrfPSsOU6Ne9FMSsj24A7z5B98GLvsyynLoXEkSiA/fvVESgodD7E0aafjDs46mVUvcIbiQzNlpEdgQuw5spdBlwU+1DG7j2++6efPfKYxi98igQmDBlPGfdeirbHrBFvUMzZZLQdsjIq8E32VvgjIGmHyANJ9a8bafhm9B8HjiZG7a+STDi19DwPbo6BIHg7jDy5uy4nZHImAcguA3exzfkDYwa+wSE9sTrFPzeQLKxf4HINzIjlh1vDMPoexFfkTPXmTO9fzvt5P1b8bifOQ2HF3h/LkPC9cnSKR7PnnWJpxaKZg2JyOuqOjXz8++ABZnqZMM+a6g7VbXBYYNcPX+Hhdp2XTcvZbTcbbtm081fXvZrrMslocIG2udroMXTG5WOI/CJiD9zWWg3oPtXpf4dEjeADdY/CtOlWr9Db6Tq/RC7A9wVI1W/B8nZaNuN3qRvwS2RxtMQ/2pF2+6pE1B3GdrqjZxVCUP0mxA5BNpvg5g3sljDB0D029DxOMRuB7cNzcSjbX/wag8Q9+b0GXExIgG09VpI/Q8eWg9pOq3gRBZu4t/eSN70J0DAa7fpJ2V1XL3htt3pZQ9pKyojoflHOJGDqtqG9/t6oPP9WfH7khKJAtX7W0l4v4fYg4AL4f2RhhM6i//0t1JnBOcBewMLgdWAzVRVRWQt4HZV3a7/wuwy0M4IjFnBXfYTiD1GV6qhH+8+RIqum7uOl8s/5s+If9VCuylJtQNduC+k59E1q2XYu/yjsW7thECi3kCyrHhy6w6sEOy2XIAwMvrWrDKUbuIdWHwQeRPXBabhjOkhJbUX3NZroPWa/CeafobTUL37G+6y8yH2CFnvjzMaGfu4V5azRlQVXXx0Jlmh2+/LvyYy5sHOinHVVtE4AlW9GPgh3sji7bWrx3CA/6t2kMYMZpqem3NQAa8DaCU7w8cFbUfbrqusodh078wi62DeAbokp514ZlluPIU6AXKWKxBDl1+avUrLeRScvTQ5Czf1Rf7ySrVeX2T55VVrQtPzMmdPOe+P2+Kd1dVSclamOl3O7ys9B+Iza9t2ESXP51T1JVV9WLVrcnBVfV9VX619aMYMIsm3ejHXfxoSr1TUjCZeJn9wU40k38t+nPqw+LpVOoC5bitFB9dlVfnqo+RbIIVqI3dA4p/Va6dg22+AFuiQtR1N/ru2bRfRf3XSjBnKfBPoVe2BSnPQfauSP2V0jTg5WeKlxlf416lSo6WukVfxcOWbAJpbbQ28MpKrFVheRc6ETApwrghSp7EJ1hEYUw3+DTMH6dw8Cof8A3cEaTipomYkeijkXUNeMcFvLofCVbgKyY07ArkxNp5eJKgROKWKzPSC4zgQKLKv0F5VaQMA/wbgX5381x1AGo6qXjuFhHfHG9+Rc+NZ/BDet7ZtF2EdgTFVICLIqNsgOA3vwB/25tYfeR2E98gsi3i1B5ovQELbVtaObwIy6hZwJtFZPCawMYy6Hfzr0TmIzLcmjLoVgltkxzPit/nVv0IHQ/Rb3joS9W48N56ARL+ZtZoT/QZEjiPrAOaMhTGPVPRaihp1m3eg7i6wFYyo3lxD3u8r//2RUdchval8VlHbIWTMPeD/Cl2/rzWQ0XciTlNN2y4aU60K09SKZQ2ZgU7dxeC2e8VcMumG6i735hjyTazKyGVVBfcLIJyV7qjpLwFFfBNKxuOmvvBSQAObdI5cVrcd3IXgWwkpeOnC47oJSL4GvpVxKsh8KpebXgSp/0JgfZxaZvEUeH/6i6bnA27W76tW+lqPwBiTw3VdaP0VxB4GFML7QNOPcZwg4ozOur6uyXfQ9tsh9TmEtoPoEeAuQ9tuh9R7EJyKRI/u1cFARAreZyg0ajg3HgDHPxH8E7tiTH2SFQ9F4nETH8HSU8CdAwRxo8fiNH8/bz1VhfhT3myppJHIQRDeq1epkY5vTA0n/+tS6P3pLwNlOnA7IzCmAu6C3TODqrpxxsHY57IGV7mxJ2HZWXjpmS5efn8DEMsUT08BAZAwMuZPiH/NfnsNK2jidXTJMZlMluLxuIn3YXGBa9j+zXHG3pO1yF32Y4g9QWeGk0QguB0y8nc2CLNO+lqPwBjTjRubnt8JgFcDONZVk0I1BS0/wcsXX5FRFPdmHtUYXWmSSW8EbUtO3n4/0ZbzM6mZufFckr3i0u8W3kFqNm56Xtf+ku9C7HGy0lw15lVUS9qXuIHIOgJjeqvjseLPxR7v+jk9B6/sZDkUEi/3IajKqCYyRecLxfOv7EVuifn32+7s+jnxItlFclbssh2NP1dJmKbGrCMwprdkVPHnfN2ek0boTbUzaag8por58WYsLSAvnhLX97vXCpARRfYZQpwS752pG+sIjOmtxhIzrDSe1vmj+CZ4qZ15ORm+AsvC3k3kfibiQORA8uoWFIqnaI67eFNrd663Z16K/Ir1CO9TcaymdqwjMKaXHP8kaPoJeUe7xjNwAl/JWiQjf5sZdRsBacKrH3BMpp5AKLMsCOHdkcYi1+BrTJrPheBWPcbjjLwcnAKjbkdkz5skThMy6vfemYE0dv6TUdcOmCwZk82yhoypkOu2Q+wB7/JP9DAcp/ilHU3+B9z5ENjAS1cENPUxpD8G/1qIb2LRbfuLpj7x7mv0EI+beAfa7/DKfUZOwPEVvmSkmvTGG2gagpshBef2Mf2lVNaQdQRmWFC3FY09CqkPkcAGENkbkdxSlVVoJ/UpGnsEtAUJ7QTBbfuULqmahvizaOIFcMYikQP7ZfBRvXm/r8cg9UHm97UXIrnlIocHTS9CY3+G9FwktIVXN6HsCQ67WEdghjVNfYwuOrRrbn6JgjR7c7/7xlWtHTf2OCz7MV7GTBKIQmhbZOS13rX43satCXTxsZB8B2jHK3zvQ0Zeh4TqUg6kX2jqE3TRId1qKUTBaUbGPDDsLi1p4t/okuMyE+TFvb9d3+rI6Ht6XcTGxhGYYU2X/QR0GZ157epNpaDLL6teG247LDsHb8zAipTRdoi/APEZle2z/X5vumRWTL+cAI2hS7/vjVEYorQl5/dFO7gLqvr7GgxUFV36/cwYj3hmYTuk/oe231bVtqwjMEOad536FfKniE5B/O/VayjxcoFZQQFi3iWpSnQ8QnbxkhWSkHqnsn0OcKqpTK2G3N9Xurq/r8Eg/bE3P1WeOFT6N1WEdQRmiBOK/5lXsSRgqWu2Fd8kLbZPZehOEyYUyT1l6L7mIiRA0RoXFdwjKMU6AjOkifghtAv5B5EghA+oXkPBLSl4AJMIEvlGRbuU6GHeHD15T4zMTDk99Ij4ILQrheoEEN6/HiHVjfgmZYrk5NYtiEDk0Kq2ZR2BGfKk+SJv5Ks00FnU3b8O0vSD6rUhQWTUDZkb0VG8ovUhiHwTgpXVHiC8L4S+2rUvaQAZ4c2ZP4QnbpMRF3lFfvJ+Xz+sd2j9TkZd682MuuK9IOJlouXUiuhzO5Y1ZIYDVRcSL0BqDgTWhcAWNTmYqtsK8adBW73ZNv2r932fyfe9eX+c0RDetSZprwNN1u/Lvw4EtxzSnV8pqgmvJnR6PgQ39dJpK2D1CMyA9vG7n/Hms+8yYlwzW+2zGcFQda9/QmYqhdAO3r8qUXeJ9wFVF0I7ZwrErCgPWax8ZO9JYB0IZNcE1vRciD/nXSYI7YI4jX1qwzvwvpQZ4LY2BDb3agp0PORNIudfC6LH4zj9MyisN78vTX3o3WB2RnnvxRAbuCYShPBXa9pGzToCEVkVuANYCe/u1k2qenXOOjsDjwAfZRY9pKoX1SomM7C4rssVx1/PzPteBMDx+wgE/fz66QuYslHfv0nXkhv7S2bMgA9EoeVnaPRbELsX0Ezet6LRo3Gaz6pu2603QOu1mbbFa2/k75DQ9hXtT93F6KKjwJ3rxS0OyBTQT0CXd63Yeg3u6D/hBDeqyuvoK1VFW86D2GOAZLK2AjD6DiRnqg9TWs0uDYnIysDKqvqqiDQBs4EDVfWdbuvsDJypqmVXbLZLQ0PH03c/x5Un3UhHWzxr+UqTx3HnhwO3gImmF6ILdqEzt7ukCDLqBiS0TXXaTr6JLjqSvLRSiSLjXuz1ICMAd8n/ZVIzyxibIM04Kw2Mz5/GpqPLziOr7gGAMxEZ98yA/fupl7oMKFPVuar6aubn5cC7QH5tPTNs/eWmGXmdAMCyBS189GaBwi8DRfwpiqc45oqhsfur1rS2P4RX7SyXQKL3c/17U1iU2QkAaAtual7P6/UDbb+PvE4AQJd6JTdN2fola0hEJgObAoUqb2wjIq+LyBMiUvAuiIicKCKzRGTWggULahip6U+JjsJFW8QRkvFyC7rUga4oO1nu+uWcOZSrSNtKJq7ecgvvr6Rqvp6+KBaHVPheDF817whEpBF4EDhDVVtynn4VWF1VpwLXAH8utA9VvUlVp6nqtHHjqjc3jKmvXY/cgVAk/8aeP+BnrU2n1CGiMoV2oewzAokikbKvfPa8u/BeQKHJ15JQwfxDIgEIbE75ZzghnCpkQlVFeH+81Npcfqgws2a4qmlHIN4UeQ8Cd6nqQ7nPq2qLqrZmfn4cCIjI2FrGZAaOfU7cgzU3mUyk0fswB0J+QtEQ59x1Oj5/FUf9Vpn4V4fGk/AOQg7eQTQCgR3xcr0zORgS9eb5D1Ux4yO4HYR3zww0E7zspDA0ndc5vXVvyYiLM1XFVnQwUWA0BUdej/hlRW3UgkQPhcD6mXEbAEEgjIz8jTeQ0JStljeLBbgdWKyqZxRZZwLwpaqqiGwJPIB3hlA0KLtZPLSkU2lemj6b2X97gzETRrLHMTszftXB8V1Ak2+jselAGgnvjQQ38WY6jT0M7jIkvKs3lqCCmUdLtqtefWPt+Bs4ESRyAOJfq2/7dJd702en3kcCG0J4X5QULL8cEq+Cb3Vo/hGOf3J1XkSVePc4nkHjz4NvHBI5GPGtXO+wBqS6TEMtItsDzwFv0nUR8lxgNQBVvUFETgW+i3enKgb8QFVfLLVf6whMd6oKqbfBbYXAxhVlzeTt010OyTfBGeONaBVBNQ6J10DCENio6gd3Y2qtLgPKVPV5erjwqKrXAtfWKgYztGnqI3TJ8eAuAhzQNNr8U5xoZXP7ALitN0Pr1d5EcZoC/6po5AhovdxrA9cr5zjq95arboYM+1pjBiVVF118DKQ/8+Zo11YgBi0Xocm3K9tn/DlovQaIZwZSxSD1ASy/sKsNbQf3S3TxMd4U18YMAdYRmMEpOTtzsM69tJlA2++uaJfadhv5eelugTa8dog/X1E7xgw01hGYwcldRuErjy64Cyvc5+JerKyZKlrGDH7WEZjBKbhZkUFDESS0W2X7DO+Gl/5ZBk1nahAYM/hZR2AGJXFGQ+MpZA+uCoN/dYhUVnBGokeDbxxdnYF4+3RWzWknAtGjEN/EitoxZqCxURdm0HIaT0EDU9H2u8BdCuG9kOghiJT5rT6HOM0w5hHvHkN8JjjjkYZjIbAB2v4wdDwG0oBED4fQztV8KcbUlRWmKVP78hitS1oZM2k0Pt/AHfU6HKm7DDQGzkoDYsbJasej6oI7D6QJcZqqEKEZjqwwTR90tMe58sQbeO7Bl3F8QigS4pSrj2O3I6pX4MRURt3F6NIfetW7cLwKXiMuq9qUzwMhHjc2A5Zf4A2Yw0VDuyAjLu1zIRpjurN7BD24/Nhref6hl0nGk8TbE7QsWs6VJ97A6zMry1U31aGq6OLjIPEykATi4M5Fl5yMpj4eEvFo8g1Y9sNMFlQHXsrqM+jS06oYuTHWEZS0dMEyXnpsdt50yfH2BPdcmjeHnulPqbch9TH58+gn0fY76xNPurrxaOvN5E+1nIDEK2j684r2aUwh1hGUsHjuUgLBwlfP5s2Z38/RmCzpL7ySinlSkKpDUZv0FxT+OKUyHVYl+/yEgoPZJAjpLyvbpzEFWEdQwsS1JuC6+UU7HJ/DRtuvV4eITKfABkXGEYS9qZ/7W2DDEvFsXdk+g1tS8DaeJrxi8sZUiXUEJYSjIY46/xBC0a50RMcRwg0hjjjv63WMzIhvEkT2Jzu/3w9OExI9pA7xTKx6PNLwbZAGsj+mEWg4zkt1NaZKLGuoB4eddQATJo/n3sseZvG8pWy0w3oc+/PDWXmNleod2rAnzb9A/RtC+x2gbRDaFWn8v7odJAvHc2rF8YhvAox5GG39LSReAGcU0nA8hCsbMGdMMTaOwBhjhoFS4wjs0pAxVaLagdvyC9wvN8Wdtx7u4m+hqQ+q307qE9zFx+POWx/3y6m4y36Kum3Vbyf+LO6CvXDnfQV3/na4bX9ksH1xNOWxjsCYKtEl34P2+7zLQqQh8S900WFounoZZuouQxcdAonngZQ3gjn2MLrkuKoepDX+ErrkVEh/iDej6wJY/iu07fdVa8MMHNYRGFMFmvoQEq+QnfevoPGK6yMUbKf9Qe/gT/dstgQk/wPJN6rXTusVeIPYuotB2w2o5o6VMIOddQTGVEPqQ5BCuRcJSL5VxXbeJv8ADYh4MVStnY8KL9ekN8GfGVKsIzCmGvxreDWO8wS9MQ9Va2d9IFzkuTWr2M7kwsvFD86I6rVjBgTrCIypAvGvBcHNyStsI0EkekT12ol+HSRE9kc3AL61IbBx9dppMaAW7wAACLdJREFU/D75HU4EGk5CJFC1dszAYB2BMVUio66HyCEgEcCBwBbI6PsQX/XGnIgzEhnzAAS38dogBJEDkNG3VnUKbglti4y8GnxTvAXOGGj6AdJwUtXaMAOHjSMwpgZUtea1Efqjjf5sx9SWjSMwpp/1x4Gzvw7O1gkMfdYRGGPMMGcdgTHGDHPWERhjzDBnHYExxgxz1hEYY8wwZx2BMcYMc9YRGGPMMGcVysywoKlP0LbbIfUeBKci0aO9CmDGmNqdEYjIqiLyjIi8IyJvi8jpBdYREfmtiHwgIm+IyGa1iscMX5p4HV20H8TugeQr0HY7unAfb+poY0xNLw2lgB+q6vrA1sD3RGT9nHX2AtbO/DsRuL6G8ZhhSlvOz8zhv2J20CRoK9pyaT3DMmbAqFlHoKpzVfXVzM/LgXeBSTmrHQDcoZ6XgJEisnKtYjLDj2oCUv8p9Awk/tXv8RgzEPXLzWIRmQxsCryc89Qk4NNujz8jv7NARE4UkVkiMmvBggW1CtMMST6gyLTJEu3XSIwZqGreEYhII/AgcIaqtlSyD1W9SVWnqeq0cePGVTdAM6SJ+CByAHl1AghD9Mh6hGTMgFPTjkC8ChYPAnep6kMFVvkcWLXb41Uyy4ypGmk+D4JbASGQJu//8O5I43frHZoxA0LN0kfFm7v2FuBdVb2iyGqPAqeKyL3AVsAyVZ1bq5jM8CQSQUbfjKY+gfQc8K+N+OxWlDEr1HIcwXbAt4A3ReS1zLJzgdUAVPUG4HFgb+ADoB04robxmGFO/KuBf7V6h2HMgFOzjkBVnwdKVrRQrzza92oVgzHGmJ7ZFBPGGDPMWUdgjDHDnHUExhgzzFlHYIwxw5x1BMYYM8yJl7gzeIjIAuDjPuxiLLCwSuHU21B6LTC0Xs9Qei0wtF7PcH0tq6tqwakZBl1H0FciMktVp9U7jmoYSq8FhtbrGUqvBYbW67HXks8uDRljzDBnHYExxgxzw7EjuKneAVTRUHotMLRez1B6LTC0Xo+9lhzD7h6BMcaYbMPxjMAYY0w31hEYY8wwN2w6AhH5g4jMF5G36h1LX4nIqiLyjIi8IyJvi8jp9Y6pUiISFpF/icjrmddyYb1j6isR8YnIv0Vker1j6SsRmSMib4rIayIyq97x9JWI/H979x9r9RzHcfz54jK3H4qFNVEWovmj0koiTWqiJYthspEtNqIZVmHZbNbYmo0xVJZJTT9nZpESYdRK0m9Dmxoy0Q+aWl7++H5uu7XW7X7PtW/nfN+P7axzvud7Pud1a933+X6+3/P+tJc0V9ImSRsl9Ss6Ux6SuqV/k4bbbknjco9XlnMEkgYAe4E3bV9WdJ5KSOoIdLS9WlJbYBUwwvaGgqM1W1rAqLXtvWlFu8+Ah21/WXC03CQ9AvQGTrc9rOg8lZC0Fehtuya+gCVpBrDc9lRJpwKtbP9ZdK5KSDqZbGXHvrZzfdm2NEcEtj8FdhadoyXY/tn26nR/D7AROLfYVPk4szc9PCXdqvbTiaROwI3A1KKzhMNJagcMIFs5Edv7q70IJIOA7/MWAShRIahVkroAPYGvik2SX5pKWQPsABbbrtqfBXgBeBz4t+ggLcTAh5JWSRpTdJgKXQD8BryRpu6mSmpddKgWcDswq5IBohBUMUltgHnAONu7i86Tl+2DtnsAnYA+kqpy6k7SMGCH7VVFZ2lBV9nuBQwFHkhTrNWqDugFvGK7J/AXML7YSJVJ01vDgTmVjBOFoEql+fR5wEzb84vO0xLSYfrHwPVFZ8mpPzA8zavPBq6V9FaxkSpje3v6cwewAOhTbKKKbAO2NTrinEtWGKrZUGC17V8rGSQKQRVKJ1inARttTyk6TyUknSWpfbpfDwwGNhWbKh/bE2x3st2F7HB9qe1RBcfKTVLrdDECaQplCFC1V93Z/gX4SVK3tGkQUHUXWBzhDiqcFoL/cfH6E42kWcBAoIOkbcAk29OKTZVbf+Au4Ns0tw4w0fb7BWbKqyMwI135cBLwju2qv+yyRpwDLMg+d1AHvG17UbGRKjYWmJmmVH4A7ik4T26pOA8G7qt4rLJcPhpCCOHoYmoohBBKLgpBCCGUXBSCEEIouSgEIYRQclEIQgih5KIQhJoj6WDqyLhO0hxJrZrYf+JxjrtVUofj3d5SJI2Q1L3R42WSamLx9XBiiEIQatE+2z1Sl9n9wP1N7H9chaBAI4DuTe4VQk5RCEKtWw5cCCBpVFr7YI2kV1Ozu8lAfdo2M+23MDVZW5+30Vr6Vu709H5fS7opbb9b0nxJiyR9J+m5Rq+5V9KW9JrXJb0k6UqyXjLPp4xd0+63pv22SLq6gr+fEMrzzeJQPpLqyHqxLJJ0KXAb0N/2AUkvA3faHi/pwdT0rsFo2ztTy4uVkubZ/r2Zb/8EWYuJ0amFxgpJH6XnepB1jP0H2CzpReAg8BRZ75s9wFLgG9tfSHoXeM/23PRzAdTZ7iPpBmAScF0z84VwSBSCUIvqG7XeWE7Wl2kMcDnZL3aAerK210fzkKSb0/3zgIuA5haCIWQN6B5Nj08Dzk/3l9jeBSBpA9AZ6AB8Yntn2j4HuPgY4zc0GlwFdGlmthAOE4Ug1KJ9R3zCb2jUN8P2hGO9UNJAsk/X/Wz/LWkZ2S/x5hIw0vbmI8bvS3Yk0OAg+f4fNoyR9/UhHBLnCEJZLAFukXQ2gKQzJXVOzx1Ibb0B2gF/pCJwCXBFzvf7ABibChCSejax/0rgGklnpCmtkY2e2wO0zZkjhCZFIQilkNZzfpJsta21wGKyzqcArwFr08niRUCdpI3AZOB4105eK2lbuk0BniFbdnOtpPXp8bHybQeeBVYAnwNbgV3p6dnAY+mkc9ejjxBCftF9NIQThKQ2tvemI4IFwHTbC4rOFWpfHBGEcOJ4Op3kXgf8CCwsOE8oiTgiCCGEkosjghBCKLkoBCGEUHJRCEIIoeSiEIQQQslFIQghhJL7D3acu5LIA2dRAAAAAElFTkSuQmCC\n", 347 | "text/plain": [ 348 | "
" 349 | ] 350 | }, 351 | "metadata": { 352 | "needs_background": "light" 353 | }, 354 | "output_type": "display_data" 355 | } 356 | ], 357 | "source": [ 358 | "# Plot the `clusters` and their `centroids`\n", 359 | "# Gather all the necessary data to plot the `clusters`\n", 360 | "xs: List[float] = []\n", 361 | "ys: List[float] = []\n", 362 | "cs: List[int] = []\n", 363 | "for cluster_idx, items in km.clusters.items():\n", 364 | " for item in items:\n", 365 | " cs.append(cluster_idx)\n", 366 | " xs.append(item[0])\n", 367 | " ys.append(item[1])\n", 368 | "\n", 369 | "fig = plt.figure()\n", 370 | "ax = fig.add_subplot()\n", 371 | "ax.scatter(xs, ys, c=cs)\n", 372 | "\n", 373 | "# Add the centroids\n", 374 | "for c in km.centroids.values():\n", 375 | " ax.scatter(c[0], c[1], c='red', marker='+')\n", 376 | "\n", 377 | "# Set labels\n", 378 | "ax.set_xlabel('Petal Length')\n", 379 | "ax.set_ylabel('Sepal Width');" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 16, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "# Function which quantifies how far apart two values are\n", 389 | "# We'll use it to calculate errors later on\n", 390 | "def squared_error(a: float, b: float) -> float:\n", 391 | " return (a - b) ** 2\n", 392 | "\n", 393 | "assert squared_error(2, 2) == 0\n", 394 | "assert squared_error(1, 2) == 1\n", 395 | "assert squared_error(1, 10) == 81" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 17, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAfJklEQVR4nO3dfXAcd53n8fd3ZvQsyxpL8qOkkWOnCCF7tjUiFx6WTeJjLwQ2yVGQJXcsKS5XOSBwcNwVG/hna/fqrpaqrWWP2iVcIECyx1NIoBI4ikouJMDekhDZTpwnAo5jyZYdW7YlWZZkSTPzvT+mNR7JsiXZmukZzedVNTXdv26Nvnbi/kz/ft2/NndHREQEIBJ2ASIiUjoUCiIikqNQEBGRHIWCiIjkKBRERCQnFnYBl6K1tdW7urrCLkNEpKzs2rXruLu3zbetrEOhq6uL3t7esMsQESkrZtZ3vm3qPhIRkRyFgoiI5BQ0FMzsgJm9YGbPmVlv0LbGzB43s98H7/Gg3czsy2a2z8z2mll3IWsTEZFzFeNM4Tp33+7uPcH63cAT7n458ESwDvAe4PLgdSdwTxFqExGRPGF0H90M3B8s3w/cktf+gGc9DTSb2YYQ6hMRqViFDgUHHjOzXWZ2Z9C2zt2PBMtvAOuC5U3AwbyfPRS0zWJmd5pZr5n1Dg4OFqpuEZGKVOhLUt/p7gNmthZ43Mx+m7/R3d3MljRNq7vfC9wL0NPToyleRUSWUUHPFNx9IHg/BvwIuBo4OtMtFLwfC3YfADryfrw9aFt2u/pO8sWf/RZNGy4iMlvBQsHMGsxs1cwy8MfAi8CjwO3BbrcDjwTLjwIfCa5CugYYyetmWlYvDpzinqde4/DImUJ8vIhI2Spk99E64EdmNvN7vuPuPzOzZ4EHzewOoA+4Ndj/p8CNwD5gHPhooQpLJuIA7OobYlNzXaF+jYhI2SlYKLj7fmDbPO0ngJ3ztDtwV6HqyXfF+lXUVUXZ3TfETds2FuNXioiUhYq8ozkWjbC9o5ldfUNhlyIiUlIqMhQAuhPNvHzkFONTqbBLEREpGRUbCslEnHTG2XtoJOxSRERKRsWGwo6Os4PNIiKSVbGhEG+oZktbA7sVCiIiORUbCpDtQtrVP6Sb2EREAhUdCt2dcYbHp9l/fCzsUkRESkJFh8LMTWzqQhIRyaroUNjS1khTbYzd/QoFERGo8FCIRIzuRFxXIImIBCo6FACSnXF+d/Q0IxPTYZciIhK6ig+F7mBcYY+6kEREFArbOpqJGOzuHw67FBGR0FV8KDTWxLhifZOuQBIRQaEAZC9N3dM/RDqjm9hEpLIpFMiGwthUmlffGA27FBGRUCkUyHsSmwabRaTCKRSA9ngdrY01GlcQkYqnUADMjGSiWXc2i0jFUygEkok4fSfGGRydDLsUEZHQKBQCucnxdLYgIhVMoRB4y8bVVEcjGlcQkYqmUAjUVkV5y6YmTY4nIhVNoZAn2Rln78AIU6lM2KWIiIRCoZAnmYgzlcrw0uGRsEsREQmFQiHPzIyp6kISkUqlUMizrqmW9nidrkASkYqlUJgjGTyJzV2T44lI5VEozNHdGefoqUkGhifCLkVEpOgUCnOcvYlND90RkcqjUJjjivWrqKuK6iY2EalICoU5YtEI2zuadQWSiFSkgoeCmUXNbI+Z/SRY32xmz5jZPjP7vplVB+01wfq+YHtXoWs7n2QizstHTjE+lQqrBBGRUBTjTOHTwCt5618EvuTuW4Eh4I6g/Q5gKGj/UrBfKLoTzaQzzvMHdRObiFSWgoaCmbUD7wW+HqwbcD3wULDL/cAtwfLNwTrB9p3B/kW3o0MzpopIZSr0mcLfAZ8DZiYTagGG3X2mX+YQsClY3gQcBAi2jwT7z2Jmd5pZr5n1Dg4OFqToeEM1W9oaNNgsIhWnYKFgZu8Djrn7ruX8XHe/19173L2nra1tOT96lmQizq5+3cQmIpWlkGcK7wBuMrMDwPfIdhv9T6DZzGLBPu3AQLA8AHQABNtXAycKWN8FJRNxhsen2X98LKwSRESKrmCh4O6fd/d2d+8CPgT83N3/HfAk8IFgt9uBR4LlR4N1gu0/9xC/pic1OZ6IVKAw7lP4c+CzZraP7JjBfUH7fUBL0P5Z4O4Qasu5rLWRptqYxhVEpKLEFt7l0rn7U8BTwfJ+4Op59jkDfLAY9SxGJGJ0J+K6AklEKoruaL6AZGec3x09zcjEdNiliIgUhULhAmbGFfbobEFEKoRC4QK2dTQTMTSuICIVQ6FwAQ01Ma5Y38QunSmISIVQKCwgmYjzXP8w6YxuYhORlU+hsIBkIs7YVJpX3xgNuxQRkYJTKCwgdxObupBEpAIoFBbQHq+jbVWNBptFpCIoFBZgZiQ745ruQkQqgkJhEboTzfSfHGdwdDLsUkRECkqhsAgz4wqa8kJEVjqFwiK8ZeNqqqMRjSuIyIqnUFiE2qooV21q0riCiKx4CoVFSibi7B0YYTKVDrsUEZGCUSgsUndnnKlUhpcOnwq7FBGRglEoLFL3zGCzupBEZAVTKCzSuqZa2uN1ugJJRFY0hcISJBPZm9hCfHS0iEhBKRSWIJmIc/TUJAPDE2GXIiJSEAqFJejuDCbH07iCiKxQCoUluGL9KuqqouzpHw67FBGRglAoLEEsGmF7R7POFERkxVIoLFEyEeflI6cYn0qFXYqIyLJTKCxRMhEnnXGePzgSdikiIstOobBEOzqbAc2YKiIrk0JhiZrrq9nS1qBxBRFZkRQKFyGZiLO7XzexicjKo1C4CMlEnOHxafYfHwu7FBGRZaVQuAgzT2JTF5KIrDQKhYtwWWsjq+uqNGOqiKw4CoWLEIkYOzp1E5uIrDwKhYuU7Izz+2OnGZmYDrsUEZFlo1C4SDPjCnt0v4KIrCAFCwUzqzWz35jZ82b2kpn9ZdC+2cyeMbN9ZvZ9M6sO2muC9X3B9q5C1bYctnU0EzE9iU1EVpZCnilMAte7+zZgO3CDmV0DfBH4krtvBYaAO4L97wCGgvYvBfuVrIaaGG/e0MQunSmIyApSsFDwrNPBalXwcuB64KGg/X7glmD55mCdYPtOM7NC1bcckok4z/UPk0pnwi5FRGRZFHRMwcyiZvYccAx4HHgNGHb3mSlGDwGbguVNwEGAYPsI0DLPZ95pZr1m1js4OFjI8hfU3RlnbCrNq0dHQ61DRGS5FDQU3D3t7tuBduBq4Ipl+Mx73b3H3Xva2touucZLMTPYvFsP3RGRFaIoVx+5+zDwJPA2oNnMYsGmdmAgWB4AOgCC7auBE8Wo72K1x+toW1WjwWYRWTEKefVRm5k1B8t1wLuBV8iGwweC3W4HHgmWHw3WCbb/3Et8xjkzI9kZ101sIrJiFPJMYQPwpJntBZ4FHnf3nwB/DnzWzPaRHTO4L9j/PqAlaP8scHcBa1s2yUSc/pPjHBs9E3YpIiKXLLbwLhfH3fcCO+Zp3092fGFu+xngg4Wqp1C6E8FDd/qGueGq9SFXIyJyaXRH8yV6y8bVVEcjurNZRFYEhcIlqq2KctWmJo0riMiKoFBYBslEnL0DI0ym0mGXIiJySRYMheAGtL8pRjHlKpmIM5XK8NLhU2GXIiJySRYMBXdPA+8sQi1lq7szuIlNXUgiUuYWe/XRHjN7FPgBkHswsbv/sCBVlZm1TbW0x+vY1TfEf/jDsKsREbl4iw2FWrJ3F1+f1+aAQiGQTMR5ev8J3J0Sn8dPROS8FhUK7v7RQhdS7pKJOI88d5iB4Qna4/VhlyMiclEWdfWRmbWb2Y/M7FjwetjM2gtdXDmZGVfQpakiUs4We0nqN8nOTbQxeP04aJPAFetXUV8d1WCziJS1xYZCm7t/091TwetbQLjzVpeYWDTCtvZmPYlNRMraYkPhhJl9OLhnIWpmH6bEp7UOQzIR55Ujo4xPpRbeWUSkBC02FP49cCvwBnCE7NTWGnyeI5mIk844zx8cCbsUEZGLsuDVR2YWBd7v7jcVoZ6ytqMzmDG1f4i3bTnnSaIiIiVvsXc031aEWspec301W9c26gokESlbi7157f+Z2d8D32f2Hc27C1JVGevubOaxl4+SyTiRiG5iE5HysthQ2B68/1VemzP7DmchO67wYO8h9h8fY+vaxrDLERFZksWMKUSAe9z9wSLUU/aSiWByvP4hhYKIlJ3FjClkgM8VoZYV4bLWRlbXVekmNhEpS4u9JPX/mtl/NbMOM1sz8ypoZWUqEjG6O5s12CwiZWmxYwp/GrzfldfmwGXLW87KkEzEefLVQUbGp1ldXxV2OSIii7bYWVI3F7qQlST30J2DQ1z3prUhVyMisngX7D4ys8/lLX9wzrb/Uaiiyt22jmYiBnvUhSQiZWahMYUP5S1/fs62G5a5lhWjoSbGmzc0aXI8ESk7C4WCnWd5vnXJk0zEea5/mFQ6E3YpIiKLtlAo+HmW51uXPMlEnLGpNK8eHQ27FBGRRVtooHmbmZ0ie1ZQFywTrNcWtLIylxts7hviLRtXh1yNiMjiXPBMwd2j7t7k7qvcPRYsz6zrWssLaI/X0baqRvcriEhZWezNa7JEZkayM87u/uGwSxERWTSFQgElE3H6T45zbPRM2KWIiCyKQqGAumcmx+vT2YKIlAeFQgFdtamJ6miE3bpfQUTKRMFCIZg870kze9nMXjKzTwfta8zscTP7ffAeD9rNzL5sZvvMbK+ZdReqtmKpiUW5alOTBptFpGwU8kwhBfwXd78SuAa4y8yuBO4GnnD3y4EngnWA9wCXB687gXsKWFvRJBNxXhgYYTKVDrsUEZEFFSwU3P3IzOM63X0UeAXYBNwM3B/sdj9wS7B8M/CAZz0NNJvZhkLVVyzJRJypVIaXDp9aeGcRkZAVZUzBzLqAHcAzwDp3PxJsegNYFyxvAg7m/dihoG3uZ91pZr1m1js4OFiwmpdL/k1sIiKlruChYGaNwMPAZ9x91tdld3eWOF2Gu9/r7j3u3tPW1raMlRbG2qZaOtbUaVxBRMpCQUPBzKrIBsK33f2HQfPRmW6h4P1Y0D4AdOT9eHvQVva6O+Ps6hsim4EiIqWrkFcfGXAf8Iq7/23epkeB24Pl24FH8to/ElyFdA0wktfNVNaSiTjHRicZGJ4IuxQRkQta7OM4L8Y7gD8DXjCz54K2LwB/DTxoZncAfcCtwbafAjcC+4Bx4KMFrK2oZsYVdvUN0R6vD7kaEZHzK1gouPs/cf5nLuycZ39n9jOgV4wr1q+ivjrK7r4hbt5+zti5iEjJ0B3NRRCLRtje0awnsYlIyVMoFEkyEeeVI6OMTabCLkVE5LwUCkXS3RknnXGeP6TJ8USkdCkUimRHZzMAe/R8BREpYQqFImmur2br2kbdxCYiJU2hUETZJ7ENkcnoJjYRKU0KhSJKJuIMj0+z//hY2KWIiMxLoVBE3YnsuIImxxORUqVQKKLLWhtZXVelJ7GJSMlSKBRRJGJ0dzZrsFlESpZCociSiTi/P3aakfHpsEsRETmHQqHIuhPBQ3cO6mxBREqPQqHItrU3E42YBptFpCQpFIqsoSbGFetXaVxBREqSQiEEyUSc5w8Ok0pnwi5FRGQWhUIIkok4Y1NpXj06GnYpIiKzKBRCMPMkNo0riEipUSiEoD1ex9pVNRpXEJGSo1AIgZnR3RnXk9hEpOQoFEKSTMQ5eHKCY6Nnwi5FRCRHoRCS3E1sfXrojoiUDoVCSK7a1ER1NKLJ8USkpCgUQlITi/IH7as12CwiJUWhEKLuzmZeODTCZCoddikiIoBCIVTJRJypdIYXB06FXYqICKBQCNXMTWx7NK4gIiVCoRCitU21dKyp07iCiJQMhULIkp1xevuGcPewSxERUSiELZmIMzg6yaGhibBLERFRKIRtx8zkeBpXEJESoFAI2RXrV1FfHdWMqSJSEhQKIYtFI2zvaNbkeCJSEhQKJSCZiPPKkVHGJlNhlyIiFa5goWBm3zCzY2b2Yl7bGjN73Mx+H7zHg3Yzsy+b2T4z22tm3YWqqxR1J+KkM87zhzQ5noiEq5BnCt8CbpjTdjfwhLtfDjwRrAO8B7g8eN0J3FPAukpOd4eexCYipaFgoeDuvwROzmm+Gbg/WL4fuCWv/QHPehpoNrMNhaqt1Kyur2Lr2kbdxCYioSv2mMI6dz8SLL8BrAuWNwEH8/Y7FLSdw8zuNLNeM+sdHBwsXKVFluyMs+fgMJmMbmITkfCENtDs2Vt4l3wEdPd73b3H3Xva2toKUFk4kok4w+PT7D8+FnYpIlLBih0KR2e6hYL3Y0H7ANCRt1970FYxzj6JTV1IIhKeYofCo8DtwfLtwCN57R8JrkK6BhjJ62aqCJe1NtBcX6VxBREJVaxQH2xm3wWuBVrN7BDwF8BfAw+a2R1AH3BrsPtPgRuBfcA48NFC1VWqIhFjh25iE5GQFSwU3P2282zaOc++DtxVqFrKRU/XGp589VX+6scv87E/uoy1TbVhlyQiFaZgoSBL95G3Jdg/OMb9vz7At5/p47arO/n4tVtYp3AQkSKxcp7Hv6enx3t7e8MuY9n1nRjj73++jx/uGSAaMf7t1Z187I+2sH61wkFELp2Z7XL3nnm3KRRKV/+Jcf7hyX08vPsQkYhx21s7+Pi1WxUOInJJFApl7uDJbDg8tOsQETP+9K0dfPzaLWxsrgu7NBEpQwqFFeLgyXG+8tQ+ftCbDYdb39rOJ67dqnAQkSVRKKwwh4bG+cpTr/GD3uzMILf2dPCJ67aySeEgIougUFihBoYn+MqT+3gwCIcP9nTwiWu30B6vD7kyESllCoUVbmB4gnue2seDzx7CcT6QzHYrdaxROIjIuRQKFeLw8AT3PPUa33/2IBnPhsNd1ykcRGQ2hUKFOTIywVefeo3vPnuQTMZ5f/cmPnnd5XS2KBxERKFQsd4YOcNXf/Ea3/lNP+mM8/4dm/jk9VtJtDSEXZqIhEihUOGOngrC4Zl+Uhnn3+zYxCev20pXq8JBpBIpFASAY6fO8NVf7Ofbz/SRyjg3b9/Ip66/nM0KB5GKolCQWY6NnuF/BeEwlcpwy/Zst9JlbY1hlyYiRaBQkHkNjk5y7y9f4x+fzobDTds28qmdl7NF4SCyoikU5IIGRyf52q/284+/7mMyleZPtmW7lbauVTiIrEQKBVmU46cn+dov9/PAr/s4k0rzJ/9iI/9p51a2rl0VdmkisowUCrIkJ05P8rVfvc4Dvz7AxHSa9/7BBj5+7RbevL6JSMTCLk9ELpFCQS7KybEpvvar/TzwzwcYm0pTHYuQWFNPV2sDm1sb6GppoKu1ns2tDaxvqsVMgSFSDhQKckmGxqb42Utv8PrxMV4/PsaB42P0nRxnKpXJ7VNbFaGrJQiL1gY2t2Tfu1rraWusUWCIlJALhYKe0SwLijdUc9vVnbPa0hnn8PAEB05kQ+L14+McODHGq2+M8vjLR0llzn7ZaKiOBgFxNiw2t9bT1dLAmobqsguMM9NpTo5NcXJsihNjU5wcm+TE6alc2/HTU4xNptjYXMfm1no2tzbmzqjqq/VPTkqbzhRk2aXSGQaGJ3JnFQdOjGeXT4xxaGiCdF5grKqN5XVFnQ2Lza0NNNdXF6Xeiak0J8Ymzx7kT5894J84ndcevE5Ppub9nFjEWNNQzZqGahpqYgwMTfDGqTOz9lnXVENXSwOXtTXk/pybWxvobKmnJhYtxh9XRGcKUlyxaIRES0N2jqU3zd42lcpwaCh7VvH68fEgNMbY3T/Ej/ceJv87SnN9VfYA2tow50yjnlW1VfP+bndnfCp9zrf4mYN69hv97AP9+FR63s+qihotDTWsaaimpbGaREs9axqqaW2syR38W3LvNTTVxc456xmfSnHg+Myf9+zrsZeOcmJsKrdfxAjOLM4GRVdr9s++qbmOWDRycf8xRJZIZwpSMiZTaQ6eHM+FxetB19SB42McHpn9jbu1sZqulgbWr65l9Ewqr+tmksm8sY581bEILcEBfk1DTe6Anj3QZ9tyB/rGalbVnHuQX04jE9NB19vZ14ETY7w+OMZo3tlIVdToiNefHa/Je61vqtUVYbJkGmiWsjcxlabvZN74RRAaR0+doam2KjjQz3xzr8k7+Ge/xa9prKahOloW4xfuzomxKQ4cH2P/8bFZwXHgxBhnps+GXk0sMqv77bK84GhtLL/xGikOhYLICpHJOEdHz/D6YDYUXx8cy3VN9Z8cZzp99t9zY01sztlFPYmWBtbUV9NUV8Wq2hhV6paqSBpTEFkhIhFjw+o6Nqyu4+1bW2dtS6UzHB4+w/7jp8+eXZwY57mDQ/yfvYfJzPP9r746SlNtFU11MVbVVtFUG6OprirXln2fbz27f3VMoXIp0hlnMpVmcjrDZCqTXU5lgvX02bY526dSGd6+pZUrNzYte00KBZEVIhaN0NlSn33C3pwB/pnxmv6T44xMTHNqIsWpiWlOnQmWz2SXj5+eYv/xsWBbataVYvOprYrMCoqlBEpTXWxZrrjKZJyMO2l3MhlIu5POOJnMTJvntWW3Z87Tnp75rLyfn0pl5j84B8tT6QyT0zMH8PkP4rMP8mf3Ty3w93sh/+2WqxQKInJxamJRtq5dtaR5rGau5JoJjtEzc0IkCI78cDkZjIXMtC900KuJRbJdWTXZQ9HZA3b2W/S5B/VzD/5hMoPaWJSaqgg1sQg1sWj2verscmNNNvwutE9NLEL1edprquZfrqsqzCXMCgURmZeZ0VATo6EmxobVS/95d2diOs3omXnOSuYGypkUBkQjRtSMSMSIWHY9Yjbr/ewyuX3P/ky2/dx9z+4zs/3cz2WefY3q6PkP1LGIrbjBfIWCiBSEmVFfHaO+Osa6ptqwy5FF0iiRiIjklFQomNkNZvaqme0zs7vDrkdEpNKUTCiYWRT4B+A9wJXAbWZ2ZbhViYhUlpIJBeBqYJ+773f3KeB7wM0h1yQiUlFKKRQ2AQfz1g8FbSIiUiSlFAqLYmZ3mlmvmfUODg6GXY6IyIpSSqEwAHTkrbcHbbO4+73u3uPuPW1tbUUrTkSkEpRSKDwLXG5mm82sGvgQ8GjINYmIVJSSmiXVzG4E/g6IAt9w9/++wP6DQN9F/rpW4PhF/mwhqa6lUV1LV6q1qa6luZS6Eu4+b1dLSYVCMZlZ7/mmjg2T6loa1bV0pVqb6lqaQtVVSt1HIiISMoWCiIjkVHIo3Bt2AeehupZGdS1dqdamupamIHVV7JiCiIicq5LPFEREZA6FgoiI5FRcKJjZN8zsmJm9GHYt+cysw8yeNLOXzewlM/t02DUBmFmtmf3GzJ4P6vrLsGvKZ2ZRM9tjZj8Ju5YZZnbAzF4ws+fMrDfsemaYWbOZPWRmvzWzV8zsbSVQ05uCv6eZ1ykz+0zYdQGY2X8O/p9/0cy+a2Yl8aQgM/t0UNNLhfi7qrgxBTN7F3AaeMDdrwq7nhlmtgHY4O67zWwVsAu4xd1fDrkuAxrc/bSZVQH/BHza3Z8Os64ZZvZZoAdocvf3hV0PZEMB6HH3krrhyczuB37l7l8PZg2od/fhsOuaEUyfPwD8S3e/2JtSl6uWTWT/X7/S3SfM7EHgp+7+rZDruorsDNJXA1PAz4CPufu+5fodFXem4O6/BE6GXcdc7n7E3XcHy6PAK5TALLGedTpYrQpeJfFNwszagfcCXw+7llJnZquBdwH3Abj7VCkFQmAn8FrYgZAnBtSZWQyoBw6HXA/Am4Fn3H3c3VPAL4D3L+cvqLhQKAdm1gXsAJ4Jt5KsoIvmOeAY8Li7l0RdZKdE+RyQCbuQORx4zMx2mdmdYRcT2AwMAt8Mutu+bmYNYRc1x4eA74ZdBIC7DwB/A/QDR4ARd38s3KoAeBH4QzNrMbN64EZmTyR6yRQKJcbMGoGHgc+4+6mw6wFw97S7byc7c+3VwSlsqMzsfcAxd98Vdi3zeKe7d5N9iuBdQZdl2GJAN3CPu+8AxoCSeeRt0J11E/CDsGsBMLM42Yd8bQY2Ag1m9uFwqwJ3fwX4IvAY2a6j54D0cv4OhUIJCfrsHwa+7e4/DLueuYLuhieBG8KuBXgHcFPQf/894Hoz+9/hlpQVfMvE3Y8BPyLb/xu2Q8ChvLO8h8iGRKl4D7Db3Y+GXUjgXwGvu/ugu08DPwTeHnJNALj7fe6edPd3AUPA75bz8xUKJSIY0L0PeMXd/zbsemaYWZuZNQfLdcC7gd+GWxW4++fdvd3du8h2O/zc3UP/JmdmDcGFAgTdM39M9pQ/VO7+BnDQzN4UNO0EQr2IYY7bKJGuo0A/cI2Z1Qf/NneSHecLnZmtDd47yY4nfGc5Pz+2nB9WDszsu8C1QKuZHQL+wt3vC7cqIPvN98+AF4L+e4AvuPtPQ6wJYANwf3BlSAR40N1L5vLPErQO+FH2OEIM+I67/yzcknI+BXw76KrZD3w05HqAXHi+G/iPYdcyw92fMbOHgN1ACthD6Ux38bCZtQDTwF3LfcFAxV2SKiIi56fuIxERyVEoiIhIjkJBRERyFAoiIpKjUBARkRyFgsgyMrPTecs3mtnvzCwRZk0iS1Fx9ymIFIOZ7QS+DPzrEprgTWRBCgWRZRbMdfQ14EZ3fy3sekSWQjeviSwjM5sGRoFr3X1v2PWILJXGFESW1zTwz8AdYRcicjEUCiLLKwPcSnaK8S+EXYzIUmlMQWSZufu4mb0X+JWZHS2RCRdFFkWhIFIA7n7SzG4Afmlmg+7+aNg1iSyGBppFRCRHYwoiIpKjUBARkRyFgoiI5CgUREQkR6EgIiI5CgUREclRKIiISM7/BzdSFx7LBl6XAAAAAElFTkSuQmCC\n", 406 | "text/plain": [ 407 | "
" 408 | ] 409 | }, 410 | "metadata": { 411 | "needs_background": "light" 412 | }, 413 | "output_type": "display_data" 414 | } 415 | ], 416 | "source": [ 417 | "# Create an \"Elbow chart\" to find the \"best\" `k`\n", 418 | "# See: https://en.wikipedia.org/wiki/Elbow_method_(clustering)\n", 419 | "\n", 420 | "# Lists to record the `k` values and the computed `error` sums\n", 421 | "# which are used for plotting later on\n", 422 | "ks: List[int] = []\n", 423 | "error_sums: List[float] = []\n", 424 | "\n", 425 | "# Create clusterings for the range of `k` values\n", 426 | "for k in range(1, 10):\n", 427 | " # Create and train a new KMeans instance for the current `k`\n", 428 | " km: KMeans = KMeans(k)\n", 429 | " km.train(data_points)\n", 430 | " # List to keep track of the individual KMean errors\n", 431 | " errors: List[float] = []\n", 432 | " # Iterate over all `clusters` and extract their `centroid_idx`s and `items`\n", 433 | " centroid_idx: List[float]\n", 434 | " items: List[List[float]]\n", 435 | " for centroid_idx, items in km.clusters.items():\n", 436 | " # Lookup `centroid` coordinates based on its index\n", 437 | " centroid: List[float] = km.centroids[centroid_idx]\n", 438 | " # Iterate over each `item` in the cluster\n", 439 | " item: List[float]\n", 440 | " for item in items:\n", 441 | " # Calculate how far the current `cluster`s `item` is from the `centroid`\n", 442 | " dist: float = distance(centroid, item)\n", 443 | " # The closer the `item` in question, the better (less error)\n", 444 | " # (the closest one can be is `0`)\n", 445 | " error: float = squared_error(dist, 0)\n", 446 | " # Record the `error` value\n", 447 | " errors.append(error)\n", 448 | " # Append the current `k` and the sum of all `errors`\n", 449 | " ks.append(k)\n", 450 | " error_sums.append(sum(errors))\n", 451 | "\n", 452 | "# Plot the `k` and error values to see which `k` is \"best\"\n", 453 | "plt.plot(ks, error_sums)\n", 454 | "plt.xlabel('K')\n", 455 | "plt.ylabel('Error');" 456 | ] 457 | } 458 | ], 459 | "metadata": { 460 | "kernelspec": { 461 | "display_name": "Python 3", 462 | "language": "python", 463 | "name": "python3" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.6.9" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 4 480 | } 481 | -------------------------------------------------------------------------------- /x-from-scratch/naive-bayes-from-scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Naive Bayes from scratch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import re\n", 17 | "import glob\n", 18 | "from pathlib import Path\n", 19 | "from random import shuffle\n", 20 | "from math import exp, log\n", 21 | "from collections import defaultdict, Counter\n", 22 | "from typing import NamedTuple, List, Set, Tuple" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Ensure that we have a `data` directory we use to store downloaded data\n", 32 | "!mkdir -p data\n", 33 | "data_dir: Path = Path('data')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "--2020-02-09 12:03:06-- http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz\n", 46 | "Resolving nlp.cs.aueb.gr (nlp.cs.aueb.gr)... 195.251.248.252\n", 47 | "Connecting to nlp.cs.aueb.gr (nlp.cs.aueb.gr)|195.251.248.252|:80... connected.\n", 48 | "HTTP request sent, awaiting response... 200 OK\n", 49 | "Length: 1802573 (1.7M) [application/x-gzip]\n", 50 | "Saving to: ‘data/enron1.tar.gz’\n", 51 | "\n", 52 | "enron1.tar.gz 100%[===================>] 1.72M 920KB/s in 1.9s \n", 53 | "\n", 54 | "2020-02-09 12:03:08 (920 KB/s) - ‘data/enron1.tar.gz’ saved [1802573/1802573]\n", 55 | "\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "# We're using the \"Enron Spam\" data set\n", 61 | "!wget -nc -P data http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "!tar -xzf data/enron1.tar.gz -C data" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# The data set has 2 directories: One for `spam` messages, one for `ham` messages\n", 80 | "spam_data_path: Path = data_dir / 'enron1' / 'spam'\n", 81 | "ham_data_path: Path = data_dir / 'enron1' / 'ham'" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 6, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# Our data container for `spam` and `ham` messages\n", 91 | "class Message(NamedTuple):\n", 92 | " text: str\n", 93 | " is_spam: bool" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "['data/enron1/spam/4743.2005-06-25.GP.spam.txt',\n", 105 | " 'data/enron1/spam/1309.2004-06-08.GP.spam.txt',\n", 106 | " 'data/enron1/spam/0726.2004-03-26.GP.spam.txt',\n", 107 | " 'data/enron1/spam/0202.2004-01-13.GP.spam.txt',\n", 108 | " 'data/enron1/spam/3988.2005-03-06.GP.spam.txt']" 109 | ] 110 | }, 111 | "execution_count": 7, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "# Globbing for all the `.txt` files in both (`spam` and `ham`) directories\n", 118 | "spam_message_paths: List[str] = glob.glob(str(spam_data_path / '*.txt'))\n", 119 | "ham_message_paths: List[str] = glob.glob(str(ham_data_path / '*.txt'))\n", 120 | "\n", 121 | "message_paths: List[str] = spam_message_paths + ham_message_paths\n", 122 | "message_paths[:5]" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 8, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# The list which eventually contains all the parsed Enron `spam` and `ham` messages\n", 132 | "messages: List[Message] = []" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 9, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# Open every file individually, turn it into a `Message` and append it to our `messages` list\n", 142 | "for path in message_paths:\n", 143 | " with open(path, errors='ignore') as file:\n", 144 | " is_spam: bool = True if 'spam' in path else False\n", 145 | " # We're only interested in the subject for the time being \n", 146 | " text: str = file.readline().replace('Subject:', '').strip()\n", 147 | " messages.append(Message(text, is_spam))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 10, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "[Message(text='january production estimate', is_spam=False),\n", 159 | " Message(text='re : your code # 5 g 6878', is_spam=True),\n", 160 | " Message(text='account # 20367 s tue , 28 jun 2005 11 : 41 : 41 - 0800', is_spam=True),\n", 161 | " Message(text='congratulations', is_spam=True),\n", 162 | " Message(text='fw : hpl imbalance payback', is_spam=False)]" 163 | ] 164 | }, 165 | "execution_count": 10, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "shuffle(messages)\n", 172 | "messages[:5]" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 11, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "5172" 184 | ] 185 | }, 186 | "execution_count": 11, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "len(messages)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 12, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# Given a string, normalize and extract all words with length greater than 2\n", 202 | "def tokenize(text: str) -> Set[str]:\n", 203 | " words: List[str] = []\n", 204 | " for word in re.findall(r'[A-Za-z0-9\\']+', text):\n", 205 | " if len(word) >= 2:\n", 206 | " words.append(word.lower())\n", 207 | " return set(words)\n", 208 | "\n", 209 | "assert tokenize('Is this a text? If so, Tokenize this text!...') == {'is', 'this', 'text', 'if', 'so', 'tokenize'}" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 13, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "{'estimate', 'january', 'production'}" 221 | ] 222 | }, 223 | "execution_count": 13, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "tokenize(messages[0].text)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 14, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# Split the list of messages into a `train` and `test` set (defaults to 80/20 train/test split)\n", 239 | "def train_test_split(messages: List[Message], pct=0.8) -> Tuple[List[Message], List[Message]]:\n", 240 | " shuffle(messages)\n", 241 | " num_train = int(round(len(messages) * pct, 0))\n", 242 | " return messages[:num_train], messages[num_train:]\n", 243 | "\n", 244 | "assert len(train_test_split(messages)[0]) + len(train_test_split(messages)[1]) == len(messages)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 15, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "# The Naive Bayes classifier\n", 254 | "class NaiveBayes:\n", 255 | " def __init__(self, k=1) -> None:\n", 256 | " # `k` is the smoothening factor\n", 257 | " self._k: int = k\n", 258 | " self._num_spam_messages: int = 0\n", 259 | " self._num_ham_messages: int = 0\n", 260 | " self._num_word_in_spam: Dict[int] = defaultdict(int)\n", 261 | " self._num_word_in_ham: Dict[int] = defaultdict(int)\n", 262 | " self._spam_words: Set[str] = set()\n", 263 | " self._ham_words: Set[str] = set()\n", 264 | " self._words: Set[str] = set()\n", 265 | "\n", 266 | " # Iterate through the given messages and gather the necessary statistics\n", 267 | " def train(self, messages: List[Message]) -> None:\n", 268 | " msg: Message\n", 269 | " token: str\n", 270 | " for msg in messages:\n", 271 | " tokens: Set[str] = tokenize(msg.text)\n", 272 | " self._words.update(tokens)\n", 273 | " if msg.is_spam:\n", 274 | " self._num_spam_messages += 1\n", 275 | " self._spam_words.update(tokens)\n", 276 | " for token in tokens:\n", 277 | " self._num_word_in_spam[token] += 1\n", 278 | " else:\n", 279 | " self._num_ham_messages += 1\n", 280 | " self._ham_words.update(tokens)\n", 281 | " for token in tokens:\n", 282 | " self._num_word_in_ham[token] += 1 \n", 283 | " \n", 284 | " # Probability of `word` being spam\n", 285 | " def _p_word_spam(self, word: str) -> float:\n", 286 | " return (self._k + self._num_word_in_spam[word]) / ((2 * self._k) + self._num_spam_messages)\n", 287 | " \n", 288 | " # Probability of `word` being ham\n", 289 | " def _p_word_ham(self, word: str) -> float:\n", 290 | " return (self._k + self._num_word_in_ham[word]) / ((2 * self._k) + self._num_ham_messages)\n", 291 | " \n", 292 | " # Given a `text`, how likely is it spam?\n", 293 | " def predict(self, text: str) -> float:\n", 294 | " text_words: Set[str] = tokenize(text)\n", 295 | " log_p_spam: float = 0.0\n", 296 | " log_p_ham: float = 0.0\n", 297 | "\n", 298 | " for word in self._words:\n", 299 | " p_spam: float = self._p_word_spam(word)\n", 300 | " p_ham: float = self._p_word_ham(word)\n", 301 | " if word in text_words:\n", 302 | " log_p_spam += log(p_spam)\n", 303 | " log_p_ham += log(p_ham)\n", 304 | " else:\n", 305 | " log_p_spam += log(1 - p_spam)\n", 306 | " log_p_ham += log(1 - p_ham)\n", 307 | "\n", 308 | " p_if_spam: float = exp(log_p_spam)\n", 309 | " p_if_ham: float = exp(log_p_ham)\n", 310 | " return p_if_spam / (p_if_spam + p_if_ham)\n", 311 | "\n", 312 | "# Tests\n", 313 | "def test_naive_bayes():\n", 314 | " messages: List[Message] = [\n", 315 | " Message('Spam message', is_spam=True),\n", 316 | " Message('Ham message', is_spam=False),\n", 317 | " Message('Ham message about Spam', is_spam=False)]\n", 318 | " \n", 319 | " nb: NaiveBayes = NaiveBayes()\n", 320 | " nb.train(messages)\n", 321 | " \n", 322 | " assert nb._num_spam_messages == 1\n", 323 | " assert nb._num_ham_messages == 2\n", 324 | " assert nb._spam_words == {'spam', 'message'}\n", 325 | " assert nb._ham_words == {'ham', 'message', 'about', 'spam'}\n", 326 | " assert nb._num_word_in_spam == {'spam': 1, 'message': 1}\n", 327 | " assert nb._num_word_in_ham == {'ham': 2, 'message': 2, 'about': 1, 'spam': 1}\n", 328 | " assert nb._words == {'spam', 'message', 'ham', 'about'}\n", 329 | "\n", 330 | " # Our test message\n", 331 | " text: str = 'A spam message'\n", 332 | " \n", 333 | " # Reminder: The `_words` we iterater over are: {'spam', 'message', 'ham', 'about'}\n", 334 | " \n", 335 | " # Calculate how spammy the `text` might be\n", 336 | " p_if_spam: float = exp(sum([\n", 337 | " log( (1 + 1) / ((2 * 1) + 1)), # `spam` (also in `text`)\n", 338 | " log( (1 + 1) / ((2 * 1) + 1)), # `message` (also in `text`)\n", 339 | " log(1 - ((1 + 0) / ((2 * 1) + 1))), # `ham` (NOT in `text`)\n", 340 | " log(1 - ((1 + 0) / ((2 * 1) + 1))), # `about` (NOT in `text`)\n", 341 | " ]))\n", 342 | " \n", 343 | " # Calculate how hammy the `text` might be\n", 344 | " p_if_ham: float = exp(sum([\n", 345 | " log( (1 + 1) / ((2 * 1) + 2)), # `spam` (also in `text`)\n", 346 | " log( (1 + 2) / ((2 * 1) + 2)), # `message` (also in `text`)\n", 347 | " log(1 - ((1 + 2) / ((2 * 1) + 2))), # `ham` (NOT in `text`)\n", 348 | " log(1 - ((1 + 1) / ((2 * 1) + 2))), # `about` (NOT in `text`)\n", 349 | " ]))\n", 350 | " \n", 351 | " p_spam: float = p_if_spam / (p_if_spam + p_if_ham)\n", 352 | " \n", 353 | " assert p_spam == nb.predict(text)\n", 354 | "\n", 355 | "test_naive_bayes()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 16, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "train: List[Message]\n", 365 | "test: List[Message]\n", 366 | "\n", 367 | "# Splitting our Enron messages into a `train` and `test` set\n", 368 | "train, test = train_test_split(messages)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 17, 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "Spam messages in training data: 1227\n", 381 | "Ham messages in training data: 2911\n", 382 | "Most spammy words: [('you', 115), ('the', 104), ('your', 104), ('for', 86), ('to', 83), ('re', 81), ('on', 56), ('and', 51), ('get', 48), ('is', 48), ('in', 43), ('with', 40), ('of', 38), ('it', 35), ('at', 35), ('online', 34), ('all', 33), ('from', 33), ('this', 32), ('new', 31)]\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "# Train our Naive Bayes classifier with the `train` set\n", 388 | "nb: NaiveBayes = NaiveBayes()\n", 389 | "nb.train(train)\n", 390 | "\n", 391 | "print(f'Spam messages in training data: {nb._num_spam_messages}')\n", 392 | "print(f'Ham messages in training data: {nb._num_ham_messages}')\n", 393 | "print(f'Most spammy words: {Counter(nb._num_word_in_spam).most_common(20)}')" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 18, 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "data": { 403 | "text/plain": [ 404 | "[Message(text=\"a witch . i don ' t\", is_spam=True),\n", 405 | " Message(text='active and strong', is_spam=True),\n", 406 | " Message(text='get great prices on medications', is_spam=True),\n", 407 | " Message(text='', is_spam=True),\n", 408 | " Message(text='popular software at low low prices . misunderstand developments', is_spam=True)]" 409 | ] 410 | }, 411 | "execution_count": 18, 412 | "metadata": {}, 413 | "output_type": "execute_result" 414 | } 415 | ], 416 | "source": [ 417 | "# Grabbing all the spam messages from our `test` set\n", 418 | "spam_messages: List[Message] = [item for item in test if item.is_spam]\n", 419 | "spam_messages[:5]" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 19, 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "name": "stdout", 429 | "output_type": "stream", 430 | "text": [ 431 | "Predicting likelihood of \"get your hand clock repliacs todday carson\" being spam.\n" 432 | ] 433 | }, 434 | { 435 | "data": { 436 | "text/plain": [ 437 | "0.9884313222593173" 438 | ] 439 | }, 440 | "execution_count": 19, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | } 444 | ], 445 | "source": [ 446 | "# Using our trained Naive Bayes classifier to classify a spam message\n", 447 | "message: str = spam_messages[10].text\n", 448 | " \n", 449 | "print(f'Predicting likelihood of \"{message}\" being spam.')\n", 450 | "nb.predict(message)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 20, 456 | "metadata": {}, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "[Message(text='new update for buybacks', is_spam=False),\n", 462 | " Message(text='enron and blockbuster to launch entertainment on - demand service', is_spam=False),\n", 463 | " Message(text='re : astros web site comments', is_spam=False),\n", 464 | " Message(text='re : formosa meter # : 1000', is_spam=False),\n", 465 | " Message(text='re : deal extension for 11 / 21 / 2000 for 98 - 439', is_spam=False)]" 466 | ] 467 | }, 468 | "execution_count": 20, 469 | "metadata": {}, 470 | "output_type": "execute_result" 471 | } 472 | ], 473 | "source": [ 474 | "# Grabbing all the ham messages from our `test` set\n", 475 | "ham_messages: List[Message] = [item for item in test if not item.is_spam]\n", 476 | "ham_messages[:5]" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 21, 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "Predicting likelihood of \"associate & analyst mid - year 2001 prc process\" being spam.\n" 489 | ] 490 | }, 491 | { 492 | "data": { 493 | "text/plain": [ 494 | "5.3089147140900964e-05" 495 | ] 496 | }, 497 | "execution_count": 21, 498 | "metadata": {}, 499 | "output_type": "execute_result" 500 | } 501 | ], 502 | "source": [ 503 | "# Using our trained Naive Bayes classifier to classify a ham message\n", 504 | "message: str = ham_messages[10].text\n", 505 | "\n", 506 | "print(f'Predicting likelihood of \"{text}\" being spam.')\n", 507 | "nb.predict(message)" 508 | ] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "Python 3", 514 | "language": "python", 515 | "name": "python3" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.6.9" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 2 532 | } 533 | --------------------------------------------------------------------------------