├── .gitignore ├── LICENSE ├── README.md ├── gnns ├── gat.ipynb └── rgcn.ipynb └── linear-regression ├── data └── insurance.csv ├── ex-insurance.ipynb ├── learning.gif ├── theory.ipynb └── todo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | MANIFEST 2 | build 3 | dist 4 | _build 5 | docs/man/*.gz 6 | docs/source/api/generated 7 | docs/source/config.rst 8 | docs/gh-pages 9 | notebook/i18n/*/LC_MESSAGES/*.mo 10 | notebook/i18n/*/LC_MESSAGES/nbjs.json 11 | notebook/static/components 12 | notebook/static/style/*.min.css* 13 | notebook/static/*/js/built/ 14 | notebook/static/*/built/ 15 | notebook/static/built/ 16 | notebook/static/*/js/main.min.js* 17 | notebook/static/lab/*bundle.js 18 | node_modules 19 | *.py[co] 20 | __pycache__ 21 | *.egg-info 22 | *~ 23 | *.bak 24 | .ipynb_checkpoints 25 | .tox 26 | .DS_Store 27 | \#*# 28 | .#* 29 | .coverage 30 | .pytest_cache 31 | src 32 | 33 | *.swp 34 | *.map 35 | .idea/ 36 | Read the Docs 37 | config.rst 38 | *.iml 39 | /.project 40 | /.pydevproject 41 | 42 | package-lock.json 43 | geckodriver.log 44 | *.iml 45 | 46 | # jetbrains IDE stuff 47 | *.iml 48 | .idea/ 49 | 50 | # ms IDE stuff 51 | *.code-workspace 52 | .history 53 | .vscode 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Giuseppe Futia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # notebooks 2 | Machine learning and statistical algorithms implemented in pure NumPy and well-known libraries such as Sklearn, Pytorch, and Statsmodels. 3 | 4 | ## Linear Regression 5 | 1. [Implementation](/linear-regression/theory.ipynb) from scratch. 6 | 7 | ## Graph Neural Networks 8 | 1. [Implementation](/gnns/rgcn.ipynb) from scratch of Relational Graph Convolutional (R-GCN) layer. 9 | 2. [Implementation](/gnns/gat.ipynb) with NumPy and DGL of the Graph Attention (GAT) layer. -------------------------------------------------------------------------------- /gnns/gat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.6.12-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3.6.12 64-bit ('learning': conda)", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "566c0a97317f6f88d4bc5f478002f1c75c862f0281a52c0ded6c5ead36971532" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "source": [ 31 | "# Graph Attention Networks (GATs)\n", 32 | "Original paper: Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017). [Graph attention networks](https://arxiv.org/abs/1710.10903). *arXiv preprint arXiv:1710.10903*. \n", 33 | "\n", 34 | "## Math Warm-up\n", 35 | "This section comes from https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html, with minor extensions.\n", 36 | "\n", 37 | "### GCN Layer\n", 38 | "$$h_i^{(l+1)}=\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)} {\\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\\right)$$\n", 39 | "\n", 40 | "* $\\mathcal{N}(i)$: set of the one-hop neighbors (no self-loop) of $n_i$.\n", 41 | "* $c_{ij}=\\sqrt{|\\mathcal{N}(i)|}\\sqrt{|\\mathcal{N}(j)|}$: normalization costant based on the graph structure.\n", 42 | "* $\\sigma$: activation function.\n", 43 | "* $W^{(l)}$: weight matrix for feature transformation.\n", 44 | "\n", 45 | "A broad explanation on Graph Neural Networks (GNNs) is available in the Medium article entitled \"[Understanding the Building Blocks of Graph Neural Networks](https://towardsdatascience.com/understanding-the-building-blocks-of-graph-neural-networks-intro-56627f0719d5)\".\n", 46 | "\n", 47 | "### GAT Layer\n", 48 | "\n", 49 | "\\begin{split}\\begin{align}\n", 50 | "z_i^{(l)}&=W^{(l)}h_i^{(l)}, \\\\\n", 51 | "e_{ij}^{(l)}&=\\text{LeakyReLU}(\\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),\\\\\n", 52 | "\\alpha_{ij}^{(l)}&=\\frac{\\exp(e_{ij}^{(l)})}{\\sum_{k\\in \\mathcal{N}(i)}^{}\\exp(e_{ik}^{(l)})},\\\\\n", 53 | "h_i^{(l+1)}&=\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)} {\\alpha^{(l)}_{ij} z^{(l)}_j }\\right),\n", 54 | "\\end{align}\\end{split}\n", 55 | "\n", 56 | "* Equation (1) is a linear transformation of the lower layer embedding $h_i^{(l)}$ and $W^{(l)}$ is its learnable weight matrix. This transformation is useful to achieve a sufficient expressive power to transform input features (in our example one-hot vectors) into high-level and dense features.\n", 57 | "* Equation (2) computes a pair-wise *un-normalized* attention score between two neighbors. Here, it first concatenates the $z$ embeddings of the two nodes, where $||$ denotes concatenation, then takes a dot product of it and a learnable weight vector $\\vec a^{(l)}$, and applies a LeakyReLU in the end. This form of attention is usually called additive attention, contrast with the dot-product attention in the Transformer model. The attention score indicates the importance of a neighbor node in the message passing framework.\n", 58 | "* Equation (3) applies a softmax to normalize the attention scores on each node’s incoming edges.\n", 59 | "* Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together, scaled by the attention scores.\n", 60 | "\n", 61 | "## Imports" 62 | ], 63 | "cell_type": "markdown", 64 | "metadata": {} 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 1, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "output_type": "stream", 73 | "name": "stderr", 74 | "text": [ 75 | "Using backend: pytorch\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "import numpy as np\n", 81 | "import torch\n", 82 | "import torch.nn as nn\n", 83 | "import torch.nn.functional as F\n", 84 | "import dgl\n", 85 | "\n", 86 | "np.random.seed(1)\n" 87 | ] 88 | }, 89 | { 90 | "source": [ 91 | "## GAT Layer Implementation with NumPy\n", 92 | "### Basic Functions" 93 | ], 94 | "cell_type": "markdown", 95 | "metadata": {} 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 2, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def leaky_relu(z):\n", 104 | " return np.where(z > 0, z, z * 0.01)\n", 105 | "\n", 106 | "def softmax(z):\n", 107 | " if len(z.shape) > 1:\n", 108 | " # Softmax for matrix\n", 109 | " max_matrix = np.max(z, axis=0)\n", 110 | " stable_z = z - max_matrix\n", 111 | " e = np.exp(stable_z)\n", 112 | " a = e / np.sum(e, axis=0, keepdims=True)\n", 113 | " else:\n", 114 | " # Softmax for vector\n", 115 | " vector_max_value = np.max(z)\n", 116 | " a = (np.exp(z - vector_max_value)) / sum(np.exp(z - vector_max_value))\n", 117 | "\n", 118 | " assert a.shape == z.shape\n", 119 | "\n", 120 | " return a\n" 121 | ] 122 | }, 123 | { 124 | "source": [ 125 | "### Graph and Weight Matrix Generation" 126 | ], 127 | "cell_type": "markdown", 128 | "metadata": {} 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 3, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "output_type": "stream", 137 | "name": "stdout", 138 | "text": [ 139 | "\n\n----- One-hot vector representation of nodes. Shape(n,n)\n\n[[0. 0. 1. 0. 0.]\n [0. 1. 0. 0. 0.]\n [0. 0. 0. 0. 1.]\n [1. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0.]]\n\n\n----- Embedding dimension\n\n3\n\n\n----- Weight Matrix. Shape(emb, n)\n\n[[-0.4294049 0.57624235 -0.3047382 -0.11941829 -0.12942953]\n [ 0.19600584 0.5029172 0.3998854 -0.21561317 0.02834577]\n [-0.06529497 -0.31225734 0.03973776 0.47800217 -0.04941563]]\n\n\n----- Adjacency Matrix (undirected graph). Shape(n,n)\n\n[[1 1 1 0 1]\n [1 1 1 1 1]\n [1 1 1 1 0]\n [0 1 1 1 1]\n [1 1 0 1 1]]\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "print('\\n\\n----- One-hot vector representation of nodes. Shape(n,n)\\n')\n", 145 | "X = np.eye(5, 5)\n", 146 | "n = X.shape[0]\n", 147 | "np.random.shuffle(X)\n", 148 | "print(X)\n", 149 | "\n", 150 | "print('\\n\\n----- Embedding dimension\\n')\n", 151 | "emb = 3\n", 152 | "print(emb)\n", 153 | "\n", 154 | "print('\\n\\n----- Weight Matrix. Shape(emb, n)\\n')\n", 155 | "W = np.random.uniform(-np.sqrt(1. / emb), np.sqrt(1. / emb), (emb, n))\n", 156 | "print(W)\n", 157 | "\n", 158 | "print('\\n\\n----- Adjacency Matrix (undirected graph). Shape(n,n)\\n')\n", 159 | "A = np.random.randint(2, size=(n, n))\n", 160 | "np.fill_diagonal(A, 1) \n", 161 | "A = (A + A.T)\n", 162 | "A[A > 1] = 1\n", 163 | "print(A)" 164 | ] 165 | }, 166 | { 167 | "source": [ 168 | "### Linear Transformation" 169 | ], 170 | "cell_type": "markdown", 171 | "metadata": {} 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 4, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "output_type": "stream", 180 | "name": "stdout", 181 | "text": [ 182 | "\n\n----- Linear Transformation. Shape(n, emb)\n\n[[-0.3047382 0.3998854 0.03973776]\n [ 0.57624235 0.5029172 -0.31225734]\n [-0.12942953 0.02834577 -0.04941563]\n [-0.4294049 0.19600584 -0.06529497]\n [-0.11941829 -0.21561317 0.47800217]]\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "# equation (1)\n", 188 | "print('\\n\\n----- Linear Transformation. Shape(n, emb)\\n')\n", 189 | "z1 = X.dot(W.T)\n", 190 | "print(z1)" 191 | ] 192 | }, 193 | { 194 | "source": [ 195 | "### Transformer: Additive Attention Mechanism" 196 | ], 197 | "cell_type": "markdown", 198 | "metadata": {} 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 5, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "output_type": "stream", 207 | "name": "stdout", 208 | "text": [ 209 | "\n\n----- Concat hidden features to represent edges. Shape(len(emb.concat(emb)), number of edges)\n\n[[-0.3047382 0.3998854 0.03973776 -0.3047382 0.3998854 0.03973776]\n [-0.3047382 0.3998854 0.03973776 0.57624235 0.5029172 -0.31225734]\n [-0.3047382 0.3998854 0.03973776 -0.12942953 0.02834577 -0.04941563]\n [-0.3047382 0.3998854 0.03973776 -0.11941829 -0.21561317 0.47800217]\n [ 0.57624235 0.5029172 -0.31225734 -0.3047382 0.3998854 0.03973776]\n [ 0.57624235 0.5029172 -0.31225734 0.57624235 0.5029172 -0.31225734]\n [ 0.57624235 0.5029172 -0.31225734 -0.12942953 0.02834577 -0.04941563]\n [ 0.57624235 0.5029172 -0.31225734 -0.4294049 0.19600584 -0.06529497]\n [ 0.57624235 0.5029172 -0.31225734 -0.11941829 -0.21561317 0.47800217]\n [-0.12942953 0.02834577 -0.04941563 -0.3047382 0.3998854 0.03973776]\n [-0.12942953 0.02834577 -0.04941563 0.57624235 0.5029172 -0.31225734]\n [-0.12942953 0.02834577 -0.04941563 -0.12942953 0.02834577 -0.04941563]\n [-0.12942953 0.02834577 -0.04941563 -0.4294049 0.19600584 -0.06529497]\n [-0.4294049 0.19600584 -0.06529497 0.57624235 0.5029172 -0.31225734]\n [-0.4294049 0.19600584 -0.06529497 -0.12942953 0.02834577 -0.04941563]\n [-0.4294049 0.19600584 -0.06529497 -0.4294049 0.19600584 -0.06529497]\n [-0.4294049 0.19600584 -0.06529497 -0.11941829 -0.21561317 0.47800217]\n [-0.11941829 -0.21561317 0.47800217 -0.3047382 0.3998854 0.03973776]\n [-0.11941829 -0.21561317 0.47800217 0.57624235 0.5029172 -0.31225734]\n [-0.11941829 -0.21561317 0.47800217 -0.4294049 0.19600584 -0.06529497]\n [-0.11941829 -0.21561317 0.47800217 -0.11941829 -0.21561317 0.47800217]]\n\n\n----- Attention coefficients. Shape(1, len(emb.concat(emb)))\n\n[[0.09834683 0.42110763 0.95788953 0.53316528 0.69187711 0.31551563]]\n\n\n----- Edge representations combined with the attention coefficients. Shape(1, number of edges)\n\n[[ 0.30322275]\n [ 0.73315639]\n [ 0.11150219]\n [ 0.11445879]\n [ 0.09607946]\n [ 0.52601309]\n [-0.0956411 ]\n [-0.14458757]\n [-0.0926845 ]\n [ 0.07860653]\n [ 0.50854017]\n [-0.11311402]\n [-0.16206049]\n [ 0.53443082]\n [-0.08722337]\n [-0.13616985]\n [-0.08426678]\n [ 0.48206613]\n [ 0.91199976]\n [ 0.2413991 ]\n [ 0.29330217]]\n\n\n----- Leaky Relu. Shape(1, number of edges)\n[[ 3.03222751e-01]\n [ 7.33156386e-01]\n [ 1.11502195e-01]\n [ 1.14458791e-01]\n [ 9.60794571e-02]\n [ 5.26013092e-01]\n [-9.56410988e-04]\n [-1.44587571e-03]\n [-9.26845030e-04]\n [ 7.86065337e-02]\n [ 5.08540169e-01]\n [-1.13114022e-03]\n [-1.62060495e-03]\n [ 5.34430817e-01]\n [-8.72233739e-04]\n [-1.36169846e-03]\n [-8.42667781e-04]\n [ 4.82066128e-01]\n [ 9.11999763e-01]\n [ 2.41399100e-01]\n [ 2.93302168e-01]]\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "# equation (2)\n", 215 | "print('\\n\\n----- Concat hidden features to represent edges. Shape(len(emb.concat(emb)), number of edges)\\n')\n", 216 | "edge_coords = np.where(A==1)\n", 217 | "h_src_nodes = z1[edge_coords[0]]\n", 218 | "h_dst_nodes = z1[edge_coords[1]]\n", 219 | "z2 = np.concatenate((h_src_nodes, h_dst_nodes), axis=1)\n", 220 | "\n", 221 | "# Concatenation tests\n", 222 | "assert len(edge_coords[1]) == z2.shape[0], \"The number of edges in A is not equal to the number of concat edges\"\n", 223 | "test_value = np.array([-0.11941829, -0.12942953, 0.19600584, 0.5029172, 0.3998854, -0.21561317])\n", 224 | "assert z2[4 ,:].tolist().sort() == test_value.tolist().sort(), \"Something went wrong in the concat process\"\n", 225 | "print(z2)\n", 226 | "\n", 227 | "print('\\n\\n----- Attention coefficients. Shape(1, len(emb.concat(emb)))\\n')\n", 228 | "att = np.random.rand(1, z2.shape[1])\n", 229 | "print(att)\n", 230 | "\n", 231 | "print('\\n\\n----- Edge representations combined with the attention coefficients. Shape(1, number of edges)\\n')\n", 232 | "z2_att = z2.dot(att.T)\n", 233 | "print(z2_att)\n", 234 | "\n", 235 | "print('\\n\\n----- Leaky Relu. Shape(1, number of edges)')\n", 236 | "e = leaky_relu(z2_att)\n", 237 | "print(e)" 238 | ] 239 | }, 240 | { 241 | "source": [ 242 | "### Normalize the Attention Scores" 243 | ], 244 | "cell_type": "markdown", 245 | "metadata": {} 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 6, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "output_type": "stream", 254 | "name": "stdout", 255 | "text": [ 256 | "\n\n----- Edge scores as matrix. Shape(n,n)\n\n[[ 3.03222751e-01 7.33156386e-01 1.11502195e-01 0.00000000e+00\n 1.14458791e-01]\n [ 9.60794571e-02 5.26013092e-01 -9.56410988e-04 -1.44587571e-03\n -9.26845030e-04]\n [ 7.86065337e-02 5.08540169e-01 -1.13114022e-03 -1.62060495e-03\n 0.00000000e+00]\n [ 0.00000000e+00 5.34430817e-01 -8.72233739e-04 -1.36169846e-03\n -8.42667781e-04]\n [ 4.82066128e-01 9.11999763e-01 0.00000000e+00 2.41399100e-01\n 2.93302168e-01]]\n\n\n----- For each node, normalize the edge (or neighbor) contributions using softmax\n\n[0.26263543 0.21349717 0.20979916 0.31406823 0.21610715 0.17567419\n 0.1726313 0.1771592 0.25842816 0.27167844 0.24278118 0.24273876\n 0.24280162 0.23393014 0.23388927 0.23394984 0.29823075 0.25138555\n 0.22399017 0.22400903 0.30061525]\n\n\n----- Normalized edge score matrix. Shape(n,n)\n\n[[0.26263543 0.21349717 0.20979916 0. 0.31406823]\n [0.21610715 0.17567419 0.1726313 0.1771592 0.25842816]\n [0.27167844 0.24278118 0.24273876 0.24280162 0. ]\n [0. 0.23393014 0.23388927 0.23394984 0.29823075]\n [0.25138555 0.22399017 0. 0.22400903 0.30061525]]\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# equation (3)\n", 262 | "print('\\n\\n----- Edge scores as matrix. Shape(n,n)\\n')\n", 263 | "e_matr = np.zeros(A.shape)\n", 264 | "e_matr[edge_coords[0], edge_coords[1]] = e.reshape(-1,)\n", 265 | "print(e_matr)\n", 266 | "\n", 267 | "print('\\n\\n----- For each node, normalize the edge (or neighbor) contributions using softmax\\n')\n", 268 | "alpha0 = softmax(e_matr[:,0][e_matr[:,0] != 0]) \n", 269 | "alpha1 = softmax(e_matr[:,1][e_matr[:,1] != 0])\n", 270 | "alpha2 = softmax(e_matr[:,2][e_matr[:,2] != 0])\n", 271 | "alpha3 = softmax(e_matr[:,3][e_matr[:,3] != 0])\n", 272 | "alpha4 = softmax(e_matr[:,4][e_matr[:,4] != 0])\n", 273 | "alpha = np.concatenate((alpha0, alpha1, alpha2, alpha3, alpha4))\n", 274 | "print(alpha)\n", 275 | "\n", 276 | "print('\\n\\n----- Normalized edge score matrix. Shape(n,n)\\n')\n", 277 | "A_scaled = np.zeros(A.shape)\n", 278 | "A_scaled[edge_coords[0], edge_coords[1]] = alpha.reshape(-1,)\n", 279 | "print(A_scaled)" 280 | ] 281 | }, 282 | { 283 | "source": [ 284 | "### Neighborhood Diffusion (GCN) Scaled by the Attention Scores (GAT)" 285 | ], 286 | "cell_type": "markdown", 287 | "metadata": {} 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 7, 292 | "metadata": {}, 293 | "outputs": [ 294 | { 295 | "output_type": "stream", 296 | "name": "stdout", 297 | "text": [ 298 | "\n\nNeighborhood aggregation (GCN) scaled with attention scores (GAT). Shape(n, emb)\n\n[[-0.02166863 0.15062515 0.08352843]\n [-0.09390287 0.15866476 0.05716299]\n [-0.07856777 0.28521023 -0.09286313]\n [-0.03154513 0.10583032 0.04267501]\n [-0.07962369 0.19226439 0.069115 ]]\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "# equation (4)\n", 304 | "print('\\n\\nNeighborhood aggregation (GCN) scaled with attention scores (GAT). Shape(n, emb)\\n')\n", 305 | "ND_GAT = A_scaled.dot(z1)\n", 306 | "print(ND_GAT)" 307 | ] 308 | }, 309 | { 310 | "source": [ 311 | "## GAT Layer - DGL Test\n", 312 | "Original layer implementation: https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html " 313 | ], 314 | "cell_type": "markdown", 315 | "metadata": {} 316 | }, 317 | { 318 | "source": [ 319 | "class GATTestLayer(nn.Module):\n", 320 | " def __init__(self, g, in_dim, out_dim):\n", 321 | " super(GATTestLayer, self).__init__()\n", 322 | " self.g = g\n", 323 | " # equation (1)\n", 324 | " self.fc = nn.Linear(in_dim, out_dim, bias=False)\n", 325 | " # equation (2)\n", 326 | " self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)\n", 327 | " self.reset_parameters()\n", 328 | "\n", 329 | " def reset_parameters(self):\n", 330 | " \"\"\"Reinizialitation modified for testing\"\"\"\n", 331 | " gain = nn.init.calculate_gain('relu')\n", 332 | " self.fc.state_dict()['weight'][:] = torch.from_numpy(W)\n", 333 | " self.attn_fc.state_dict()['weight'][:] = torch.from_numpy(att)\n", 334 | "\n", 335 | " def edge_attention(self, edges):\n", 336 | " # edge UDF for equation (2)\n", 337 | " z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)\n", 338 | " a = self.attn_fc(z2)\n", 339 | " return {'e': F.leaky_relu(a)}\n", 340 | "\n", 341 | " def message_func(self, edges):\n", 342 | " # message UDF for equation (3) & (4)\n", 343 | " return {'z': edges.src['z'], 'e': edges.data['e']}\n", 344 | "\n", 345 | " def reduce_func(self, nodes):\n", 346 | " # reduce UDF for equation (3) & (4)\n", 347 | " # equation (3)\n", 348 | " alpha = F.softmax(nodes.mailbox['e'], dim=1)\n", 349 | " # equation (4)\n", 350 | " h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n", 351 | " return {'h': h}\n", 352 | "\n", 353 | " def forward(self, h):\n", 354 | " # equation (1)\n", 355 | " z = self.fc(h)\n", 356 | " self.g.ndata['z'] = z\n", 357 | " # equation (2)\n", 358 | " self.g.apply_edges(self.edge_attention)\n", 359 | " # equation (3) & (4)\n", 360 | " self.g.update_all(self.message_func, self.reduce_func)\n", 361 | " return self.g.ndata.pop('h')" 362 | ], 363 | "cell_type": "code", 364 | "metadata": {}, 365 | "execution_count": 8, 366 | "outputs": [] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 9, 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "output_type": "stream", 375 | "name": "stdout", 376 | "text": [ 377 | "\n\n----- Create a new DGL graph using the NumPy graph\n\nGraph(num_nodes=5, num_edges=21,\n ndata_schemes={}\n edata_schemes={})\n\n\n----- Create a DGL instance of the GAT test layer\n\ntensor([[-0.0217, 0.1506, 0.0835],\n [-0.0939, 0.1587, 0.0572],\n [-0.0786, 0.2852, -0.0929],\n [-0.0315, 0.1058, 0.0427],\n [-0.0796, 0.1923, 0.0691]], grad_fn=)\n\n\n----- Recap of the NumPy GAT layer\n[[-0.0217 0.1506 0.0835]\n [-0.0939 0.1587 0.0572]\n [-0.0786 0.2852 -0.0929]\n [-0.0315 0.1058 0.0427]\n [-0.0796 0.1923 0.0691]]\n" 378 | ] 379 | } 380 | ], 381 | "source": [ 382 | "print('\\n\\n----- Create a new DGL graph using the NumPy graph\\n')\n", 383 | "src_ids = torch.tensor(edge_coords[0])\n", 384 | "dst_ids = torch.tensor(edge_coords[1])\n", 385 | "g = dgl.graph((src_ids, dst_ids))\n", 386 | "print(g)\n", 387 | "\n", 388 | "print('\\n\\n----- Create a DGL instance of the GAT test layer\\n')\n", 389 | "net = GATTestLayer(g,\n", 390 | " in_dim=n,\n", 391 | " out_dim=3)\n", 392 | "print(net.forward(torch.Tensor(X)))\n", 393 | "\n", 394 | "print('\\n\\n----- Recap of the NumPy GAT layer')\n", 395 | "print(np.round(ND_GAT, decimals=4))\n", 396 | "\n" 397 | ] 398 | }, 399 | { 400 | "source": [ 401 | "The resulting matrices from the NumPy implementation and the DGL implementation are equal \\o/." 402 | ], 403 | "cell_type": "markdown", 404 | "metadata": {} 405 | }, 406 | { 407 | "source": [ 408 | "## Math Warm-up on Multi-head Attention\n", 409 | "The multi-head attention is useful to enrich the model capability and to stabilize the learning process. The outputs of each attention head can be combined in two different ways:\n", 410 | "\n", 411 | "$\\text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\\sigma\\left(\\sum_{j\\in \\mathcal{N}(i)}\\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\\right)$\n", 412 | "\n", 413 | "or\n", 414 | "\n", 415 | "$\\text{average}: h_{i}^{(l+1)}=\\sigma\\left(\\frac{1}{K}\\sum_{k=1}^{K}\\sum_{j\\in\\mathcal{N}(i)}\\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\\right)$\n", 416 | "\n", 417 | "* K is the number of heads. Concatenation is adopted for intermediary layers. The average is employed for the final (prediction) layer, because the concatenation is no longer sensible.\n", 418 | "\n", 419 | "\n" 420 | ], 421 | "cell_type": "markdown", 422 | "metadata": {} 423 | }, 424 | { 425 | "source": [ 426 | "## Multi Head GAT Layer Implementation with NumPy\n", 427 | "Multiple head attentions are created generating multiple GAT layers." 428 | ], 429 | "cell_type": "markdown", 430 | "metadata": {} 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 10, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "output_type": "stream", 439 | "name": "stdout", 440 | "text": [ 441 | "\n", 442 | "\n", 443 | "----- Recap on the output of the GAT layer\n", 444 | "\n", 445 | "Layer 1. Shape(emb,n)\n", 446 | "[[-0.02166863 0.15062515 0.08352843]\n", 447 | " [-0.09390287 0.15866476 0.05716299]\n", 448 | " [-0.07856777 0.28521023 -0.09286313]\n", 449 | " [-0.03154513 0.10583032 0.04267501]\n", 450 | " [-0.07962369 0.19226439 0.069115 ]]\n", 451 | "\n", 452 | "Layer 2. Shape(emb,n)\n", 453 | "[[-0.02166863 0.15062515 0.08352843]\n", 454 | " [-0.09390287 0.15866476 0.05716299]\n", 455 | " [-0.07856777 0.28521023 -0.09286313]\n", 456 | " [-0.03154513 0.10583032 0.04267501]\n", 457 | " [-0.07962369 0.19226439 0.069115 ]]\n", 458 | "\n", 459 | "\n", 460 | "----- Concatenate multiple attentions. Shape(num_layers*emb, n)\n", 461 | "\n", 462 | "[[-0.02166863 0.15062515 0.08352843 -0.02166863 0.15062515 0.08352843]\n", 463 | " [-0.09390287 0.15866476 0.05716299 -0.09390287 0.15866476 0.05716299]\n", 464 | " [-0.07856777 0.28521023 -0.09286313 -0.07856777 0.28521023 -0.09286313]\n", 465 | " [-0.03154513 0.10583032 0.04267501 -0.03154513 0.10583032 0.04267501]\n", 466 | " [-0.07962369 0.19226439 0.069115 -0.07962369 0.19226439 0.069115 ]]\n", 467 | "\n", 468 | "\n", 469 | "----- Average multiple attentions.\n", 470 | "\n", 471 | "0.04979367027023359\n" 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "print('\\n\\n----- Recap on the output of the GAT layer')\n", 477 | "print('\\nLayer 1. Shape(emb,n)')\n", 478 | "layer1 = ND_GAT\n", 479 | "print(layer1)\n", 480 | "\n", 481 | "print('\\nLayer 2. Shape(emb,n)')\n", 482 | "layer2 = ND_GAT\n", 483 | "print(layer2)\n", 484 | "\n", 485 | "print('\\n\\n----- Concatenate multiple attentions. Shape(num_layers*emb, n)\\n')\n", 486 | "concat = np.concatenate((layer1, layer2), axis=1)\n", 487 | "print(concat)\n", 488 | "\n", 489 | "print('\\n\\n----- Average multiple attentions.\\n')\n", 490 | "# 30 is the number of parameters: num_layers*emb*n\n", 491 | "average = np.sum((layer1, layer2)) / 30\n", 492 | "print(average)" 493 | ] 494 | }, 495 | { 496 | "source": [ 497 | "## Multi Head GAT Layer - DGL Test\n", 498 | "Original layer implementation: https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html " 499 | ], 500 | "cell_type": "markdown", 501 | "metadata": {} 502 | }, 503 | { 504 | "source": [ 505 | "class MultiHeadGATTestLayer(nn.Module):\n", 506 | " def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):\n", 507 | " super(MultiHeadGATTestLayer, self).__init__()\n", 508 | " self.heads = nn.ModuleList()\n", 509 | " for i in range(num_heads):\n", 510 | " # Use the test layer for consistency with the NumPy implementation\n", 511 | " self.heads.append(GATTestLayer(g, in_dim, out_dim))\n", 512 | " self.merge = merge\n", 513 | "\n", 514 | " def forward(self, h):\n", 515 | " head_outs = [attn_head(h) for attn_head in self.heads]\n", 516 | " if self.merge == 'cat':\n", 517 | " # concat on the output feature dimension (dim=1)\n", 518 | " return torch.cat(head_outs, dim=1)\n", 519 | " else:\n", 520 | " # merge using average\n", 521 | " return torch.mean(torch.stack(head_outs))\n", 522 | "\n", 523 | "print('\\n\\n----- Multi head GAT layer (concat operation). Shape(num_layers*emb, n)\\n')\n", 524 | "concat_net = MultiHeadGATTestLayer(g, in_dim=n, out_dim=3, num_heads=2)\n", 525 | "print(concat_net)\n", 526 | "print('\\n----- DGL concat output\\n')\n", 527 | "print(concat_net.forward(torch.Tensor(X)))\n", 528 | "\n", 529 | "print('\\n----- Recap of the NumPy concatenation\\n')\n", 530 | "print(np.round(concat, decimals=4))\n", 531 | "\n", 532 | "print('\\n\\n----- Multi head GAT Layer (average operation). Shape(emb, n)\\n')\n", 533 | "mean_net = MultiHeadGATTestLayer(g, in_dim=n, out_dim=3, num_heads=2, merge='mean')\n", 534 | "print(mean_net)\n", 535 | "print('\\n----- DGL average output\\n')\n", 536 | "print(mean_net.forward(torch.Tensor(X)))\n", 537 | "\n", 538 | "print('\\n----- Recap of the NumPy average\\n')\n", 539 | "print(np.round(average, decimals=4))" 540 | ], 541 | "cell_type": "code", 542 | "metadata": {}, 543 | "execution_count": 11, 544 | "outputs": [ 545 | { 546 | "output_type": "stream", 547 | "name": "stdout", 548 | "text": [ 549 | "\n\n----- Multi head GAT layer (concat operation). Shape(num_layers*emb, n)\n\nMultiHeadGATTestLayer(\n (heads): ModuleList(\n (0): GATTestLayer(\n (fc): Linear(in_features=5, out_features=3, bias=False)\n (attn_fc): Linear(in_features=6, out_features=1, bias=False)\n )\n (1): GATTestLayer(\n (fc): Linear(in_features=5, out_features=3, bias=False)\n (attn_fc): Linear(in_features=6, out_features=1, bias=False)\n )\n )\n)\n\n----- DGL concat output\n\ntensor([[-0.0217, 0.1506, 0.0835, -0.0217, 0.1506, 0.0835],\n [-0.0939, 0.1587, 0.0572, -0.0939, 0.1587, 0.0572],\n [-0.0786, 0.2852, -0.0929, -0.0786, 0.2852, -0.0929],\n [-0.0315, 0.1058, 0.0427, -0.0315, 0.1058, 0.0427],\n [-0.0796, 0.1923, 0.0691, -0.0796, 0.1923, 0.0691]],\n grad_fn=)\n\n----- Recap of the NumPy concatenation\n\n[[-0.0217 0.1506 0.0835 -0.0217 0.1506 0.0835]\n [-0.0939 0.1587 0.0572 -0.0939 0.1587 0.0572]\n [-0.0786 0.2852 -0.0929 -0.0786 0.2852 -0.0929]\n [-0.0315 0.1058 0.0427 -0.0315 0.1058 0.0427]\n [-0.0796 0.1923 0.0691 -0.0796 0.1923 0.0691]]\n\n\n----- Multi head GAT Layer (average operation). Shape(emb, n)\n\nMultiHeadGATTestLayer(\n (heads): ModuleList(\n (0): GATTestLayer(\n (fc): Linear(in_features=5, out_features=3, bias=False)\n (attn_fc): Linear(in_features=6, out_features=1, bias=False)\n )\n (1): GATTestLayer(\n (fc): Linear(in_features=5, out_features=3, bias=False)\n (attn_fc): Linear(in_features=6, out_features=1, bias=False)\n )\n )\n)\n\n----- DGL average output\n\ntensor(0.0498, grad_fn=)\n\n----- Recap of the NumPy average\n\n0.0498\n" 550 | ] 551 | } 552 | ] 553 | }, 554 | { 555 | "source": [ 556 | "The resulting matrices from the NumPy implementation and the DGL implementation are equal \\o/." 557 | ], 558 | "cell_type": "markdown", 559 | "metadata": {} 560 | }, 561 | { 562 | "source": [ 563 | "# From Theory to Practice\n", 564 | "After the understanding of math and the implementation of GAT building blocks, we can run some experiments as reported in the original paper. Let's recap the DGL modules using a fair parameter initialization. The following implementation is based on the example available here: https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html." 565 | ], 566 | "cell_type": "markdown", 567 | "metadata": {} 568 | }, 569 | { 570 | "source": [ 571 | "## New Imports" 572 | ], 573 | "cell_type": "markdown", 574 | "metadata": {} 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 12, 579 | "metadata": {}, 580 | "outputs": [], 581 | "source": [ 582 | "import time\n", 583 | "from dgl import DGLGraph\n", 584 | "from dgl.data import citation_graph as citegrh\n", 585 | "import networkx as nx" 586 | ] 587 | }, 588 | { 589 | "source": [ 590 | "## GAT Implementation with DGL" 591 | ], 592 | "cell_type": "markdown", 593 | "metadata": {} 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 13, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "class GATLayer(nn.Module):\n", 602 | " def __init__(self, g, in_dim, out_dim):\n", 603 | " super(GATLayer, self).__init__()\n", 604 | " self.g = g\n", 605 | " # equation (1)\n", 606 | " self.fc = nn.Linear(in_dim, out_dim, bias=False)\n", 607 | " # equation (2)\n", 608 | " self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)\n", 609 | " self.reset_parameters()\n", 610 | "\n", 611 | " def reset_parameters(self):\n", 612 | " \"\"\"Reinitialize learnable parameters.\"\"\"\n", 613 | " gain = nn.init.calculate_gain('relu')\n", 614 | " nn.init.xavier_normal_(self.fc.weight, gain=gain)\n", 615 | " nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)\n", 616 | "\n", 617 | " def edge_attention(self, edges):\n", 618 | " # edge UDF for equation (2)\n", 619 | " z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)\n", 620 | " a = self.attn_fc(z2)\n", 621 | " return {'e': F.leaky_relu(a)}\n", 622 | "\n", 623 | " def message_func(self, edges):\n", 624 | " # message UDF for equation (3) & (4)\n", 625 | " return {'z': edges.src['z'], 'e': edges.data['e']}\n", 626 | "\n", 627 | " def reduce_func(self, nodes):\n", 628 | " # reduce UDF for equation (3) & (4)\n", 629 | " # equation (3)\n", 630 | " alpha = F.softmax(nodes.mailbox['e'], dim=1)\n", 631 | " # equation (4)\n", 632 | " h = torch.sum(alpha * nodes.mailbox['z'], dim=1)\n", 633 | " return {'h': h}\n", 634 | "\n", 635 | " def forward(self, h):\n", 636 | " # equation (1)\n", 637 | " z = self.fc(h)\n", 638 | " self.g.ndata['z'] = z\n", 639 | " # equation (2)\n", 640 | " self.g.apply_edges(self.edge_attention)\n", 641 | " # equation (3) & (4)\n", 642 | " self.g.update_all(self.message_func, self.reduce_func)\n", 643 | " return self.g.ndata.pop('h')" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 14, 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [ 652 | "class MultiHeadGATLayer(nn.Module):\n", 653 | " def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):\n", 654 | " super(MultiHeadGATLayer, self).__init__()\n", 655 | " self.heads = nn.ModuleList()\n", 656 | " for i in range(num_heads):\n", 657 | " self.heads.append(GATLayer(g, in_dim, out_dim))\n", 658 | " self.merge = merge\n", 659 | "\n", 660 | " def forward(self, h):\n", 661 | " head_outs = [attn_head(h) for attn_head in self.heads]\n", 662 | " if self.merge == 'cat':\n", 663 | " # concat on the output feature dimension (dim=1)\n", 664 | " return torch.cat(head_outs, dim=1)\n", 665 | " else:\n", 666 | " # merge using average\n", 667 | " return torch.mean(torch.stack(head_outs))" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 15, 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [ 676 | "class GAT(nn.Module):\n", 677 | " def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):\n", 678 | " super(GAT, self).__init__()\n", 679 | " self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)\n", 680 | " # Be aware that the input dimension is hidden_dim*num_heads since\n", 681 | " # multiple head outputs are concatenated together. Also, only\n", 682 | " # one attention head in the output layer.\n", 683 | " self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)\n", 684 | "\n", 685 | " def forward(self, h):\n", 686 | " h = self.layer1(h)\n", 687 | " h = F.elu(h)\n", 688 | " h = self.layer2(h)\n", 689 | " return h" 690 | ] 691 | }, 692 | { 693 | "source": [ 694 | "## Evaluation Functions" 695 | ], 696 | "cell_type": "markdown", 697 | "metadata": {} 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 16, 702 | "metadata": {}, 703 | "outputs": [], 704 | "source": [ 705 | "def accuracy(logits, labels):\n", 706 | " _, indices = torch.max(logits, dim=1)\n", 707 | " correct = torch.sum(indices == labels)\n", 708 | " return correct.item() * 1.0 / len(labels)\n", 709 | "\n", 710 | "def evaluate(model, features, labels, mask):\n", 711 | " model.eval()\n", 712 | " with torch.no_grad():\n", 713 | " logits = model(features)\n", 714 | " logits = logits[mask]\n", 715 | " labels = labels[mask]\n", 716 | " return accuracy(logits, labels)" 717 | ] 718 | }, 719 | { 720 | "source": [ 721 | "## Load Cora Dataset" 722 | ], 723 | "cell_type": "markdown", 724 | "metadata": {} 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 17, 729 | "metadata": {}, 730 | "outputs": [ 731 | { 732 | "output_type": "stream", 733 | "name": "stdout", 734 | "text": [ 735 | "Loading from cache failed, re-processing.\n" 736 | ] 737 | }, 738 | { 739 | "output_type": "error", 740 | "ename": "KeyboardInterrupt", 741 | "evalue": "", 742 | "traceback": [ 743 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 744 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 745 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_cora_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'\\n\\n----- Features of CORA dataset'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 746 | "\u001b[0;32m\u001b[0m in \u001b[0;36mload_cora_data\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload_cora_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcitegrh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_cora\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtrain_mask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBoolTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 747 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/citation_graph.py\u001b[0m in \u001b[0;36mload_cora\u001b[0;34m(raw_dir, force_reload, verbose)\u001b[0m\n\u001b[1;32m 715\u001b[0m \u001b[0mCoraGraphDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 716\u001b[0m \"\"\"\n\u001b[0;32m--> 717\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCoraGraphDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mraw_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforce_reload\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 718\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 719\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 748 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/citation_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, raw_dir, force_reload, verbose)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cora'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 389\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCoraGraphDataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mraw_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforce_reload\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 391\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 749 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/citation_graph.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, raw_dir, force_reload, verbose)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0mraw_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mraw_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[0mforce_reload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mforce_reload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m verbose=verbose)\n\u001b[0m\u001b[1;32m 68\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 750 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/dgl_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, url, raw_dir, hash_key, force_reload, verbose)\u001b[0m\n\u001b[1;32m 284\u001b[0m \u001b[0mhash_key\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhash_key\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mforce_reload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mforce_reload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 286\u001b[0;31m verbose=verbose)\n\u001b[0m\u001b[1;32m 287\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 751 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/dgl_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, url, raw_dir, save_dir, hash_key, force_reload, verbose)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_save_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msave_dir\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 752 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/dgl_dataset.py\u001b[0m in \u001b[0;36m_load\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mload_flag\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_download\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 177\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 753 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/dgl/data/citation_graph.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mallx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolil\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0mfeatures\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest_idx_reorder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest_idx_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 107\u001b[0m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDiGraph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_dict_of_lists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 754 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/scipy/sparse/lil.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_intXint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;31m# Everything else takes the normal path.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mIndexMixin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_asindices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 755 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/scipy/sparse/_index.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_arrayXint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_arrayXslice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# row.ndim == 2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mINT_TYPES\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 756 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/scipy/sparse/lil.py\u001b[0m in \u001b[0;36m_get_arrayXslice\u001b[0;34m(self, row, col)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_get_arrayXslice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_row_ranges\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_get_intXarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 757 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/scipy/sparse/lil.py\u001b[0m in \u001b[0;36m_get_row_ranges\u001b[0;34m(self, rows, col_slice)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mcol_range\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj_start\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj_stop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj_stride\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0mnj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcol_range\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m \u001b[0mnew\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlil_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m _csparsetools.lil_get_row_ranges(self.shape[0], self.shape[1],\n", 758 | "\u001b[0;32m~/anaconda3/envs/learning/lib/python3.6/site-packages/scipy/sparse/lil.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, arg1, shape, dtype, copy)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrows\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'unrecognized lil_matrix constructor usage'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 759 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 760 | ] 761 | } 762 | ], 763 | "source": [ 764 | "def load_cora_data():\n", 765 | " data = citegrh.load_cora()\n", 766 | " features = torch.FloatTensor(data.features)\n", 767 | " labels = torch.LongTensor(data.labels)\n", 768 | " train_mask = torch.BoolTensor(data.train_mask)\n", 769 | " val_mask = torch.BoolTensor(data.val_mask)\n", 770 | " test_mask = torch.BoolTensor(data.test_mask)\n", 771 | " g = DGLGraph(data.graph)\n", 772 | " return g, features, labels, train_mask, val_mask, test_mask\n", 773 | "\n", 774 | "g, features, labels, train_mask, val_mask, test_mask = load_cora_data()\n", 775 | "print('\\n\\n----- Features of CORA dataset')\n", 776 | "\n", 777 | "print('\\n----- Graph:')\n", 778 | "print(g)\n", 779 | "\n", 780 | "print('\\n----- Features:')\n", 781 | "print(features)\n", 782 | "print(features.nonzero(as_tuple=True)[1])\n", 783 | "\n", 784 | "print('\\n----- Labels:')\n", 785 | "print(labels)\n", 786 | "print(labels.size())\n", 787 | "output = torch.unique(labels)\n", 788 | "occs = torch.bincount(labels)\n", 789 | "print('----- Number of unique labels:')\n", 790 | "print(output)\n", 791 | "print('----- Number of label occurrences:')\n", 792 | "print(occs)\n", 793 | "\n", 794 | "print('\\n----- Training mask:')\n", 795 | "train_long = train_mask.long()\n", 796 | "occs = torch.bincount(train_long)\n", 797 | "print(output)\n", 798 | "print(occs)\n", 799 | "\n", 800 | "print('\\n----- Validation mask:')\n", 801 | "val_long = val_mask.long()\n", 802 | "occs = torch.bincount(val_long)\n", 803 | "print(output)\n", 804 | "print(occs)\n", 805 | "\n", 806 | "print('\\n----- Testing mask:')\n", 807 | "test_long = test_mask.long()\n", 808 | "occs = torch.bincount(test_long)\n", 809 | "print(output)\n", 810 | "print(occs)\n" 811 | ] 812 | }, 813 | { 814 | "source": [ 815 | "Analyzing the cora dataset, you can get the following information:\n", 816 | "\n", 817 | "1. Nodes have no features (one-hot encoding vectors)\n", 818 | "2. Node labels are uniformly distributed\n" 819 | ], 820 | "cell_type": "markdown", 821 | "metadata": {} 822 | }, 823 | { 824 | "source": [ 825 | "## Training Loop" 826 | ], 827 | "cell_type": "markdown", 828 | "metadata": {} 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": null, 833 | "metadata": {}, 834 | "outputs": [], 835 | "source": [ 836 | "# create the model, 2 heads, each head has hidden size 8\n", 837 | "model = GAT(g,\n", 838 | " in_dim=features.size()[1],\n", 839 | " hidden_dim=8,\n", 840 | " out_dim=7,\n", 841 | " num_heads=2)\n", 842 | "\n", 843 | "# create optimizer\n", 844 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", 845 | "\n", 846 | "# main loop\n", 847 | "dur = []\n", 848 | "for epoch in range(300):\n", 849 | " t0 = time.time()\n", 850 | "\n", 851 | " logits = model(features)\n", 852 | " logp = F.log_softmax(logits, 1)\n", 853 | " loss = F.nll_loss(logp[train_mask], labels[train_mask])\n", 854 | "\n", 855 | " train_acc = accuracy(logp[train_mask], labels[train_mask])\n", 856 | "\n", 857 | " optimizer.zero_grad()\n", 858 | " loss.backward()\n", 859 | " optimizer.step()\n", 860 | "\n", 861 | " dur.append(time.time() - t0)\n", 862 | "\n", 863 | " print(\"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Training Accuracy {:.4f}\".format(\n", 864 | " epoch, np.mean(dur), loss.item(), train_acc))\n", 865 | "\n", 866 | " if epoch % 30==0:\n", 867 | " print(\"\\nEval on validation dataset...\")\n", 868 | " val_acc = evaluate(model, features, labels, val_mask)\n", 869 | " print(\"Validation Accuracy: {:.4f}\\n\".format(val_acc))\n", 870 | "\n", 871 | "print()\n", 872 | "acc = evaluate(model, features, labels, test_mask)\n", 873 | "print(\"Test Accuracy {:.4f}\".format(acc))\n" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": null, 879 | "metadata": {}, 880 | "outputs": [], 881 | "source": [] 882 | } 883 | ] 884 | } -------------------------------------------------------------------------------- /gnns/rgcn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.8.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "Python 3.8.5 64-bit ('torchkgae': conda)", 18 | "display_name": "Python 3.8.5 64-bit ('torchkgae': conda)", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "334c8f8e89452fc07a078fc466e9ca5cd061fa11063f9d04ce9b4f9d89d22362" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Understanding Relational Graph Convolutional Networks (R-GCNs)\n", 34 | "What happens under the hood of Graph Neural Networks (GNNs) applied to multi-relational data, such as Knowledge Graphs (KGs)? A brief introduction to R-GCNs using pure numpy.\n", 35 | "\n", 36 | "Original Paper: Schlichtkrull, M., Kipf, T. N., Bloem, P., Van Den Berg, R., Titov, I., & Welling, M. (2018, June). *Modeling relational data with graph convolutional networks*. In *European Semantic Web Conference* (pp. 593-607). Springer, Cham. \n", 37 | "\n", 38 | "## Requirements\n", 39 | "- Numpy.\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "try:\n", 49 | " import numpy as np\n", 50 | "except ImportError as e:\n", 51 | " print('numpy is not available in you environment: try \"pip install numpy\"')\n", 52 | "\n", 53 | "np.random.seed(1)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Recall on the GNNs (and the Vanilla GCNs)\n", 61 | "This section provides a recall on the behaviour of a basic GNN layer. \n", 62 | "\n", 63 | "Main ingredients:\n", 64 | "- One-hot vectors (no features) adopted to represent nodes.\n", 65 | "- Weight matrix representing the learnable parameters (or weights).\n", 66 | "- Adjacency matrix describing undirected edges between nodes. " 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": { 73 | "tags": [] 74 | }, 75 | "outputs": [ 76 | { 77 | "output_type": "stream", 78 | "name": "stdout", 79 | "text": "\n\n----- One-hot vector representation of nodes:\n[[0. 0. 1. 0. 0.]\n [0. 1. 0. 0. 0.]\n [0. 0. 0. 0. 1.]\n [1. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0.]]\n\n\n----- Weight Matrix:\n[[-0.4294049 0.57624235 -0.3047382 ]\n [-0.11941829 -0.12942953 0.19600584]\n [ 0.5029172 0.3998854 -0.21561317]\n [ 0.02834577 -0.06529497 -0.31225734]\n [ 0.03973776 0.47800217 -0.04941563]]\n\n\n----- Adjacency Matrix (undirected graph):\n[[1 1 1 0 1]\n [1 1 1 1 1]\n [1 1 1 1 0]\n [0 1 1 1 1]\n [1 1 0 1 1]]\n" 80 | } 81 | ], 82 | "source": [ 83 | "import numpy as np\n", 84 | "\n", 85 | "# One-hot vectors for representing nodes (randomly initialized)\n", 86 | "X = np.eye(5, 5)\n", 87 | "n = X.shape[0]\n", 88 | "np.random.shuffle(X)\n", 89 | "\n", 90 | "print('\\n\\n----- One-hot vector representation of nodes:')\n", 91 | "print(X)\n", 92 | "\n", 93 | "# Low-dimensional vector to represent node embeddings\n", 94 | "emb = 3 \n", 95 | "\n", 96 | "# Weight matrix (randomly inizialized according to Glorot and Bengio (2010))\n", 97 | "W = np.random.uniform(-np.sqrt(1. / emb), np.sqrt(1. / emb), (n, emb))\n", 98 | "\n", 99 | "print('\\n\\n----- Weight Matrix:')\n", 100 | "print(W)\n", 101 | "\n", 102 | "# Adjacency matrix (randomly initialized)\n", 103 | "A = np.random.randint(2, size=(n, n))\n", 104 | "np.fill_diagonal(A, 1) # Include the self loop\n", 105 | "A_und = (A + A.T) # Hack for creating a symmetric adjacency matrix (undirected graph)\n", 106 | "A_und[A_und > 1] = 1\n", 107 | "\n", 108 | "print('\\n\\n----- Adjacency Matrix (undirected graph):')\n", 109 | "print(A_und)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "Considering these ingredients, a \"recursive neighborhood diffusion\" is performed through the so-called “message passing framework”. The main idea behind this framework is that each node representation is updated with its neighbors' features. The neighbors' features are *passed* to the target node as *messages* through the edges. \n", 117 | "\n", 118 | "The operations are the following:\n", 119 | "* A linear transformation (or projection) involving the nodes features and the weight matrix.\n", 120 | "* A neighborhood diffusion to update the nodes representations, aggregating the features of its neighbors. " 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "metadata": { 127 | "tags": [] 128 | }, 129 | "outputs": [ 130 | { 131 | "output_type": "stream", 132 | "name": "stdout", 133 | "text": "\n\n----- Output of the linear transformation:\n[[ 0.5029172 0.3998854 -0.21561317]\n [-0.11941829 -0.12942953 0.19600584]\n [ 0.03973776 0.47800217 -0.04941563]\n [-0.4294049 0.57624235 -0.3047382 ]\n [ 0.02834577 -0.06529497 -0.31225734]]\n\n\n----- (GNN) Output of the neighborhood diffusion:\n[[ 0.45158244 0.68316307 -0.3812803 ]\n [ 0.02217754 1.25940542 -0.6860185 ]\n [-0.00616823 1.3247004 -0.37376116]\n [-0.48073966 0.85952002 -0.47040533]\n [-0.01756022 0.78140325 -0.63660287]]\n" 134 | } 135 | ], 136 | "source": [ 137 | "# Linear transformation\n", 138 | "L_0 = X.dot(W)\n", 139 | "\n", 140 | "print('\\n\\n----- Output of the linear transformation:')\n", 141 | "print(L_0)\n", 142 | "\n", 143 | "# Neighborhood diffusion\n", 144 | "ND_GNN = A_und.dot(L_0)\n", 145 | "\n", 146 | "print('\\n\\n----- (GNN) Output of the neighborhood diffusion:')\n", 147 | "print(ND_GNN)\n", 148 | "\n", 149 | "# Test on the aggregation:\n", 150 | "assert(ND_GNN[0,0] == L_0[0,0] + L_0[1,0] + L_0[2,0] + L_0[4,0])\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "In the simplest formulation of the GNN, represented by the Vanilla Graph Convolutional Networks (GCNs), the aggregation/update operation is an isotropic computation, where the features of neighbor nodes are considered in the same way. \n", 158 | "\n", 159 | "More precisely, an isotropic *average* computation is performed in the specific case of Vanilla GCNs. This average operation requires a new ingredient represented by the indegree of each node, which consists in the number of incoming edges." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 4, 165 | "metadata": { 166 | "tags": [] 167 | }, 168 | "outputs": [ 169 | { 170 | "output_type": "stream", 171 | "name": "stdout", 172 | "text": "\n\n----- Degree vector - Each element represents the i-node degree:\n[4 5 4 4 4]\n\n\n----- Reciprocal of the degree (in a diagonal matrix):\n[[0.25 0. 0. 0. 0. ]\n [0. 0.2 0. 0. 0. ]\n [0. 0. 0.25 0. 0. ]\n [0. 0. 0. 0.25 0. ]\n [0. 0. 0. 0. 0.25]]\n\n\n----- (GCN) Isotropic average computation:\n[[ 0.11289561 0.17079077 -0.09532007]\n [ 0.00443551 0.25188109 -0.1372037 ]\n [-0.00154206 0.3311751 -0.09344029]\n [-0.12018491 0.21488001 -0.11760133]\n [-0.00439005 0.19535081 -0.15915072]]\n" 173 | } 174 | ], 175 | "source": [ 176 | "# Degree vector (degree for each node)\n", 177 | "D = A_und.sum(axis=1)\n", 178 | "\n", 179 | "print('\\n\\n----- Degree vector - Each element represents the i-node degree:')\n", 180 | "print(D)\n", 181 | "\n", 182 | "# Reciprocal of the degree to perform the average computation (diagonal matrix)\n", 183 | "D_rec = np.diag(np.reciprocal(D.astype(np.float32))) # Need to convert degree integer values as float\n", 184 | "\n", 185 | "print('\\n\\n----- Reciprocal of the degree (in a diagonal matrix):')\n", 186 | "print(D_rec)\n", 187 | "\n", 188 | "# Isotropic average computation\n", 189 | "ND_GCN = D_rec.dot(ND_GNN)\n", 190 | "\n", 191 | "print('\\n\\n----- (GCN) Isotropic average computation:')\n", 192 | "print(ND_GCN)\n", 193 | "\n", 194 | "# Test on the isotropic average computation:\n", 195 | "assert(ND_GCN[0,0] == ND_GNN[0,0] * D_rec[0,0])" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## From GCNs to R-GCNs for encoding KGs\n", 203 | "The previous example considers an undirected and no-typed graph. As mentioned before, the update process is based on the following steps (the node indegree is not considered for the sake of simplicity):\n", 204 | "\n", 205 | "1. a projection step (or linear transformation), which is achieved multiplying of: (i) the one-hot feature matrix with (ii) the weight matrix.\n", 206 | " \n", 207 | " (i). Matrix (n, n) that defines the initial features of the nodes.\n", 208 | " \n", 209 | " (ii). Matrix (n, emb) that describes the model's parameters. The current matrix is able to encode only one type of relation.\n", 210 | "\n", 211 | "2. an aggregation step, which is achieved multiplying: (i) the adjacency matrix with (ii) the matrix resulting from the projection step.\n", 212 | "\n", 213 | " (i). Symmetric Matrix (n, n) that describes undirected and untyped edges.\n", 214 | "\n", 215 | " (ii). Matrix (n, emb) that describes the latent node representation of nodes.\n", 216 | "\n", 217 | "In order to extend the GCN layer to encode the structure of a KG, we need to represent our data as a directed and multi-typed graph. The update process is similar to the previous one, but the ingredients are more complex:\n", 218 | "\n", 219 | "1. a projection step, which is achieved multiplying: (i) the one-hot feature matrix with (ii) the weight **tensor**.\n", 220 | " \n", 221 | " (i). Matrix (n, n) that defines the initial features of the nodes.\n", 222 | " \n", 223 | " (ii). **Tensor (r, n, emb)** that describes the model's parameters, which will embed the latent node representations at the end of the training process. This tensor is able to encode different relations by stacking **r** batches of matrices (n, emb). Each of these batches encodes a single typed relation.\n", 224 | "\n", 225 | "Tip: the projection step will no longer be a simple multiplication of matrices, but it will be a *batch matrix multiplication*, in which (i) is multiplied with each batch of (ii).\n", 226 | "\n", 227 | "2. an aggregation step, which is achieved multiplying (i) the **(directed) adjacency tensor** with (ii) the **tensor** resulting from the projection step.\n", 228 | "\n", 229 | " (i) **Tensor (r, n, n)** that describes directed and **r**-typed edges. This tensor is composed of **r** batches of adjacency matrices (n,n). In detail, each of these matrices describes the edges between nodes, according to a specific type of relation. Moreover, compared to the adjacency matrix of an undirected graph, each of these adjacency matrices is not symmetric, because it encodes a specific edge direction.\n", 230 | " (ii) **Tensor (r, n, emb)** is the result of the projection layer.\n", 231 | "\n", 232 | "Tip: as happened for the projection step, the aggregation phase consists of a *batch matrix multiplication*. Each batch of (i) is multiplied with each batch of (ii). This aggregation defines the GCN transformation for each batch. At the end of the process, the batches have to be added together (R-GCN) to obtain a node representation that incorporates the neighborhood aggregation according to different type of relations.\n", 233 | "\n", 234 | "The following example shows the behaviour of a R-GCN layer encoding a directed and multi-typed graph with 2 types of edges (or relations).\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 5, 247 | "metadata": { 248 | "tags": [] 249 | }, 250 | "outputs": [ 251 | { 252 | "output_type": "stream", 253 | "name": "stdout", 254 | "text": "\n\n----- Recall --> One-hot vector representation of nodes:\n[[0. 0. 1. 0. 0.]\n [0. 1. 0. 0. 0.]\n [0. 0. 0. 0. 1.]\n [1. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0.]]\n\n\n----- Number of relation types:\n2\n\n\n----- Weight matrix of relation 1:\n[[-0.46378913 -0.09109707 0.52872529]\n [ 0.03829597 0.22156061 -0.2130242 ]\n [ 0.21535272 0.38639244 -0.55623279]\n [ 0.28884178 0.56448816 0.28655701]\n [-0.25352144 0.334031 -0.45815514]]\n\n\n----- Weight matrix of relation 2:\n[[0.22946783 0.4552118 0.15387093]\n [0.15100992 0.073714 0.01948981]\n [0.34262941 0.11369778 0.14011786]\n [0.25087085 0.03614765 0.29131763]\n [0.081897 0.29875971 0.3528816 ]]\n\n\n-----Tensor including both weight matrices:\n[[[-0.46378913 -0.09109707 0.52872529]\n [ 0.03829597 0.22156061 -0.2130242 ]\n [ 0.21535272 0.38639244 -0.55623279]\n [ 0.28884178 0.56448816 0.28655701]\n [-0.25352144 0.334031 -0.45815514]]\n\n [[ 0.22946783 0.4552118 0.15387093]\n [ 0.15100992 0.073714 0.01948981]\n [ 0.34262941 0.11369778 0.14011786]\n [ 0.25087085 0.03614765 0.29131763]\n [ 0.081897 0.29875971 0.3528816 ]]]\n\n\n----- Linear trasformation (or projection) with batch matrix multiplication:\n[[[ 0.21535272 0.38639244 -0.55623279]\n [ 0.03829597 0.22156061 -0.2130242 ]\n [-0.25352144 0.334031 -0.45815514]\n [-0.46378913 -0.09109707 0.52872529]\n [ 0.28884178 0.56448816 0.28655701]]\n\n [[ 0.34262941 0.11369778 0.14011786]\n [ 0.15100992 0.073714 0.01948981]\n [ 0.081897 0.29875971 0.3528816 ]\n [ 0.22946783 0.4552118 0.15387093]\n [ 0.25087085 0.03614765 0.29131763]]]\n\n\n----- Adjacency matrix of relation 1:\n[[0 1 1 1 1]\n [1 1 0 0 1]\n [1 0 0 1 0]\n [0 0 1 1 1]\n [1 1 0 1 0]]\n\n\n----- Adjacency matrix of relation 2:\n[[0 0 0 1 0]\n [1 0 0 0 0]\n [1 0 0 1 1]\n [0 0 0 0 0]\n [0 1 0 0 0]]\n\n\n----- Tensor including both adjacency matrices:\n[[[0 1 1 1 1]\n [1 1 0 0 1]\n [1 0 0 1 0]\n [0 0 1 1 1]\n [1 1 0 1 0]]\n\n [[0 0 0 1 0]\n [1 0 0 0 0]\n [1 0 0 1 1]\n [0 0 0 0 0]\n [0 1 0 0 0]]]\n\n\n----- (GCN) Output of the neighborhood diffusion (for each typed edge):\n[[[-0.39017282 1.0289827 0.14410296]\n [ 0.54249047 1.17244121 -0.48269997]\n [-0.24843641 0.29529538 -0.0275075 ]\n [-0.42846879 0.80742209 0.35712716]\n [-0.21014043 0.51685598 -0.2405317 ]]\n\n [[ 0.22946783 0.4552118 0.15387093]\n [ 0.34262941 0.11369778 0.14011786]\n [ 0.82296809 0.60505722 0.58530642]\n [ 0. 0. 0. ]\n [ 0.15100992 0.073714 0.01948981]]]\n\n\n----- (R-GCN) Aggregation of the results of the GCN layer applied to different types of edge:\n[[-0.16070499 1.48419449 0.29797389]\n [ 0.88511988 1.28613899 -0.34258211]\n [ 0.57453168 0.9003526 0.55779892]\n [-0.42846879 0.80742209 0.35712716]\n [-0.05913052 0.59056998 -0.22104189]]\n" 255 | } 256 | ], 257 | "source": [ 258 | "print('\\n\\n----- Recall --> One-hot vector representation of nodes:')\n", 259 | "print(X)\n", 260 | "\n", 261 | "# Number of relation types\n", 262 | "num_rels = 2\n", 263 | "\n", 264 | "print('\\n\\n----- Number of relation types:')\n", 265 | "print(num_rels)\n", 266 | "\n", 267 | "# Weight matrix of relation number 1 (randomly inizialized according to Glorot and Bengio (2010))\n", 268 | "W_rel1 = np.random.uniform(-np.sqrt(1. / emb), np.sqrt(1. / emb), (n, emb))\n", 269 | "print('\\n\\n----- Weight matrix of relation 1:')\n", 270 | "print(W_rel1)\n", 271 | "\n", 272 | "# Weight matrix of relation number 2 (randomly initialized with uniform distribution)\n", 273 | "W_rel2 = np.random.uniform(1/100, 0.5, (n, emb))\n", 274 | "print('\\n\\n----- Weight matrix of relation 2:')\n", 275 | "print(W_rel2)\n", 276 | "\n", 277 | "# Tensor including both weight matrices\n", 278 | "W_rels = np.concatenate((W_rel1, W_rel2))\n", 279 | "W_rels = np.reshape(W_rels,(num_rels, n, emb)) # num_rels is the number of the relations, n is the number of nodes, emb is the low-dimensional representation\n", 280 | "print('\\n\\n-----Tensor including both weight matrices:')\n", 281 | "print(W_rels)\n", 282 | "\n", 283 | "L_0_rels = np.matmul(X, W_rels)\n", 284 | "print('\\n\\n----- Linear trasformation (or projection) with batch matrix multiplication:')\n", 285 | "print(L_0_rels)\n", 286 | "\n", 287 | "# Adjacency matrix of relation number 1\n", 288 | "A_rel1 = np.random.randint(2, size=(n, n))\n", 289 | "np.fill_diagonal(A, 0) # Not consider the self loop (diag values = 0)\n", 290 | "print('\\n\\n----- Adjacency matrix of relation 1:')\n", 291 | "print(A_rel1)\n", 292 | "\n", 293 | "# Adjacency matrix of relation number 2\n", 294 | "A_rel2 = np.random.randint(3,size=(n,n))\n", 295 | "np.fill_diagonal(A_rel2, 0) # Not consider the self loop (diag values = 0)\n", 296 | "A_rel2[A_rel2>1] = 0\n", 297 | "print('\\n\\n----- Adjacency matrix of relation 2:')\n", 298 | "print(A_rel2)\n", 299 | "\n", 300 | "# Tensor including both adjacency matrices\n", 301 | "A_rels = np.concatenate((A_rel1, A_rel2))\n", 302 | "A_rels = np.reshape(A_rels, (num_rels, n, n)) # num_rels is the number of the relations, (n,n) is the dimension of the original adj matrix\n", 303 | "print('\\n\\n----- Tensor including both adjacency matrices:')\n", 304 | "print(A_rels)\n", 305 | "\n", 306 | "# GCN for each typed edge\n", 307 | "ND_GCN = np.matmul(A_rels, L_0_rels)\n", 308 | "print('\\n\\n----- (GCN) Output of the neighborhood diffusion (for each typed edge):')\n", 309 | "print(ND_GCN)\n", 310 | "\n", 311 | "# R-GCN\n", 312 | "RGCN = np.sum(ND_GCN, axis=0)\n", 313 | "print('\\n\\n----- (R-GCN) Aggregation of the results of the GCN layer applied to different types of edge:')\n", 314 | "print(RGCN)\n", 315 | "\n", 316 | "# Test of the aggregation\n", 317 | "assert(RGCN[0,0] == L_0_rels[0,1,0] + L_0_rels[0,2,0] + L_0_rels[0,3,0] + L_0_rels[0,4,0] + L_0_rels[1,3,0])\n", 318 | "\n", 319 | "\n", 320 | "\n" 321 | ] 322 | } 323 | ] 324 | } -------------------------------------------------------------------------------- /linear-regression/data/insurance.csv: -------------------------------------------------------------------------------- 1 | age,sex,bmi,children,smoker,region,expenses 19,female,27.9,0,yes,southwest,16884.92 18,male,33.8,1,no,southeast,1725.55 28,male,33.0,3,no,southeast,4449.46 33,male,22.7,0,no,northwest,21984.47 32,male,28.9,0,no,northwest,3866.86 31,female,25.7,0,no,southeast,3756.62 46,female,33.4,1,no,southeast,8240.59 37,female,27.7,3,no,northwest,7281.51 37,male,29.8,2,no,northeast,6406.41 60,female,25.8,0,no,northwest,28923.14 25,male,26.2,0,no,northeast,2721.32 62,female,26.3,0,yes,southeast,27808.73 23,male,34.4,0,no,southwest,1826.84 56,female,39.8,0,no,southeast,11090.72 27,male,42.1,0,yes,southeast,39611.76 19,male,24.6,1,no,southwest,1837.24 52,female,30.8,1,no,northeast,10797.34 23,male,23.8,0,no,northeast,2395.17 56,male,40.3,0,no,southwest,10602.39 30,male,35.3,0,yes,southwest,36837.47 60,female,36.0,0,no,northeast,13228.85 30,female,32.4,1,no,southwest,4149.74 18,male,34.1,0,no,southeast,1137.01 34,female,31.9,1,yes,northeast,37701.88 37,male,28.0,2,no,northwest,6203.9 59,female,27.7,3,no,southeast,14001.13 63,female,23.1,0,no,northeast,14451.84 55,female,32.8,2,no,northwest,12268.63 23,male,17.4,1,no,northwest,2775.19 31,male,36.3,2,yes,southwest,38711 22,male,35.6,0,yes,southwest,35585.58 18,female,26.3,0,no,northeast,2198.19 19,female,28.6,5,no,southwest,4687.8 63,male,28.3,0,no,northwest,13770.1 28,male,36.4,1,yes,southwest,51194.56 19,male,20.4,0,no,northwest,1625.43 62,female,33.0,3,no,northwest,15612.19 26,male,20.8,0,no,southwest,2302.3 35,male,36.7,1,yes,northeast,39774.28 60,male,39.9,0,yes,southwest,48173.36 24,female,26.6,0,no,northeast,3046.06 31,female,36.6,2,no,southeast,4949.76 41,male,21.8,1,no,southeast,6272.48 37,female,30.8,2,no,southeast,6313.76 38,male,37.1,1,no,northeast,6079.67 55,male,37.3,0,no,southwest,20630.28 18,female,38.7,2,no,northeast,3393.36 28,female,34.8,0,no,northwest,3556.92 60,female,24.5,0,no,southeast,12629.9 36,male,35.2,1,yes,southeast,38709.18 18,female,35.6,0,no,northeast,2211.13 21,female,33.6,2,no,northwest,3579.83 48,male,28.0,1,yes,southwest,23568.27 36,male,34.4,0,yes,southeast,37742.58 40,female,28.7,3,no,northwest,8059.68 58,male,37.0,2,yes,northwest,47496.49 58,female,31.8,2,no,northeast,13607.37 18,male,31.7,2,yes,southeast,34303.17 53,female,22.9,1,yes,southeast,23244.79 34,female,37.3,2,no,northwest,5989.52 43,male,27.4,3,no,northeast,8606.22 25,male,33.7,4,no,southeast,4504.66 64,male,24.7,1,no,northwest,30166.62 28,female,25.9,1,no,northwest,4133.64 20,female,22.4,0,yes,northwest,14711.74 19,female,28.9,0,no,southwest,1743.21 61,female,39.1,2,no,southwest,14235.07 40,male,26.3,1,no,northwest,6389.38 40,female,36.2,0,no,southeast,5920.1 28,male,24.0,3,yes,southeast,17663.14 27,female,24.8,0,yes,southeast,16577.78 31,male,28.5,5,no,northeast,6799.46 53,female,28.1,3,no,southwest,11741.73 58,male,32.0,1,no,southeast,11946.63 44,male,27.4,2,no,southwest,7726.85 57,male,34.0,0,no,northwest,11356.66 29,female,29.6,1,no,southeast,3947.41 21,male,35.5,0,no,southeast,1532.47 22,female,39.8,0,no,northeast,2755.02 41,female,33.0,0,no,northwest,6571.02 31,male,26.9,1,no,northeast,4441.21 45,female,38.3,0,no,northeast,7935.29 22,male,37.6,1,yes,southeast,37165.16 48,female,41.2,4,no,northwest,11033.66 37,female,34.8,2,yes,southwest,39836.52 45,male,22.9,2,yes,northwest,21098.55 57,female,31.2,0,yes,northwest,43578.94 56,female,27.2,0,no,southwest,11073.18 46,female,27.7,0,no,northwest,8026.67 55,female,27.0,0,no,northwest,11082.58 21,female,39.5,0,no,southeast,2026.97 53,female,24.8,1,no,northwest,10942.13 59,male,29.8,3,yes,northeast,30184.94 35,male,34.8,2,no,northwest,5729.01 64,female,31.3,2,yes,southwest,47291.06 28,female,37.6,1,no,southeast,3766.88 54,female,30.8,3,no,southwest,12105.32 55,male,38.3,0,no,southeast,10226.28 56,male,20.0,0,yes,northeast,22412.65 38,male,19.3,0,yes,southwest,15820.7 41,female,31.6,0,no,southwest,6186.13 30,male,25.5,0,no,northeast,3645.09 18,female,30.1,0,no,northeast,21344.85 61,female,29.9,3,yes,southeast,30942.19 34,female,27.5,1,no,southwest,5003.85 20,male,28.0,1,yes,northwest,17560.38 19,female,28.4,1,no,southwest,2331.52 26,male,30.9,2,no,northwest,3877.3 29,male,27.9,0,no,southeast,2867.12 63,male,35.1,0,yes,southeast,47055.53 54,male,33.6,1,no,northwest,10825.25 55,female,29.7,2,no,southwest,11881.36 37,male,30.8,0,no,southwest,4646.76 21,female,35.7,0,no,northwest,2404.73 52,male,32.2,3,no,northeast,11488.32 60,male,28.6,0,no,northeast,30260 58,male,49.1,0,no,southeast,11381.33 29,female,27.9,1,yes,southeast,19107.78 49,female,27.2,0,no,southeast,8601.33 37,female,23.4,2,no,northwest,6686.43 44,male,37.1,2,no,southwest,7740.34 18,male,23.8,0,no,northeast,1705.62 20,female,29.0,0,no,northwest,2257.48 44,male,31.4,1,yes,northeast,39556.49 47,female,33.9,3,no,northwest,10115.01 26,female,28.8,0,no,northeast,3385.4 19,female,28.3,0,yes,southwest,17081.08 52,female,37.4,0,no,southwest,9634.54 32,female,17.8,2,yes,northwest,32734.19 38,male,34.7,2,no,southwest,6082.41 59,female,26.5,0,no,northeast,12815.44 61,female,22.0,0,no,northeast,13616.36 53,female,35.9,2,no,southwest,11163.57 19,male,25.6,0,no,northwest,1632.56 20,female,28.8,0,no,northeast,2457.21 22,female,28.1,0,no,southeast,2155.68 19,male,34.1,0,no,southwest,1261.44 22,male,25.2,0,no,northwest,2045.69 54,female,31.9,3,no,southeast,27322.73 22,female,36.0,0,no,southwest,2166.73 34,male,22.4,2,no,northeast,27375.9 26,male,32.5,1,no,northeast,3490.55 34,male,25.3,2,yes,southeast,18972.5 29,male,29.7,2,no,northwest,18157.88 30,male,28.7,3,yes,northwest,20745.99 29,female,38.8,3,no,southeast,5138.26 46,male,30.5,3,yes,northwest,40720.55 51,female,37.7,1,no,southeast,9877.61 53,female,37.4,1,no,northwest,10959.69 19,male,28.4,1,no,southwest,1842.52 35,male,24.1,1,no,northwest,5125.22 48,male,29.7,0,no,southeast,7789.64 32,female,37.1,3,no,northeast,6334.34 42,female,23.4,0,yes,northeast,19964.75 40,female,25.5,1,no,northeast,7077.19 44,male,39.5,0,no,northwest,6948.7 48,male,24.4,0,yes,southeast,21223.68 18,male,25.2,0,yes,northeast,15518.18 30,male,35.5,0,yes,southeast,36950.26 50,female,27.8,3,no,southeast,19749.38 42,female,26.6,0,yes,northwest,21348.71 18,female,36.9,0,yes,southeast,36149.48 54,male,39.6,1,no,southwest,10450.55 32,female,29.8,2,no,southwest,5152.13 37,male,29.6,0,no,northwest,5028.15 47,male,28.2,4,no,northeast,10407.09 20,female,37.0,5,no,southwest,4830.63 32,female,33.2,3,no,northwest,6128.8 19,female,31.8,1,no,northwest,2719.28 27,male,18.9,3,no,northeast,4827.9 63,male,41.5,0,no,southeast,13405.39 49,male,30.3,0,no,southwest,8116.68 18,male,16.0,0,no,northeast,1694.8 35,female,34.8,1,no,southwest,5246.05 24,female,33.3,0,no,northwest,2855.44 63,female,37.7,0,yes,southwest,48824.45 38,male,27.8,2,no,northwest,6455.86 54,male,29.2,1,no,southwest,10436.1 46,female,28.9,2,no,southwest,8823.28 41,female,33.2,3,no,northeast,8538.29 58,male,28.6,0,no,northwest,11735.88 18,female,38.3,0,no,southeast,1631.82 22,male,20.0,3,no,northeast,4005.42 44,female,26.4,0,no,northwest,7419.48 44,male,30.7,2,no,southeast,7731.43 36,male,41.9,3,yes,northeast,43753.34 26,female,29.9,2,no,southeast,3981.98 30,female,30.9,3,no,southwest,5325.65 41,female,32.2,1,no,southwest,6775.96 29,female,32.1,2,no,northwest,4922.92 61,male,31.6,0,no,southeast,12557.61 36,female,26.2,0,no,southwest,4883.87 25,male,25.7,0,no,southeast,2137.65 56,female,26.6,1,no,northwest,12044.34 18,male,34.4,0,no,southeast,1137.47 19,male,30.6,0,no,northwest,1639.56 39,female,32.8,0,no,southwest,5649.72 45,female,28.6,2,no,southeast,8516.83 51,female,18.1,0,no,northwest,9644.25 64,female,39.3,0,no,northeast,14901.52 19,female,32.1,0,no,northwest,2130.68 48,female,32.2,1,no,southeast,8871.15 60,female,24.0,0,no,northwest,13012.21 27,female,36.1,0,yes,southeast,37133.9 46,male,22.3,0,no,southwest,7147.11 28,female,28.9,1,no,northeast,4337.74 59,male,26.4,0,no,southeast,11743.3 35,male,27.7,2,yes,northeast,20984.09 63,female,31.8,0,no,southwest,13880.95 40,male,41.2,1,no,northeast,6610.11 20,male,33.0,1,no,southwest,1980.07 40,male,30.9,4,no,northwest,8162.72 24,male,28.5,2,no,northwest,3537.7 34,female,26.7,1,no,southeast,5002.78 45,female,30.9,2,no,southwest,8520.03 41,female,37.1,2,no,southwest,7371.77 53,female,26.6,0,no,northwest,10355.64 27,male,23.1,0,no,southeast,2483.74 26,female,29.9,1,no,southeast,3392.98 24,female,23.2,0,no,southeast,25081.77 34,female,33.7,1,no,southwest,5012.47 53,female,33.3,0,no,northeast,10564.88 32,male,30.8,3,no,southwest,5253.52 19,male,34.8,0,yes,southwest,34779.62 42,male,24.6,0,yes,southeast,19515.54 55,male,33.9,3,no,southeast,11987.17 28,male,38.1,0,no,southeast,2689.5 58,female,41.9,0,no,southeast,24227.34 41,female,31.6,1,no,northeast,7358.18 47,male,25.5,2,no,northeast,9225.26 42,female,36.2,1,no,northwest,7443.64 59,female,27.8,3,no,southeast,14001.29 19,female,17.8,0,no,southwest,1727.79 59,male,27.5,1,no,southwest,12333.83 39,male,24.5,2,no,northwest,6710.19 40,female,22.2,2,yes,southeast,19444.27 18,female,26.7,0,no,southeast,1615.77 31,male,38.4,2,no,southeast,4463.21 19,male,29.1,0,yes,northwest,17352.68 44,male,38.1,1,no,southeast,7152.67 23,female,36.7,2,yes,northeast,38511.63 33,female,22.1,1,no,northeast,5354.07 55,female,26.8,1,no,southwest,35160.13 40,male,35.3,3,no,southwest,7196.87 63,female,27.7,0,yes,northeast,29523.17 54,male,30.0,0,no,northwest,24476.48 60,female,38.1,0,no,southeast,12648.7 24,male,35.9,0,no,southeast,1986.93 19,male,20.9,1,no,southwest,1832.09 29,male,29.0,1,no,northeast,4040.56 18,male,17.3,2,yes,northeast,12829.46 63,female,32.2,2,yes,southwest,47305.31 54,male,34.2,2,yes,southeast,44260.75 27,male,30.3,3,no,southwest,4260.74 50,male,31.8,0,yes,northeast,41097.16 55,female,25.4,3,no,northeast,13047.33 56,male,33.6,0,yes,northwest,43921.18 38,female,40.2,0,no,southeast,5400.98 51,male,24.4,4,no,northwest,11520.1 19,male,31.9,0,yes,northwest,33750.29 58,female,25.2,0,no,southwest,11837.16 20,female,26.8,1,yes,southeast,17085.27 52,male,24.3,3,yes,northeast,24869.84 19,male,37.0,0,yes,northwest,36219.41 53,female,38.1,3,no,southeast,20463 46,male,42.4,3,yes,southeast,46151.12 40,male,19.8,1,yes,southeast,17179.52 59,female,32.4,3,no,northeast,14590.63 45,male,30.2,1,no,southwest,7441.05 49,male,25.8,1,no,northeast,9282.48 18,male,29.4,1,no,southeast,1719.44 50,male,34.2,2,yes,southwest,42856.84 41,male,37.1,2,no,northwest,7265.7 50,male,27.5,1,no,northeast,9617.66 25,male,27.6,0,no,northwest,2523.17 47,female,26.6,2,no,northeast,9715.84 19,male,20.6,2,no,northwest,2803.7 22,female,24.3,0,no,southwest,2150.47 59,male,31.8,2,no,southeast,12928.79 51,female,21.6,1,no,southeast,9855.13 40,female,28.1,1,yes,northeast,22331.57 54,male,40.6,3,yes,northeast,48549.18 30,male,27.6,1,no,northeast,4237.13 55,female,32.4,1,no,northeast,11879.1 52,female,31.2,0,no,southwest,9625.92 46,male,26.6,1,no,southeast,7742.11 46,female,48.1,2,no,northeast,9432.93 63,female,26.2,0,no,northwest,14256.19 59,female,36.8,1,yes,northeast,47896.79 52,male,26.4,3,no,southeast,25992.82 28,female,33.4,0,no,southwest,3172.02 29,male,29.6,1,no,northeast,20277.81 25,male,45.5,2,yes,southeast,42112.24 22,female,28.8,0,no,southeast,2156.75 25,male,26.8,3,no,southwest,3906.13 18,male,23.0,0,no,northeast,1704.57 19,male,27.7,0,yes,southwest,16297.85 47,male,25.4,1,yes,southeast,21978.68 31,male,34.4,3,yes,northwest,38746.36 48,female,28.9,1,no,northwest,9249.5 36,male,27.6,3,no,northeast,6746.74 53,female,22.6,3,yes,northeast,24873.38 56,female,37.5,2,no,southeast,12265.51 28,female,33.0,2,no,southeast,4349.46 57,female,38.0,2,no,southwest,12646.21 29,male,33.3,2,no,northwest,19442.35 28,female,27.5,2,no,southwest,20177.67 30,female,33.3,1,no,southeast,4151.03 58,male,34.9,0,no,northeast,11944.59 41,female,33.1,2,no,northwest,7749.16 50,male,26.6,0,no,southwest,8444.47 19,female,24.7,0,no,southwest,1737.38 43,male,36.0,3,yes,southeast,42124.52 49,male,35.9,0,no,southeast,8124.41 27,female,31.4,0,yes,southwest,34838.87 52,male,33.3,0,no,northeast,9722.77 50,male,32.2,0,no,northwest,8835.26 54,male,32.8,0,no,northeast,10435.07 44,female,27.6,0,no,northwest,7421.19 32,male,37.3,1,no,northeast,4667.61 34,male,25.3,1,no,northwest,4894.75 26,female,29.6,4,no,northeast,24671.66 34,male,30.8,0,yes,southwest,35491.64 57,male,40.9,0,no,northeast,11566.3 29,male,27.2,0,no,southwest,2866.09 40,male,34.1,1,no,northeast,6600.21 27,female,23.2,1,no,southeast,3561.89 45,male,36.5,2,yes,northwest,42760.5 64,female,33.8,1,yes,southwest,47928.03 52,male,36.7,0,no,southwest,9144.57 61,female,36.4,1,yes,northeast,48517.56 52,male,27.4,0,yes,northwest,24393.62 61,female,31.2,0,no,northwest,13429.04 56,female,28.8,0,no,northeast,11658.38 43,female,35.7,2,no,northeast,19144.58 64,male,34.5,0,no,southwest,13822.8 60,male,25.7,0,no,southeast,12142.58 62,male,27.6,1,no,northwest,13937.67 50,male,32.3,1,yes,northeast,41919.1 46,female,27.7,1,no,southeast,8232.64 24,female,27.6,0,no,southwest,18955.22 62,male,30.0,0,no,northwest,13352.1 60,female,27.6,0,no,northeast,13217.09 63,male,36.8,0,no,northeast,13981.85 49,female,41.5,4,no,southeast,10977.21 34,female,29.3,3,no,southeast,6184.3 33,male,35.8,2,no,southeast,4890 46,male,33.3,1,no,northeast,8334.46 36,female,29.9,1,no,southeast,5478.04 19,male,27.8,0,no,northwest,1635.73 57,female,23.2,0,no,northwest,11830.61 50,female,25.6,0,no,southwest,8932.08 30,female,27.7,0,no,southwest,3554.2 33,male,35.2,0,no,northeast,12404.88 18,female,38.3,0,no,southeast,14133.04 46,male,27.6,0,no,southwest,24603.05 46,male,43.9,3,no,southeast,8944.12 47,male,29.8,3,no,northwest,9620.33 23,male,41.9,0,no,southeast,1837.28 18,female,20.8,0,no,southeast,1607.51 48,female,32.3,2,no,northeast,10043.25 35,male,30.5,1,no,southwest,4751.07 19,female,21.7,0,yes,southwest,13844.51 21,female,26.4,1,no,southwest,2597.78 21,female,21.9,2,no,southeast,3180.51 49,female,30.8,1,no,northeast,9778.35 56,female,32.3,3,no,northeast,13430.27 42,female,25.0,2,no,northwest,8017.06 44,male,32.0,2,no,northwest,8116.27 18,male,30.4,3,no,northeast,3481.87 61,female,21.1,0,no,northwest,13415.04 57,female,22.2,0,no,northeast,12029.29 42,female,33.2,1,no,northeast,7639.42 26,male,32.9,2,yes,southwest,36085.22 20,male,33.3,0,no,southeast,1391.53 23,female,28.3,0,yes,northwest,18033.97 39,female,24.9,3,yes,northeast,21659.93 24,male,40.2,0,yes,southeast,38126.25 64,female,30.1,3,no,northwest,16455.71 62,male,31.5,1,no,southeast,27000.98 27,female,18.0,2,yes,northeast,15006.58 55,male,30.7,0,yes,northeast,42303.69 55,male,33.0,0,no,southeast,20781.49 35,female,43.3,2,no,southeast,5846.92 44,male,22.1,2,no,northeast,8302.54 19,male,34.4,0,no,southwest,1261.86 58,female,39.1,0,no,southeast,11856.41 50,male,25.4,2,no,northwest,30284.64 26,female,22.6,0,no,northwest,3176.82 24,female,30.2,3,no,northwest,4618.08 48,male,35.6,4,no,northeast,10736.87 19,female,37.4,0,no,northwest,2138.07 48,male,31.4,1,no,northeast,8964.06 49,male,31.4,1,no,northeast,9290.14 46,female,32.3,2,no,northeast,9411.01 46,male,19.9,0,no,northwest,7526.71 43,female,34.4,3,no,southwest,8522 21,male,31.0,0,no,southeast,16586.5 64,male,25.6,2,no,southwest,14988.43 18,female,38.2,0,no,southeast,1631.67 51,female,20.6,0,no,southwest,9264.8 47,male,47.5,1,no,southeast,8083.92 64,female,33.0,0,no,northwest,14692.67 49,male,32.3,3,no,northwest,10269.46 31,male,20.4,0,no,southwest,3260.2 52,female,38.4,2,no,northeast,11396.9 33,female,24.3,0,no,southeast,4185.1 47,female,23.6,1,no,southwest,8539.67 38,male,21.1,3,no,southeast,6652.53 32,male,30.0,1,no,southeast,4074.45 19,male,17.5,0,no,northwest,1621.34 44,female,20.2,1,yes,northeast,19594.81 26,female,17.2,2,yes,northeast,14455.64 25,male,23.9,5,no,southwest,5080.1 19,female,35.2,0,no,northwest,2134.9 43,female,35.6,1,no,southeast,7345.73 52,male,34.1,0,no,southeast,9140.95 36,female,22.6,2,yes,southwest,18608.26 64,male,39.2,1,no,southeast,14418.28 63,female,27.0,0,yes,northwest,28950.47 64,male,33.9,0,yes,southeast,46889.26 61,male,35.9,0,yes,southeast,46599.11 40,male,32.8,1,yes,northeast,39125.33 25,male,30.6,0,no,northeast,2727.4 48,male,30.2,2,no,southwest,8968.33 45,male,24.3,5,no,southeast,9788.87 38,female,27.3,1,no,northeast,6555.07 18,female,29.2,0,no,northeast,7323.73 21,female,16.8,1,no,northeast,3167.46 27,female,30.4,3,no,northwest,18804.75 19,male,33.1,0,no,southwest,23082.96 29,female,20.2,2,no,northwest,4906.41 42,male,26.9,0,no,southwest,5969.72 60,female,30.5,0,no,southwest,12638.2 31,male,28.6,1,no,northwest,4243.59 60,male,33.1,3,no,southeast,13919.82 22,male,31.7,0,no,northeast,2254.8 35,male,28.9,3,no,southwest,5926.85 52,female,46.8,5,no,southeast,12592.53 26,male,29.5,0,no,northeast,2897.32 31,female,32.7,1,no,northwest,4738.27 33,female,33.5,0,yes,southwest,37079.37 18,male,43.0,0,no,southeast,1149.4 59,female,36.5,1,no,southeast,28287.9 56,male,26.7,1,yes,northwest,26109.33 45,female,33.1,0,no,southwest,7345.08 60,male,29.6,0,no,northeast,12731 56,female,25.7,0,no,northwest,11454.02 40,female,29.6,0,no,southwest,5910.94 35,male,38.6,1,no,southwest,4762.33 39,male,29.6,4,no,southwest,7512.27 30,male,24.1,1,no,northwest,4032.24 24,male,23.4,0,no,southwest,1969.61 20,male,29.7,0,no,northwest,1769.53 32,male,46.5,2,no,southeast,4686.39 59,male,37.4,0,no,southwest,21797 55,female,30.1,2,no,southeast,11881.97 57,female,30.5,0,no,northwest,11840.78 56,male,39.6,0,no,southwest,10601.41 40,female,33.0,3,no,southeast,7682.67 49,female,36.6,3,no,southeast,10381.48 42,male,30.0,0,yes,southwest,22144.03 62,female,38.1,2,no,northeast,15230.32 56,male,25.9,0,no,northeast,11165.42 19,male,25.2,0,no,northwest,1632.04 30,female,28.4,1,yes,southeast,19521.97 60,female,28.7,1,no,southwest,13224.69 56,female,33.8,2,no,northwest,12643.38 28,female,24.3,1,no,northeast,23288.93 18,female,24.1,1,no,southeast,2201.1 27,male,32.7,0,no,southeast,2497.04 18,female,30.1,0,no,northeast,2203.47 19,female,29.8,0,no,southwest,1744.47 47,female,33.3,0,no,northeast,20878.78 54,male,25.1,3,yes,southwest,25382.3 61,male,28.3,1,yes,northwest,28868.66 24,male,28.5,0,yes,northeast,35147.53 25,male,35.6,0,no,northwest,2534.39 21,male,36.9,0,no,southeast,1534.3 23,male,32.6,0,no,southeast,1824.29 63,male,41.3,3,no,northwest,15555.19 49,male,37.5,2,no,southeast,9304.7 18,female,31.4,0,no,southeast,1622.19 51,female,39.5,1,no,southwest,9880.07 48,male,34.3,3,no,southwest,9563.03 31,female,31.1,0,no,northeast,4347.02 54,female,21.5,3,no,northwest,12475.35 19,male,28.7,0,no,southwest,1253.94 44,female,38.1,0,yes,southeast,48885.14 53,male,31.2,1,no,northwest,10461.98 19,female,32.9,0,no,southwest,1748.77 61,female,25.1,0,no,southeast,24513.09 18,female,25.1,0,no,northeast,2196.47 61,male,43.4,0,no,southwest,12574.05 21,male,25.7,4,yes,southwest,17942.11 20,male,27.9,0,no,northeast,1967.02 31,female,23.6,2,no,southwest,4931.65 45,male,28.7,2,no,southwest,8027.97 44,female,24.0,2,no,southeast,8211.1 62,female,39.2,0,no,southwest,13470.86 29,male,34.4,0,yes,southwest,36197.7 43,male,26.0,0,no,northeast,6837.37 51,male,23.2,1,yes,southeast,22218.11 19,male,30.3,0,yes,southeast,32548.34 38,female,28.9,1,no,southeast,5974.38 37,male,30.9,3,no,northwest,6796.86 22,male,31.4,1,no,northwest,2643.27 21,male,23.8,2,no,northwest,3077.1 24,female,25.3,0,no,northeast,3044.21 57,female,28.7,0,no,southwest,11455.28 56,male,32.1,1,no,northeast,11763 27,male,33.7,0,no,southeast,2498.41 51,male,22.4,0,no,northeast,9361.33 19,male,30.4,0,no,southwest,1256.3 39,male,28.3,1,yes,southwest,21082.16 58,male,35.7,0,no,southwest,11362.76 20,male,35.3,1,no,southeast,27724.29 45,male,30.5,2,no,northwest,8413.46 35,female,31.0,1,no,southwest,5240.77 31,male,30.9,0,no,northeast,3857.76 50,female,27.4,0,no,northeast,25656.58 32,female,44.2,0,no,southeast,3994.18 51,female,33.9,0,no,northeast,9866.3 38,female,37.7,0,no,southeast,5397.62 42,male,26.1,1,yes,southeast,38245.59 18,female,33.9,0,no,southeast,11482.63 19,female,30.6,2,no,northwest,24059.68 51,female,25.8,1,no,southwest,9861.03 46,male,39.4,1,no,northeast,8342.91 18,male,25.5,0,no,northeast,1708 57,male,42.1,1,yes,southeast,48675.52 62,female,31.7,0,no,northeast,14043.48 59,male,29.7,2,no,southeast,12925.89 37,male,36.2,0,no,southeast,19214.71 64,male,40.5,0,no,southeast,13831.12 38,male,28.0,1,no,northeast,6067.13 33,female,38.9,3,no,southwest,5972.38 46,female,30.2,2,no,southwest,8825.09 46,female,28.1,1,no,southeast,8233.1 53,male,31.4,0,no,southeast,27346.04 34,female,38.0,3,no,southwest,6196.45 20,female,31.8,2,no,southeast,3056.39 63,female,36.3,0,no,southeast,13887.2 54,female,47.4,0,yes,southeast,63770.43 54,male,30.2,0,no,northwest,10231.5 49,male,25.8,2,yes,northwest,23807.24 28,male,35.4,0,no,northeast,3268.85 54,female,46.7,2,no,southwest,11538.42 25,female,28.6,0,no,northeast,3213.62 43,female,46.2,0,yes,southeast,45863.21 63,male,30.8,0,no,southwest,13390.56 32,female,28.9,0,no,southeast,3972.92 62,male,21.4,0,no,southwest,12957.12 52,female,31.7,2,no,northwest,11187.66 25,female,41.3,0,no,northeast,17878.9 28,male,23.8,2,no,southwest,3847.67 46,male,33.4,1,no,northeast,8334.59 34,male,34.2,0,no,southeast,3935.18 35,female,34.1,3,yes,northwest,39983.43 19,male,35.5,0,no,northwest,1646.43 46,female,20.0,2,no,northwest,9193.84 54,female,32.7,0,no,northeast,10923.93 27,male,30.5,0,no,southwest,2494.02 50,male,44.8,1,no,southeast,9058.73 18,female,32.1,2,no,southeast,2801.26 19,female,30.5,0,no,northwest,2128.43 38,female,40.6,1,no,northwest,6373.56 41,male,30.6,2,no,northwest,7256.72 49,female,31.9,5,no,southwest,11552.9 48,male,40.6,2,yes,northwest,45702.02 31,female,29.1,0,no,southwest,3761.29 18,female,37.3,1,no,southeast,2219.45 30,female,43.1,2,no,southeast,4753.64 62,female,36.9,1,no,northeast,31620 57,female,34.3,2,no,northeast,13224.06 58,female,27.2,0,no,northwest,12222.9 22,male,26.8,0,no,southeast,1665 31,female,38.1,1,yes,northeast,58571.07 52,male,30.2,1,no,southwest,9724.53 25,female,23.5,0,no,northeast,3206.49 59,male,25.5,1,no,northeast,12913.99 19,male,30.6,0,no,northwest,1639.56 39,male,45.4,2,no,southeast,6356.27 32,female,23.7,1,no,southeast,17626.24 19,male,20.7,0,no,southwest,1242.82 33,female,28.3,1,no,southeast,4779.6 21,male,20.2,3,no,northeast,3861.21 34,female,30.2,1,yes,northwest,43943.88 61,female,35.9,0,no,northeast,13635.64 38,female,30.7,1,no,southeast,5976.83 58,female,29.0,0,no,southwest,11842.44 47,male,19.6,1,no,northwest,8428.07 20,male,31.1,2,no,southeast,2566.47 21,female,21.9,1,yes,northeast,15359.1 41,male,40.3,0,no,southeast,5709.16 46,female,33.7,1,no,northeast,8823.99 42,female,29.5,2,no,southeast,7640.31 34,female,33.3,1,no,northeast,5594.85 43,male,32.6,2,no,southwest,7441.5 52,female,37.5,2,no,northwest,33471.97 18,female,39.2,0,no,southeast,1633.04 51,male,31.6,0,no,northwest,9174.14 56,female,25.3,0,no,southwest,11070.54 64,female,39.1,3,no,southeast,16085.13 19,female,28.3,0,yes,northwest,17468.98 51,female,34.1,0,no,southeast,9283.56 27,female,25.2,0,no,northeast,3558.62 59,female,23.7,0,yes,northwest,25678.78 28,male,27.0,2,no,northeast,4435.09 30,male,37.8,2,yes,southwest,39241.44 47,female,29.4,1,no,southeast,8547.69 38,female,34.8,2,no,southwest,6571.54 18,female,33.2,0,no,northeast,2207.7 34,female,19.0,3,no,northeast,6753.04 20,female,33.0,0,no,southeast,1880.07 47,female,36.6,1,yes,southeast,42969.85 56,female,28.6,0,no,northeast,11658.12 49,male,25.6,2,yes,southwest,23306.55 19,female,33.1,0,yes,southeast,34439.86 55,female,37.1,0,no,southwest,10713.64 30,male,31.4,1,no,southwest,3659.35 37,male,34.1,4,yes,southwest,40182.25 49,female,21.3,1,no,southwest,9182.17 18,male,33.5,0,yes,northeast,34617.84 59,male,28.8,0,no,northwest,12129.61 29,female,26.0,0,no,northwest,3736.46 36,male,28.9,3,no,northeast,6748.59 33,male,42.5,1,no,southeast,11326.71 58,male,38.0,0,no,southwest,11365.95 44,female,39.0,0,yes,northwest,42983.46 53,male,36.1,1,no,southwest,10085.85 24,male,29.3,0,no,southwest,1977.82 29,female,35.5,0,no,southeast,3366.67 40,male,22.7,2,no,northeast,7173.36 51,male,39.7,1,no,southwest,9391.35 64,male,38.2,0,no,northeast,14410.93 19,female,24.5,1,no,northwest,2709.11 35,female,38.1,2,no,northeast,24915.05 39,male,26.4,0,yes,northeast,20149.32 56,male,33.7,4,no,southeast,12949.16 33,male,42.4,5,no,southwest,6666.24 42,male,28.3,3,yes,northwest,32787.46 61,male,33.9,0,no,northeast,13143.86 23,female,35.0,3,no,northwest,4466.62 43,male,35.3,2,no,southeast,18806.15 48,male,30.8,3,no,northeast,10141.14 39,male,26.2,1,no,northwest,6123.57 40,female,23.4,3,no,northeast,8252.28 18,male,28.5,0,no,northeast,1712.23 58,female,33.0,0,no,northeast,12430.95 49,female,42.7,2,no,southeast,9800.89 53,female,39.6,1,no,southeast,10579.71 48,female,31.1,0,no,southeast,8280.62 45,female,36.3,2,no,southeast,8527.53 59,female,35.2,0,no,southeast,12244.53 52,female,25.3,2,yes,southeast,24667.42 26,female,42.4,1,no,southwest,3410.32 27,male,33.2,2,no,northwest,4058.71 48,female,35.9,1,no,northeast,26392.26 57,female,28.8,4,no,northeast,14394.4 37,male,46.5,3,no,southeast,6435.62 57,female,24.0,1,no,southeast,22192.44 32,female,31.5,1,no,northeast,5148.55 18,male,33.7,0,no,southeast,1136.4 64,female,23.0,0,yes,southeast,27037.91 43,male,38.1,2,yes,southeast,42560.43 49,male,28.7,1,no,southwest,8703.46 40,female,32.8,2,yes,northwest,40003.33 62,male,32.0,0,yes,northeast,45710.21 40,female,29.8,1,no,southeast,6500.24 30,male,31.6,3,no,southeast,4837.58 29,female,31.2,0,no,northeast,3943.6 36,male,29.7,0,no,southeast,4399.73 41,female,31.0,0,no,southeast,6185.32 44,female,43.9,2,yes,southeast,46200.99 45,male,21.4,0,no,northwest,7222.79 55,female,40.8,3,no,southeast,12485.8 60,male,31.4,3,yes,northwest,46130.53 56,male,36.1,3,no,southwest,12363.55 49,female,23.2,2,no,northwest,10156.78 21,female,17.4,1,no,southwest,2585.27 19,male,20.3,0,no,southwest,1242.26 39,male,35.3,2,yes,southwest,40103.89 53,male,24.3,0,no,northwest,9863.47 33,female,18.5,1,no,southwest,4766.02 53,male,26.4,2,no,northeast,11244.38 42,male,26.1,2,no,northeast,7729.65 40,male,41.7,0,no,southeast,5438.75 47,female,24.1,1,no,southwest,26236.58 27,male,31.1,1,yes,southeast,34806.47 21,male,27.4,0,no,northeast,2104.11 47,male,36.2,1,no,southwest,8068.19 20,male,32.4,1,no,northwest,2362.23 24,male,23.7,0,no,northwest,2352.97 27,female,34.8,1,no,southwest,3578 26,female,40.2,0,no,northwest,3201.25 53,female,32.3,2,no,northeast,29186.48 41,male,35.8,1,yes,southeast,40273.65 56,male,33.7,0,no,northwest,10976.25 23,female,39.3,2,no,southeast,3500.61 21,female,34.9,0,no,southeast,2020.55 50,female,44.7,0,no,northeast,9541.7 53,male,41.5,0,no,southeast,9504.31 34,female,26.4,1,no,northwest,5385.34 47,female,29.5,1,no,northwest,8930.93 33,female,32.9,2,no,southwest,5375.04 51,female,38.1,0,yes,southeast,44400.41 49,male,28.7,3,no,northwest,10264.44 31,female,30.5,3,no,northeast,6113.23 36,female,27.7,0,no,northeast,5469.01 18,male,35.2,1,no,southeast,1727.54 50,female,23.5,2,no,southeast,10107.22 43,female,30.7,2,no,northwest,8310.84 20,male,40.5,0,no,northeast,1984.45 24,female,22.6,0,no,southwest,2457.5 60,male,28.9,0,no,southwest,12146.97 49,female,22.6,1,no,northwest,9566.99 60,male,24.3,1,no,northwest,13112.6 51,female,36.7,2,no,northwest,10848.13 58,female,33.4,0,no,northwest,12231.61 51,female,40.7,0,no,northeast,9875.68 53,male,36.6,3,no,southwest,11264.54 62,male,37.4,0,no,southwest,12979.36 19,male,35.4,0,no,southwest,1263.25 50,female,27.1,1,no,northeast,10106.13 30,female,39.1,3,yes,southeast,40932.43 41,male,28.4,1,no,northwest,6664.69 29,female,21.8,1,yes,northeast,16657.72 18,female,40.3,0,no,northeast,2217.6 41,female,36.1,1,no,southeast,6781.35 35,male,24.4,3,yes,southeast,19362 53,male,21.4,1,no,southwest,10065.41 24,female,30.1,3,no,southwest,4234.93 48,female,27.3,1,no,northeast,9447.25 59,female,32.1,3,no,southwest,14007.22 49,female,34.8,1,no,northwest,9583.89 37,female,38.4,0,yes,southeast,40419.02 26,male,23.7,2,no,southwest,3484.33 23,male,31.7,3,yes,northeast,36189.1 29,male,35.5,2,yes,southwest,44585.46 45,male,24.0,2,no,northeast,8604.48 27,male,29.2,0,yes,southeast,18246.5 53,male,34.1,0,yes,northeast,43254.42 31,female,26.6,0,no,southeast,3757.84 50,male,26.4,0,no,northwest,8827.21 50,female,30.1,1,no,northwest,9910.36 34,male,27.0,2,no,southwest,11737.85 19,male,21.8,0,no,northwest,1627.28 47,female,36.0,1,no,southwest,8556.91 28,male,30.9,0,no,northwest,3062.51 37,female,26.4,0,yes,southeast,19539.24 21,male,29.0,0,no,northwest,1906.36 64,male,37.9,0,no,northwest,14210.54 58,female,22.8,0,no,southeast,11833.78 24,male,33.6,4,no,northeast,17128.43 31,male,27.6,2,no,northeast,5031.27 39,female,22.8,3,no,northeast,7985.82 47,female,27.8,0,yes,southeast,23065.42 30,male,37.4,3,no,northeast,5428.73 18,male,38.2,0,yes,southeast,36307.8 22,female,34.6,2,no,northeast,3925.76 23,male,35.2,1,no,southwest,2416.96 33,male,27.1,1,yes,southwest,19040.88 27,male,26.0,0,no,northeast,3070.81 45,female,25.2,2,no,northeast,9095.07 57,female,31.8,0,no,northwest,11842.62 47,male,32.3,1,no,southwest,8062.76 42,female,29.0,1,no,southwest,7050.64 64,female,39.7,0,no,southwest,14319.03 38,female,19.5,2,no,northwest,6933.24 61,male,36.1,3,no,southwest,27941.29 53,female,26.7,2,no,southwest,11150.78 44,female,36.5,0,no,northeast,12797.21 19,female,28.9,0,yes,northwest,17748.51 41,male,34.2,2,no,northwest,7261.74 51,male,33.3,3,no,southeast,10560.49 40,male,32.3,2,no,northwest,6986.7 45,male,39.8,0,no,northeast,7448.4 35,male,34.3,3,no,southeast,5934.38 53,male,28.9,0,no,northwest,9869.81 30,male,24.4,3,yes,southwest,18259.22 18,male,41.1,0,no,southeast,1146.8 51,male,36.0,1,no,southeast,9386.16 50,female,27.6,1,yes,southwest,24520.26 31,female,29.3,1,no,southeast,4350.51 35,female,27.7,3,no,southwest,6414.18 60,male,37.0,0,no,northeast,12741.17 21,male,36.9,0,no,northwest,1917.32 29,male,22.5,3,no,northeast,5209.58 62,female,29.9,0,no,southeast,13457.96 39,female,41.8,0,no,southeast,5662.23 19,male,27.6,0,no,southwest,1252.41 22,female,23.2,0,no,northeast,2731.91 53,male,20.9,0,yes,southeast,21195.82 39,female,31.9,2,no,northwest,7209.49 27,male,28.5,0,yes,northwest,18310.74 30,male,44.2,2,no,southeast,4266.17 30,female,22.9,1,no,northeast,4719.52 58,female,33.1,0,no,southwest,11848.14 33,male,24.8,0,yes,northeast,17904.53 42,female,26.2,1,no,southeast,7046.72 64,female,36.0,0,no,southeast,14313.85 21,male,22.3,1,no,southwest,2103.08 18,female,42.2,0,yes,southeast,38792.69 23,male,26.5,0,no,southeast,1815.88 45,female,35.8,0,no,northwest,7731.86 40,female,41.4,1,no,northwest,28476.73 19,female,36.6,0,no,northwest,2136.88 18,male,30.1,0,no,southeast,1131.51 25,male,25.8,1,no,northeast,3309.79 46,female,30.8,3,no,southwest,9414.92 33,female,42.9,3,no,northwest,6360.99 54,male,21.0,2,no,southeast,11013.71 28,male,22.5,2,no,northeast,4428.89 36,male,34.4,2,no,southeast,5584.31 20,female,31.5,0,no,southeast,1877.93 24,female,24.2,0,no,northwest,2842.76 23,male,37.1,3,no,southwest,3597.6 47,female,26.1,1,yes,northeast,23401.31 33,female,35.5,0,yes,northwest,55135.4 45,male,33.7,1,no,southwest,7445.92 26,male,17.7,0,no,northwest,2680.95 18,female,31.1,0,no,southeast,1621.88 44,female,29.8,2,no,southeast,8219.2 60,male,24.3,0,no,northwest,12523.6 64,female,31.8,2,no,northeast,16069.08 56,male,31.8,2,yes,southeast,43813.87 36,male,28.0,1,yes,northeast,20773.63 41,male,30.8,3,yes,northeast,39597.41 39,male,21.9,1,no,northwest,6117.49 63,male,33.1,0,no,southwest,13393.76 36,female,25.8,0,no,northwest,5266.37 28,female,23.8,2,no,northwest,4719.74 58,male,34.4,0,no,northwest,11743.93 36,male,33.8,1,no,northwest,5377.46 42,male,36.0,2,no,southeast,7160.33 36,male,31.5,0,no,southwest,4402.23 56,female,28.3,0,no,northeast,11657.72 35,female,23.5,2,no,northeast,6402.29 59,female,31.4,0,no,northwest,12622.18 21,male,31.1,0,no,southwest,1526.31 59,male,24.7,0,no,northeast,12323.94 23,female,32.8,2,yes,southeast,36021.01 57,female,29.8,0,yes,southeast,27533.91 53,male,30.5,0,no,northeast,10072.06 60,female,32.5,0,yes,southeast,45008.96 51,female,34.2,1,no,southwest,9872.7 23,male,50.4,1,no,southeast,2438.06 27,female,24.1,0,no,southwest,2974.13 55,male,32.8,0,no,northwest,10601.63 37,female,30.8,0,yes,northeast,37270.15 61,male,32.3,2,no,northwest,14119.62 46,female,35.5,0,yes,northeast,42111.66 53,female,23.8,2,no,northeast,11729.68 49,female,23.8,3,yes,northeast,24106.91 20,female,29.6,0,no,southwest,1875.34 48,female,33.1,0,yes,southeast,40974.16 25,male,24.1,0,yes,northwest,15817.99 25,female,32.2,1,no,southeast,18218.16 57,male,28.1,0,no,southwest,10965.45 37,female,47.6,2,yes,southwest,46113.51 38,female,28.0,3,no,southwest,7151.09 55,female,33.5,2,no,northwest,12269.69 36,female,19.9,0,no,northeast,5458.05 51,male,25.4,0,no,southwest,8782.47 40,male,29.9,2,no,southwest,6600.36 18,male,37.3,0,no,southeast,1141.45 57,male,43.7,1,no,southwest,11576.13 61,male,23.7,0,no,northeast,13129.6 25,female,24.3,3,no,southwest,4391.65 50,male,36.2,0,no,southwest,8457.82 26,female,29.5,1,no,southeast,3392.37 42,male,24.9,0,no,southeast,5966.89 43,male,30.1,1,no,southwest,6849.03 44,male,21.9,3,no,northeast,8891.14 23,female,28.1,0,no,northwest,2690.11 49,female,27.1,1,no,southwest,26140.36 33,male,33.4,5,no,southeast,6653.79 41,male,28.8,1,no,southwest,6282.24 37,female,29.5,2,no,southwest,6311.95 22,male,34.8,3,no,southwest,3443.06 23,male,27.4,1,no,northwest,2789.06 21,female,22.1,0,no,northeast,2585.85 51,female,37.1,3,yes,northeast,46255.11 25,male,26.7,4,no,northwest,4877.98 32,male,28.9,1,yes,southeast,19719.69 57,male,29.0,0,yes,northeast,27218.44 36,female,30.0,0,no,northwest,5272.18 22,male,39.5,0,no,southwest,1682.6 57,male,33.6,1,no,northwest,11945.13 64,female,26.9,0,yes,northwest,29330.98 36,female,29.0,4,no,southeast,7243.81 54,male,24.0,0,no,northeast,10422.92 47,male,38.9,2,yes,southeast,44202.65 62,male,32.1,0,no,northeast,13555 61,female,44.0,0,no,southwest,13063.88 43,female,20.0,2,yes,northeast,19798.05 19,male,25.6,1,no,northwest,2221.56 18,female,40.3,0,no,southeast,1634.57 19,female,22.5,0,no,northwest,2117.34 49,male,22.5,0,no,northeast,8688.86 60,male,40.9,0,yes,southeast,48673.56 26,male,27.3,3,no,northeast,4661.29 49,male,36.9,0,no,southeast,8125.78 60,female,35.1,0,no,southwest,12644.59 26,female,29.4,2,no,northeast,4564.19 27,male,32.6,3,no,northeast,4846.92 44,female,32.3,1,no,southeast,7633.72 63,male,39.8,3,no,southwest,15170.07 32,female,24.6,0,yes,southwest,17496.31 22,male,28.3,1,no,northwest,2639.04 18,male,31.7,0,yes,northeast,33732.69 59,female,26.7,3,no,northwest,14382.71 44,female,27.5,1,no,southwest,7626.99 33,male,24.6,2,no,northwest,5257.51 24,female,34.0,0,no,southeast,2473.33 43,female,26.9,0,yes,northwest,21774.32 45,male,22.9,0,yes,northeast,35069.37 61,female,28.2,0,no,southwest,13041.92 35,female,34.2,1,no,southeast,5245.23 62,female,25.0,0,no,southwest,13451.12 62,female,33.2,0,no,southwest,13462.52 38,male,31.0,1,no,southwest,5488.26 34,male,35.8,0,no,northwest,4320.41 43,male,23.2,0,no,southwest,6250.44 50,male,32.1,2,no,northeast,25333.33 19,female,23.4,2,no,southwest,2913.57 57,female,20.1,1,no,southwest,12032.33 62,female,39.2,0,no,southeast,13470.8 41,male,34.2,1,no,southeast,6289.75 26,male,46.5,1,no,southeast,2927.06 39,female,32.5,1,no,southwest,6238.3 46,male,25.8,5,no,southwest,10096.97 45,female,35.3,0,no,southwest,7348.14 32,male,37.2,2,no,southeast,4673.39 59,female,27.5,0,no,southwest,12233.83 44,male,29.7,2,no,northeast,32108.66 39,female,24.2,5,no,northwest,8965.8 18,male,26.2,2,no,southeast,2304 53,male,29.5,0,no,southeast,9487.64 18,male,23.2,0,no,southeast,1121.87 50,female,46.1,1,no,southeast,9549.57 18,female,40.2,0,no,northeast,2217.47 19,male,22.6,0,no,northwest,1628.47 62,male,39.9,0,no,southeast,12982.87 56,female,35.8,1,no,southwest,11674.13 42,male,35.8,2,no,southwest,7160.09 37,male,34.2,1,yes,northeast,39047.29 42,male,31.3,0,no,northwest,6358.78 25,male,29.7,3,yes,southwest,19933.46 57,male,18.3,0,no,northeast,11534.87 51,male,42.9,2,yes,southeast,47462.89 30,female,28.4,1,no,northwest,4527.18 44,male,30.2,2,yes,southwest,38998.55 34,male,27.8,1,yes,northwest,20009.63 31,male,39.5,1,no,southeast,3875.73 54,male,30.8,1,yes,southeast,41999.52 24,male,26.8,1,no,northwest,12609.89 43,male,35.0,1,yes,northeast,41034.22 48,male,36.7,1,no,northwest,28468.92 19,female,39.6,1,no,northwest,2730.11 29,female,25.9,0,no,southwest,3353.28 63,female,35.2,1,no,southeast,14474.68 46,male,24.8,3,no,northeast,9500.57 52,male,36.8,2,no,northwest,26467.1 35,male,27.1,1,no,southwest,4746.34 51,male,24.8,2,yes,northwest,23967.38 44,male,25.4,1,no,northwest,7518.03 21,male,25.7,2,no,northeast,3279.87 39,female,34.3,5,no,southeast,8596.83 50,female,28.2,3,no,southeast,10702.64 34,female,23.6,0,no,northeast,4992.38 22,female,20.2,0,no,northwest,2527.82 19,female,40.5,0,no,southwest,1759.34 26,male,35.4,0,no,southeast,2322.62 29,male,22.9,0,yes,northeast,16138.76 48,male,40.2,0,no,southeast,7804.16 26,male,29.2,1,no,southeast,2902.91 45,female,40.0,3,no,northeast,9704.67 36,female,29.9,0,no,southeast,4889.04 54,male,25.5,1,no,northeast,25517.11 34,male,21.4,0,no,northeast,4500.34 31,male,25.9,3,yes,southwest,19199.94 27,female,30.6,1,no,northeast,16796.41 20,male,30.1,5,no,northeast,4915.06 44,female,25.8,1,no,southwest,7624.63 43,male,30.1,3,no,northwest,8410.05 45,female,27.6,1,no,northwest,28340.19 34,male,34.7,0,no,northeast,4518.83 24,female,20.5,0,yes,northeast,14571.89 26,female,19.8,1,no,southwest,3378.91 38,female,27.8,2,no,northeast,7144.86 50,female,31.6,2,no,southwest,10118.42 38,male,28.3,1,no,southeast,5484.47 27,female,20.0,3,yes,northwest,16420.49 39,female,23.3,3,no,northeast,7986.48 39,female,34.1,3,no,southwest,7418.52 63,female,36.9,0,no,southeast,13887.97 33,female,36.3,3,no,northeast,6551.75 36,female,26.9,0,no,northwest,5267.82 30,male,23.0,2,yes,northwest,17361.77 24,male,32.7,0,yes,southwest,34472.84 24,male,25.8,0,no,southwest,1972.95 48,male,29.6,0,no,southwest,21232.18 47,male,19.2,1,no,northeast,8627.54 29,male,31.7,2,no,northwest,4433.39 28,male,29.3,2,no,northeast,4438.26 47,male,28.2,3,yes,northwest,24915.22 25,male,25.0,2,no,northeast,23241.47 51,male,27.7,1,no,northeast,9957.72 48,female,22.8,0,no,southwest,8269.04 43,male,20.1,2,yes,southeast,18767.74 61,female,33.3,4,no,southeast,36580.28 48,male,32.3,1,no,northwest,8765.25 38,female,27.6,0,no,southwest,5383.54 59,male,25.5,0,no,northwest,12124.99 19,female,24.6,1,no,northwest,2709.24 26,female,34.2,2,no,southwest,3987.93 54,female,35.8,3,no,northwest,12495.29 21,female,32.7,2,no,northwest,26018.95 51,male,37.0,0,no,southwest,8798.59 22,female,31.0,3,yes,southeast,35595.59 47,male,36.1,1,yes,southeast,42211.14 18,male,23.3,1,no,southeast,1711.03 47,female,45.3,1,no,southeast,8569.86 21,female,34.6,0,no,southwest,2020.18 19,male,26.0,1,yes,northwest,16450.89 23,male,18.7,0,no,northwest,21595.38 54,male,31.6,0,no,southwest,9850.43 37,female,17.3,2,no,northeast,6877.98 46,female,23.7,1,yes,northwest,21677.28 55,female,35.2,0,yes,southeast,44423.8 30,female,27.9,0,no,northeast,4137.52 18,male,21.6,0,yes,northeast,13747.87 61,male,38.4,0,no,northwest,12950.07 54,female,23.0,3,no,southwest,12094.48 22,male,37.1,2,yes,southeast,37484.45 45,female,30.5,1,yes,northwest,39725.52 22,male,28.9,0,no,northeast,2250.84 19,male,27.3,2,no,northwest,22493.66 35,female,28.0,0,yes,northwest,20234.85 18,male,23.1,0,no,northeast,1704.7 20,male,30.7,0,yes,northeast,33475.82 28,female,25.8,0,no,southwest,3161.45 55,male,35.2,1,no,northeast,11394.07 43,female,24.7,2,yes,northwest,21880.82 43,female,25.1,0,no,northeast,7325.05 22,male,52.6,1,yes,southeast,44501.4 25,female,22.5,1,no,northwest,3594.17 49,male,30.9,0,yes,southwest,39727.61 44,female,37.0,1,no,northwest,8023.14 64,male,26.4,0,no,northeast,14394.56 49,male,29.8,1,no,northeast,9288.03 47,male,29.8,3,yes,southwest,25309.49 27,female,21.5,0,no,northwest,3353.47 55,male,27.6,0,no,northwest,10594.5 48,female,28.9,0,no,southwest,8277.52 45,female,31.8,0,no,southeast,17929.3 24,female,39.5,0,no,southeast,2480.98 32,male,33.8,1,no,northwest,4462.72 24,male,32.0,0,no,southeast,1981.58 57,male,27.9,1,no,southeast,11554.22 59,male,41.1,1,yes,southeast,48970.25 36,male,28.6,3,no,northwest,6548.2 29,female,25.6,4,no,southwest,5708.87 42,female,25.3,1,no,southwest,7045.5 48,male,37.3,2,no,southeast,8978.19 39,male,42.7,0,no,northeast,5757.41 63,male,21.7,1,no,northwest,14349.85 54,female,31.9,1,no,southeast,10928.85 37,male,37.1,1,yes,southeast,39871.7 63,male,31.4,0,no,northeast,13974.46 21,male,31.3,0,no,northwest,1909.53 54,female,28.9,2,no,northeast,12096.65 60,female,18.3,0,no,northeast,13204.29 32,female,29.6,1,no,southeast,4562.84 47,female,32.0,1,no,southwest,8551.35 21,male,26.0,0,no,northeast,2102.26 28,male,31.7,0,yes,southeast,34672.15 63,male,33.7,3,no,southeast,15161.53 18,male,21.8,2,no,southeast,11884.05 32,male,27.8,1,no,northwest,4454.4 38,male,20.0,1,no,northwest,5855.9 32,male,31.5,1,no,southwest,4076.5 62,female,30.5,2,no,northwest,15019.76 39,female,18.3,5,yes,southwest,19023.26 55,male,29.0,0,no,northeast,10796.35 57,male,31.5,0,no,northwest,11353.23 52,male,47.7,1,no,southeast,9748.91 56,male,22.1,0,no,southwest,10577.09 47,male,36.2,0,yes,southeast,41676.08 55,female,29.8,0,no,northeast,11286.54 23,male,32.7,3,no,southwest,3591.48 22,female,30.4,0,yes,northwest,33907.55 50,female,33.7,4,no,southwest,11299.34 18,female,31.4,4,no,northeast,4561.19 51,female,35.0,2,yes,northeast,44641.2 22,male,33.8,0,no,southeast,1674.63 52,female,30.9,0,no,northeast,23045.57 25,female,34.0,1,no,southeast,3227.12 33,female,19.1,2,yes,northeast,16776.3 53,male,28.6,3,no,southwest,11253.42 29,male,38.9,1,no,southeast,3471.41 58,male,36.1,0,no,southeast,11363.28 37,male,29.8,0,no,southwest,20420.6 54,female,31.2,0,no,southeast,10338.93 49,female,29.9,0,no,northwest,8988.16 50,female,26.2,2,no,northwest,10493.95 26,male,30.0,1,no,southwest,2904.09 45,male,20.4,3,no,southeast,8605.36 54,female,32.3,1,no,northeast,11512.41 38,male,38.4,3,yes,southeast,41949.24 48,female,25.9,3,yes,southeast,24180.93 28,female,26.3,3,no,northwest,5312.17 23,male,24.5,0,no,northeast,2396.1 55,male,32.7,1,no,southeast,10807.49 41,male,29.6,5,no,northeast,9222.4 25,male,33.3,2,yes,southeast,36124.57 33,male,35.8,1,yes,southeast,38282.75 30,female,20.0,3,no,northwest,5693.43 23,female,31.4,0,yes,southwest,34166.27 46,male,38.2,2,no,southeast,8347.16 53,female,36.9,3,yes,northwest,46661.44 27,female,32.4,1,no,northeast,18903.49 23,female,42.8,1,yes,northeast,40904.2 63,female,25.1,0,no,northwest,14254.61 55,male,29.9,0,no,southwest,10214.64 35,female,35.9,2,no,southeast,5836.52 34,male,32.8,1,no,southwest,14358.36 19,female,18.6,0,no,southwest,1728.9 39,female,23.9,5,no,southeast,8582.3 27,male,45.9,2,no,southwest,3693.43 57,male,40.3,0,no,northeast,20709.02 52,female,18.3,0,no,northwest,9991.04 28,male,33.8,0,no,northwest,19673.34 50,female,28.1,3,no,northwest,11085.59 44,female,25.0,1,no,southwest,7623.52 26,female,22.2,0,no,northwest,3176.29 33,male,30.3,0,no,southeast,3704.35 19,female,32.5,0,yes,northwest,36898.73 50,male,37.1,1,no,southeast,9048.03 41,female,32.6,3,no,southwest,7954.52 52,female,24.9,0,no,southeast,27117.99 39,male,32.3,2,no,southeast,6338.08 50,male,32.3,2,no,southwest,9630.4 52,male,32.8,3,no,northwest,11289.11 60,male,32.8,0,yes,southwest,52590.83 20,female,31.9,0,no,northwest,2261.57 55,male,21.5,1,no,southwest,10791.96 42,male,34.1,0,no,southwest,5979.73 18,female,30.3,0,no,northeast,2203.74 58,female,36.5,0,no,northwest,12235.84 43,female,32.6,3,yes,southeast,40941.29 35,female,35.8,1,no,northwest,5630.46 48,female,27.9,4,no,northwest,11015.17 36,female,22.1,3,no,northeast,7228.22 19,male,44.9,0,yes,southeast,39722.75 23,female,23.2,2,no,northwest,14426.07 20,female,30.6,0,no,northeast,2459.72 32,female,41.1,0,no,southwest,3989.84 43,female,34.6,1,no,northwest,7727.25 34,male,42.1,2,no,southeast,5124.19 30,male,38.8,1,no,southeast,18963.17 18,female,28.2,0,no,northeast,2200.83 41,female,28.3,1,no,northwest,7153.55 35,female,26.1,0,no,northeast,5227.99 57,male,40.4,0,no,southeast,10982.5 29,female,24.6,2,no,southwest,4529.48 32,male,35.2,2,no,southwest,4670.64 37,female,34.1,1,no,northwest,6112.35 18,male,27.4,1,yes,northeast,17178.68 43,female,26.7,2,yes,southwest,22478.6 56,female,41.9,0,no,southeast,11093.62 38,male,29.3,2,no,northwest,6457.84 29,male,32.1,2,no,northwest,4433.92 22,female,27.1,0,no,southwest,2154.36 52,female,24.1,1,yes,northwest,23887.66 40,female,27.4,1,no,southwest,6496.89 23,female,34.9,0,no,northeast,2899.49 31,male,29.8,0,yes,southeast,19350.37 42,female,41.3,1,no,northeast,7650.77 24,female,29.9,0,no,northwest,2850.68 25,female,30.3,0,no,southwest,2632.99 48,female,27.4,1,no,northeast,9447.38 23,female,28.5,1,yes,southeast,18328.24 45,male,23.6,2,no,northeast,8603.82 20,male,35.6,3,yes,northwest,37465.34 62,female,32.7,0,no,northwest,13844.8 43,female,25.3,1,yes,northeast,21771.34 23,female,28.0,0,no,southwest,13126.68 31,female,32.8,2,no,northwest,5327.4 41,female,21.8,1,no,northeast,13725.47 58,female,32.4,1,no,northeast,13019.16 48,female,36.6,0,no,northwest,8671.19 31,female,21.8,0,no,northwest,4134.08 19,female,27.9,3,no,northwest,18838.7 19,female,30.0,0,yes,northwest,33307.55 41,male,33.6,0,no,southeast,5699.84 40,male,29.4,1,no,northwest,6393.6 31,female,25.8,2,no,southwest,4934.71 37,male,24.3,2,no,northwest,6198.75 46,male,40.4,2,no,northwest,8733.23 22,male,32.1,0,no,northwest,2055.32 51,male,32.3,1,no,northeast,9964.06 18,female,27.3,3,yes,southeast,18223.45 35,male,17.9,1,no,northwest,5116.5 59,female,34.8,2,no,southwest,36910.61 36,male,33.4,2,yes,southwest,38415.47 37,female,25.6,1,yes,northeast,20296.86 59,male,37.1,1,no,southwest,12347.17 36,male,30.9,1,no,northwest,5373.36 39,male,34.1,2,no,southeast,23563.02 18,male,21.5,0,no,northeast,1702.46 52,female,33.3,2,no,southwest,10806.84 27,female,31.3,1,no,northwest,3956.07 18,male,39.1,0,no,northeast,12890.06 40,male,25.1,0,no,southeast,5415.66 29,male,37.3,2,no,southeast,4058.12 46,female,34.6,1,yes,southwest,41661.6 38,female,30.2,3,no,northwest,7537.16 30,female,21.9,1,no,northeast,4718.2 40,male,25.0,2,no,southeast,6593.51 50,male,25.3,0,no,southeast,8442.67 20,female,24.4,0,yes,southeast,26125.67 41,male,23.9,1,no,northeast,6858.48 33,female,39.8,1,no,southeast,4795.66 38,male,16.8,2,no,northeast,6640.54 42,male,37.2,2,no,southeast,7162.01 56,male,34.4,0,no,southeast,10594.23 58,male,30.3,0,no,northeast,11938.26 52,male,34.5,3,yes,northwest,60021.4 20,female,21.8,0,yes,southwest,20167.34 54,female,24.6,3,no,northwest,12479.71 58,male,23.3,0,no,southwest,11345.52 45,female,27.8,2,no,southeast,8515.76 26,male,31.1,0,no,northwest,2699.57 63,female,21.7,0,no,northeast,14449.85 58,female,28.2,0,no,northwest,12224.35 37,male,22.7,3,no,northeast,6985.51 25,female,42.1,1,no,southeast,3238.44 52,male,41.8,2,yes,southeast,47269.85 64,male,37.0,2,yes,southeast,49577.66 22,female,21.3,3,no,northwest,4296.27 28,female,33.1,0,no,southeast,3171.61 18,male,33.3,0,no,southeast,1135.94 28,male,24.3,5,no,southwest,5615.37 45,female,25.7,3,no,southwest,9101.8 33,male,29.4,4,no,southwest,6059.17 18,female,39.8,0,no,southeast,1633.96 32,male,33.6,1,yes,northeast,37607.53 24,male,29.8,0,yes,northeast,18648.42 19,male,19.8,0,no,southwest,1241.57 20,male,27.3,0,yes,southwest,16232.85 40,female,29.3,4,no,southwest,15828.82 34,female,27.7,0,no,southeast,4415.16 42,female,37.9,0,no,southwest,6474.01 51,female,36.4,3,no,northwest,11436.74 54,female,27.6,1,no,northwest,11305.93 55,male,37.7,3,no,northwest,30063.58 52,female,23.2,0,no,northeast,10197.77 32,female,20.5,0,no,northeast,4544.23 28,male,37.1,1,no,southwest,3277.16 41,female,28.1,1,no,southeast,6770.19 43,female,29.9,1,no,southwest,7337.75 49,female,33.3,2,no,northeast,10370.91 64,male,23.8,0,yes,southeast,26926.51 55,female,30.5,0,no,southwest,10704.47 24,male,31.1,0,yes,northeast,34254.05 20,female,33.3,0,no,southwest,1880.49 45,male,27.5,3,no,southwest,8615.3 26,male,33.9,1,no,northwest,3292.53 25,female,34.5,0,no,northwest,3021.81 43,male,25.5,5,no,southeast,14478.33 35,male,27.6,1,no,southeast,4747.05 26,male,27.1,0,yes,southeast,17043.34 57,male,23.7,0,no,southwest,10959.33 22,female,30.4,0,no,northeast,2741.95 32,female,29.7,0,no,northwest,4357.04 39,male,29.9,1,yes,northeast,22462.04 25,female,26.8,2,no,northwest,4189.11 48,female,33.3,0,no,southeast,8283.68 47,female,27.6,2,yes,northwest,24535.7 18,female,21.7,0,yes,northeast,14283.46 18,male,30.0,1,no,southeast,1720.35 61,male,36.3,1,yes,southwest,47403.88 47,female,24.3,0,no,northeast,8534.67 28,female,17.3,0,no,northeast,3732.63 36,female,25.9,1,no,southwest,5472.45 20,male,39.4,2,yes,southwest,38344.57 44,male,34.3,1,no,southeast,7147.47 38,female,20.0,2,no,northeast,7133.9 19,male,34.9,0,yes,southwest,34828.65 21,male,23.2,0,no,southeast,1515.34 46,male,25.7,3,no,northwest,9301.89 58,male,25.2,0,no,northeast,11931.13 20,male,22.0,1,no,southwest,1964.78 18,male,26.1,0,no,northeast,1708.93 28,female,26.5,2,no,southeast,4340.44 33,male,27.5,2,no,northwest,5261.47 19,female,25.7,1,no,northwest,2710.83 45,male,30.4,0,yes,southeast,62592.87 62,male,30.9,3,yes,northwest,46718.16 25,female,20.8,1,no,southwest,3208.79 43,male,27.8,0,yes,southwest,37829.72 42,male,24.6,2,yes,northeast,21259.38 24,female,27.7,0,no,southeast,2464.62 29,female,21.9,0,yes,northeast,16115.3 32,male,28.1,4,yes,northwest,21472.48 25,female,30.2,0,yes,southwest,33900.65 41,male,32.2,2,no,southwest,6875.96 42,male,26.3,1,no,northwest,6940.91 33,female,26.7,0,no,northwest,4571.41 34,male,42.9,1,no,southwest,4536.26 19,female,34.7,2,yes,southwest,36397.58 30,female,23.7,3,yes,northwest,18765.88 18,male,28.3,1,no,northeast,11272.33 19,female,20.6,0,no,southwest,1731.68 18,male,53.1,0,no,southeast,1163.46 35,male,39.7,4,no,northeast,19496.72 39,female,26.3,2,no,northwest,7201.7 31,male,31.1,3,no,northwest,5425.02 62,male,26.7,0,yes,northeast,28101.33 62,male,38.8,0,no,southeast,12981.35 42,female,40.4,2,yes,southeast,43896.38 31,male,25.9,1,no,northwest,4239.89 61,male,33.5,0,no,northeast,13143.34 42,female,32.9,0,no,northeast,7050.02 51,male,30.0,1,no,southeast,9377.9 23,female,24.2,2,no,northeast,22395.74 52,male,38.6,2,no,southwest,10325.21 57,female,25.7,2,no,southeast,12629.17 23,female,33.4,0,no,southwest,10795.94 52,female,44.7,3,no,southwest,11411.69 50,male,31.0,3,no,northwest,10600.55 18,female,31.9,0,no,northeast,2205.98 18,female,36.9,0,no,southeast,1629.83 21,female,25.8,0,no,southwest,2007.95 61,female,29.1,0,yes,northwest,29141.36 -------------------------------------------------------------------------------- /linear-regression/ex-insurance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.6.12-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3.6.12 64-bit ('learning': conda)", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "566c0a97317f6f88d4bc5f478002f1c75c862f0281a52c0ded6c5ead36971532" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "source": [ 31 | "# Linear Regression - Insurance Dataset" 32 | ], 33 | "cell_type": "markdown", 34 | "metadata": {} 35 | }, 36 | { 37 | "source": [ 38 | "## Import" 39 | ], 40 | "cell_type": "markdown", 41 | "metadata": {} 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 17, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import numpy as np\n", 50 | "import pandas as pd\n", 51 | "from matplotlib import pyplot as plt\n", 52 | "import seaborn as sns\n", 53 | "import statsmodels.api as sm\n", 54 | "import scipy.stats as stats\n", 55 | "from sklearn.preprocessing import LabelEncoder\n", 56 | "import copy\n", 57 | "sns.set()" 58 | ] 59 | }, 60 | { 61 | "source": [ 62 | "## Check Missing Values and Outliers" 63 | ], 64 | "cell_type": "markdown", 65 | "metadata": {} 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 18, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "output_type": "execute_result", 74 | "data": { 75 | "text/plain": [ 76 | " age sex bmi children smoker region expenses\n", 77 | "0 19 female 27.9 0 yes southwest 16884.92\n", 78 | "1 18 male 33.8 1 no southeast 1725.55\n", 79 | "2 28 male 33.0 3 no southeast 4449.46\n", 80 | "3 33 male 22.7 0 no northwest 21984.47\n", 81 | "4 32 male 28.9 0 no northwest 3866.86" 82 | ], 83 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
agesexbmichildrensmokerregionexpenses
019female27.90yessouthwest16884.92
118male33.81nosoutheast1725.55
228male33.03nosoutheast4449.46
333male22.70nonorthwest21984.47
432male28.90nonorthwest3866.86
\n
" 84 | }, 85 | "metadata": {}, 86 | "execution_count": 18 87 | } 88 | ], 89 | "source": [ 90 | "# Read the dataset\n", 91 | "insurance_df = pd.read_csv('data/insurance.csv')\n", 92 | "\n", 93 | "# Display the first 10 elements\n", 94 | "insurance_df.head()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 19, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "output_type": "stream", 104 | "name": "stdout", 105 | "text": [ 106 | "\nRangeIndex: 1338 entries, 0 to 1337\nData columns (total 7 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 age 1338 non-null int64 \n 1 sex 1338 non-null object \n 2 bmi 1338 non-null float64\n 3 children 1338 non-null int64 \n 4 smoker 1338 non-null object \n 5 region 1338 non-null object \n 6 expenses 1338 non-null float64\ndtypes: float64(2), int64(2), object(3)\nmemory usage: 73.3+ KB\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# Info about the data, including shapes and individual columns\n", 112 | "insurance_df.info()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 20, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "output_type": "execute_result", 122 | "data": { 123 | "text/plain": [ 124 | " age sex bmi children smoker region expenses\n", 125 | "False 1338 1338 1338 1338 1338 1338 1338" 126 | ], 127 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
agesexbmichildrensmokerregionexpenses
False1338133813381338133813381338
\n
" 128 | }, 129 | "metadata": {}, 130 | "execution_count": 20 131 | } 132 | ], 133 | "source": [ 134 | "# Check for the null values \n", 135 | "insurance_df.isna().apply(pd.value_counts)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 21, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "output_type": "display_data", 145 | "data": { 146 | "text/plain": "
", 147 | "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-01-02T17:01:45.377837\n image/svg+xml\n \n \n Matplotlib v3.3.3, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABGoAAANiCAYAAAA5Qy/lAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAABCvUlEQVR4nO3dfZzXdZ3v/+dcQIQXDF5woaJWZ+2QZ2vLynXFi7DW8gaI7S3tsMqWbhSsZrbuEa9y03KlzqZsoER62jxp5/bbm+l6dTYrUi7WPHbbzvGqVfOCQUAQBFQQZIbP7w92phkYGARmvm+Y+/0fh/l+vp/P6/OZ9/fLzMPvd6irqqoKAAAAADVXX+sBAAAAANhMqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFaOxug1Wr1mbTpqpHDn7ggftm5co3emTf7PmsD7pjjdAda4TuWCNsj/VBd6wRumON0JX6+roMHrzPNm/vNtRs2lT1WKhp2z9si/VBd6wRumON0B1rhO2xPuiONUJ3rBHeLm99AgAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhGms9AMDe4Pbbb82iRQtrPQZb6NevIRs3tna73Zo1q5MkgwY19exA9KoRI47IhAkTaz0GAMDbItQA7AaLFi1M81NP5ND162s9Ch10n2g2Wz1gQJJk3+d+13PD0KsW/8fXFABgTyPUAOwmh65fnwtfeKHWY7ATpr/rXUni67cXafuaAgDsafyOGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKESfCDULFszNggVzaz0GAABAMfycBGVqrPUAvWH+/IeSJMcff2KNJwEAACiDn5OgTH3iFTUAAAAAewKhBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAdsrq1aty3XVXp7n5xVx33dVZs2Z1++e2/HhH97Uj2+6uuXvjWDuitHl62s6cb1+6RkINAAAAO+Wee+7Ms88+ndmzZ+bZZ5/O3Xf/pP1zW368o/vakW1319y9cawdUdo8PW1nzrcvXSOhBgAAgLdt9epVmT//oVRVlSVLFqeqqsyb91DmzXtoq4/nz5+73VdCdNxXd9vuzrl7+lh74jw9bWfOt69do8ZaD9Ab1qxZnTVr1mTatGtqPQpvQ79+Ddm4sbXWY1CwktZIc/PC7NfYJ55SYY/wWmNjXm9e2O3f/SU9j1Ae64Pu7OlrpLl5YQYNGrTT97/nnjuzaVPV6XMtLS2pq9v6402bNuXuu3+Sc845t9t9dbftrurNY+2J8/S0nTnfvnaNvKIGAACAt+3hhxektbVli89Wqapqq49bW1vy8MMLdmhf3W27q3rzWHviPD1tZ863r12jPvG/fwcNasqgQU255JIraz0Kb8PBB++XV155vdZjULCS1si0adekdcUrtR4D+A/7t7Rk8OFHdPt3f0nPI5TH+qA7e/oa2dV3HBx33PGZO/fBLWJNXerq8h+B5vcfNzQ05rjjjt+hfXW37a7qzWPtifP0tJ053752jbyiBgAAgLdt7NgzUl9f1+lzjY2NaWho3Orj+vr6jBv36R3aV3fb7qrePNaeOE9P25nz7WvXSKgBAADgbWtqGpxRo05KXV1dDjnk0NTV1eWEE07KCSectNXHo0admEGDmnZoX91tuzvn7ulj7Ynz9LSdOd++do36xFufAAAA2P3Gjj0jixe/lAkTJub222/NuHGfTlVVWbz4pa0+3tF99carJXrzWDuitHl62s6cb1+6RkINAAAAO6WpaXCmTv1akrT/d3sf7+i+elpvHmtHlDZPT9uZ8+1L18hbnwAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFKKx1gP0hlGjTqr1CAAAAEXxcxKUqU+EmuOPP7HWIwAAABTFz0lQJm99AgAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQjTWegCAvcXiAQMy/V3vqvUY7ISXBgxIEl+/vcjiAQNyeK2HAADYCUINwG4wYsQRtR6BLvTr15CNG1u73a5pzeokScOgpp4diF5zeDwuAYA9k1ADsBtMmDCx1iPQhYMP3i+vvPJ6rccAAIAd5nfUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIRq726C+vq5HB+jp/bNnsz7ojjVCd6wRumONsD3WB92xRuiONcKWulsTdVVVVb00CwAAAADb4a1PAAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIVo7I2DTJs2LT/96U+zePHi3HPPPTnqqKOSJKNHj07//v3zjne8I0ly8cUX54QTTuiNkSjIqlWr8t/+239Lc3Nz+vfvnyOOOCJXX311DjjggLzwwguZOnVqVq9enaampkybNi1HHnlkrUeml21vjXgeoc2UKVPy0ksvpb6+PgMHDsyVV16ZkSNHeh4hybbXh+cQtjRjxox897vfbf+e1XMIHW25PjyH0NG21oPnEd62qhc8+uij1ZIlS6qPfexj1dNPP93++S3/TN+0atWq6le/+lX7n6+77rrq0ksvraqqqs4555zqrrvuqqqqqu66667qnHPOqcmM1Nb21ojnEdq89tpr7R//7Gc/q8aPH19VlecRNtvW+vAcQkdPPPFEdd5551Unn3xy+7rwHEKbrtaH5xA62tZ68DzC29Urb3368Ic/nOHDh/fGodgDNTU15dhjj23/8x/90R9lyZIlWblyZZ566qmMGTMmSTJmzJg89dRTefXVV2s1KjWyrTUCHe23337tH7/xxhupq6vzPEK7rtYHdPTWW2/l6quvzlVXXdW+PjyH0Kar9QE7wvMIO6NX3vq0PRdffHGqqsoxxxyTr371q9l///1rPRI1tGnTpvz4xz/O6NGjs3Tp0gwdOjQNDQ1JkoaGhgwZMiRLly7NAQccUONJqZWOa6SN5xHaXH755VmwYEGqqsrNN9/seYROtlwfbTyHkCTTp0/PuHHjMmLEiPbPeQ6hTVfro43nEDracj14HmFn1PSXCd922225++67c8cdd6Sqqlx99dW1HIcCXHPNNRk4cGDOPvvsWo9CobZcI55H6Oib3/xmHnzwwVx00UX51re+VetxKExX68NzCEnym9/8Jo8//ngmTJhQ61Eo0PbWh+cQOrIe2F1qGmra3g7Vv3//TJgwIf/2b/9Wy3GosWnTpmXhwoW54YYbUl9fn+HDh2fZsmVpbW1NkrS2tmb58uXeRteHbblGEs8jdG38+PF55JFHMmzYMM8jbKVtfaxatcpzCEmSRx99NM8//3xOOeWUjB49Oi+//HLOO++8NDc3ew5hm+tj/vz5nkPopKv14GcadkbNQs26devy+uuvJ0mqqsr999+fkSNH1mocauz666/PE088kZkzZ6Z///5JkgMPPDAjR47MvffemyS59957M3LkSC8R7KO6WiOeR2izdu3aLF26tP3Pc+bMyaBBgzyPkGTb6+Md73iH5xCSJJMmTcr8+fMzZ86czJkzJ8OGDcstt9yS0047zXMI21wfH/rQhzyH0G5b35f6XoSdUVdVVdXTB/nGN76RBx54ICtWrMjgwYPT1NSUWbNm5YILLkhra2s2bdqU97znPbniiisyZMiQnh6Hwjz77LMZM2ZMjjzyyAwYMCBJcthhh2XmzJl57rnnMnXq1Lz22mvZf//9M23atLz73e+u8cT0tm2tkalTp3oeIUmyYsWKTJkyJW+++Wbq6+szaNCgXHLJJTn66KM9j7DN9bH//vt7DqFLo0ePzqxZs3LUUUd5DmErbevjne98p+cQ2i1atGib68HzCG9Xr4QaAAAAALpX099RAwAAAMDvCTUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAUbfTo0fnXf/3XXd7PrFmzcvnll++GiQAAek5jrQcAAOgNX/rSl2o9AgBAt7yiBgAAAKAQQg0AULzHH388p512Wj7ykY/k0ksvzYYNG/LII4/kxBNPzPe///0cd9xxGTVqVH7+85/noYceyqmnnpqPfvSjmTVrVvs+vvvd7+biiy+u4VkAAHRPqAEAinfPPffklltuyc9+9rO88MILufHGG5MkK1asyIYNGzJ37tx8+ctfzhVXXJG77747d9xxR2677bbMnDkzixYtqvH0AAA7TqgBAIr353/+5xk+fHiampoyefLk3HfffUmSxsbGTJ48Of369ctpp52WVatWZeLEidl3333zB3/wB/mDP/iDPP300zWeHgBgxwk1AEDxhg8f3v7xIYcckuXLlydJmpqa0tDQkCQZMGBAkuTAAw9s3/Yd73hH1q5d24uTAgDsGqEGACje0qVL2z9esmRJhgwZUsNpAAB6jlADABTv9ttvz8svv5zVq1fne9/7Xk477bRajwQA0COEGgCgeGPGjMm5556bj3/84xkxYkQmT55c65EAAHpEXVVVVa2HAAAAAMAragAAAACKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhGrvbYNWqtdm0qeqNWejgwAP3zcqVb9R6DOgzPOag93i8Qe/ymIPe4/HGjqivr8vgwfts8/ZuQ82mTZVQUyOuO/QujznoPR5v0Ls85qD3eLyxq7z1CQAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFaKz1AEBt3H77rVm0aGGtxyhGv34N2bixtdZjQDHWrFmdJBk0qGm379vjDXpXTz3mRow4IhMmTNzt+wXo64Qa6KMWLVqYp556MevXH1rrUQqxsdYDQFEGDFiVJHnuuX16YO8eb9C7dv9jbsCAxbt9nwBsJtRAH7Z+/aF54YULaz0GUKB3vWt6kniOALrU9hwBwO7nd9QAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAh+kSoWbBgbhYsmFvrMQAAAICd0Jd+rm+s9QC9Yf78h5Ikxx9/Yo0nAQAAAN6uvvRzfZ94RQ0AAADAnkCoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQjTWeoDesGbN6qxZsybTpl1T61F2WL9+Ddm4sbXWY7AXa25emMbGfWs9BgCwB2psfC3NzW/sUd9fQ2/wc1zPaW5emEGDBtV6jF7hFTUAAAAAhegTr6gZNKgpgwY15ZJLrqz1KDvs4IP3yyuvvF7rMdiLTZt2TVas2FjrMQCAPVBLy/45/PAD96jvr6E3+Dmu5/SlV/B5RQ0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAoRGOtB+gNo0adVOsRAAAAgJ3Ul36u7xOh5vjjT6z1CAAAAMBO6ks/13vrEwAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFKKx1gMAtTNgwOK8613Taz0GUKABA15KEs8RQJcGDFic5MhajwGwVxJqoI8aMeKIWo9QlH79GrJxY2utx4BirFkzOEkyaFC/3b5vjzfoXT3zmDvS9xIAPUSogT5qwoSJtR6hKAcfvF9eeeX1Wo8BfYLHG/QujzmAPYvfUQMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIVo7G6D+vq63piDLrj20Ls85qD3eLxB7/KYg97j8UZ3ulsjdVVVVb00CwAAAADb4a1PAAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1NbZq1ap84QtfyKmnnpqxY8fm/PPPz6uvvpokeeGFF3LWWWfl1FNPzVlnnZUXX3yxtsPCXmDKlCkZN25cxo8fnwkTJuS3v/1tEo836GkzZszIe9/73jzzzDNJPOagp4wePTqf/OQnc/rpp+f000/PvHnzknjMQU/YsGFDrrrqqvzpn/5pxo4dmyuvvDKJxxu7rq6qqqrWQ/Rlq1evztNPP51jjz02STJt2rSsWbMm1157bSZOnJg/+7M/y+mnn55//ud/zh133JFbb721xhPDnu3111/PfvvtlyT5+c9/npkzZ+bOO+/0eIMe9OSTT+b666/Pc889l+9973s56qijPOagh4wePTqzZs3KUUcd1enzHnOw+33jG99IfX19Lr300tTV1WXFihU56KCDPN7YZV5RU2NNTU3tkSZJ/uiP/ihLlizJypUr89RTT2XMmDFJkjFjxuSpp55qf7UNsHPaIk2SvPHGG6mrq/N4gx701ltv5eqrr85VV12Vurq6JPGYg17mMQe739q1a3PXXXflwgsvbP/77aCDDvJ4Y7dorPUA/N6mTZvy4x//OKNHj87SpUszdOjQNDQ0JEkaGhoyZMiQLF26NAcccECNJ4U92+WXX54FCxakqqrcfPPNHm/Qg6ZPn55x48ZlxIgR7Z/zmIOedfHFF6eqqhxzzDH56le/6jEHPWDRokVpamrKjBkz8sgjj2SfffbJhRdemAEDBni8scu8oqYg11xzTQYOHJizzz671qPAXu2b3/xmHnzwwVx00UX51re+VetxYK/1m9/8Jo8//ngmTJhQ61Ggz7jtttty991354477khVVbn66qtrPRLslVpaWrJo0aK8733vy09+8pNcfPHFueCCC7Ju3bpaj8ZeQKgpxLRp07Jw4cLccMMNqa+vz/Dhw7Ns2bK0trYmSVpbW7N8+fIMHz68xpPC3mP8+PF55JFHMmzYMI836AGPPvponn/++ZxyyikZPXp0Xn755Zx33nlpbm72mIMe0vY46t+/fyZMmJB/+7d/830l9IBDDjkkjY2N7W9x+sAHPpDBgwdnwIABHm/sMqGmANdff32eeOKJzJw5M/3790+SHHjggRk5cmTuvffeJMm9996bkSNHerkc7IK1a9dm6dKl7X+eM2dOBg0a5PEGPWTSpEmZP39+5syZkzlz5mTYsGG55ZZbctppp3nMQQ9Yt25dXn/99SRJVVW5//77M3LkSH/PQQ844IADcuyxx2bBggVJNv9LTytXrsyRRx7p8cYu868+1dizzz6bMWPG5Mgjj8yAAQOSJIcddlhmzpyZ5557LlOnTs1rr72W/fffP9OmTcu73/3uGk8Me64VK1ZkypQpefPNN1NfX59BgwblkksuydFHH+3xBr2g479G4zEHu9+iRYtywQUXpLW1NZs2bcp73vOeXHHFFRkyZIjHHPSARYsW5bLLLsvq1avT2NiYr3zlKznppJM83thlQg0AAABAIbz1CQAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAo3uzZs/Pxj388H/zgB3PaaaflZz/7WZKktbU11113XY499tiMHj06P/rRj/Le9743LS0tSZLXX389l112WUaNGpUTTjgh119/fVpbW2t5KgAA29VY6wEAALozYsSI3HbbbTn44IPzL//yL/mbv/mbPPDAA/nFL36RuXPn5p//+Z/zzne+MxdeeGGn+11yySU56KCD8sADD+TNN9/MF7/4xQwfPjyf/exna3QmAADb5xU1AEDxPvWpT2Xo0KGpr6/PaaedliOOOCKPPfZY/vf//t+ZOHFihg0blkGDBmXSpEnt91mxYkXmzp2byy67LAMHDsyBBx6Yz33uc7nvvvtqeCYAANvnFTUAQPHuuuuu/OAHP8jixYuTJOvWrcuqVauyfPnyDB8+vH27YcOGtX+8ZMmStLS0ZNSoUe2f27RpU6ftAQBKI9QAAEVbvHhxrrjiivzjP/5jPvjBD6ahoSGnn356kuTggw/Oyy+/3L5tx4+HDRuW/v3751e/+lUaG33LAwDsGbz1CQAo2ptvvpm6uroccMABSZI77rgjzz77bJLNb4m69dZbs2zZsrz22mv5/ve/336/IUOG5Pjjj891112XN954I5s2bUpzc3P+z//5PzU5DwCAHSHUAABF+0//6T/l3HPPzWc/+9n8yZ/8SZ555pl86EMfSpKceeaZOf744zNu3LiMHz8+J510UhobG9PQ0JAk+da3vpWNGzfmtNNOy0c+8pF8+ctfziuvvFLL0wEA2K66qqqqWg8BALA7PPTQQ/nbv/3b/PKXv6z1KAAAO8UragCAPdb69evz0EMPpaWlJcuWLcvMmTPz8Y9/vNZjAQDsNK+oAQD2WG+++WbOPvvsPP/88xkwYEBOPvnkXH755dl3331rPRoAwE4RagAAAAAK4a1PAAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFKKxuw1WrVqbTZuq3pilSwceuG9WrnyjZseHxDqkDNYhJbAOKYF1SK1Zg5TAOtxz1dfXZfDgfbZ5e7ehZtOmqqahpm0GqDXrkBJYh5TAOqQE1iG1Zg1SAutw7+StTwAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAoRGOtB2D3uv32W7No0cJaj7FHWLNmdZJk0KCmbrft168hGze29uxA7BYjRhyRCRMm1noMAACAnSLU7GUWLVqY5uYncuih62s9SvFWrx6QJNl33991u22rRrNHWLx4QK1HAAAA2CVCzV7o0EPX58ILX6j1GMWbPv1dSeJa7UXavqYAAAB7Kr+jBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACtEnQs2CBXOzYMHcWo8BAOzFfL8BAOwOjbUeoDfMn/9QkuT440+s8SQAwN7K9xsAwO7QJ15RAwAAALAnEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABSisdYDAADsTVavXpUZM65PXV0yfvxnMmPGd5Ikl132txkx4oisXr0q06d/O8uWvZxLL70qI0Yc0em+N9zw7SxfvvVtHfe9YcP6LF++LBs3bsyXvvTlPPDA/amrS84//6upqiqzZn03kyd/OYMGNbXvc9mypRk6dHjOPfeLufXWW9LS0pIk6devsf1+M2ZcnySZOPG89m3q6pLGxs3bDBrU1D7HrFnfzYQJf5Fbb70lra0tSepSVVUaGxtzwQWbt33yycfy939/XQ47bET++q8vbT9GS0tLGhsbM3Hiebn99h+2z7rlNfjMZybkRz/6QQ466OC8/vpr+au/+mruvPP/S11dlZaW1jQ0/H4fEyb8RZf7arsWVVXl7//+77JkyeIMGzY873znwPY5t/d1TNK+XXPzi7nuuqszdOiwfOpTYzJr1owcdtiI/OVfTs7tt9/a6Zp3PO7mazUxt99+a8aOPSM33nhDpk79Wvt66Pj16mrutvX08Y9/KrNnz8hf//Wled/7/stWc55xxmcyc+b1GTp0eD7/+S/k1lv/R/u66HieXR1zW5qbX8y0addkypSv5J577ux0nyeffCzf+c60DBs2LAMH7rPVGpkx4/q0tra0r5+Oa3PLj6dP/3ZefnlpDj54SOrrGzqto64eW23H2tb163hNbrzxhkyZ8pXceec/pa4uOeecrdfdttbAjl6nXdFTx+mt+XfVnjLnzujpc3v11Vdz3XV/V/S1253XYG9eK1vyihoAgN3onnvuzPPP/y7PPfe73HTTP2TDhg3ZsGFDvve9Ge23L1z4YtavX9/+uY73bW7u+raO+168+KVs3LgxSfL979/Yfry77/5J7rnnzjz77NO5++6fdNrnhg0b0tz8YmbPnpHnn/9dmptfTHPzi53u9/zzv8vzz/+u0zYLF/5+m45zPPvs0+3bLVz4YhYufCHNzS/m+ed/v+1NN303SfLSS4s6HaNtu9mzZ3Sadctr8D//5w9SVVVeeWV51q9fn5tumv4f5/pcFi7svI9t7avtc/fcc2cWL34pVVVl6dIlnebc3tex43azZ8/M+vXrs3Dhi/n+92e1n9vs2TO3uuYdj7t5vs3b3HTTP+TNN9/stB66m7vt63vzzTelqqrceOP0Lue86aZ/+I/5Xsjs2TM7rYstz23LY27L7Nkz8+abb+amm/5hq/vcdNN3/+N6Lu1yjbStja7W5pYfL1y4eY2+9NKirdbRlvvseKxtXb+O16Rt/rb7drVWuvJ2rtOu6Knj9Nb8u2pPmXNn9PS5/a//9b+Kv3a78xrszWtlS0INAMBu0tKyMfPmPdj+53Xr1rZ/vGTJ4jz55BOdbl+yZHEWLVqYZPP/KZw7t+vb2m6fN++hrY65+dUsm82b92DmzXswVVVl/vy5aW5+MXPn/rLT9kuWLN5qH23329428+c/lDVrVmf16lWZP/+hVFXV5Xab9/dQHnnk4U7nP3fuLzsdo+04bbO27bvjNUiqTtt33N+W++hqX21zzpv3UB566Jdb3XfevM3ntKUtr/W8eQ/lyScf63S+Ha97x2M3Ny9sP+78+Q+1fz3atmk7hyVLFueppx7vsG3Xc3e8Hm3HXLdubZ566omt5txyvbVp+9q1nduWx9yW5uYX2/ezbt3aTvd58snHtvp6dFwj8+Y92Om2efM6rs3fX5ctz7HzfTrP3XGf8+c/lObmF7u8fl1dky2vTXfn/3au067oqeP01vy7ak+Zc2f09LmtXr0qv/jFL4q+drvzGuzNa6UrfeKtT2vWrM6aNWsybdo1tR6lxzU3L8x++/WJLyts5bXXGvP66wv32sd6v34N2bixtdZj0MdZh9vW3Lw5qrS0bPv63HTT9Pa3HLX53vdm5Bvf+HbuuefOTj/8d7wtSZe3b2nzvuuSJJs2bcrs2TPT2tr912vLmba1Tdv/xdy0qep225tvvqnT57Y3x6ZNm9r33d05dqfjvtrm3Hx+W8/cdk7nnHNup89vea1bWlraXx3U3bFnz57R6bjVdi7VjTf+Q/u2b2fuzfednmOPPW6HrlfH87znnju3OuaW599m9uyZXZ7j3Xf/JI888vA2j7P549atbmtbmx2vy/bOccu5O+6zpaUls2fP7PL67ega2t75v53rtCt66ji9Nf+u2lPm3Bk9fW6b97+px/a/O+zOa7A3r5WueEUNAMBusn79m9nWD53Jtl8RkiQPP7xgm7e13V5t76f+dpu3aW1t2eYrXnZGVVV5+OEFefjhBTvwg3D1toJLa2tL+753Vcd9/X6GbV23qstjbn2tqy6/dl0de8mSxe3H3byP7a+Htm3f3tyb77uja6Lta5ek0/7bjrktXa2ftvt0dT06rpGuZ6/at/v97dubv/PcHbdte5VSV9dvxx4n2z//t3OddkVPHae35t9Ve8qcO6Onz+3hhxe0R/ZSr93uvAZ781rpSp946cWgQU0ZNKgpl1xyZa1H6XHTpl2T1tZXaj0G1MT++7dk8OAj9trH+sEH75dXXnm91mPQx1mH2zZt2jV5+eWlWbNmTbb1w+fAgfts9QPuIYccmiQ57rjj88tf/rzL29puf/DBX+zAD6F1Sao0NDRm6NChuy3W1NXV5bjjjk+SzJ37YDchpi4NDQ07HGsaGhrb973lNXi7Ou7r93NuviZdzdm2bUdbX+u6DBw4sNtY03bNly1bltbWltTV1f3HK0e2vR42bNiQ1taWtzn35vsee+xxO7QmOn7tjjvu+Pb9dzxmVw455NCt1k/bfbZ8a9uWx/nlL3/Rxeybz6fzddn2OXb8+mx+fPx+n3V1dRk+/JD2a93xXHbscZLtnv/buU67oqeO01vz76o9Zc6d0dPndtxxx2fevAfT0lLutdud12BvXitd8YoaAIDd5KCDDkpjY8M2b588+cI0Nnb+/2Rf/OL5SZKxY89IQ0PXt23r9i01Nja2H7++vj6TJv1VGhq2PU/n+3W/73HjPp2xY89IfX1dt9v+5V9O7vS5hoaGbR6jvr6+fd/dnWN3Ou6rbc7GxsbU1299HdrOaUtbztHY2JjJky/YoWNPmnR+p+Nubz1MmfLl9m23Nfe2rseUKRfu8PXqeJ4d9992zG2ZNOmvujzHceM+3eX16LhGtjzvjtdiy4+3dQ5bzt1xn42NjZk06a+6vH47uoa2d/5v5zrtip46Tm/Nv6v2lDl3Rk+f2+b91/fY/neH3XkN9ua10hWhBgBgN2ls7JcTTji5/c8DB+7T/vEhhxyao4/+L51uP+SQQ9v/Ce6mpsE58cSub2u7/YQTTtrqmB1/KD3hhJNzwgknp66uLqNGnZjDDz8yJ574sU7bd3yVzpb32942o0adlEGDmtLUNDijRp2Uurq6LrfbvL+Tcuyxx3U6/xNP/FinY7Qdp23Wtn13vAZtv9OkTcf9bbmPrvbVNucJJ5yUk0762Fb3PeGEk7r8J163vNYnnHBSjj76/Z3Ot+N173jsww8/ov24o0ad1P71aNum7RwOOeTQvO99f9hh267n7ng92o45cOA+ed/7/stWc2653tq0fe3azm3LY27L4Ycf2b6fgQP36XSfo49+/1Zfj45rZMuvdee1+fvrsuU5dr5P57k77nPUqJNy+OFHdnn9uromW16b7s7/7VynXdFTx+mt+XfVnjLnzujpc2tqGpxTTjml6Gu3O6/B3rxWutIn3voEANBbxo49IwsXvpi6umT8+M9kxozvJOn8ypnnn/9dli17udMrZtpue+6532X58q1v67jvDRvWZ/nyZdm4cWO+8IUpeeCB+1NXl4wb9+lUVZXFi1/q9EqE5577XZYtW5qhQ4fn3HO/mFtvvaX9dxv069fYfr+FC19MkkyceF77NnV1W7/yZOzYM7J48UuZMOEvcuutt7S/Taeqqk7bTp58Qf7+76/LYYeN6HSMlpaWNDY2ZuLE83L77T/cat9t1+Azn5mQH/3oBznooIPz+uuvZfLkC3Pnnf9f6uqqtLS0pqHh9/uYMOEvutxX27WoqirPPPPvWbJkcYYNG553vnPgdv+PbNu1TtK+3aRJf5Xrrrs6Q4cOy6c+NSazZs3IYYeNyF/+5eTcfvutna55x+NuvlYTc/vtt2bs2DNy4403dFoPHb9eXd2/bT19/OOfyuzZMzJlyoVdznnGGZ/JzJnXZ+jQ4fn857+QW2/9H+3rYstz2/KY2zJp0l9l2rRrMnnyl3PPPXd2us/kyRfkO9+ZlmHDhmXgwH22OoeFC19Ma2tL+5rouDa3/Pj553+Xl19emoMPHpL6+oYuX+3U8bHV1bXe1jW58cYbMnnyl3Pnnf+UurrknHO2XnddeTvXaVf01HF6a/5dtafMuTN6+tw++9nP5rnnXij62u3Oa7A3r5Ut1VXdvIFz5co3uv3N/j1pd7wXvu1fgNlbf29FR5t/R82vc+GFL9R6lOJNn/6uJHGt9iLTp78rDQ0f3msf6343CCWwDretL32/UWvWIbVmDVIC63DPVV9flwMP3Hfbt/fiLAAAAABsh1ADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAAAABRCqAEAAAAohFADAAAAUIjGWg/QG0aNOqnWIwAAeznfbwAAu0OfCDXHH39irUcAAPZyvt8AAHYHb30CAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCNNZ6AHa/xYsHZPr0d9V6jOK99NKAJHGt9iKLFw/I4YfXegoAAICdJ9TsZUaMOKLWI+wxmppWJ0kaGpq63bZfv4Zs3NjaswOxyw4/3GMAAADYswk1e5kJEybWeoS90sEH75dXXnm91mMAAACwl/M7agAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBBCDQAAAEAhhBoAAACAQgg1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBCN3W1QX1/XG3MUPwNYh5TAOqQE1iElsA6pNWuQEliHe6buvm51VVVVvTQLAAAAANvhrU8AAAAAhRBqAAAAAAoh1AAAAAAUQqgBAAAAKIRQAwAAAFAIoQYAAACgEEINAAAAQCGEGgAAAIBCCDUAAAAAhSg21Lzwwgs566yzcuqpp+ass87Kiy++WOuR2ENNmzYto0ePznvf+94888wz7Z/f3hrridvo21atWpUvfOELOfXUUzN27Nicf/75efXVV5NYi/SuKVOmZNy4cRk/fnwmTJiQ3/72t0msQ2pjxowZnf5+tg7pTaNHj84nP/nJnH766Tn99NMzb968JNYhvWvDhg256qqr8qd/+qcZO3ZsrrzyyiTWYZ9XFeqcc86p7rrrrqqqququu+6qzjnnnBpPxJ7q0UcfrZYsWVJ97GMfq55++un2z29vjfXEbfRtq1atqn71q1+1//m6666rLr300qqqrEV612uvvdb+8c9+9rNq/PjxVVVZh/S+J554ojrvvPOqk08+uf3vZ+uQ3rTl94ZtrEN60zXXXFN985vfrDZt2lRVVVW98sorVVVZh31dkaFmxYoV1THHHFO1tLRUVVVVLS0t1THHHFOtXLmyxpOxJ+v4l/H21lhP3AZb+pd/+ZfqL/7iL6xFaurOO++szjjjDOuQXrdhw4bqzDPPrJqbm9v/frYO6W1dhRrrkN70xhtvVMccc0z1xhtvdPq8dUhjrV/R05WlS5dm6NChaWhoSJI0NDRkyJAhWbp0aQ444IAaT8feYHtrrKqq3X6bdUtHmzZtyo9//OOMHj3aWqQmLr/88ixYsCBVVeXmm2+2Dul106dPz7hx4zJixIj2z1mH1MLFF1+cqqpyzDHH5Ktf/ap1SK9atGhRmpqaMmPGjDzyyCPZZ599cuGFF2bAgAHWYR9X7O+oAdhbXXPNNRk4cGDOPvvsWo9CH/XNb34zDz74YC666KJ861vfqvU49DG/+c1v8vjjj2fChAm1HoU+7rbbbsvdd9+dO+64I1VV5eqrr671SPQxLS0tWbRoUd73vvflJz/5SS6++OJccMEFWbduXa1Ho8aKfEXN8OHDs2zZsrS2tqahoSGtra1Zvnx5hg8fXuvR2Etsb41VVbXbb4M206ZNy8KFCzNr1qzU19dbi9TU+PHj87WvfS3Dhg2zDuk1jz76aJ5//vmccsopSZKXX3455513Xi699FLrkF7Vtg769++fCRMmZPLkydYhveqQQw5JY2NjxowZkyT5wAc+kMGDB2fAgAHWYR9X5CtqDjzwwIwcOTL33ntvkuTee+/NyJEjvSSL3WZ7a6wnboMkuf766/PEE09k5syZ6d+/fxJrkd61du3aLF26tP3Pc+bMyaBBg6xDetWkSZMyf/78zJkzJ3PmzMmwYcNyyy235LTTTrMO6TXr1q3L66+/niSpqir3339/Ro4c6fmQXnXAAQfk2GOPzYIFC5Js/leZVq5cmSOPPNI67OPqqqqqaj1EV5577rlMnTo1r732Wvbff/9MmzYt7373u2s9Fnugb3zjG3nggQeyYsWKDB48OE1NTbnvvvu2u8Z64jb6tmeffTZjxozJkUcemQEDBiRJDjvssMycOdNapNesWLEiU6ZMyZtvvpn6+voMGjQol1xySY4++mjrkJoZPXp0Zs2alaOOOso6pNcsWrQoF1xwQVpbW7Np06a85z3vyRVXXJEhQ4ZYh/SqRYsW5bLLLsvq1avT2NiYr3zlKznppJOswz6u2FADAAAA0NcU+dYnAAAAgL5IqAEAAAAohFADAAAAUAihBgAAAKAQQg0AAABAIYQaAAAAgEIINQAAAACFEGoAAAAACiHUAAA9ZtmyZbngggvyx3/8xxk9enRuvfXWrF69OieeeGLmzJmTJFm7dm0+8YlP5K677kqSTJ06NV/72tfy+c9/Ph/84Adz9tlnZ/Hixe37fO655/L5z38+H/3oR3Pqqafm/vvvb79t6tSp+frXv55Jkyblgx/8YD7zmc+kubk5SVJVVa699tocd9xxOeaYYzJ27Ng888wzSZK33nor06ZNy8knn5w/+ZM/yde+9rWsX78+SfLqq6/mi1/8Yj784Q/nox/9aCZMmJBNmzb1xuUDAPogoQYA6BGbNm3K5MmT8973vjdz587ND3/4w/zwhz/M448/nmuvvTZXXnllVq5cmb/7u7/LyJEjM378+Pb73nPPPZkyZUoeeeSR/Of//J9z8cUXJ0nWrVuXc889N2PGjMm//uu/5jvf+U6+/vWv59lnn22/73333Zfzzz8/jz76aA4//PBcf/31SZL58+fn17/+dX7605/m17/+dW644YY0NTUlSb797W/nhRdeyF133ZUHHnggy5cvz8yZM5MkP/jBDzJ06NA8/PDDWbBgQb761a+mrq6udy4iANDnCDUAQI94/PHH8+qrr+b8889P//79M2LEiJx55pm5//77M2rUqHzyk5/M5z73uTz44IP5+te/3um+J598cj7ykY+kf//+ueiii/J//+//zdKlS/Pggw/m0EMPzZ/92Z+lsbExRx99dE499dT89Kc/bb/vJz7xibz//e9PY2Njxo0bl9/+9rdJksbGxqxduzbPP/98qqrKe97zngwZMiRVVeWf/umfctlll6WpqSn77rtvvvjFL+a+++5rv98rr7ySJUuWpF+/fvnwhz8s1AAAPaax1gMAAHunxYsXZ/ny5fnwhz/c/rnW1tb2P5955pn50Y9+lC996UsZPHhwp/sOGzas/eN99tkngwYNyvLly7N48eI89thjW+1z3Lhx7X8+6KCD2j8eMGBA1q1blyQ57rjj8ud//ue5+uqrs2TJknziE5/IJZdckg0bNuTNN9/Mpz/96fb7VVXV/vam8847LzNmzMi5556bJDnrrLMyadKkXb4+AABdEWoAgB4xfPjwHHbYYXnggQe2uq21tTVXXXVVxo8fnx//+Mf59Kc/nSOOOKL99pdffrn947Vr12bNmjUZMmRIhg8fno985CP5wQ9+sFMzTZw4MRMnTszKlSvzla98JTfffHO+/OUvZ8CAAbnvvvsydOjQre6z7777ZurUqZk6dWqeffbZTJw4MX/4h3+Y4447bqdmAADYHm99AgB6xPvf//7su+++mT17dtavX5/W1tY888wzeeyxxzJr1qwkybXXXptzzz03l1xySVpbW9vv+9BDD+XXv/513nrrrUyfPj0f+MAHMnz48Jx88sl58cUXc9ddd2Xjxo3ZuHFjHnvssTz33HPdzvPYY4/l//2//5eNGzfmne98Z/r375+GhobU19fnM5/5TK699tqsXLkyyeZfgjxv3rwkyS9/+cssXLgwVVVl3333bb8PAEBP8F0GANAjGhoactNNN+Xf//3fc8opp+SP//iPc8UVV+RXv/pV/vEf/zHTpk1LQ0NDvvCFLyRJZs+e3X7fMWPGZObMmTn22GPz5JNP5tvf/naSza9uueWWW3L//ffnhBNOyKhRo/Lf//t/z1tvvdXtPGvXrs0VV1yRj370o/nYxz6Wpqam9rcz/c3f/E2OOOKInHnmmfnQhz6Uz33uc3nhhReSJAsXLmz/F6jOOuus/Nf/+l9z7LHH7u7LBQCQJKmrqqqq9RAAAG2mTp2aoUOH5qKLLqr1KAAAvc4ragAAAAAKIdQAAAAAFMJbnwAAAAAK4RU1AAAAAIUQagAAAAAKIdQAAAAAFEKoAQAAACiEUAMAAABQCKEGAAAAoBD/P4VocMnjA1UqAAAAAElFTkSuQmCC\n" 149 | }, 150 | "metadata": {} 151 | } 152 | ], 153 | "source": [ 154 | "# Checking for the outliers\n", 155 | "plt.figure(figsize= (20,15))\n", 156 | "plt.subplot(3,1,1)\n", 157 | "sns.boxplot(x= insurance_df.bmi, color='red')\n", 158 | "\n", 159 | "plt.subplot(3,1,2)\n", 160 | "sns.boxplot(x= insurance_df.age, color='blue')\n", 161 | "\n", 162 | "plt.subplot(3,1,3)\n", 163 | "sns.boxplot(x= insurance_df.expenses, color='yellow')\n", 164 | "\n", 165 | "plt.show()" 166 | ] 167 | }, 168 | { 169 | "source": [ 170 | "Interesting output for this analysis:\n", 171 | "* bmi has a few extreme values;\n", 172 | "* expenses as it is highly skewed, there are quiet a lot of extreme values." 173 | ], 174 | "cell_type": "markdown", 175 | "metadata": {} 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 22, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "output_type": "execute_result", 184 | "data": { 185 | "text/plain": [ 186 | " count mean std min 25% 50% \\\n", 187 | "age 1338.0 39.207025 14.049960 18.00 27.0000 39.00 \n", 188 | "bmi 1338.0 30.665471 6.098382 16.00 26.3000 30.40 \n", 189 | "children 1338.0 1.094918 1.205493 0.00 0.0000 1.00 \n", 190 | "expenses 1338.0 13270.422414 12110.011240 1121.87 4740.2875 9382.03 \n", 191 | "\n", 192 | " 75% max \n", 193 | "age 51.000 64.00 \n", 194 | "bmi 34.700 53.10 \n", 195 | "children 2.000 5.00 \n", 196 | "expenses 16639.915 63770.43 " 197 | ], 198 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
countmeanstdmin25%50%75%max
age1338.039.20702514.04996018.0027.000039.0051.00064.00
bmi1338.030.6654716.09838216.0026.300030.4034.70053.10
children1338.01.0949181.2054930.000.00001.002.0005.00
expenses1338.013270.42241412110.0112401121.874740.28759382.0316639.91563770.43
\n
" 199 | }, 200 | "metadata": {}, 201 | "execution_count": 22 202 | } 203 | ], 204 | "source": [ 205 | "# a brief summary for the dataset\n", 206 | "insurance_df.describe().T" 207 | ] 208 | }, 209 | { 210 | "source": [ 211 | "### Age field\n", 212 | "* 50% of the people has 39 years" 213 | ], 214 | "cell_type": "markdown", 215 | "metadata": {} 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 23, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# Plots to see the distribution of the continuous features individually\n", 224 | "\n", 225 | "plt.figure(figsize= (20,15))\n", 226 | "plt.subplot(3,3,1)\n", 227 | "plt.hist(insurance_df.bmi, color='lightblue', edgecolor = 'black', alpha = 0.7)\n", 228 | "plt.xlabel('bmi')\n", 229 | "\n", 230 | "plt.figure(figsize= (20,15))\n", 231 | "plt.subplot(3,3,2)\n", 232 | "plt.hist(insurance_df.bmi, color='lightblue', edgecolor = 'black', alpha = 0.7)\n", 233 | "plt.xlabel('age')\n", 234 | "\n", 235 | "plt.figure(figsize= (20,15))\n", 236 | "plt.subplot(3,3,2)\n", 237 | "plt.hist(insurance_df.bmi, color='lightblue', edgecolor = 'black', alpha = 0.7)\n", 238 | "plt.xlabel('charges')\n", 239 | "\n", 240 | "### Write the code to show age and charges graphs, same as we did for bmi\n", 241 | "\n", 242 | "plt.show()" 243 | ] 244 | }, 245 | { 246 | "source": [ 247 | "Output should include this Analysis:\n", 248 | "\n", 249 | "- bmi looks normally distributed.\n", 250 | "\n", 251 | "- Age looks uniformly distributed.\n", 252 | "\n", 253 | "- As seen in the previous step, charges are highly skewed" 254 | ], 255 | "cell_type": "markdown", 256 | "metadata": {} 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "# showing the skewness of variables\n", 265 | "Skewness = pd.DataFrame({'Skeweness' : [stats.skew(insurance_df.bmi), stats.skew(insurance_df.age), stats.skew(insurance_df.charges)]},\n", 266 | " index=['bmi','age','charges'])\n", 267 | "# We Will Measure the skeweness of the required columns\n", 268 | "Skewness" 269 | ] 270 | }, 271 | { 272 | "source": [ 273 | "Output should include this Analysis:\n", 274 | "\n", 275 | "- Skewness of bmi is very low as seen in the previous step\n", 276 | "\n", 277 | "- age is uniformly distributed and thus not skewed\n", 278 | "\n", 279 | "- charges are highly skewed" 280 | ], 281 | "cell_type": "markdown", 282 | "metadata": {} 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "# visualizing data to make analysis\n", 291 | "\n", 292 | "plt.figure(figsize=(20,25))\n", 293 | "\n", 294 | "\n", 295 | "x = insurance_df.smoker.value_counts().index #Values for x-axis\n", 296 | "y = [insurance_df['smoker'].value_counts()[i] for i in x] # Count of each class on y-axis\n", 297 | "\n", 298 | "plt.subplot(4,2,1)\n", 299 | "plt.bar(x,y, align='center',color = 'red',edgecolor = 'black',alpha = 0.7) # plot a bar chart\n", 300 | "plt.xlabel('Smoker?')\n", 301 | "plt.ylabel('Count ')\n", 302 | "plt.title('Smoker distribution')\n", 303 | "\n", 304 | "x1 = insurance_df.sex.value_counts().index #Values for x-axis\n", 305 | "y1 = [insurance_df['sex'].value_counts()[j] for j in x1] # Count of each class on y-axis\n", 306 | "\n", 307 | "plt.subplot(4,2,2)\n", 308 | "plt.bar(x,y, align='center',color = 'red',edgecolor = 'black',alpha = 0.7) # plot a bar chart\n", 309 | "plt.xlabel('Gender')\n", 310 | "plt.ylabel('Count ')\n", 311 | "plt.title('Gender distribution')\n", 312 | "\n", 313 | "x2 = insurance_df.region.value_counts().index #Values for x-axis\n", 314 | "y2 = [insurance_df['region'].value_counts()[k] for k in x2] # Count of each class on y-axis\n", 315 | "\n", 316 | "plt.subplot(4,2,3)\n", 317 | "plt.bar(x,y, align='center',color = 'red',edgecolor = 'black',alpha = 0.7) # plot a bar chart\n", 318 | "plt.xlabel('Region')\n", 319 | "plt.ylabel('Count ')\n", 320 | "plt.title('Region distribution')\n", 321 | "\n", 322 | "x3 = insurance_df.children.value_counts().index #Values for x-axis\n", 323 | "y3 = [insurance_df['children'].value_counts()[l] for l in x3] # Count of each class on y-axis\n", 324 | "\n", 325 | "plt.subplot(4,2,4)\n", 326 | "plt.bar(x,y, align='center',color = 'red',edgecolor = 'black',alpha = 0.7) # plot a bar chart\n", 327 | "plt.xlabel('No. of children')\n", 328 | "plt.ylabel('Count ')\n", 329 | "plt.title('Children distribution')\n", 330 | "\n", 331 | "\n", 332 | "plt.show()" 333 | ] 334 | }, 335 | { 336 | "source": [ 337 | "\n", 338 | "- There are lot more non-smokers than smokers.\n", 339 | "\n", 340 | "- Instances are distributed evenly accross all regions.\n", 341 | "\n", 342 | "- Gender is also distributed evenly.\n", 343 | "\n", 344 | "- Most instances have less than 3 children and very few have 4 or 5 children." 345 | ], 346 | "cell_type": "markdown", 347 | "metadata": {} 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "# Label encoding the variables before doing a pairplot because pairplot ignores strings\n", 356 | "\n", 357 | "insurance_df_encoded = copy.deepcopy(insurance_df)\n", 358 | "insurance_df_encoded.loc[:,['sex', 'smoker', 'region']] = insurance_df_encoded.loc[:,['sex', 'smoker', 'region']].apply(LabelEncoder().fit_transform) \n", 359 | "\n", 360 | "sns.pairplot(insurance_df_encoded) # pairplot\n", 361 | "plt.show()" 362 | ] 363 | }, 364 | { 365 | "source": [ 366 | "Output should include this Analysis:\n", 367 | "\n", 368 | "- There is an obvious correlation between 'charges' and 'smoker'\n", 369 | "\n", 370 | "- Looks like smokers claimed more money than non-smokers\n", 371 | "\n", 372 | "- There's an interesting pattern between 'age' and 'charges'. Notice that older people are charged more than the younger ones" 373 | ], 374 | "cell_type": "markdown", 375 | "metadata": {} 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "# Do charges of people who smoke differ significantly from the people who don't?\n", 384 | "print(\"Do charges of people who smoke differ significantly from the people who don't?\")\n", 385 | "insurance_df.smoker.value_counts()" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "# Scatter plot to look for visual evidence of dependency between attributes smoker and charges accross different ages\n", 395 | "plt.figure(figsize=(8,6))\n", 396 | "sns.scatterplot(insurance_df.age, insurance_df.charges,hue=insurance_df.smoker,palette= ['red','green'] ,alpha=0.6)\n", 397 | "plt.title(\"Difference between charges of smokers and charges of non-smokers is apparent\")\n", 398 | "plt.show()" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "plt.figure(figsize=(8,6))\n", 408 | "sns.scatterplot(insurance_df.age, insurance_df.charges,hue=insurance_df.sex,palette= ['pink','lightblue'] )" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "# T-test to check dependency of smoking on charges\n", 425 | "Ho = \"Charges of smoker and non-smoker are same\" # Stating the Null Hypothesis\n", 426 | "Ha = \"Charges of smoker and non-smoker are not the same\" # Stating the Alternate Hypothesis\n", 427 | "\n", 428 | "x = np.array(insurance_df[insurance_df.smoker ==\"yes\"].charges)\n", 429 | "y = np.array(insurance_df[insurance_df.smoker ==\"no\"].charges)\n", 430 | "\n", 431 | "t, p_value = stats.ttest_ind(x,y, axis = 0) # Performing an Independent t-test\n", 432 | "\n", 433 | "if p_value < 0.05: # Setting our significance level at 5%\n", 434 | " print(f'{Ha} as the p_value ({p_value}) < 0.05')\n", 435 | "else:\n", 436 | " print(f'{Ho} as the p_value ({p_value}) > 0.05')" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "#Does bmi of males differ significantly from that of females?\n", 446 | "print (\"Does bmi of males differ significantly from that of females?\")\n", 447 | "insurance_df.sex.value_counts() #Checking the distribution of males and females" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "# T-test to check dependency of bmi on gender\n", 457 | "Ho = \"Gender has no effect on bmi\" # Stating the Null Hypothesis\n", 458 | "Ha = \"Gender has an effect on bmi\" # Stating the Alternate Hypothesis\n", 459 | "\n", 460 | "x = np.array(insurance_df[insurance_df.sex ==\"male\"].bmi)\n", 461 | "y = np.array(insurance_df[insurance_df.sex ==\"female\"].bmi)\n", 462 | "\n", 463 | "t, p_value = stats.ttest_ind(x,y, axis = 0) #Performing an Independent t-test\n", 464 | "\n", 465 | "if p_value < 0.05: # Setting our significance level at 5%\n", 466 | " print(f'{Ha} as the p_value ({p_value.round()}) < 0.05')\n", 467 | "else:\n", 468 | " print(f'{Ho} as the p_value ({p_value.round(3)}) > 0.05')" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 24, 474 | "metadata": {}, 475 | "outputs": [ 476 | { 477 | "output_type": "stream", 478 | "name": "stdout", 479 | "text": [ 480 | "Gender has an effect on smoking habits as the p_value (0.007) < 0.05\n" 481 | ] 482 | }, 483 | { 484 | "output_type": "execute_result", 485 | "data": { 486 | "text/plain": [ 487 | "smoker no yes\n", 488 | "sex \n", 489 | "female 547 115\n", 490 | "male 517 159" 491 | ], 492 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
smokernoyes
sex
female547115
male517159
\n
" 493 | }, 494 | "metadata": {}, 495 | "execution_count": 24 496 | } 497 | ], 498 | "source": [ 499 | "#Is the proportion of smokers significantly different in different genders?\n", 500 | "\n", 501 | "\n", 502 | "# Chi_square test to check if smoking habits are different for different genders\n", 503 | "Ho = \"Gender has no effect on smoking habits\" # Stating the Null Hypothesis\n", 504 | "Ha = \"Gender has an effect on smoking habits\" # Stating the Alternate Hypothesis\n", 505 | "\n", 506 | "crosstab = pd.crosstab(insurance_df['sex'], insurance_df['smoker'])\n", 507 | "\n", 508 | "chi, p_value, dof, expected = stats.chi2_contingency(crosstab)\n", 509 | "\n", 510 | "if p_value < 0.05: # Setting our significance level at 5%\n", 511 | " print(f'{Ha} as the p_value ({p_value.round(3)}) < 0.05')\n", 512 | "else:\n", 513 | " print(f'{Ho} as the p_value ({p_value.round(3)}) > 0.05')\n", 514 | "crosstab" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "# Chi_square test to check if smoking habits are different for people of different regions\n", 524 | "Ho = \"Region has no effect on smoking habits\" # Stating the Null Hypothesis\n", 525 | "Ha = \"Region has an effect on smoking habits\" # Stating the Alternate Hypothesis\n", 526 | "\n", 527 | "crosstab = pd.crosstab(insurance_df['smoker'], insurance_df['region'])\n", 528 | "\n", 529 | "chi, p_value, dof, expected = stats.chi2_contingency(crosstab)\n", 530 | "\n", 531 | "if p_value < 0.05: # Setting our significance level at 5%\n", 532 | " print(f'{Ha} as the p_value ({p_value.round(3)}) < 0.05')\n", 533 | "else:\n", 534 | " print(f'{Ho} as the p_value ({p_value.round(3)}) > 0.05')\n", 535 | "crosstab" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 25, 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "output_type": "stream", 545 | "name": "stdout", 546 | "text": [ 547 | "No. of children has no effect on bmi as the p_value (0.715) > 0.05\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "# Is the distribution of bmi across women with no children, one child and two children, the same ?\n", 553 | "# Test to see if the distributions of bmi values for females having different number of children, are significantly different\n", 554 | "\n", 555 | "Ho = \"No. of children has no effect on bmi\" # Stating the Null Hypothesis\n", 556 | "Ha = \"No. of children has an effect on bmi\" # Stating the Alternate Hypothesis\n", 557 | "\n", 558 | "\n", 559 | "female_df = copy.deepcopy(insurance_df[insurance_df['sex'] == 'female']) \n", 560 | "\n", 561 | "zero = female_df[female_df.children == 0]['bmi']\n", 562 | "one = female_df[female_df.children == 1]['bmi']\n", 563 | "two = female_df[female_df.children == 2]['bmi']\n", 564 | "\n", 565 | "\n", 566 | "f_stat, p_value = stats.f_oneway(zero,one,two)\n", 567 | "\n", 568 | "\n", 569 | "if p_value < 0.05: # Setting our significance level at 5%\n", 570 | " print(f'{Ha} as the p_value ({p_value.round(3)}) < 0.05')\n", 571 | "else:\n", 572 | " print(f'{Ho} as the p_value ({p_value.round(3)}) > 0.05')" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": null, 578 | "metadata": {}, 579 | "outputs": [], 580 | "source": [] 581 | } 582 | ] 583 | } -------------------------------------------------------------------------------- /linear-regression/learning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giuseppefutia/notebooks/a21cb4dca09df1abf000e2c25b7a340478a5fce2/linear-regression/learning.gif -------------------------------------------------------------------------------- /linear-regression/todo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": 3 14 | }, 15 | "orig_nbformat": 2 16 | }, 17 | "nbformat": 4, 18 | "nbformat_minor": 2, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Multiple generation of random data (useful to compute the p-value)\n", 27 | "rand_generations = 100000\n", 28 | "mult_x = [x for i in range(rand_generations)] # 100000 batches of features\n", 29 | "mult_y = [np.sin(x) + 0.1*np.power(x,2) + 0.5*np.random.randn(100,1)\n", 30 | " for i in range(rand_generations)] # 100000 batches of targets\n", 31 | "\n", 32 | "# Recap of optimized parameters (from Pytorch)\n", 33 | "W = model.weight.item()\n", 34 | "b = model.bias.item()\n", 35 | "\n", 36 | "# Compute multiple MSEs (learned model)\n", 37 | "mult_y_hat = [mult_x[i].dot(W) + b\n", 38 | " for i in range(rand_generations)]\n", 39 | "\n", 40 | "mult_loss_fit = [mult_y_hat[i] - mult_y[i]\n", 41 | " for i in range(rand_generations)]\n", 42 | "\n", 43 | "mult_mse_fit = [np.sum(np.power(mult_loss_fit[i], 2)) / (2 * n)\n", 44 | " for i in range(rand_generations)]\n", 45 | "\n", 46 | "# Compute multiple MSEs (mean model)\n", 47 | "mult_y_mean = [np.mean(mult_y[i])\n", 48 | " for i in range(rand_generations)]\n", 49 | "\n", 50 | "mult_loss_mean = [mult_y_mean[i] - mult_y[i]\n", 51 | " for i in range(rand_generations)]\n", 52 | "\n", 53 | "mult_mse_mean = [np.sum(np.power(mult_loss_mean[i], 2)) / (2 * n)\n", 54 | " for i in range(rand_generations)]\n", 55 | "\n", 56 | "# Compute multiple F\n", 57 | "mult_f_ratio = [compute_f_ratio(mult_mse_fit[i], mult_mse_mean[i], p_fit, p_mean, n)\n", 58 | " for i in range(rand_generations)]\n", 59 | "\n", 60 | "# Plot the multiple F ratios generated from random data\n", 61 | "plt.figure(figsize=(20,25))\n", 62 | "\n", 63 | "# Emphasize the generated F ratio among all generated random data\n", 64 | "plt.subplot(4,2,1)\n", 65 | "heights, bins, patches = plt.hist(x=Counter(mult_f_ratio), bins='auto',\n", 66 | " color='#0504aa',\n", 67 | " alpha=0.7, rwidth=0.85)\n", 68 | "\n", 69 | "print(heights)\n", 70 | "print(bins)\n", 71 | "\n", 72 | "idx = (np.abs(bins - f_ratio)).argmin() # Visualization trick: closest F ratio in bins \n", 73 | "\n", 74 | "patches[idx].set_fc('#ff8c00')\n", 75 | "plt.grid(axis='y', alpha=0.75)\n", 76 | "plt.xlabel('F')\n", 77 | "plt.ylabel('Frequency of F')\n", 78 | "maxfreq = heights.max()\n", 79 | "plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10) # Clean upper y-axis limit.\n", 80 | "\n", 81 | "# Emphasize the all bars that have a probability equal or less than the compute F ratio\n", 82 | "# This will show a graphical representation of the p-value\n", 83 | "plt.subplot(4,2,2)\n", 84 | "heights, bins, patches = plt.hist(x=mult_f_ratio, bins='auto',\n", 85 | " color='#0504aa',\n", 86 | " alpha=0.7, rwidth=0.85)\n", 87 | "\n", 88 | "idx_p_value = list(np.argwhere(heights < heights[idx]).reshape(-1,))\n", 89 | "patches[idx].set_fc('#ff8c00') # Coluring f ratio\n", 90 | "for p in idx_p_value:\n", 91 | " patches[p].set_fc('#ffd700') # Coluring f < f ratio\n", 92 | "\n", 93 | "plt.grid(axis='y', alpha=0.75)\n", 94 | "plt.xlabel('F')\n", 95 | "plt.ylabel('Frequency of F')\n", 96 | "maxfreq = heights.max()\n", 97 | "plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10) # Clean upper y-axis limit.\n", 98 | "\n", 99 | "# Plot!\n", 100 | "plt.show()\n", 101 | "\n", 102 | "# Recap of number of samples\n", 103 | "num_samples = np.sum(heights)\n", 104 | "\n", 105 | "# p-value is the sum of 3 different probabilities\n", 106 | "\n", 107 | "# probability 1\n", 108 | "p1 = heights[idx] / num_samples # Probability of the sample\n", 109 | "\n", 110 | "# probability 2\n", 111 | "p2 = (heights == heights[idx]).sum() - 1 # Cases with the same probability (excluding the sample)\n", 112 | "p2 = p2 * heights[idx]\n", 113 | "p2 = p2 / num_samples # Cases with the same probability of the samples\n", 114 | "\n", 115 | "# probability 3\n", 116 | "idx_p_value = list(np.argwhere(heights < heights[idx]).reshape(-1,)) # Recap of the indices with less probability\n", 117 | "occurrences = 0\n", 118 | "for p in idx_p_value:\n", 119 | " occurrences += heights[p]\n", 120 | "p3 = occurrences / num_samples\n", 121 | "\n", 122 | "p_value = p1 + p2 + p3\n", 123 | "\n", 124 | "print('----- p_value: %.4f\\n' % p_value)\n", 125 | "print()\n", 126 | "if p_value < 0.05: # Setting our significance level at 5%\n", 127 | " print('The computed R_squared is statistically relevant.')\n", 128 | "else:\n", 129 | " print('The computed R_squared is not statistically relevant.')\n", 130 | "\n", 131 | "\n", 132 | "https://openclassrooms.com/en/courses/5873596-design-effective-statistical-models-to-understand-your-data/6229141-build-and-interpret-a-univariate-linear-regression-model#:~:text=null%20hypothesis%20anyway.-,R%2DSquared,positive%20and%20lower%20than%201.\n", 133 | "\n" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# Create grid coordinates for plotting\n", 143 | "B0 = np.linspace(W[0] - 2, W[0] + 2, 50)\n", 144 | "print(B0)\n", 145 | "print(B0.size)\n", 146 | "B1 = np.linspace(W[1] - 2, W[1] + 2, 50)\n", 147 | "print(B1)\n", 148 | "print(B1.size)\n", 149 | "xx, yy = np.meshgrid(B0, B1, indexing='xy')\n", 150 | "print(xx)\n", 151 | "print(xx[0])\n", 152 | "print(xx.shape)\n", 153 | "print(yy)\n", 154 | "print(yy[0])\n", 155 | "print(yy.shape)\n", 156 | "Z = np.zeros((B0.size, B1.size))\n", 157 | "print(Z.shape)\n", 158 | "\n", 159 | "# Calculate Z-values (MSE) based on grid of parameters\n", 160 | "for (i, j) , v in np.ndenumerate(Z): # Iterate each element of a multiple array and return the coordinates and the value\n", 161 | " Z[i,j] =((y - (xx[i,j] + x*yy[i,j]))**2).sum() / n\n", 162 | "\n", 163 | "print(Z)\n", 164 | "print(Z[25][25])\n", 165 | "\n", 166 | "\n", 167 | "\n", 168 | "\n", 169 | "\n", 170 | "\n", 171 | "\n", 172 | "\n", 173 | "\n", 174 | "\n", 175 | "# Minimized MSE\n", 176 | "min_MSE_label = r'$b$, $W$ that minimize the MSE'\n", 177 | "min_mse = error\n", 178 | "\n", 179 | "fig = plt.figure(figsize=(15,6))\n", 180 | "fig.suptitle('Mean Squared Error - Regression Parameters', fontsize=20)\n", 181 | "\n", 182 | "ax1 = fig.add_subplot(121)\n", 183 | "ax2 = fig.add_subplot(122, projection='3d')\n", 184 | "\n", 185 | "# Left plot\n", 186 | "CS = ax1.contour(xx, yy, Z, cmap=plt.cm.Set1, levels=[2.2, 2.3, 2.5, 3])\n", 187 | "ax1.scatter(W[0], W[1], c='r', label=min_MSE_label)\n", 188 | "ax1.clabel(CS, inline=True, fontsize=10, fmt='%1.1f')\n", 189 | "\n", 190 | "# Right plot\n", 191 | "ax2.plot_surface(xx, yy, Z, rstride=3, cstride=3, alpha=0.3)\n", 192 | "ax2.contour(xx, yy, Z, zdir='z', offset=Z.min(), cmap=plt.cm.Set1,\n", 193 | " alpha=0.4, levels=[2.2, 2.3, 2.5, 3])\n", 194 | "ax2.scatter3D(W[0], W[1], min_mse, c='r', label=min_MSE_label)\n", 195 | "ax2.set_zlabel('MSE')\n", 196 | "ax2.set_zlim(Z.min(),Z.max())\n", 197 | "ax2.set_ylim(8, 12)\n", 198 | "\n", 199 | "# Settings common to both plots\n", 200 | "for ax in fig.axes:\n", 201 | " ax.set_xlabel(r'$b$', fontsize=14)\n", 202 | " ax.set_ylabel(r'$W$', fontsize=14)\n", 203 | " ax.set_yticks([W[1]-2, W[1]-1, W[1], W[1]+1, W[1]+2])\n", 204 | " ax.set_xticks([W[0]-2, W[0]-1, W[0], W[0]+1, W[0]+2])\n", 205 | " ax.legend()\n" 206 | ] 207 | } 208 | ] 209 | } --------------------------------------------------------------------------------