├── README.md
└── XGBoostScratch.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # XGBoost-Using-Numpy
2 | The following code is a simple XGBoost model developed using numpy. Tha main purpose of this code is to unveil the maths behind XGBoost.
3 |
--------------------------------------------------------------------------------
/XGBoostScratch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## XGBoost using Numpy (from scratch)"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "#### The following code is a simple XGBoost model developed using numpy.\n",
15 | "Tha main purpose of this code is to unveil the maths behind XGBoost."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 85,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "import numpy as np\n",
25 | "import pandas as pd\n",
26 | "import matplotlib.pyplot as plt\n",
27 | "%matplotlib inline"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "##### Consider the following data where the years of experience is predictor variable and salary is the target."
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 86,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "year = [5,7,12,23,25,28,29,34,35,40]\n",
44 | "salary = [82,80,103,118,172,127,204,189,99,166]"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "##### Using regression trees as base learners, we can create a model to predict the salary. \n",
52 | "For the sake of simplicity, we can choose square loss as our loss function and our objective would be to minimize the square error.\n",
53 | "##### As the first step, the model should be initialized with a function F0(x). F0(x) should be a function which minimizes the loss function or MSE (mean squared error)\n",
54 | "##### For MSE the Function F minimizes at mean\n",
55 | "##### If we had taken MAE , the function would have minimized at median\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 87,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "data": {
65 | "text/html": [
66 | "
\n",
67 | "\n",
80 | "
\n",
81 | " \n",
82 | " \n",
83 | " \n",
84 | " Years \n",
85 | " Salary \n",
86 | " \n",
87 | " \n",
88 | " \n",
89 | " \n",
90 | " 0 \n",
91 | " 5 \n",
92 | " 82 \n",
93 | " \n",
94 | " \n",
95 | " 1 \n",
96 | " 7 \n",
97 | " 80 \n",
98 | " \n",
99 | " \n",
100 | " 2 \n",
101 | " 12 \n",
102 | " 103 \n",
103 | " \n",
104 | " \n",
105 | " 3 \n",
106 | " 23 \n",
107 | " 118 \n",
108 | " \n",
109 | " \n",
110 | " 4 \n",
111 | " 25 \n",
112 | " 172 \n",
113 | " \n",
114 | " \n",
115 | "
\n",
116 | "
"
117 | ],
118 | "text/plain": [
119 | " Years Salary\n",
120 | "0 5 82\n",
121 | "1 7 80\n",
122 | "2 12 103\n",
123 | "3 23 118\n",
124 | "4 25 172"
125 | ]
126 | },
127 | "execution_count": 87,
128 | "metadata": {},
129 | "output_type": "execute_result"
130 | }
131 | ],
132 | "source": [
133 | "df = pd.DataFrame(columns=['Years','Salary'])\n",
134 | "df.Years = year\n",
135 | "df.Salary = salary\n",
136 | "df.head()"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 88,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "data": {
146 | "text/plain": [
147 | ""
148 | ]
149 | },
150 | "execution_count": 88,
151 | "metadata": {},
152 | "output_type": "execute_result"
153 | },
154 | {
155 | "data": {
156 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAATiElEQVR4nO3dfYxc53me8esuxbjrxu1a0UYQl3SppPIWjpWI6kZR67SV7SSUXCNkhSKVkMaKq5Ztora24axtukBcFzDkhkncGEFV0LEqC3DlKA1DC4VQRlXcqgUsG0tTFiXZrFh/iUtZXENZJYG3CkU//WPOSkNyl/u9M3N4/YDFnHnOmdmHB+DN4XveOW+qCklSu/yFXjcgSVp7hrsktZDhLkktZLhLUgsZ7pLUQpf0ugGAyy67rLZv397rNiRpoBw+fPg7VTUy376+CPft27czOTnZ6zYkaaAk+eZC+xyWkaQWMtwlqYUMd0lqIcNdklrIcJekFuqL2TKSNt7BI1PsO3SMkzOzbBkeYmLnGLt3jPa6La0Rw126CB08MsXeA0eZPX0GgKmZWfYeOApgwLeEwzLSRWjfoWMvB/uc2dNn2HfoWI860loz3KWL0MmZ2WXVNXgMd+kitGV4aFl1DR7DXboITewcY2jzprNqQ5s3MbFzrEcdaa0tGu5JtiX5XJKnkjyZ5F1N/dIkDyV5unl8bVNPko8nOZ7k8STXrvcfQtLy7N4xyp03X83o8BABRoeHuPPmq72Y2iJLmS3zEvDeqvpSktcAh5M8BPwi8HBVfTTJB4APAO8HbgKuan5+ArireZTUR3bvGDXMW2zRT+5V9WxVfanZ/lPgK8AosAv4VHPYp4DdzfYu4N7qeBQYTnLFmncuSVrQssbck2wHdgBfAC6vqmebXd8GLm+2R4Fnul52oqlJkjbIksM9yfcDvw+8u6r+pHtfVRVQy/nFSfYkmUwyOT09vZyXSpIWsaRwT7KZTrB/uqoONOXn5oZbmsdTTX0K2Nb18q1N7SxVtb+qxqtqfGRk3oVEJEkrtJTZMgE+CXylqn6za9cDwG3N9m3AZ7vq72hmzVwPvNA1fCNJ2gBLmS3zJuAXgKNJHmtqHwQ+Ctyf5Hbgm8DPNfseBN4GHAe+C7xzTTuWJC1q0XCvqv8NZIHdb53n+ALuWGVfkqRV8BuqktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EJLuf2AJPXcwSNT7Dt0jJMzs2wZHmJi55iLjVyA4S6p7x08MsXeA0eZPX0GgKmZWfYeOApgwC/AYRlJfW/foWMvB/uc2dNn2HfoWI866n+Gu6S+d3Jmdll1Ge6SBsCW4aFl1WW4SxoAEzvHGNq86aza0OZNTOwc61FH/c8LqpL63txFU2fLLJ3hLmkg7N4xapgvg8MyktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLXQouGe5O4kp5I80VW7JsmjSR5LMpnkuqaeJB9PcjzJ40muXc/mJUnzW8on93uAG8+p/Rrw4aq6BvjV5jnATcBVzc8e4K61aVOStByLhntVPQI8f24Z+MvN9l8BTjbbu4B7q+NRYDjJFWvVrCRpaVb6DdV3A4eS/DqdfyD+VlMfBZ7pOu5EU3t2xR1KkpZtpRdUfwl4T1VtA94DfHK5b5BkTzNePzk9Pb3CNiRJ81lpuN8GHGi2fw+4rtmeArZ1Hbe1qZ2nqvZX1XhVjY+MjKywDUnSfFYa7ieBv9tsvwV4utl+AHhHM2vmeuCFqnJIRpI22KJj7knuA24ALktyAvgQ8E+B30pyCfD/6MyMAXgQeBtwHPgu8M516FmStIhFw72qbl1g19+Y59gC7lhtU5Kk1fEbqpLUQoa7JLWQ4S5JLWS4S1ILuYaqtIEOHplykWdtCMNd2iAHj0yx98BRZk+fAWBqZpa9B44CGPBacw7LSBtk36FjLwf7nNnTZ9h36FiPOlKbGe7SBjk5M7usurQahru0QbYMDy2rLq2G4S5tkImdYwxt3nRWbWjzJiZ2jvWoI7WZF1SlDTJ30dTZMtoIhru0gXbvGDXMtSEclpGkFjLcJamFHJaRpB5Y728rG+6StME24tvKDstI0gbbiG8rG+6StME24tvKhrskbbCN+Lay4S5JG2wjvq3sBVVJ2mAb8W1lw12SemC9v63ssIwktdCi4Z7k7iSnkjxxTv1fJvlqkieT/FpXfW+S40mOJdm5Hk1Lki5sKcMy9wC/Ddw7V0jyZmAX8GNV9WKSH2zqbwBuAX4E2AL89ySvr6oz572rJGndLPrJvaoeAZ4/p/xLwEer6sXmmFNNfRfwmap6saq+DhwHrlvDfiVJS7DSMffXA387yReS/M8kP97UR4Fnuo470dTOk2RPkskkk9PT0ytsQ5I0n5WG+yXApcD1wARwf5Is5w2qan9VjVfV+MjIyArbkCTNZ6XhfgI4UB1fBL4HXAZMAdu6jtva1CRJG2il4X4QeDNAktcD3wd8B3gAuCXJq5JcCVwFfHEtGpUkLd2is2WS3AfcAFyW5ATwIeBu4O5meuSfA7dVVQFPJrkfeAp4CbjDmTKStPHSyeTeGh8fr8nJyV63IUkDJcnhqhqfb5/fUJWkFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqoUXDPcndSU4leWKefe9NUkkua54nyceTHE/yeJJr16NpSdKFLeWT+z3AjecWk2wDfgb4Vlf5JuCq5mcPcNfqW5QkLdei4V5VjwDPz7PrY8D7gOqq7QLurY5HgeEkV6xJp5KkJVvRmHuSXcBUVX35nF2jwDNdz080tfneY0+SySST09PTK2lDkrSAZYd7klcDHwR+dTW/uKr2V9V4VY2PjIys5q0kSee4ZAWv+WHgSuDLSQC2Al9Kch0wBWzrOnZrU5MkbaBlf3KvqqNV9YNVtb2qttMZerm2qr4NPAC8o5k1cz3wQlU9u7YtS5IWs5SpkPcBnwfGkpxIcvsFDn8Q+BpwHPgE8Mtr0qUkaVkWHZapqlsX2b+9a7uAO1bfliRpNfyGqiS1kOEuSS1kuEtSCxnuktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EIruZ+7pD538MgU+w4d4+TMLFuGh5jYOcbuHfMuiqaWMtylljl4ZIq9B44ye/oMAFMzs+w9cBTAgL+IOCwjtcy+Q8deDvY5s6fPsO/QsR51pF4w3KWWOTkzu6y62slwl1pmy/DQsupqJ8NdapmJnWMMbd50Vm1o8yYmdo71qCP1ghdUpZaZu2jqbJmLm+EutdDuHaOG+UXOYRlJaiHDXZJayHCXpBZaNNyT3J3kVJInumr7knw1yeNJ/iDJcNe+vUmOJzmWZOd6NS5JWthSPrnfA9x4Tu0h4I1V9aPA/wH2AiR5A3AL8CPNa/5Dkk1IkjbUouFeVY8Az59T+8Oqeql5+iiwtdneBXymql6sqq8Dx4Hr1rBfSdISrMVUyH8M/G6zPUon7OecaGrnSbIH2APwute9bg3akNaXd1rUIFnVBdUk/xp4Cfj0cl9bVfuraryqxkdGRlbThrTu5u60ODUzS/HKnRYPHpnqdWvSvFYc7kl+EXg78PNVVU15CtjWddjWpiYNNO+0qEGzonBPciPwPuBnq+q7XbseAG5J8qokVwJXAV9cfZtSb3mnRQ2apUyFvA/4PDCW5ESS24HfBl4DPJTksST/EaCqngTuB54C/htwR1WdWeCtpYHhnRY1aBa9oFpVt85T/uQFjv8I8JHVNCX1m4mdY2etbgTeaVH9zRuHSUvgnRY1aAx3aYm806IGifeWkaQWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIxToGyMEjU64EJGlJDPcBcfDI1FlreE7NzLL3wFEAA17SeRyWGRD7Dh07a3FmgNnTZ9h36FiPOpLUzwz3AXFyZnZZdUkXN8N9QGwZHlpWXdLFbdFwT3J3klNJnuiqXZrkoSRPN4+vbepJ8vEkx5M8nuTa9Wz+YjKxc4yhzZvOqg1t3sTEzrEedSSpny3lk/s9wI3n1D4APFxVVwEPN88BbgKuan72AHetTZvavWOUO2++mtHhIQKMDg9x581XezFV0rwWnS1TVY8k2X5OeRdwQ7P9KeB/AO9v6vdWVQGPJhlOckVVPbtWDV/Mdu8YNcylBThV+GwrHXO/vCuwvw1c3myPAs90HXeiqZ0nyZ4kk0kmp6enV9iGJL0yVXhqZpbilanCB49M9bq1nln1BdXmU3qt4HX7q2q8qsZHRkZW24aki5hThc+30nB/LskVAM3jqaY+BWzrOm5rU5OkdeNU4fOtNNwfAG5rtm8DPttVf0cza+Z64AXH2yWtN6cKn28pUyHvAz4PjCU5keR24KPATyd5Gvip5jnAg8DXgOPAJ4BfXpeuJamLU4XPt5TZMrcusOut8xxbwB2rbUqSlmNuVoyzZV7hjcMktYJThc/m7QckqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJaaGDXUD14ZMrFcCVpAQMZ7gePTLH3wFFmT58BYGpmlr0HjgIY8JLEgA7L7Dt07OVgnzN7+gz7Dh3rUUeS1F9W9ck9yXuAfwIUcBR4J3AF8BngB4DDwC9U1Z+vss+znJyZXVb9QhzekdRGK/7knmQU+FfAeFW9EdgE3AL8O+BjVfXXgD8Gbl+LRrttGR5aVn0hc8M7UzOzFK8M7xw8MrUGXUpS76x2WOYSYCjJJcCrgWeBtwD/pdn/KWD3Kn/HeSZ2jjG0edNZtaHNm5jYObas93F4R1JbrTjcq2oK+HXgW3RC/QU6wzAzVfVSc9gJYN4xjiR7kkwmmZyenl7W7969Y5Q7b76a0eEhAowOD3HnzVcvezhlLYd3JKmfrHjMPclrgV3AlcAM8HvAjUt9fVXtB/YDjI+P13J//+4do6seG98yPMTUPEG+3OEdSeo3qxmW+Sng61U1XVWngQPAm4DhZpgGYCvQtwPYazW8I0n9ZjXh/i3g+iSvThLgrcBTwOeAf9Accxvw2dW1uH7WanhHkvpNqpY9IvLKi5MPA/8QeAk4Qmda5CidqZCXNrV/VFUvXuh9xsfHa3JycsV9SNLFKMnhqhqfb9+q5rlX1YeAD51T/hpw3WreV5K0OgP5DVVJ0oUZ7pLUQoa7JLWQ4S5JLbSq2TJr1kQyDXxzhS+/DPjOGraz3gap30HqFQar30HqFQar30HqFVbX71+tqpH5dvRFuK9GksmFpgL1o0Hqd5B6hcHqd5B6hcHqd5B6hfXr12EZSWohw12SWqgN4b6/1w0s0yD1O0i9wmD1O0i9wmD1O0i9wjr1O/Bj7pKk87Xhk7sk6RyGuyS10ECHe5JvJDma5LEkfXdbySR3JzmV5Imu2qVJHkrydPP42l72OGeBXv9Nkqnm/D6W5G297HFOkm1JPpfkqSRPJnlXU+/Xc7tQv313fpP8xSRfTPLlptcPN/Urk3whyfEkv5vk+3rdK1yw33uSfL3r3F7T617nJNmU5EiS/9o8X5dzO9Dh3nhzVV3Tp/Na7+H81ak+ADxcVVcBDzfP+8E9zL+S1sea83tNVT24wT0t5CXgvVX1BuB64I4kb6B/z+1C/UL/nd8XgbdU1Y8B1wA3JrmeDVj4foUW6hdgouvcPta7Fs/zLuArXc/X5dy2Idz7VlU9Ajx/TnkXnYXDYZ0WEF+JBXrtS1X1bFV9qdn+Uzp/UUbp33O7UL99pzr+rHm6ufkpNmDh+5W4QL99KclW4O8Bv9M8D+t0bgc93Av4wySHk+zpdTNLdHlVPdtsfxu4vJfNLMG/SPJ4M2zTF8Mc3ZJsB3YAX2AAzu05/UIfnt9m2OAx4BTwEPB/WeLC971wbr9VNXduP9Kc248leVUPW+z274H3Ad9rnv8A63RuBz3cf7KqrgVuovNf3b/T64aWozrzUPv2UwZwF/DDdP67+yzwG71t52xJvh/4feDdVfUn3fv68dzO029fnt+qOlNV19BZA/k64K/3uKULOrffJG8E9tLp+8fprAr3/h62CECStwOnqurwRvy+gQ73qppqHk8Bf8BgrAD1XJIrAJrHUz3uZ0FV9VzzF+d7wCfoo/ObZDOdoPx0VR1oyn17bufrt5/PL0BVzdBZE/lvMgAL33f1e2MzFFbNEp//if44t28CfjbJN+gsRfoW4LdYp3M7sOGe5C8lec3cNvAzwBMXflVfeIDOwuHQ5wuIzwVl4+/TJ+e3Gaf8JPCVqvrNrl19eW4X6rcfz2+SkSTDzfYQ8NN0rhH05cL3C/T71a5/5ENnDLvn57aq9lbV1qraDtwC/FFV/TzrdG4H9huqSX6Izqd16KwF+5+r6iM9bOk8Se4DbqBzS8/n6Kw3exC4H3gdndsc/1xV9fxC5gK93kBnyKCAbwD/rGtMu2eS/CTwv4CjvDJ2+UE649j9eG4X6vdW+uz8JvlROhf1NtH58Hd/Vf3b5u/bsha+3wgX6PePgBEgwGPAP++68NpzSW4AfqWq3r5e53Zgw12StLCBHZaRJC3McJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphf4/Bfux0f+dmx8AAAAASUVORK5CYII=\n",
157 | "text/plain": [
158 | ""
159 | ]
160 | },
161 | "metadata": {
162 | "needs_background": "light"
163 | },
164 | "output_type": "display_data"
165 | }
166 | ],
167 | "source": [
168 | "plt.scatter(x=df.Years,y=df.Salary)"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "metadata": {},
174 | "source": [
175 | "#### The residual is the difference between y and f0 i.e. (y-f0)\n",
176 | "##### We can use the residuals from F0(x) to create h1(x). h1(x) will be a regression tree which will try and reduce the residuals from the previous step. The output of h1(x) won’t be a prediction of y; instead, it will help in predicting the successive function F1(x) which will bring down the residuals."
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 89,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "df1 = df"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "The additive model h1(x) computes the mean of the residuals (y – F0) at each leaf of the tree. "
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {},
198 | "source": [
199 | "A split is done and the mean of upper part and lower part is calculated \n",
200 | "Here , I have selected a random split point"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 91,
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "data": {
210 | "text/html": [
211 | "\n",
212 | "\n",
225 | "
\n",
226 | " \n",
227 | " \n",
228 | " \n",
229 | " Years \n",
230 | " Salary \n",
231 | " f0 \n",
232 | " y-f0 \n",
233 | " h1 \n",
234 | " f1 \n",
235 | " y-f1 \n",
236 | " h2 \n",
237 | " \n",
238 | " \n",
239 | " \n",
240 | " \n",
241 | " 0 \n",
242 | " 5 \n",
243 | " 82 \n",
244 | " 134.0 \n",
245 | " -52.0 \n",
246 | " 0.0 \n",
247 | " 134.0 \n",
248 | " -52.0 \n",
249 | " -45.666667 \n",
250 | " \n",
251 | " \n",
252 | " 1 \n",
253 | " 7 \n",
254 | " 80 \n",
255 | " 134.0 \n",
256 | " -54.0 \n",
257 | " 0.0 \n",
258 | " 134.0 \n",
259 | " -54.0 \n",
260 | " -45.666667 \n",
261 | " \n",
262 | " \n",
263 | " 2 \n",
264 | " 12 \n",
265 | " 103 \n",
266 | " 134.0 \n",
267 | " -31.0 \n",
268 | " 0.0 \n",
269 | " 134.0 \n",
270 | " -31.0 \n",
271 | " -45.666667 \n",
272 | " \n",
273 | " \n",
274 | " 3 \n",
275 | " 23 \n",
276 | " 118 \n",
277 | " 134.0 \n",
278 | " -16.0 \n",
279 | " 0.0 \n",
280 | " 134.0 \n",
281 | " -16.0 \n",
282 | " 19.571429 \n",
283 | " \n",
284 | " \n",
285 | " 4 \n",
286 | " 25 \n",
287 | " 172 \n",
288 | " 134.0 \n",
289 | " 38.0 \n",
290 | " 0.0 \n",
291 | " 134.0 \n",
292 | " 38.0 \n",
293 | " 19.571429 \n",
294 | " \n",
295 | " \n",
296 | "
\n",
297 | "
"
298 | ],
299 | "text/plain": [
300 | " Years Salary f0 y-f0 h1 f1 y-f1 h2\n",
301 | "0 5 82 134.0 -52.0 0.0 134.0 -52.0 -45.666667\n",
302 | "1 7 80 134.0 -54.0 0.0 134.0 -54.0 -45.666667\n",
303 | "2 12 103 134.0 -31.0 0.0 134.0 -31.0 -45.666667\n",
304 | "3 23 118 134.0 -16.0 0.0 134.0 -16.0 19.571429\n",
305 | "4 25 172 134.0 38.0 0.0 134.0 38.0 19.571429"
306 | ]
307 | },
308 | "execution_count": 91,
309 | "metadata": {},
310 | "output_type": "execute_result"
311 | }
312 | ],
313 | "source": [
314 | "for i in range(2):\n",
315 | " f = df.Salary.mean()\n",
316 | " if(i>0):\n",
317 | " df['f'+str(i)] = df['f'+str(i-1)] + df['h'+str(i)]\n",
318 | " else:\n",
319 | " df['f'+str(i)] = f\n",
320 | " df['y-f'+str(i)] = df.Salary - df['f'+str(i)]\n",
321 | " splitIndex = np.random.randint(0,df.shape[0]-1)\n",
322 | " a= []\n",
323 | " h_upper = df['y-f'+str(i)][0:splitIndex].mean()\n",
324 | " h_bottom = df['y-f'+str(i)][splitIndex:].mean()\n",
325 | " for j in range(splitIndex):\n",
326 | " a.append(h_upper)\n",
327 | " for j in range(df.shape[0]-splitIndex):\n",
328 | " a.append(h_bottom)\n",
329 | " df['h'+str(i+1)] = a\n",
330 | " \n",
331 | "df.head()\n"
332 | ]
333 | },
334 | {
335 | "cell_type": "markdown",
336 | "metadata": {},
337 | "source": [
338 | "#### This is how the dataset looks after 2 iterations"
339 | ]
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "metadata": {},
344 | "source": [
345 | "If we continue to iterate for 100 times , we can see the Loss of MSE(Fi) decreasing by a huge margin"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 92,
351 | "metadata": {},
352 | "outputs": [
353 | {
354 | "data": {
355 | "text/html": [
356 | "\n",
357 | "\n",
370 | "
\n",
371 | " \n",
372 | " \n",
373 | " \n",
374 | " Years \n",
375 | " Salary \n",
376 | " f0 \n",
377 | " y-f0 \n",
378 | " h1 \n",
379 | " f1 \n",
380 | " y-f1 \n",
381 | " h2 \n",
382 | " f2 \n",
383 | " y-f2 \n",
384 | " ... \n",
385 | " h97 \n",
386 | " f97 \n",
387 | " y-f97 \n",
388 | " h98 \n",
389 | " f98 \n",
390 | " y-f98 \n",
391 | " h99 \n",
392 | " f99 \n",
393 | " y-f99 \n",
394 | " h100 \n",
395 | " \n",
396 | " \n",
397 | " \n",
398 | " \n",
399 | " 0 \n",
400 | " 5 \n",
401 | " 82 \n",
402 | " 134.0 \n",
403 | " -52.0 \n",
404 | " -45.666667 \n",
405 | " 88.333333 \n",
406 | " -6.333333 \n",
407 | " -9.473903e-15 \n",
408 | " 88.333333 \n",
409 | " -6.333333 \n",
410 | " ... \n",
411 | " 1.776357e-15 \n",
412 | " 85.594638 \n",
413 | " -3.594638 \n",
414 | " -3.083745 \n",
415 | " 82.510894 \n",
416 | " -0.510894 \n",
417 | " 0.387978 \n",
418 | " 82.898871 \n",
419 | " -0.898871 \n",
420 | " -0.898871 \n",
421 | " \n",
422 | " \n",
423 | " 1 \n",
424 | " 7 \n",
425 | " 80 \n",
426 | " 134.0 \n",
427 | " -54.0 \n",
428 | " -45.666667 \n",
429 | " 88.333333 \n",
430 | " -8.333333 \n",
431 | " -9.473903e-15 \n",
432 | " 88.333333 \n",
433 | " -8.333333 \n",
434 | " ... \n",
435 | " 1.776357e-15 \n",
436 | " 84.605788 \n",
437 | " -4.605788 \n",
438 | " -3.083745 \n",
439 | " 81.522043 \n",
440 | " -1.522043 \n",
441 | " 0.387978 \n",
442 | " 81.910021 \n",
443 | " -1.910021 \n",
444 | " 0.099875 \n",
445 | " \n",
446 | " \n",
447 | " 2 \n",
448 | " 12 \n",
449 | " 103 \n",
450 | " 134.0 \n",
451 | " -31.0 \n",
452 | " -45.666667 \n",
453 | " 88.333333 \n",
454 | " 14.666667 \n",
455 | " -9.473903e-15 \n",
456 | " 88.333333 \n",
457 | " 14.666667 \n",
458 | " ... \n",
459 | " 1.776357e-15 \n",
460 | " 97.285360 \n",
461 | " 5.714640 \n",
462 | " -3.083745 \n",
463 | " 94.201615 \n",
464 | " 8.798385 \n",
465 | " 0.387978 \n",
466 | " 94.589593 \n",
467 | " 8.410407 \n",
468 | " 0.099875 \n",
469 | " \n",
470 | " \n",
471 | " 3 \n",
472 | " 23 \n",
473 | " 118 \n",
474 | " 134.0 \n",
475 | " -16.0 \n",
476 | " 19.571429 \n",
477 | " 153.571429 \n",
478 | " -35.571429 \n",
479 | " -1.218073e-14 \n",
480 | " 153.571429 \n",
481 | " -35.571429 \n",
482 | " ... \n",
483 | " 1.776357e-15 \n",
484 | " 127.849193 \n",
485 | " -9.849193 \n",
486 | " -3.083745 \n",
487 | " 124.765448 \n",
488 | " -6.765448 \n",
489 | " 0.387978 \n",
490 | " 125.153425 \n",
491 | " -7.153425 \n",
492 | " 0.099875 \n",
493 | " \n",
494 | " \n",
495 | " 4 \n",
496 | " 25 \n",
497 | " 172 \n",
498 | " 134.0 \n",
499 | " 38.0 \n",
500 | " 19.571429 \n",
501 | " 153.571429 \n",
502 | " 18.428571 \n",
503 | " -1.218073e-14 \n",
504 | " 153.571429 \n",
505 | " 18.428571 \n",
506 | " ... \n",
507 | " 1.776357e-15 \n",
508 | " 157.370016 \n",
509 | " 14.629984 \n",
510 | " 2.055830 \n",
511 | " 159.425846 \n",
512 | " 12.574154 \n",
513 | " 0.387978 \n",
514 | " 159.813823 \n",
515 | " 12.186177 \n",
516 | " 0.099875 \n",
517 | " \n",
518 | " \n",
519 | "
\n",
520 | "
5 rows × 302 columns
\n",
521 | "
"
522 | ],
523 | "text/plain": [
524 | " Years Salary f0 y-f0 h1 f1 y-f1 h2 \\\n",
525 | "0 5 82 134.0 -52.0 -45.666667 88.333333 -6.333333 -9.473903e-15 \n",
526 | "1 7 80 134.0 -54.0 -45.666667 88.333333 -8.333333 -9.473903e-15 \n",
527 | "2 12 103 134.0 -31.0 -45.666667 88.333333 14.666667 -9.473903e-15 \n",
528 | "3 23 118 134.0 -16.0 19.571429 153.571429 -35.571429 -1.218073e-14 \n",
529 | "4 25 172 134.0 38.0 19.571429 153.571429 18.428571 -1.218073e-14 \n",
530 | "\n",
531 | " f2 y-f2 ... h97 f97 y-f97 h98 \\\n",
532 | "0 88.333333 -6.333333 ... 1.776357e-15 85.594638 -3.594638 -3.083745 \n",
533 | "1 88.333333 -8.333333 ... 1.776357e-15 84.605788 -4.605788 -3.083745 \n",
534 | "2 88.333333 14.666667 ... 1.776357e-15 97.285360 5.714640 -3.083745 \n",
535 | "3 153.571429 -35.571429 ... 1.776357e-15 127.849193 -9.849193 -3.083745 \n",
536 | "4 153.571429 18.428571 ... 1.776357e-15 157.370016 14.629984 2.055830 \n",
537 | "\n",
538 | " f98 y-f98 h99 f99 y-f99 h100 \n",
539 | "0 82.510894 -0.510894 0.387978 82.898871 -0.898871 -0.898871 \n",
540 | "1 81.522043 -1.522043 0.387978 81.910021 -1.910021 0.099875 \n",
541 | "2 94.201615 8.798385 0.387978 94.589593 8.410407 0.099875 \n",
542 | "3 124.765448 -6.765448 0.387978 125.153425 -7.153425 0.099875 \n",
543 | "4 159.425846 12.574154 0.387978 159.813823 12.186177 0.099875 \n",
544 | "\n",
545 | "[5 rows x 302 columns]"
546 | ]
547 | },
548 | "execution_count": 92,
549 | "metadata": {},
550 | "output_type": "execute_result"
551 | }
552 | ],
553 | "source": [
554 | "for i in range(100):\n",
555 | " f = df.Salary.mean()\n",
556 | " if(i>0):\n",
557 | " df['f'+str(i)] = df['f'+str(i-1)] + df['h'+str(i)]\n",
558 | " else:\n",
559 | " df['f'+str(i)] = f\n",
560 | " df['y-f'+str(i)] = df.Salary - df['f'+str(i)]\n",
561 | " splitIndex = np.random.randint(0,df.shape[0]-1)\n",
562 | " a= []\n",
563 | " h_upper = df['y-f'+str(i)][0:splitIndex].mean()\n",
564 | " h_bottom = df['y-f'+str(i)][splitIndex:].mean()\n",
565 | " for j in range(splitIndex):\n",
566 | " a.append(h_upper)\n",
567 | " for j in range(df.shape[0]-splitIndex):\n",
568 | " a.append(h_bottom)\n",
569 | " df['h'+str(i+1)] = a\n",
570 | " \n",
571 | "df.head()\n"
572 | ]
573 | },
574 | {
575 | "cell_type": "markdown",
576 | "metadata": {},
577 | "source": [
578 | "#### Following is the graph for Iteration 1 , 10 and 99\n",
579 | "#### We can see the loss decreasing and the model adapting to the dataset as the iteration increases"
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": 94,
585 | "metadata": {},
586 | "outputs": [
587 | {
588 | "data": {
589 | "text/plain": [
590 | ""
591 | ]
592 | },
593 | "execution_count": 94,
594 | "metadata": {},
595 | "output_type": "execute_result"
596 | },
597 | {
598 | "data": {
599 | "image/png": "\n",
600 | "text/plain": [
601 | ""
602 | ]
603 | },
604 | "metadata": {
605 | "needs_background": "light"
606 | },
607 | "output_type": "display_data"
608 | }
609 | ],
610 | "source": [
611 | "plt.figure(figsize=(15,10))\n",
612 | "plt.scatter(df.Years,df.Salary)\n",
613 | "plt.plot(df.Years,df.f1,label = 'f1')\n",
614 | "plt.plot(df.Years,df.f10,label = 'f10')\n",
615 | "plt.plot(df.Years,df.f99,label = 'f99')\n",
616 | "plt.legend()"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {},
622 | "source": [
623 | "# END"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": null,
629 | "metadata": {},
630 | "outputs": [],
631 | "source": []
632 | }
633 | ],
634 | "metadata": {
635 | "kernelspec": {
636 | "display_name": "Python 3",
637 | "language": "python",
638 | "name": "python3"
639 | },
640 | "language_info": {
641 | "codemirror_mode": {
642 | "name": "ipython",
643 | "version": 3
644 | },
645 | "file_extension": ".py",
646 | "mimetype": "text/x-python",
647 | "name": "python",
648 | "nbconvert_exporter": "python",
649 | "pygments_lexer": "ipython3",
650 | "version": "3.6.7"
651 | }
652 | },
653 | "nbformat": 4,
654 | "nbformat_minor": 2
655 | }
656 |
--------------------------------------------------------------------------------