├── README.md └── XGBoostScratch.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # XGBoost-Using-Numpy 2 | The following code is a simple XGBoost model developed using numpy. Tha main purpose of this code is to unveil the maths behind XGBoost. 3 | -------------------------------------------------------------------------------- /XGBoostScratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## XGBoost using Numpy (from scratch)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "#### The following code is a simple XGBoost model developed using numpy.\n", 15 | "Tha main purpose of this code is to unveil the maths behind XGBoost." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 85, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "%matplotlib inline" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "##### Consider the following data where the years of experience is predictor variable and salary is the target." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 86, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "year = [5,7,12,23,25,28,29,34,35,40]\n", 44 | "salary = [82,80,103,118,172,127,204,189,99,166]" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "##### Using regression trees as base learners, we can create a model to predict the salary. \n", 52 | "For the sake of simplicity, we can choose square loss as our loss function and our objective would be to minimize the square error.\n", 53 | "##### As the first step, the model should be initialized with a function F0(x). F0(x) should be a function which minimizes the loss function or MSE (mean squared error)\n", 54 | "##### For MSE the Function F minimizes at mean\n", 55 | "##### If we had taken MAE , the function would have minimized at median\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 87, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/html": [ 66 | "
\n", 67 | "\n", 80 | "\n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | "
YearsSalary
0582
1780
212103
323118
425172
\n", 116 | "
" 117 | ], 118 | "text/plain": [ 119 | " Years Salary\n", 120 | "0 5 82\n", 121 | "1 7 80\n", 122 | "2 12 103\n", 123 | "3 23 118\n", 124 | "4 25 172" 125 | ] 126 | }, 127 | "execution_count": 87, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "df = pd.DataFrame(columns=['Years','Salary'])\n", 134 | "df.Years = year\n", 135 | "df.Salary = salary\n", 136 | "df.head()" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 88, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "" 148 | ] 149 | }, 150 | "execution_count": 88, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | }, 154 | { 155 | "data": { 156 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAATiElEQVR4nO3dfYxc53me8esuxbjrxu1a0UYQl3SppPIWjpWI6kZR67SV7SSUXCNkhSKVkMaKq5Ztora24axtukBcFzDkhkncGEFV0LEqC3DlKA1DC4VQRlXcqgUsG0tTFiXZrFh/iUtZXENZJYG3CkU//WPOSkNyl/u9M3N4/YDFnHnOmdmHB+DN4XveOW+qCklSu/yFXjcgSVp7hrsktZDhLkktZLhLUgsZ7pLUQpf0ugGAyy67rLZv397rNiRpoBw+fPg7VTUy376+CPft27czOTnZ6zYkaaAk+eZC+xyWkaQWMtwlqYUMd0lqIcNdklrIcJekFuqL2TKSNt7BI1PsO3SMkzOzbBkeYmLnGLt3jPa6La0Rw126CB08MsXeA0eZPX0GgKmZWfYeOApgwLeEwzLSRWjfoWMvB/uc2dNn2HfoWI860loz3KWL0MmZ2WXVNXgMd+kitGV4aFl1DR7DXboITewcY2jzprNqQ5s3MbFzrEcdaa0tGu5JtiX5XJKnkjyZ5F1N/dIkDyV5unl8bVNPko8nOZ7k8STXrvcfQtLy7N4xyp03X83o8BABRoeHuPPmq72Y2iJLmS3zEvDeqvpSktcAh5M8BPwi8HBVfTTJB4APAO8HbgKuan5+ArireZTUR3bvGDXMW2zRT+5V9WxVfanZ/lPgK8AosAv4VHPYp4DdzfYu4N7qeBQYTnLFmncuSVrQssbck2wHdgBfAC6vqmebXd8GLm+2R4Fnul52oqlJkjbIksM9yfcDvw+8u6r+pHtfVRVQy/nFSfYkmUwyOT09vZyXSpIWsaRwT7KZTrB/uqoONOXn5oZbmsdTTX0K2Nb18q1N7SxVtb+qxqtqfGRk3oVEJEkrtJTZMgE+CXylqn6za9cDwG3N9m3AZ7vq72hmzVwPvNA1fCNJ2gBLmS3zJuAXgKNJHmtqHwQ+Ctyf5Hbgm8DPNfseBN4GHAe+C7xzTTuWJC1q0XCvqv8NZIHdb53n+ALuWGVfkqRV8BuqktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EJLuf2AJPXcwSNT7Dt0jJMzs2wZHmJi55iLjVyA4S6p7x08MsXeA0eZPX0GgKmZWfYeOApgwC/AYRlJfW/foWMvB/uc2dNn2HfoWI866n+Gu6S+d3Jmdll1Ge6SBsCW4aFl1WW4SxoAEzvHGNq86aza0OZNTOwc61FH/c8LqpL63txFU2fLLJ3hLmkg7N4xapgvg8MyktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLXQouGe5O4kp5I80VW7JsmjSR5LMpnkuqaeJB9PcjzJ40muXc/mJUnzW8on93uAG8+p/Rrw4aq6BvjV5jnATcBVzc8e4K61aVOStByLhntVPQI8f24Z+MvN9l8BTjbbu4B7q+NRYDjJFWvVrCRpaVb6DdV3A4eS/DqdfyD+VlMfBZ7pOu5EU3t2xR1KkpZtpRdUfwl4T1VtA94DfHK5b5BkTzNePzk9Pb3CNiRJ81lpuN8GHGi2fw+4rtmeArZ1Hbe1qZ2nqvZX1XhVjY+MjKywDUnSfFYa7ieBv9tsvwV4utl+AHhHM2vmeuCFqnJIRpI22KJj7knuA24ALktyAvgQ8E+B30pyCfD/6MyMAXgQeBtwHPgu8M516FmStIhFw72qbl1g19+Y59gC7lhtU5Kk1fEbqpLUQoa7JLWQ4S5JLWS4S1ILuYaqtIEOHplykWdtCMNd2iAHj0yx98BRZk+fAWBqZpa9B44CGPBacw7LSBtk36FjLwf7nNnTZ9h36FiPOlKbGe7SBjk5M7usurQahru0QbYMDy2rLq2G4S5tkImdYwxt3nRWbWjzJiZ2jvWoI7WZF1SlDTJ30dTZMtoIhru0gXbvGDXMtSEclpGkFjLcJamFHJaRpB5Y728rG+6StME24tvKDstI0gbbiG8rG+6StME24tvKhrskbbCN+Lay4S5JG2wjvq3sBVVJ2mAb8W1lw12SemC9v63ssIwktdCi4Z7k7iSnkjxxTv1fJvlqkieT/FpXfW+S40mOJdm5Hk1Lki5sKcMy9wC/Ddw7V0jyZmAX8GNV9WKSH2zqbwBuAX4E2AL89ySvr6oz572rJGndLPrJvaoeAZ4/p/xLwEer6sXmmFNNfRfwmap6saq+DhwHrlvDfiVJS7DSMffXA387yReS/M8kP97UR4Fnuo470dTOk2RPkskkk9PT0ytsQ5I0n5WG+yXApcD1wARwf5Is5w2qan9VjVfV+MjIyArbkCTNZ6XhfgI4UB1fBL4HXAZMAdu6jtva1CRJG2il4X4QeDNAktcD3wd8B3gAuCXJq5JcCVwFfHEtGpUkLd2is2WS3AfcAFyW5ATwIeBu4O5meuSfA7dVVQFPJrkfeAp4CbjDmTKStPHSyeTeGh8fr8nJyV63IUkDJcnhqhqfb5/fUJWkFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqoUXDPcndSU4leWKefe9NUkkua54nyceTHE/yeJJr16NpSdKFLeWT+z3AjecWk2wDfgb4Vlf5JuCq5mcPcNfqW5QkLdei4V5VjwDPz7PrY8D7gOqq7QLurY5HgeEkV6xJp5KkJVvRmHuSXcBUVX35nF2jwDNdz080tfneY0+SySST09PTK2lDkrSAZYd7klcDHwR+dTW/uKr2V9V4VY2PjIys5q0kSee4ZAWv+WHgSuDLSQC2Al9Kch0wBWzrOnZrU5MkbaBlf3KvqqNV9YNVtb2qttMZerm2qr4NPAC8o5k1cz3wQlU9u7YtS5IWs5SpkPcBnwfGkpxIcvsFDn8Q+BpwHPgE8Mtr0qUkaVkWHZapqlsX2b+9a7uAO1bfliRpNfyGqiS1kOEuSS1kuEtSCxnuktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EIruZ+7pD538MgU+w4d4+TMLFuGh5jYOcbuHfMuiqaWMtylljl4ZIq9B44ye/oMAFMzs+w9cBTAgL+IOCwjtcy+Q8deDvY5s6fPsO/QsR51pF4w3KWWOTkzu6y62slwl1pmy/DQsupqJ8NdapmJnWMMbd50Vm1o8yYmdo71qCP1ghdUpZaZu2jqbJmLm+EutdDuHaOG+UXOYRlJaiHDXZJayHCXpBZaNNyT3J3kVJInumr7knw1yeNJ/iDJcNe+vUmOJzmWZOd6NS5JWthSPrnfA9x4Tu0h4I1V9aPA/wH2AiR5A3AL8CPNa/5Dkk1IkjbUouFeVY8Az59T+8Oqeql5+iiwtdneBXymql6sqq8Dx4Hr1rBfSdISrMVUyH8M/G6zPUon7OecaGrnSbIH2APwute9bg3akNaXd1rUIFnVBdUk/xp4Cfj0cl9bVfuraryqxkdGRlbThrTu5u60ODUzS/HKnRYPHpnqdWvSvFYc7kl+EXg78PNVVU15CtjWddjWpiYNNO+0qEGzonBPciPwPuBnq+q7XbseAG5J8qokVwJXAV9cfZtSb3mnRQ2apUyFvA/4PDCW5ESS24HfBl4DPJTksST/EaCqngTuB54C/htwR1WdWeCtpYHhnRY1aBa9oFpVt85T/uQFjv8I8JHVNCX1m4mdY2etbgTeaVH9zRuHSUvgnRY1aAx3aYm806IGifeWkaQWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIxToGyMEjU64EJGlJDPcBcfDI1FlreE7NzLL3wFEAA17SeRyWGRD7Dh07a3FmgNnTZ9h36FiPOpLUzwz3AXFyZnZZdUkXN8N9QGwZHlpWXdLFbdFwT3J3klNJnuiqXZrkoSRPN4+vbepJ8vEkx5M8nuTa9Wz+YjKxc4yhzZvOqg1t3sTEzrEedSSpny3lk/s9wI3n1D4APFxVVwEPN88BbgKuan72AHetTZvavWOUO2++mtHhIQKMDg9x581XezFV0rwWnS1TVY8k2X5OeRdwQ7P9KeB/AO9v6vdWVQGPJhlOckVVPbtWDV/Mdu8YNcylBThV+GwrHXO/vCuwvw1c3myPAs90HXeiqZ0nyZ4kk0kmp6enV9iGJL0yVXhqZpbilanCB49M9bq1nln1BdXmU3qt4HX7q2q8qsZHRkZW24aki5hThc+30nB/LskVAM3jqaY+BWzrOm5rU5OkdeNU4fOtNNwfAG5rtm8DPttVf0cza+Z64AXH2yWtN6cKn28pUyHvAz4PjCU5keR24KPATyd5Gvip5jnAg8DXgOPAJ4BfXpeuJamLU4XPt5TZMrcusOut8xxbwB2rbUqSlmNuVoyzZV7hjcMktYJThc/m7QckqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJaaGDXUD14ZMrFcCVpAQMZ7gePTLH3wFFmT58BYGpmlr0HjgIY8JLEgA7L7Dt07OVgnzN7+gz7Dh3rUUeS1F9W9ck9yXuAfwIUcBR4J3AF8BngB4DDwC9U1Z+vss+znJyZXVb9QhzekdRGK/7knmQU+FfAeFW9EdgE3AL8O+BjVfXXgD8Gbl+LRrttGR5aVn0hc8M7UzOzFK8M7xw8MrUGXUpS76x2WOYSYCjJJcCrgWeBtwD/pdn/KWD3Kn/HeSZ2jjG0edNZtaHNm5jYObas93F4R1JbrTjcq2oK+HXgW3RC/QU6wzAzVfVSc9gJYN4xjiR7kkwmmZyenl7W7969Y5Q7b76a0eEhAowOD3HnzVcvezhlLYd3JKmfrHjMPclrgV3AlcAM8HvAjUt9fVXtB/YDjI+P13J//+4do6seG98yPMTUPEG+3OEdSeo3qxmW+Sng61U1XVWngQPAm4DhZpgGYCvQtwPYazW8I0n9ZjXh/i3g+iSvThLgrcBTwOeAf9Accxvw2dW1uH7WanhHkvpNqpY9IvLKi5MPA/8QeAk4Qmda5CidqZCXNrV/VFUvXuh9xsfHa3JycsV9SNLFKMnhqhqfb9+q5rlX1YeAD51T/hpw3WreV5K0OgP5DVVJ0oUZ7pLUQoa7JLWQ4S5JLbSq2TJr1kQyDXxzhS+/DPjOGraz3gap30HqFQar30HqFQar30HqFVbX71+tqpH5dvRFuK9GksmFpgL1o0Hqd5B6hcHqd5B6hcHqd5B6hfXr12EZSWohw12SWqgN4b6/1w0s0yD1O0i9wmD1O0i9wmD1O0i9wjr1O/Bj7pKk87Xhk7sk6RyGuyS10ECHe5JvJDma5LEkfXdbySR3JzmV5Imu2qVJHkrydPP42l72OGeBXv9Nkqnm/D6W5G297HFOkm1JPpfkqSRPJnlXU+/Xc7tQv313fpP8xSRfTPLlptcPN/Urk3whyfEkv5vk+3rdK1yw33uSfL3r3F7T617nJNmU5EiS/9o8X5dzO9Dh3nhzVV3Tp/Na7+H81ak+ADxcVVcBDzfP+8E9zL+S1sea83tNVT24wT0t5CXgvVX1BuB64I4kb6B/z+1C/UL/nd8XgbdU1Y8B1wA3JrmeDVj4foUW6hdgouvcPta7Fs/zLuArXc/X5dy2Idz7VlU9Ajx/TnkXnYXDYZ0WEF+JBXrtS1X1bFV9qdn+Uzp/UUbp33O7UL99pzr+rHm6ufkpNmDh+5W4QL99KclW4O8Bv9M8D+t0bgc93Av4wySHk+zpdTNLdHlVPdtsfxu4vJfNLMG/SPJ4M2zTF8Mc3ZJsB3YAX2AAzu05/UIfnt9m2OAx4BTwEPB/WeLC971wbr9VNXduP9Kc248leVUPW+z274H3Ad9rnv8A63RuBz3cf7KqrgVuovNf3b/T64aWozrzUPv2UwZwF/DDdP67+yzwG71t52xJvh/4feDdVfUn3fv68dzO029fnt+qOlNV19BZA/k64K/3uKULOrffJG8E9tLp+8fprAr3/h62CECStwOnqurwRvy+gQ73qppqHk8Bf8BgrAD1XJIrAJrHUz3uZ0FV9VzzF+d7wCfoo/ObZDOdoPx0VR1oyn17bufrt5/PL0BVzdBZE/lvMgAL33f1e2MzFFbNEp//if44t28CfjbJN+gsRfoW4LdYp3M7sOGe5C8lec3cNvAzwBMXflVfeIDOwuHQ5wuIzwVl4+/TJ+e3Gaf8JPCVqvrNrl19eW4X6rcfz2+SkSTDzfYQ8NN0rhH05cL3C/T71a5/5ENnDLvn57aq9lbV1qraDtwC/FFV/TzrdG4H9huqSX6Izqd16KwF+5+r6iM9bOk8Se4DbqBzS8/n6Kw3exC4H3gdndsc/1xV9fxC5gK93kBnyKCAbwD/rGtMu2eS/CTwv4CjvDJ2+UE649j9eG4X6vdW+uz8JvlROhf1NtH58Hd/Vf3b5u/bsha+3wgX6PePgBEgwGPAP++68NpzSW4AfqWq3r5e53Zgw12StLCBHZaRJC3McJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphf4/Bfux0f+dmx8AAAAASUVORK5CYII=\n", 157 | "text/plain": [ 158 | "
" 159 | ] 160 | }, 161 | "metadata": { 162 | "needs_background": "light" 163 | }, 164 | "output_type": "display_data" 165 | } 166 | ], 167 | "source": [ 168 | "plt.scatter(x=df.Years,y=df.Salary)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "#### The residual is the difference between y and f0 i.e. (y-f0)\n", 176 | "##### We can use the residuals from F0(x) to create h1(x). h1(x) will be a regression tree which will try and reduce the residuals from the previous step. The output of h1(x) won’t be a prediction of y; instead, it will help in predicting the successive function F1(x) which will bring down the residuals." 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 89, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "df1 = df" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "The additive model h1(x) computes the mean of the residuals (y – F0) at each leaf of the tree. " 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "A split is done and the mean of upper part and lower part is calculated \n", 200 | "Here , I have selected a random split point" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 91, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/html": [ 211 | "
\n", 212 | "\n", 225 | "\n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | "
YearsSalaryf0y-f0h1f1y-f1h2
0582134.0-52.00.0134.0-52.0-45.666667
1780134.0-54.00.0134.0-54.0-45.666667
212103134.0-31.00.0134.0-31.0-45.666667
323118134.0-16.00.0134.0-16.019.571429
425172134.038.00.0134.038.019.571429
\n", 297 | "
" 298 | ], 299 | "text/plain": [ 300 | " Years Salary f0 y-f0 h1 f1 y-f1 h2\n", 301 | "0 5 82 134.0 -52.0 0.0 134.0 -52.0 -45.666667\n", 302 | "1 7 80 134.0 -54.0 0.0 134.0 -54.0 -45.666667\n", 303 | "2 12 103 134.0 -31.0 0.0 134.0 -31.0 -45.666667\n", 304 | "3 23 118 134.0 -16.0 0.0 134.0 -16.0 19.571429\n", 305 | "4 25 172 134.0 38.0 0.0 134.0 38.0 19.571429" 306 | ] 307 | }, 308 | "execution_count": 91, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "for i in range(2):\n", 315 | " f = df.Salary.mean()\n", 316 | " if(i>0):\n", 317 | " df['f'+str(i)] = df['f'+str(i-1)] + df['h'+str(i)]\n", 318 | " else:\n", 319 | " df['f'+str(i)] = f\n", 320 | " df['y-f'+str(i)] = df.Salary - df['f'+str(i)]\n", 321 | " splitIndex = np.random.randint(0,df.shape[0]-1)\n", 322 | " a= []\n", 323 | " h_upper = df['y-f'+str(i)][0:splitIndex].mean()\n", 324 | " h_bottom = df['y-f'+str(i)][splitIndex:].mean()\n", 325 | " for j in range(splitIndex):\n", 326 | " a.append(h_upper)\n", 327 | " for j in range(df.shape[0]-splitIndex):\n", 328 | " a.append(h_bottom)\n", 329 | " df['h'+str(i+1)] = a\n", 330 | " \n", 331 | "df.head()\n" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "#### This is how the dataset looks after 2 iterations" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "If we continue to iterate for 100 times , we can see the Loss of MSE(Fi) decreasing by a huge margin" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 92, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/html": [ 356 | "
\n", 357 | "\n", 370 | "\n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | "
YearsSalaryf0y-f0h1f1y-f1h2f2y-f2...h97f97y-f97h98f98y-f98h99f99y-f99h100
0582134.0-52.0-45.66666788.333333-6.333333-9.473903e-1588.333333-6.333333...1.776357e-1585.594638-3.594638-3.08374582.510894-0.5108940.38797882.898871-0.898871-0.898871
1780134.0-54.0-45.66666788.333333-8.333333-9.473903e-1588.333333-8.333333...1.776357e-1584.605788-4.605788-3.08374581.522043-1.5220430.38797881.910021-1.9100210.099875
212103134.0-31.0-45.66666788.33333314.666667-9.473903e-1588.33333314.666667...1.776357e-1597.2853605.714640-3.08374594.2016158.7983850.38797894.5895938.4104070.099875
323118134.0-16.019.571429153.571429-35.571429-1.218073e-14153.571429-35.571429...1.776357e-15127.849193-9.849193-3.083745124.765448-6.7654480.387978125.153425-7.1534250.099875
425172134.038.019.571429153.57142918.428571-1.218073e-14153.57142918.428571...1.776357e-15157.37001614.6299842.055830159.42584612.5741540.387978159.81382312.1861770.099875
\n", 520 | "

5 rows × 302 columns

\n", 521 | "
" 522 | ], 523 | "text/plain": [ 524 | " Years Salary f0 y-f0 h1 f1 y-f1 h2 \\\n", 525 | "0 5 82 134.0 -52.0 -45.666667 88.333333 -6.333333 -9.473903e-15 \n", 526 | "1 7 80 134.0 -54.0 -45.666667 88.333333 -8.333333 -9.473903e-15 \n", 527 | "2 12 103 134.0 -31.0 -45.666667 88.333333 14.666667 -9.473903e-15 \n", 528 | "3 23 118 134.0 -16.0 19.571429 153.571429 -35.571429 -1.218073e-14 \n", 529 | "4 25 172 134.0 38.0 19.571429 153.571429 18.428571 -1.218073e-14 \n", 530 | "\n", 531 | " f2 y-f2 ... h97 f97 y-f97 h98 \\\n", 532 | "0 88.333333 -6.333333 ... 1.776357e-15 85.594638 -3.594638 -3.083745 \n", 533 | "1 88.333333 -8.333333 ... 1.776357e-15 84.605788 -4.605788 -3.083745 \n", 534 | "2 88.333333 14.666667 ... 1.776357e-15 97.285360 5.714640 -3.083745 \n", 535 | "3 153.571429 -35.571429 ... 1.776357e-15 127.849193 -9.849193 -3.083745 \n", 536 | "4 153.571429 18.428571 ... 1.776357e-15 157.370016 14.629984 2.055830 \n", 537 | "\n", 538 | " f98 y-f98 h99 f99 y-f99 h100 \n", 539 | "0 82.510894 -0.510894 0.387978 82.898871 -0.898871 -0.898871 \n", 540 | "1 81.522043 -1.522043 0.387978 81.910021 -1.910021 0.099875 \n", 541 | "2 94.201615 8.798385 0.387978 94.589593 8.410407 0.099875 \n", 542 | "3 124.765448 -6.765448 0.387978 125.153425 -7.153425 0.099875 \n", 543 | "4 159.425846 12.574154 0.387978 159.813823 12.186177 0.099875 \n", 544 | "\n", 545 | "[5 rows x 302 columns]" 546 | ] 547 | }, 548 | "execution_count": 92, 549 | "metadata": {}, 550 | "output_type": "execute_result" 551 | } 552 | ], 553 | "source": [ 554 | "for i in range(100):\n", 555 | " f = df.Salary.mean()\n", 556 | " if(i>0):\n", 557 | " df['f'+str(i)] = df['f'+str(i-1)] + df['h'+str(i)]\n", 558 | " else:\n", 559 | " df['f'+str(i)] = f\n", 560 | " df['y-f'+str(i)] = df.Salary - df['f'+str(i)]\n", 561 | " splitIndex = np.random.randint(0,df.shape[0]-1)\n", 562 | " a= []\n", 563 | " h_upper = df['y-f'+str(i)][0:splitIndex].mean()\n", 564 | " h_bottom = df['y-f'+str(i)][splitIndex:].mean()\n", 565 | " for j in range(splitIndex):\n", 566 | " a.append(h_upper)\n", 567 | " for j in range(df.shape[0]-splitIndex):\n", 568 | " a.append(h_bottom)\n", 569 | " df['h'+str(i+1)] = a\n", 570 | " \n", 571 | "df.head()\n" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "#### Following is the graph for Iteration 1 , 10 and 99\n", 579 | "#### We can see the loss decreasing and the model adapting to the dataset as the iteration increases" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 94, 585 | "metadata": {}, 586 | "outputs": [ 587 | { 588 | "data": { 589 | "text/plain": [ 590 | "" 591 | ] 592 | }, 593 | "execution_count": 94, 594 | "metadata": {}, 595 | "output_type": "execute_result" 596 | }, 597 | { 598 | "data": { 599 | "image/png": "\n", 600 | "text/plain": [ 601 | "
" 602 | ] 603 | }, 604 | "metadata": { 605 | "needs_background": "light" 606 | }, 607 | "output_type": "display_data" 608 | } 609 | ], 610 | "source": [ 611 | "plt.figure(figsize=(15,10))\n", 612 | "plt.scatter(df.Years,df.Salary)\n", 613 | "plt.plot(df.Years,df.f1,label = 'f1')\n", 614 | "plt.plot(df.Years,df.f10,label = 'f10')\n", 615 | "plt.plot(df.Years,df.f99,label = 'f99')\n", 616 | "plt.legend()" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "# END" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": null, 629 | "metadata": {}, 630 | "outputs": [], 631 | "source": [] 632 | } 633 | ], 634 | "metadata": { 635 | "kernelspec": { 636 | "display_name": "Python 3", 637 | "language": "python", 638 | "name": "python3" 639 | }, 640 | "language_info": { 641 | "codemirror_mode": { 642 | "name": "ipython", 643 | "version": 3 644 | }, 645 | "file_extension": ".py", 646 | "mimetype": "text/x-python", 647 | "name": "python", 648 | "nbconvert_exporter": "python", 649 | "pygments_lexer": "ipython3", 650 | "version": "3.6.7" 651 | } 652 | }, 653 | "nbformat": 4, 654 | "nbformat_minor": 2 655 | } 656 | --------------------------------------------------------------------------------