├── .gitignore
├── LICENSE
├── README.md
├── figs
├── Danmini_Doorbell-heatmap.png
├── Ecobee_Thermostat-heatmap.png
├── Ennio_Doorbell-heatmap.png
├── Philips_B120N10_Baby_Monitor-heatmap.png
├── Provision_PT_737E_Security_Camera-heatmap.png
├── Provision_PT_838_Security_Camera-heatmap.png
├── Samsung_SNH_1011_N_Webcam-heatmap.png
├── SimpleHome_XCS7_1002_WHT_Security_Camera-heatmap.png
└── SimpleHome_XCS7_1003_WHT_Security_Camera-heatmap.png
├── models
├── Danmini_Doorbell
│ ├── Danmini_Doorbell_without_scaling_unbalanced_model.pkl
│ └── report.txt
└── generic_without_scaling_unbalanced_model.pkl
├── reports
├── Botnet-detection-on-IoT-devices.pdf
├── BotnetDetection-IoTDevices-Presentation.pdf
├── benign_profile.html
├── gafgyt_profile.html
├── mirai_profile.html
├── model_training_results.docx
├── pycaret-model.ipynb
└── s3-preprocessing&training.html
├── s1-data-wrangling.ipynb
├── s2-eda.ipynb
├── s3-preprocessing&training.ipynb
├── s4-modeling.ipynb
└── scripts
├── models.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Dineshkumar Sundaram
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Botnet deection on IoT Devices
2 | 
3 | ### Introduction:
4 | Internet of Things (IoT) devices are widely used in modern homes and every part of our lives, because they are not that sophisticated, it becomes an easy target for Denial of service attack. IoT devices can be used as bots to launch a distributed DOS attack.
5 |
6 | The rapid growth of IoT devices which can be more easily compromised than desktop computers has led to an increase in the occurrences of IoT based botnet attacks. Botnet attack is a type of DDOS attack, where the attacker uses a large number of IoT devices to participate in the DOS to overwhelm a specific target. THis type of attack is hard to detect, since the device keeps functioning normally, and the user or the owner of the device will not notice if his device is a part of an attack, in some cases the device may suffer from delay of its functionality.
7 |
8 | Botnets such as Mirai are typically constructed in several distinct operational steps
9 | - propagation
10 | - infection
11 | - C&C communication
12 | - execution of attacks.
13 |
14 |
15 |
16 | ### Dataset:
17 | [Download](https://archive.ics.uci.edu/ml/datasets/detection_of_IoT_botnet_attacks_N_BaIoT)
18 | The N-BaIoT dataset was collected from a real network traffic of nine IoT devices. The data contains both benign and attack traffic. The dataset is separated where each device has its files, each file contains a type of traffic such as normal traffic or attacks. There are ten classes of attacks that were generated using two families of botnet attack codes from the github (Mirai, Bashlite). N-BaIoT dataset has 115 features, all of these features are statistical analysis, which is extracted from the packet traffic for various periods.
19 |
20 | The dataset contains the following nine device normal & attack traffic.
21 | - Danmini - Doorbell
22 | - Ennio - Doorbell
23 | - Ecobee - Thermostat
24 | - Philips B120N/10 - Baby Monitor
25 | - Provision PT-737E - Security Camera
26 | - Provision PT-838 - Security Camera
27 | - Simple Home XCS7-1002-WHT - Security Camera
28 | - Simple Home XCS7-1003-WHT - Security Camera
29 | - Samsung SNH 1011 N - Web cam
30 |
31 | #### Feature information:
32 | ##### Stream aggregation:
33 | - H: ("Source IP" in N-BaIoT paper) Stats summarizing the recent traffic from this packet's host (IP)
34 | - MI: ("Source MAC-IP" in N-BaIoT paper) Stats summarizing the recent traffic from this packet's host (IP + MAC)
35 | - HH: ("Channel" in N-BaIoT paper) Stats summarizing the recent traffic going from this packet's host (IP) to the packet's destination host.
36 | - HH_jit: ("Channel jitter" in N-BaIoT paper) Stats summarizing the jitter of the traffic going from this packet's host (IP) to the packet's destination host.
37 | - HpHp: ("Socket" in N-BaIoT paper) Stats summarizing the recent traffic going from this packet's host+port (IP) to the packet's destination host+port. Example 192.168.4.2:1242 -> 192.168.4.12:80
38 |
39 | - Time-frame (The decay factor Lambda used in the damped window):
40 | - How much recent history of the stream is capture in these statistics
41 | - L5, L3, L1, L0.1 and L0.01
42 |
43 | - The statistics extracted from the packet stream:
44 | - weight: The weight of the stream (can be viewed as the number of items observed in recent history)
45 | - mean: …
46 | - std: …
47 | - radius: The root squared sum of the two streams' variances
48 | - magnitude: The root squared sum of the two streams' means
49 | - cov: An approximated covariance between two streams
50 | - pcc: An approximated correlation coefficient between two streams
51 |
52 | ### EDA
53 |
54 | | Device | Chart |
55 | | --- | --- |
56 | | Ennio Door bell | 
57 | | Danmin Door bell | 
58 | | Ecobee Thermostat | 
59 | | Ennio Door bell | 
60 | | Danmin Door bell | 
61 | | Ecobee Thermostat | 
62 | | Ennio Door bell | 
63 | | Danmin Door bell | 
64 | | Ecobee Thermostat | 
65 |
66 | ### Pre processing & Training
67 |
68 | ### Modeling
69 |
70 | ### Deploymnet
71 |
72 | ### Future works
73 |
74 | ### Credits & Links
75 |
76 |
--------------------------------------------------------------------------------
/figs/Danmini_Doorbell-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Danmini_Doorbell-heatmap.png
--------------------------------------------------------------------------------
/figs/Ecobee_Thermostat-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Ecobee_Thermostat-heatmap.png
--------------------------------------------------------------------------------
/figs/Ennio_Doorbell-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Ennio_Doorbell-heatmap.png
--------------------------------------------------------------------------------
/figs/Philips_B120N10_Baby_Monitor-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Philips_B120N10_Baby_Monitor-heatmap.png
--------------------------------------------------------------------------------
/figs/Provision_PT_737E_Security_Camera-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Provision_PT_737E_Security_Camera-heatmap.png
--------------------------------------------------------------------------------
/figs/Provision_PT_838_Security_Camera-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Provision_PT_838_Security_Camera-heatmap.png
--------------------------------------------------------------------------------
/figs/Samsung_SNH_1011_N_Webcam-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/Samsung_SNH_1011_N_Webcam-heatmap.png
--------------------------------------------------------------------------------
/figs/SimpleHome_XCS7_1002_WHT_Security_Camera-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/SimpleHome_XCS7_1002_WHT_Security_Camera-heatmap.png
--------------------------------------------------------------------------------
/figs/SimpleHome_XCS7_1003_WHT_Security_Camera-heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/figs/SimpleHome_XCS7_1003_WHT_Security_Camera-heatmap.png
--------------------------------------------------------------------------------
/models/Danmini_Doorbell/Danmini_Doorbell_without_scaling_unbalanced_model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/models/Danmini_Doorbell/Danmini_Doorbell_without_scaling_unbalanced_model.pkl
--------------------------------------------------------------------------------
/models/Danmini_Doorbell/report.txt:
--------------------------------------------------------------------------------
1 | Classification Report on Test Set
2 |
3 |
4 | precision recall f1-score support
5 |
6 | benign 1.00 1.00 1.00 10449
7 | gafgyt 1.00 1.00 1.00 66325
8 | mirai 1.00 1.00 1.00 137069
9 |
10 | accuracy 1.00 213843
11 | macro avg 1.00 1.00 1.00 213843
12 | weighted avg 1.00 1.00 1.00 213843
13 |
14 |
15 |
16 | Confusion Matrix on Test Set
17 |
18 |
19 | [[ 10447 2 0]
20 | [ 1 66324 0]
21 | [ 0 0 137069]]
22 |
23 |
--------------------------------------------------------------------------------
/models/generic_without_scaling_unbalanced_model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/models/generic_without_scaling_unbalanced_model.pkl
--------------------------------------------------------------------------------
/reports/Botnet-detection-on-IoT-devices.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/reports/Botnet-detection-on-IoT-devices.pdf
--------------------------------------------------------------------------------
/reports/BotnetDetection-IoTDevices-Presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/reports/BotnetDetection-IoTDevices-Presentation.pdf
--------------------------------------------------------------------------------
/reports/model_training_results.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dineshh912/IoT-botnet-attack-detection/93b88f6ff52b6e09324eaf8a0199ab04dadc998d/reports/model_training_results.docx
--------------------------------------------------------------------------------
/reports/pycaret-model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Importing necssary modules\n",
10 | "import pandas as pd\n",
11 | "from datetime import datetime\n",
12 | "from scripts.utils import load_data, load_data_multi_label, load_data_all"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# Data folder path and Extention of the data files\n",
22 | "base_directory = '../rawdata'\n",
23 | "file_extension = \"*.csv\""
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 3,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "danmini_doorbell_df = load_data(base_directory, file_extension, 'Danmini_Doorbell')"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 4,
38 | "metadata": {},
39 | "outputs": [
40 | {
41 | "data": {
42 | "text/plain": [
43 | "(1018298, 117)"
44 | ]
45 | },
46 | "execution_count": 4,
47 | "metadata": {},
48 | "output_type": "execute_result"
49 | }
50 | ],
51 | "source": [
52 | "danmini_doorbell_df.shape"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 6,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "df = danmini_doorbell_df.sample(frac=0.3)"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 7,
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "data": {
71 | "text/plain": [
72 | "(305489, 117)"
73 | ]
74 | },
75 | "execution_count": 7,
76 | "metadata": {},
77 | "output_type": "execute_result"
78 | }
79 | ],
80 | "source": [
81 | "df.shape"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 8,
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "'2.1.2'"
93 | ]
94 | },
95 | "execution_count": 8,
96 | "metadata": {},
97 | "output_type": "execute_result"
98 | }
99 | ],
100 | "source": [
101 | "# check version\n",
102 | "from pycaret.utils import version\n",
103 | "version()"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 9,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "from pycaret.classification import *"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 10,
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "name": "stdout",
122 | "output_type": "stream",
123 | "text": [
124 | "Setup Succesfully Completed!\n"
125 | ]
126 | },
127 | {
128 | "data": {
129 | "text/html": [
130 | "
| Description | Value |
\n",
132 | " \n",
133 | " 0 | \n",
134 | " session_id | \n",
135 | " 123 | \n",
136 | "
\n",
137 | " \n",
138 | " 1 | \n",
139 | " Target Type | \n",
140 | " Multiclass | \n",
141 | "
\n",
142 | " \n",
143 | " 2 | \n",
144 | " Label Encoded | \n",
145 | " benign: 0, gafgyt: 1, mirai: 2 | \n",
146 | "
\n",
147 | " \n",
148 | " 3 | \n",
149 | " Original Data | \n",
150 | " (305489, 117) | \n",
151 | "
\n",
152 | " \n",
153 | " 4 | \n",
154 | " Missing Values | \n",
155 | " False | \n",
156 | "
\n",
157 | " \n",
158 | " 5 | \n",
159 | " Numeric Features | \n",
160 | " 115 | \n",
161 | "
\n",
162 | " \n",
163 | " 6 | \n",
164 | " Categorical Features | \n",
165 | " 1 | \n",
166 | "
\n",
167 | " \n",
168 | " 7 | \n",
169 | " Ordinal Features | \n",
170 | " False | \n",
171 | "
\n",
172 | " \n",
173 | " 8 | \n",
174 | " High Cardinality Features | \n",
175 | " False | \n",
176 | "
\n",
177 | " \n",
178 | " 9 | \n",
179 | " High Cardinality Method | \n",
180 | " None | \n",
181 | "
\n",
182 | " \n",
183 | " 10 | \n",
184 | " Sampled Data | \n",
185 | " (305489, 117) | \n",
186 | "
\n",
187 | " \n",
188 | " 11 | \n",
189 | " Transformed Train Set | \n",
190 | " (213842, 116) | \n",
191 | "
\n",
192 | " \n",
193 | " 12 | \n",
194 | " Transformed Test Set | \n",
195 | " (91647, 116) | \n",
196 | "
\n",
197 | " \n",
198 | " 13 | \n",
199 | " Numeric Imputer | \n",
200 | " mean | \n",
201 | "
\n",
202 | " \n",
203 | " 14 | \n",
204 | " Categorical Imputer | \n",
205 | " constant | \n",
206 | "
\n",
207 | " \n",
208 | " 15 | \n",
209 | " Normalize | \n",
210 | " False | \n",
211 | "
\n",
212 | " \n",
213 | " 16 | \n",
214 | " Normalize Method | \n",
215 | " None | \n",
216 | "
\n",
217 | " \n",
218 | " 17 | \n",
219 | " Transformation | \n",
220 | " False | \n",
221 | "
\n",
222 | " \n",
223 | " 18 | \n",
224 | " Transformation Method | \n",
225 | " None | \n",
226 | "
\n",
227 | " \n",
228 | " 19 | \n",
229 | " PCA | \n",
230 | " False | \n",
231 | "
\n",
232 | " \n",
233 | " 20 | \n",
234 | " PCA Method | \n",
235 | " None | \n",
236 | "
\n",
237 | " \n",
238 | " 21 | \n",
239 | " PCA Components | \n",
240 | " None | \n",
241 | "
\n",
242 | " \n",
243 | " 22 | \n",
244 | " Ignore Low Variance | \n",
245 | " False | \n",
246 | "
\n",
247 | " \n",
248 | " 23 | \n",
249 | " Combine Rare Levels | \n",
250 | " False | \n",
251 | "
\n",
252 | " \n",
253 | " 24 | \n",
254 | " Rare Level Threshold | \n",
255 | " None | \n",
256 | "
\n",
257 | " \n",
258 | " 25 | \n",
259 | " Numeric Binning | \n",
260 | " False | \n",
261 | "
\n",
262 | " \n",
263 | " 26 | \n",
264 | " Remove Outliers | \n",
265 | " False | \n",
266 | "
\n",
267 | " \n",
268 | " 27 | \n",
269 | " Outliers Threshold | \n",
270 | " None | \n",
271 | "
\n",
272 | " \n",
273 | " 28 | \n",
274 | " Remove Multicollinearity | \n",
275 | " False | \n",
276 | "
\n",
277 | " \n",
278 | " 29 | \n",
279 | " Multicollinearity Threshold | \n",
280 | " None | \n",
281 | "
\n",
282 | " \n",
283 | " 30 | \n",
284 | " Clustering | \n",
285 | " False | \n",
286 | "
\n",
287 | " \n",
288 | " 31 | \n",
289 | " Clustering Iteration | \n",
290 | " None | \n",
291 | "
\n",
292 | " \n",
293 | " 32 | \n",
294 | " Polynomial Features | \n",
295 | " False | \n",
296 | "
\n",
297 | " \n",
298 | " 33 | \n",
299 | " Polynomial Degree | \n",
300 | " None | \n",
301 | "
\n",
302 | " \n",
303 | " 34 | \n",
304 | " Trignometry Features | \n",
305 | " False | \n",
306 | "
\n",
307 | " \n",
308 | " 35 | \n",
309 | " Polynomial Threshold | \n",
310 | " None | \n",
311 | "
\n",
312 | " \n",
313 | " 36 | \n",
314 | " Group Features | \n",
315 | " False | \n",
316 | "
\n",
317 | " \n",
318 | " 37 | \n",
319 | " Feature Selection | \n",
320 | " False | \n",
321 | "
\n",
322 | " \n",
323 | " 38 | \n",
324 | " Features Selection Threshold | \n",
325 | " None | \n",
326 | "
\n",
327 | " \n",
328 | " 39 | \n",
329 | " Feature Interaction | \n",
330 | " False | \n",
331 | "
\n",
332 | " \n",
333 | " 40 | \n",
334 | " Feature Ratio | \n",
335 | " False | \n",
336 | "
\n",
337 | " \n",
338 | " 41 | \n",
339 | " Interaction Threshold | \n",
340 | " None | \n",
341 | "
\n",
342 | " \n",
343 | " 42 | \n",
344 | " Fix Imbalance | \n",
345 | " False | \n",
346 | "
\n",
347 | " \n",
348 | " 43 | \n",
349 | " Fix Imbalance Method | \n",
350 | " SMOTE | \n",
351 | "
\n",
352 | "
"
353 | ],
354 | "text/plain": [
355 | ""
356 | ]
357 | },
358 | "metadata": {},
359 | "output_type": "display_data"
360 | }
361 | ],
362 | "source": [
363 | "clf1 = setup(df, target = 'label', session_id=123, experiment_name='doorbell-1')"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 11,
369 | "metadata": {},
370 | "outputs": [
371 | {
372 | "data": {
373 | "application/vnd.jupyter.widget-view+json": {
374 | "model_id": "acf36808e95b4f9180d69eb1f167ad83",
375 | "version_major": 2,
376 | "version_minor": 0
377 | },
378 | "text/plain": [
379 | "IntProgress(value=0, description='Processing: ', max=176)"
380 | ]
381 | },
382 | "metadata": {},
383 | "output_type": "display_data"
384 | },
385 | {
386 | "data": {
387 | "text/html": [
388 | "\n",
389 | "\n",
402 | "
\n",
403 | " \n",
404 | " \n",
405 | " | \n",
406 | " | \n",
407 | " | \n",
408 | "
\n",
409 | " \n",
410 | " | \n",
411 | " | \n",
412 | " | \n",
413 | "
\n",
414 | " \n",
415 | " \n",
416 | " \n",
417 | " Initiated | \n",
418 | " . . . . . . . . . . . . . . . . . . | \n",
419 | " 11:21:37 | \n",
420 | "
\n",
421 | " \n",
422 | " Status | \n",
423 | " . . . . . . . . . . . . . . . . . . | \n",
424 | " Finalizing Model | \n",
425 | "
\n",
426 | " \n",
427 | " ETC | \n",
428 | " . . . . . . . . . . . . . . . . . . | \n",
429 | " Almost Finished | \n",
430 | "
\n",
431 | " \n",
432 | "
\n",
433 | "
"
434 | ],
435 | "text/plain": [
436 | " \n",
437 | " \n",
438 | "Initiated . . . . . . . . . . . . . . . . . . 11:21:37\n",
439 | "Status . . . . . . . . . . . . . . . . . . Finalizing Model\n",
440 | "ETC . . . . . . . . . . . . . . . . . . Almost Finished"
441 | ]
442 | },
443 | "metadata": {},
444 | "output_type": "display_data"
445 | },
446 | {
447 | "data": {
448 | "text/html": [
449 | "\n",
450 | "\n",
463 | "
\n",
464 | " \n",
465 | " \n",
466 | " | \n",
467 | " Model | \n",
468 | " Accuracy | \n",
469 | " AUC | \n",
470 | " Recall | \n",
471 | " Prec. | \n",
472 | " F1 | \n",
473 | " Kappa | \n",
474 | " MCC | \n",
475 | " TT (Sec) | \n",
476 | "
\n",
477 | " \n",
478 | " \n",
479 | " \n",
480 | " 0 | \n",
481 | " Random Forest Classifier | \n",
482 | " 1.0000 | \n",
483 | " 0.0 | \n",
484 | " 0.9999 | \n",
485 | " 1.0000 | \n",
486 | " 1.0000 | \n",
487 | " 0.9999 | \n",
488 | " 0.9999 | \n",
489 | " 2.8086 | \n",
490 | "
\n",
491 | " \n",
492 | " 1 | \n",
493 | " Decision Tree Classifier | \n",
494 | " 0.9998 | \n",
495 | " 0.0 | \n",
496 | " 0.9997 | \n",
497 | " 0.9998 | \n",
498 | " 0.9998 | \n",
499 | " 0.9997 | \n",
500 | " 0.9997 | \n",
501 | " 21.2064 | \n",
502 | "
\n",
503 | " \n",
504 | " 2 | \n",
505 | " K Neighbors Classifier | \n",
506 | " 0.9980 | \n",
507 | " 0.0 | \n",
508 | " 0.9935 | \n",
509 | " 0.9980 | \n",
510 | " 0.9980 | \n",
511 | " 0.9960 | \n",
512 | " 0.9960 | \n",
513 | " 25.6996 | \n",
514 | "
\n",
515 | " \n",
516 | " 3 | \n",
517 | " Ridge Classifier | \n",
518 | " 0.9969 | \n",
519 | " 0.0 | \n",
520 | " 0.9958 | \n",
521 | " 0.9969 | \n",
522 | " 0.9969 | \n",
523 | " 0.9936 | \n",
524 | " 0.9936 | \n",
525 | " 1.2116 | \n",
526 | "
\n",
527 | " \n",
528 | " 4 | \n",
529 | " Ada Boost Classifier | \n",
530 | " 0.9245 | \n",
531 | " 0.0 | \n",
532 | " 0.9202 | \n",
533 | " 0.9340 | \n",
534 | " 0.9216 | \n",
535 | " 0.8392 | \n",
536 | " 0.8522 | \n",
537 | " 144.0179 | \n",
538 | "
\n",
539 | " \n",
540 | " 5 | \n",
541 | " Quadratic Discriminant Analysis | \n",
542 | " 0.6834 | \n",
543 | " 0.0 | \n",
544 | " 0.8271 | \n",
545 | " 0.8491 | \n",
546 | " 0.6724 | \n",
547 | " 0.4799 | \n",
548 | " 0.5712 | \n",
549 | " 5.3659 | \n",
550 | "
\n",
551 | " \n",
552 | " 6 | \n",
553 | " Naive Bayes | \n",
554 | " 0.6585 | \n",
555 | " 0.0 | \n",
556 | " 0.3543 | \n",
557 | " 0.7312 | \n",
558 | " 0.5410 | \n",
559 | " 0.0693 | \n",
560 | " 0.1829 | \n",
561 | " 0.8091 | \n",
562 | "
\n",
563 | " \n",
564 | " 7 | \n",
565 | " SVM - Linear Kernel | \n",
566 | " 0.4204 | \n",
567 | " 0.0 | \n",
568 | " 0.3930 | \n",
569 | " 0.4682 | \n",
570 | " 0.3959 | \n",
571 | " 0.0762 | \n",
572 | " 0.1060 | \n",
573 | " 6.0382 | \n",
574 | "
\n",
575 | " \n",
576 | " 8 | \n",
577 | " Logistic Regression | \n",
578 | " 0.0486 | \n",
579 | " 0.0 | \n",
580 | " 0.3333 | \n",
581 | " 0.0024 | \n",
582 | " 0.0045 | \n",
583 | " 0.0000 | \n",
584 | " 0.0000 | \n",
585 | " 4.2906 | \n",
586 | "
\n",
587 | " \n",
588 | "
\n",
589 | "
"
590 | ],
591 | "text/plain": [
592 | " Model Accuracy AUC Recall Prec. F1 \\\n",
593 | "0 Random Forest Classifier 1.0000 0.0 0.9999 1.0000 1.0000 \n",
594 | "1 Decision Tree Classifier 0.9998 0.0 0.9997 0.9998 0.9998 \n",
595 | "2 K Neighbors Classifier 0.9980 0.0 0.9935 0.9980 0.9980 \n",
596 | "3 Ridge Classifier 0.9969 0.0 0.9958 0.9969 0.9969 \n",
597 | "4 Ada Boost Classifier 0.9245 0.0 0.9202 0.9340 0.9216 \n",
598 | "5 Quadratic Discriminant Analysis 0.6834 0.0 0.8271 0.8491 0.6724 \n",
599 | "6 Naive Bayes 0.6585 0.0 0.3543 0.7312 0.5410 \n",
600 | "7 SVM - Linear Kernel 0.4204 0.0 0.3930 0.4682 0.3959 \n",
601 | "8 Logistic Regression 0.0486 0.0 0.3333 0.0024 0.0045 \n",
602 | "\n",
603 | " Kappa MCC TT (Sec) \n",
604 | "0 0.9999 0.9999 2.8086 \n",
605 | "1 0.9997 0.9997 21.2064 \n",
606 | "2 0.9960 0.9960 25.6996 \n",
607 | "3 0.9936 0.9936 1.2116 \n",
608 | "4 0.8392 0.8522 144.0179 \n",
609 | "5 0.4799 0.5712 5.3659 \n",
610 | "6 0.0693 0.1829 0.8091 \n",
611 | "7 0.0762 0.1060 6.0382 \n",
612 | "8 0.0000 0.0000 4.2906 "
613 | ]
614 | },
615 | "metadata": {},
616 | "output_type": "display_data"
617 | },
618 | {
619 | "ename": "KeyboardInterrupt",
620 | "evalue": "",
621 | "output_type": "error",
622 | "traceback": [
623 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
624 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
625 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mbest_model\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompare_models\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
626 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\pycaret\\classification.py\u001b[0m in \u001b[0;36mcompare_models\u001b[1;34m(exclude, include, fold, round, sort, n_select, budget_time, turbo, verbose)\u001b[0m\n\u001b[0;32m 2455\u001b[0m \u001b[0mtime_start\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2456\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Fitting Model\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2457\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mXtrain\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mytrain\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2458\u001b[0m \u001b[0mlogger\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Evaluating Metrics\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2459\u001b[0m \u001b[0mtime_end\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
627 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\sklearn\\ensemble\\_gb.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight, monitor)\u001b[0m\n\u001b[0;32m 496\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 497\u001b[0m \u001b[1;31m# fit the boosting stages\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 498\u001b[1;33m n_stages = self._fit_stages(\n\u001b[0m\u001b[0;32m 499\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mraw_predictions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_rng\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_val\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 500\u001b[0m sample_weight_val, begin_at_stage, monitor, X_idx_sorted)\n",
628 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\sklearn\\ensemble\\_gb.py\u001b[0m in \u001b[0;36m_fit_stages\u001b[1;34m(self, X, y, raw_predictions, sample_weight, random_state, X_val, y_val, sample_weight_val, begin_at_stage, monitor, X_idx_sorted)\u001b[0m\n\u001b[0;32m 553\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 554\u001b[0m \u001b[1;31m# fit next stage of trees\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 555\u001b[1;33m raw_predictions = self._fit_stage(\n\u001b[0m\u001b[0;32m 556\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mraw_predictions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_mask\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 557\u001b[0m random_state, X_idx_sorted, X_csc, X_csr)\n",
629 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\sklearn\\ensemble\\_gb.py\u001b[0m in \u001b[0;36m_fit_stage\u001b[1;34m(self, i, X, y, raw_predictions, sample_weight, sample_mask, random_state, X_idx_sorted, X_csc, X_csr)\u001b[0m\n\u001b[0;32m 209\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 210\u001b[0m \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mX_csr\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mX_csr\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 211\u001b[1;33m tree.fit(X, residual, sample_weight=sample_weight,\n\u001b[0m\u001b[0;32m 212\u001b[0m check_input=False, X_idx_sorted=X_idx_sorted)\n\u001b[0;32m 213\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
630 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\sklearn\\tree\\_classes.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight, check_input, X_idx_sorted)\u001b[0m\n\u001b[0;32m 1240\u001b[0m \"\"\"\n\u001b[0;32m 1241\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1242\u001b[1;33m super().fit(\n\u001b[0m\u001b[0;32m 1243\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1244\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
631 | "\u001b[1;32me:\\springboard\\venv\\lib\\site-packages\\sklearn\\tree\\_classes.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, sample_weight, check_input, X_idx_sorted)\u001b[0m\n\u001b[0;32m 373\u001b[0m min_impurity_split)\n\u001b[0;32m 374\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 375\u001b[1;33m \u001b[0mbuilder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtree_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_idx_sorted\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 376\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 377\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn_outputs_\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mis_classifier\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
632 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
633 | ]
634 | }
635 | ],
636 | "source": [
637 | "best_model = compare_models()"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": 12,
643 | "metadata": {},
644 | "outputs": [
645 | {
646 | "data": {
647 | "text/html": [
648 | "\n",
649 | "\n",
662 | "
\n",
663 | " \n",
664 | " \n",
665 | " | \n",
666 | " Name | \n",
667 | " Reference | \n",
668 | " Turbo | \n",
669 | "
\n",
670 | " \n",
671 | " ID | \n",
672 | " | \n",
673 | " | \n",
674 | " | \n",
675 | "
\n",
676 | " \n",
677 | " \n",
678 | " \n",
679 | " lr | \n",
680 | " Logistic Regression | \n",
681 | " sklearn.linear_model.LogisticRegression | \n",
682 | " True | \n",
683 | "
\n",
684 | " \n",
685 | " knn | \n",
686 | " K Neighbors Classifier | \n",
687 | " sklearn.neighbors.KNeighborsClassifier | \n",
688 | " True | \n",
689 | "
\n",
690 | " \n",
691 | " nb | \n",
692 | " Naive Bayes | \n",
693 | " sklearn.naive_bayes.GaussianNB | \n",
694 | " True | \n",
695 | "
\n",
696 | " \n",
697 | " dt | \n",
698 | " Decision Tree Classifier | \n",
699 | " sklearn.tree.DecisionTreeClassifier | \n",
700 | " True | \n",
701 | "
\n",
702 | " \n",
703 | " svm | \n",
704 | " SVM - Linear Kernel | \n",
705 | " sklearn.linear_model.SGDClassifier | \n",
706 | " True | \n",
707 | "
\n",
708 | " \n",
709 | " rbfsvm | \n",
710 | " SVM - Radial Kernel | \n",
711 | " sklearn.svm.SVC | \n",
712 | " False | \n",
713 | "
\n",
714 | " \n",
715 | " gpc | \n",
716 | " Gaussian Process Classifier | \n",
717 | " sklearn.gaussian_process.GPC | \n",
718 | " False | \n",
719 | "
\n",
720 | " \n",
721 | " mlp | \n",
722 | " MLP Classifier | \n",
723 | " sklearn.neural_network.MLPClassifier | \n",
724 | " False | \n",
725 | "
\n",
726 | " \n",
727 | " ridge | \n",
728 | " Ridge Classifier | \n",
729 | " sklearn.linear_model.RidgeClassifier | \n",
730 | " True | \n",
731 | "
\n",
732 | " \n",
733 | " rf | \n",
734 | " Random Forest Classifier | \n",
735 | " sklearn.ensemble.RandomForestClassifier | \n",
736 | " True | \n",
737 | "
\n",
738 | " \n",
739 | " qda | \n",
740 | " Quadratic Discriminant Analysis | \n",
741 | " sklearn.discriminant_analysis.QDA | \n",
742 | " True | \n",
743 | "
\n",
744 | " \n",
745 | " ada | \n",
746 | " Ada Boost Classifier | \n",
747 | " sklearn.ensemble.AdaBoostClassifier | \n",
748 | " True | \n",
749 | "
\n",
750 | " \n",
751 | " gbc | \n",
752 | " Gradient Boosting Classifier | \n",
753 | " sklearn.ensemble.GradientBoostingClassifier | \n",
754 | " True | \n",
755 | "
\n",
756 | " \n",
757 | " lda | \n",
758 | " Linear Discriminant Analysis | \n",
759 | " sklearn.discriminant_analysis.LDA | \n",
760 | " True | \n",
761 | "
\n",
762 | " \n",
763 | " et | \n",
764 | " Extra Trees Classifier | \n",
765 | " sklearn.ensemble.ExtraTreesClassifier | \n",
766 | " True | \n",
767 | "
\n",
768 | " \n",
769 | " xgboost | \n",
770 | " Extreme Gradient Boosting | \n",
771 | " xgboost.readthedocs.io | \n",
772 | " True | \n",
773 | "
\n",
774 | " \n",
775 | " lightgbm | \n",
776 | " Light Gradient Boosting Machine | \n",
777 | " github.com/microsoft/LightGBM | \n",
778 | " True | \n",
779 | "
\n",
780 | " \n",
781 | " catboost | \n",
782 | " CatBoost Classifier | \n",
783 | " catboost.ai | \n",
784 | " True | \n",
785 | "
\n",
786 | " \n",
787 | "
\n",
788 | "
"
789 | ],
790 | "text/plain": [
791 | " Name \\\n",
792 | "ID \n",
793 | "lr Logistic Regression \n",
794 | "knn K Neighbors Classifier \n",
795 | "nb Naive Bayes \n",
796 | "dt Decision Tree Classifier \n",
797 | "svm SVM - Linear Kernel \n",
798 | "rbfsvm SVM - Radial Kernel \n",
799 | "gpc Gaussian Process Classifier \n",
800 | "mlp MLP Classifier \n",
801 | "ridge Ridge Classifier \n",
802 | "rf Random Forest Classifier \n",
803 | "qda Quadratic Discriminant Analysis \n",
804 | "ada Ada Boost Classifier \n",
805 | "gbc Gradient Boosting Classifier \n",
806 | "lda Linear Discriminant Analysis \n",
807 | "et Extra Trees Classifier \n",
808 | "xgboost Extreme Gradient Boosting \n",
809 | "lightgbm Light Gradient Boosting Machine \n",
810 | "catboost CatBoost Classifier \n",
811 | "\n",
812 | " Reference Turbo \n",
813 | "ID \n",
814 | "lr sklearn.linear_model.LogisticRegression True \n",
815 | "knn sklearn.neighbors.KNeighborsClassifier True \n",
816 | "nb sklearn.naive_bayes.GaussianNB True \n",
817 | "dt sklearn.tree.DecisionTreeClassifier True \n",
818 | "svm sklearn.linear_model.SGDClassifier True \n",
819 | "rbfsvm sklearn.svm.SVC False \n",
820 | "gpc sklearn.gaussian_process.GPC False \n",
821 | "mlp sklearn.neural_network.MLPClassifier False \n",
822 | "ridge sklearn.linear_model.RidgeClassifier True \n",
823 | "rf sklearn.ensemble.RandomForestClassifier True \n",
824 | "qda sklearn.discriminant_analysis.QDA True \n",
825 | "ada sklearn.ensemble.AdaBoostClassifier True \n",
826 | "gbc sklearn.ensemble.GradientBoostingClassifier True \n",
827 | "lda sklearn.discriminant_analysis.LDA True \n",
828 | "et sklearn.ensemble.ExtraTreesClassifier True \n",
829 | "xgboost xgboost.readthedocs.io True \n",
830 | "lightgbm github.com/microsoft/LightGBM True \n",
831 | "catboost catboost.ai True "
832 | ]
833 | },
834 | "execution_count": 12,
835 | "metadata": {},
836 | "output_type": "execute_result"
837 | }
838 | ],
839 | "source": [
840 | "models()"
841 | ]
842 | },
843 | {
844 | "cell_type": "code",
845 | "execution_count": 13,
846 | "metadata": {},
847 | "outputs": [
848 | {
849 | "data": {
850 | "text/html": [
851 | " | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC |
\n",
855 | " \n",
856 | " 0 | \n",
857 | " 0.9999 | \n",
858 | " 0.0000 | \n",
859 | " 0.9999 | \n",
860 | " 0.9999 | \n",
861 | " 0.9999 | \n",
862 | " 0.9999 | \n",
863 | " 0.9999 | \n",
864 | "
\n",
865 | " \n",
866 | " 1 | \n",
867 | " 0.9999 | \n",
868 | " 0.0000 | \n",
869 | " 0.9998 | \n",
870 | " 0.9999 | \n",
871 | " 0.9999 | \n",
872 | " 0.9998 | \n",
873 | " 0.9998 | \n",
874 | "
\n",
875 | " \n",
876 | " 2 | \n",
877 | " 0.9999 | \n",
878 | " 0.0000 | \n",
879 | " 0.9998 | \n",
880 | " 0.9999 | \n",
881 | " 0.9999 | \n",
882 | " 0.9998 | \n",
883 | " 0.9998 | \n",
884 | "
\n",
885 | " \n",
886 | " 3 | \n",
887 | " 1.0000 | \n",
888 | " 0.0000 | \n",
889 | " 1.0000 | \n",
890 | " 1.0000 | \n",
891 | " 1.0000 | \n",
892 | " 1.0000 | \n",
893 | " 1.0000 | \n",
894 | "
\n",
895 | " \n",
896 | " 4 | \n",
897 | " 1.0000 | \n",
898 | " 0.0000 | \n",
899 | " 1.0000 | \n",
900 | " 1.0000 | \n",
901 | " 1.0000 | \n",
902 | " 1.0000 | \n",
903 | " 1.0000 | \n",
904 | "
\n",
905 | " \n",
906 | " Mean | \n",
907 | " 0.9999 | \n",
908 | " 0.0000 | \n",
909 | " 0.9999 | \n",
910 | " 0.9999 | \n",
911 | " 0.9999 | \n",
912 | " 0.9999 | \n",
913 | " 0.9999 | \n",
914 | "
\n",
915 | " \n",
916 | " SD | \n",
917 | " 0.0000 | \n",
918 | " 0.0000 | \n",
919 | " 0.0001 | \n",
920 | " 0.0000 | \n",
921 | " 0.0000 | \n",
922 | " 0.0001 | \n",
923 | " 0.0001 | \n",
924 | "
\n",
925 | "
"
926 | ],
927 | "text/plain": [
928 | ""
929 | ]
930 | },
931 | "metadata": {},
932 | "output_type": "display_data"
933 | }
934 | ],
935 | "source": [
936 | "rf = create_model('rf', fold=5)"
937 | ]
938 | },
939 | {
940 | "cell_type": "code",
941 | "execution_count": 14,
942 | "metadata": {},
943 | "outputs": [
944 | {
945 | "data": {
946 | "text/html": [
947 | " | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC |
\n",
951 | " \n",
952 | " 0 | \n",
953 | " 0.9998 | \n",
954 | " 0.0000 | \n",
955 | " 0.9998 | \n",
956 | " 0.9998 | \n",
957 | " 0.9998 | \n",
958 | " 0.9995 | \n",
959 | " 0.9995 | \n",
960 | "
\n",
961 | " \n",
962 | " 1 | \n",
963 | " 1.0000 | \n",
964 | " 0.0000 | \n",
965 | " 1.0000 | \n",
966 | " 1.0000 | \n",
967 | " 1.0000 | \n",
968 | " 1.0000 | \n",
969 | " 1.0000 | \n",
970 | "
\n",
971 | " \n",
972 | " 2 | \n",
973 | " 1.0000 | \n",
974 | " 0.0000 | \n",
975 | " 0.9997 | \n",
976 | " 1.0000 | \n",
977 | " 1.0000 | \n",
978 | " 0.9999 | \n",
979 | " 0.9999 | \n",
980 | "
\n",
981 | " \n",
982 | " 3 | \n",
983 | " 0.9997 | \n",
984 | " 0.0000 | \n",
985 | " 0.9998 | \n",
986 | " 0.9997 | \n",
987 | " 0.9997 | \n",
988 | " 0.9994 | \n",
989 | " 0.9994 | \n",
990 | "
\n",
991 | " \n",
992 | " 4 | \n",
993 | " 0.9999 | \n",
994 | " 0.0000 | \n",
995 | " 0.9999 | \n",
996 | " 0.9999 | \n",
997 | " 0.9999 | \n",
998 | " 0.9997 | \n",
999 | " 0.9997 | \n",
1000 | "
\n",
1001 | " \n",
1002 | " 5 | \n",
1003 | " 1.0000 | \n",
1004 | " 0.0000 | \n",
1005 | " 0.9999 | \n",
1006 | " 1.0000 | \n",
1007 | " 1.0000 | \n",
1008 | " 0.9999 | \n",
1009 | " 0.9999 | \n",
1010 | "
\n",
1011 | " \n",
1012 | " 6 | \n",
1013 | " 0.9996 | \n",
1014 | " 0.0000 | \n",
1015 | " 0.9991 | \n",
1016 | " 0.9996 | \n",
1017 | " 0.9996 | \n",
1018 | " 0.9992 | \n",
1019 | " 0.9992 | \n",
1020 | "
\n",
1021 | " \n",
1022 | " 7 | \n",
1023 | " 0.9999 | \n",
1024 | " 0.0000 | \n",
1025 | " 0.9999 | \n",
1026 | " 0.9999 | \n",
1027 | " 0.9999 | \n",
1028 | " 0.9998 | \n",
1029 | " 0.9998 | \n",
1030 | "
\n",
1031 | " \n",
1032 | " 8 | \n",
1033 | " 0.9998 | \n",
1034 | " 0.0000 | \n",
1035 | " 0.9995 | \n",
1036 | " 0.9998 | \n",
1037 | " 0.9998 | \n",
1038 | " 0.9995 | \n",
1039 | " 0.9995 | \n",
1040 | "
\n",
1041 | " \n",
1042 | " 9 | \n",
1043 | " 0.9999 | \n",
1044 | " 0.0000 | \n",
1045 | " 0.9996 | \n",
1046 | " 0.9999 | \n",
1047 | " 0.9999 | \n",
1048 | " 0.9998 | \n",
1049 | " 0.9998 | \n",
1050 | "
\n",
1051 | " \n",
1052 | " Mean | \n",
1053 | " 0.9998 | \n",
1054 | " 0.0000 | \n",
1055 | " 0.9997 | \n",
1056 | " 0.9998 | \n",
1057 | " 0.9998 | \n",
1058 | " 0.9997 | \n",
1059 | " 0.9997 | \n",
1060 | "
\n",
1061 | " \n",
1062 | " SD | \n",
1063 | " 0.0001 | \n",
1064 | " 0.0000 | \n",
1065 | " 0.0002 | \n",
1066 | " 0.0001 | \n",
1067 | " 0.0001 | \n",
1068 | " 0.0002 | \n",
1069 | " 0.0002 | \n",
1070 | "
\n",
1071 | "
"
1072 | ],
1073 | "text/plain": [
1074 | ""
1075 | ]
1076 | },
1077 | "metadata": {},
1078 | "output_type": "display_data"
1079 | }
1080 | ],
1081 | "source": [
1082 | "dt = create_model('dt')"
1083 | ]
1084 | },
1085 | {
1086 | "cell_type": "code",
1087 | "execution_count": 15,
1088 | "metadata": {},
1089 | "outputs": [
1090 | {
1091 | "data": {
1092 | "text/html": [
1093 | " | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC |
\n",
1097 | " \n",
1098 | " 0 | \n",
1099 | " 1.0000 | \n",
1100 | " 0.0000 | \n",
1101 | " 1.0000 | \n",
1102 | " 1.0000 | \n",
1103 | " 1.0000 | \n",
1104 | " 1.0000 | \n",
1105 | " 1.0000 | \n",
1106 | "
\n",
1107 | " \n",
1108 | " 1 | \n",
1109 | " 1.0000 | \n",
1110 | " 0.0000 | \n",
1111 | " 0.9999 | \n",
1112 | " 1.0000 | \n",
1113 | " 1.0000 | \n",
1114 | " 0.9999 | \n",
1115 | " 0.9999 | \n",
1116 | "
\n",
1117 | " \n",
1118 | " 2 | \n",
1119 | " 1.0000 | \n",
1120 | " 0.0000 | \n",
1121 | " 0.9997 | \n",
1122 | " 1.0000 | \n",
1123 | " 1.0000 | \n",
1124 | " 0.9999 | \n",
1125 | " 0.9999 | \n",
1126 | "
\n",
1127 | " \n",
1128 | " 3 | \n",
1129 | " 1.0000 | \n",
1130 | " 0.0000 | \n",
1131 | " 1.0000 | \n",
1132 | " 1.0000 | \n",
1133 | " 1.0000 | \n",
1134 | " 0.9999 | \n",
1135 | " 0.9999 | \n",
1136 | "
\n",
1137 | " \n",
1138 | " 4 | \n",
1139 | " 0.9999 | \n",
1140 | " 0.0000 | \n",
1141 | " 0.9999 | \n",
1142 | " 0.9999 | \n",
1143 | " 0.9999 | \n",
1144 | " 0.9998 | \n",
1145 | " 0.9998 | \n",
1146 | "
\n",
1147 | " \n",
1148 | " 5 | \n",
1149 | " 1.0000 | \n",
1150 | " 0.0000 | \n",
1151 | " 1.0000 | \n",
1152 | " 1.0000 | \n",
1153 | " 1.0000 | \n",
1154 | " 1.0000 | \n",
1155 | " 1.0000 | \n",
1156 | "
\n",
1157 | " \n",
1158 | " 6 | \n",
1159 | " 0.9999 | \n",
1160 | " 0.0000 | \n",
1161 | " 0.9996 | \n",
1162 | " 0.9999 | \n",
1163 | " 0.9999 | \n",
1164 | " 0.9998 | \n",
1165 | " 0.9998 | \n",
1166 | "
\n",
1167 | " \n",
1168 | " 7 | \n",
1169 | " 1.0000 | \n",
1170 | " 0.0000 | \n",
1171 | " 1.0000 | \n",
1172 | " 1.0000 | \n",
1173 | " 1.0000 | \n",
1174 | " 1.0000 | \n",
1175 | " 1.0000 | \n",
1176 | "
\n",
1177 | " \n",
1178 | " 8 | \n",
1179 | " 1.0000 | \n",
1180 | " 0.0000 | \n",
1181 | " 1.0000 | \n",
1182 | " 1.0000 | \n",
1183 | " 1.0000 | \n",
1184 | " 1.0000 | \n",
1185 | " 1.0000 | \n",
1186 | "
\n",
1187 | " \n",
1188 | " 9 | \n",
1189 | " 1.0000 | \n",
1190 | " 0.0000 | \n",
1191 | " 1.0000 | \n",
1192 | " 1.0000 | \n",
1193 | " 1.0000 | \n",
1194 | " 1.0000 | \n",
1195 | " 1.0000 | \n",
1196 | "
\n",
1197 | " \n",
1198 | " Mean | \n",
1199 | " 1.0000 | \n",
1200 | " 0.0000 | \n",
1201 | " 0.9999 | \n",
1202 | " 1.0000 | \n",
1203 | " 1.0000 | \n",
1204 | " 0.9999 | \n",
1205 | " 0.9999 | \n",
1206 | "
\n",
1207 | " \n",
1208 | " SD | \n",
1209 | " 0.0000 | \n",
1210 | " 0.0000 | \n",
1211 | " 0.0001 | \n",
1212 | " 0.0000 | \n",
1213 | " 0.0000 | \n",
1214 | " 0.0001 | \n",
1215 | " 0.0001 | \n",
1216 | "
\n",
1217 | "
"
1218 | ],
1219 | "text/plain": [
1220 | ""
1221 | ]
1222 | },
1223 | "metadata": {},
1224 | "output_type": "display_data"
1225 | }
1226 | ],
1227 | "source": [
1228 | "tuned_rf = tune_model(rf)"
1229 | ]
1230 | },
1231 | {
1232 | "cell_type": "code",
1233 | "execution_count": 16,
1234 | "metadata": {},
1235 | "outputs": [
1236 | {
1237 | "data": {
1238 | "text/html": [
1239 | " | Accuracy | AUC | Recall | Prec. | F1 | Kappa | MCC |
\n",
1243 | " \n",
1244 | " 0 | \n",
1245 | " 0.9999 | \n",
1246 | " 0.0000 | \n",
1247 | " 0.9997 | \n",
1248 | " 0.9999 | \n",
1249 | " 0.9999 | \n",
1250 | " 0.9998 | \n",
1251 | " 0.9998 | \n",
1252 | "
\n",
1253 | " \n",
1254 | " 1 | \n",
1255 | " 1.0000 | \n",
1256 | " 0.0000 | \n",
1257 | " 0.9999 | \n",
1258 | " 1.0000 | \n",
1259 | " 1.0000 | \n",
1260 | " 0.9999 | \n",
1261 | " 0.9999 | \n",
1262 | "
\n",
1263 | " \n",
1264 | " 2 | \n",
1265 | " 0.9999 | \n",
1266 | " 0.0000 | \n",
1267 | " 0.9996 | \n",
1268 | " 0.9999 | \n",
1269 | " 0.9999 | \n",
1270 | " 0.9997 | \n",
1271 | " 0.9997 | \n",
1272 | "
\n",
1273 | " \n",
1274 | " 3 | \n",
1275 | " 0.9998 | \n",
1276 | " 0.0000 | \n",
1277 | " 0.9999 | \n",
1278 | " 0.9998 | \n",
1279 | " 0.9998 | \n",
1280 | " 0.9996 | \n",
1281 | " 0.9996 | \n",
1282 | "
\n",
1283 | " \n",
1284 | " 4 | \n",
1285 | " 0.9996 | \n",
1286 | " 0.0000 | \n",
1287 | " 0.9997 | \n",
1288 | " 0.9996 | \n",
1289 | " 0.9996 | \n",
1290 | " 0.9992 | \n",
1291 | " 0.9992 | \n",
1292 | "
\n",
1293 | " \n",
1294 | " 5 | \n",
1295 | " 1.0000 | \n",
1296 | " 0.0000 | \n",
1297 | " 0.9997 | \n",
1298 | " 1.0000 | \n",
1299 | " 1.0000 | \n",
1300 | " 0.9999 | \n",
1301 | " 0.9999 | \n",
1302 | "
\n",
1303 | " \n",
1304 | " 6 | \n",
1305 | " 0.9999 | \n",
1306 | " 0.0000 | \n",
1307 | " 0.9996 | \n",
1308 | " 0.9999 | \n",
1309 | " 0.9999 | \n",
1310 | " 0.9997 | \n",
1311 | " 0.9997 | \n",
1312 | "
\n",
1313 | " \n",
1314 | " 7 | \n",
1315 | " 0.9999 | \n",
1316 | " 0.0000 | \n",
1317 | " 0.9999 | \n",
1318 | " 0.9999 | \n",
1319 | " 0.9999 | \n",
1320 | " 0.9998 | \n",
1321 | " 0.9998 | \n",
1322 | "
\n",
1323 | " \n",
1324 | " 8 | \n",
1325 | " 0.9999 | \n",
1326 | " 0.0000 | \n",
1327 | " 0.9994 | \n",
1328 | " 0.9999 | \n",
1329 | " 0.9999 | \n",
1330 | " 0.9998 | \n",
1331 | " 0.9998 | \n",
1332 | "
\n",
1333 | " \n",
1334 | " 9 | \n",
1335 | " 0.9997 | \n",
1336 | " 0.0000 | \n",
1337 | " 0.9990 | \n",
1338 | " 0.9997 | \n",
1339 | " 0.9997 | \n",
1340 | " 0.9994 | \n",
1341 | " 0.9994 | \n",
1342 | "
\n",
1343 | " \n",
1344 | " Mean | \n",
1345 | " 0.9999 | \n",
1346 | " 0.0000 | \n",
1347 | " 0.9996 | \n",
1348 | " 0.9999 | \n",
1349 | " 0.9999 | \n",
1350 | " 0.9997 | \n",
1351 | " 0.9997 | \n",
1352 | "
\n",
1353 | " \n",
1354 | " SD | \n",
1355 | " 0.0001 | \n",
1356 | " 0.0000 | \n",
1357 | " 0.0003 | \n",
1358 | " 0.0001 | \n",
1359 | " 0.0001 | \n",
1360 | " 0.0002 | \n",
1361 | " 0.0002 | \n",
1362 | "
\n",
1363 | "
"
1364 | ],
1365 | "text/plain": [
1366 | ""
1367 | ]
1368 | },
1369 | "metadata": {},
1370 | "output_type": "display_data"
1371 | }
1372 | ],
1373 | "source": [
1374 | "tuned_dt = tune_model(dt)"
1375 | ]
1376 | },
1377 | {
1378 | "cell_type": "code",
1379 | "execution_count": 17,
1380 | "metadata": {},
1381 | "outputs": [
1382 | {
1383 | "data": {
1384 | "text/html": [
1385 | "\n",
1386 | "\n",
1399 | "
\n",
1400 | " \n",
1401 | " \n",
1402 | " | \n",
1403 | " Parameters | \n",
1404 | "
\n",
1405 | " \n",
1406 | " \n",
1407 | " \n",
1408 | " bootstrap | \n",
1409 | " True | \n",
1410 | "
\n",
1411 | " \n",
1412 | " ccp_alpha | \n",
1413 | " 0 | \n",
1414 | "
\n",
1415 | " \n",
1416 | " class_weight | \n",
1417 | " None | \n",
1418 | "
\n",
1419 | " \n",
1420 | " criterion | \n",
1421 | " gini | \n",
1422 | "
\n",
1423 | " \n",
1424 | " max_depth | \n",
1425 | " None | \n",
1426 | "
\n",
1427 | " \n",
1428 | " max_features | \n",
1429 | " auto | \n",
1430 | "
\n",
1431 | " \n",
1432 | " max_leaf_nodes | \n",
1433 | " None | \n",
1434 | "
\n",
1435 | " \n",
1436 | " max_samples | \n",
1437 | " None | \n",
1438 | "
\n",
1439 | " \n",
1440 | " min_impurity_decrease | \n",
1441 | " 0 | \n",
1442 | "
\n",
1443 | " \n",
1444 | " min_impurity_split | \n",
1445 | " None | \n",
1446 | "
\n",
1447 | " \n",
1448 | " min_samples_leaf | \n",
1449 | " 1 | \n",
1450 | "
\n",
1451 | " \n",
1452 | " min_samples_split | \n",
1453 | " 2 | \n",
1454 | "
\n",
1455 | " \n",
1456 | " min_weight_fraction_leaf | \n",
1457 | " 0 | \n",
1458 | "
\n",
1459 | " \n",
1460 | " n_estimators | \n",
1461 | " 10 | \n",
1462 | "
\n",
1463 | " \n",
1464 | " n_jobs | \n",
1465 | " -1 | \n",
1466 | "
\n",
1467 | " \n",
1468 | " oob_score | \n",
1469 | " False | \n",
1470 | "
\n",
1471 | " \n",
1472 | " random_state | \n",
1473 | " 123 | \n",
1474 | "
\n",
1475 | " \n",
1476 | " verbose | \n",
1477 | " 0 | \n",
1478 | "
\n",
1479 | " \n",
1480 | " warm_start | \n",
1481 | " False | \n",
1482 | "
\n",
1483 | " \n",
1484 | "
\n",
1485 | "
"
1486 | ],
1487 | "text/plain": [
1488 | " Parameters\n",
1489 | "bootstrap True\n",
1490 | "ccp_alpha 0\n",
1491 | "class_weight None\n",
1492 | "criterion gini\n",
1493 | "max_depth None\n",
1494 | "max_features auto\n",
1495 | "max_leaf_nodes None\n",
1496 | "max_samples None\n",
1497 | "min_impurity_decrease 0\n",
1498 | "min_impurity_split None\n",
1499 | "min_samples_leaf 1\n",
1500 | "min_samples_split 2\n",
1501 | "min_weight_fraction_leaf 0\n",
1502 | "n_estimators 10\n",
1503 | "n_jobs -1\n",
1504 | "oob_score False\n",
1505 | "random_state 123\n",
1506 | "verbose 0\n",
1507 | "warm_start False"
1508 | ]
1509 | },
1510 | "metadata": {},
1511 | "output_type": "display_data"
1512 | }
1513 | ],
1514 | "source": [
1515 | "evaluate_model(rf)"
1516 | ]
1517 | },
1518 | {
1519 | "cell_type": "code",
1520 | "execution_count": 18,
1521 | "metadata": {},
1522 | "outputs": [
1523 | {
1524 | "data": {
1525 | "text/html": [
1526 | "\n",
1527 | "\n",
1540 | "
\n",
1541 | " \n",
1542 | " \n",
1543 | " | \n",
1544 | " Parameters | \n",
1545 | "
\n",
1546 | " \n",
1547 | " \n",
1548 | " \n",
1549 | " ccp_alpha | \n",
1550 | " 0 | \n",
1551 | "
\n",
1552 | " \n",
1553 | " class_weight | \n",
1554 | " None | \n",
1555 | "
\n",
1556 | " \n",
1557 | " criterion | \n",
1558 | " gini | \n",
1559 | "
\n",
1560 | " \n",
1561 | " max_depth | \n",
1562 | " None | \n",
1563 | "
\n",
1564 | " \n",
1565 | " max_features | \n",
1566 | " None | \n",
1567 | "
\n",
1568 | " \n",
1569 | " max_leaf_nodes | \n",
1570 | " None | \n",
1571 | "
\n",
1572 | " \n",
1573 | " min_impurity_decrease | \n",
1574 | " 0 | \n",
1575 | "
\n",
1576 | " \n",
1577 | " min_impurity_split | \n",
1578 | " None | \n",
1579 | "
\n",
1580 | " \n",
1581 | " min_samples_leaf | \n",
1582 | " 1 | \n",
1583 | "
\n",
1584 | " \n",
1585 | " min_samples_split | \n",
1586 | " 2 | \n",
1587 | "
\n",
1588 | " \n",
1589 | " min_weight_fraction_leaf | \n",
1590 | " 0 | \n",
1591 | "
\n",
1592 | " \n",
1593 | " presort | \n",
1594 | " deprecated | \n",
1595 | "
\n",
1596 | " \n",
1597 | " random_state | \n",
1598 | " 123 | \n",
1599 | "
\n",
1600 | " \n",
1601 | " splitter | \n",
1602 | " best | \n",
1603 | "
\n",
1604 | " \n",
1605 | "
\n",
1606 | "
"
1607 | ],
1608 | "text/plain": [
1609 | " Parameters\n",
1610 | "ccp_alpha 0\n",
1611 | "class_weight None\n",
1612 | "criterion gini\n",
1613 | "max_depth None\n",
1614 | "max_features None\n",
1615 | "max_leaf_nodes None\n",
1616 | "min_impurity_decrease 0\n",
1617 | "min_impurity_split None\n",
1618 | "min_samples_leaf 1\n",
1619 | "min_samples_split 2\n",
1620 | "min_weight_fraction_leaf 0\n",
1621 | "presort deprecated\n",
1622 | "random_state 123\n",
1623 | "splitter best"
1624 | ]
1625 | },
1626 | "metadata": {},
1627 | "output_type": "display_data"
1628 | }
1629 | ],
1630 | "source": [
1631 | "evaluate_model(dt)"
1632 | ]
1633 | },
1634 | {
1635 | "cell_type": "code",
1636 | "execution_count": null,
1637 | "metadata": {},
1638 | "outputs": [],
1639 | "source": []
1640 | }
1641 | ],
1642 | "metadata": {
1643 | "kernelspec": {
1644 | "display_name": "Python 3",
1645 | "language": "python",
1646 | "name": "python3"
1647 | },
1648 | "language_info": {
1649 | "codemirror_mode": {
1650 | "name": "ipython",
1651 | "version": 3
1652 | },
1653 | "file_extension": ".py",
1654 | "mimetype": "text/x-python",
1655 | "name": "python",
1656 | "nbconvert_exporter": "python",
1657 | "pygments_lexer": "ipython3",
1658 | "version": "3.8.5"
1659 | }
1660 | },
1661 | "nbformat": 4,
1662 | "nbformat_minor": 4
1663 | }
1664 |
--------------------------------------------------------------------------------
/s4-modeling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Importing necssary modules\n",
10 | "from scripts.utils import load_device_data_v2\n",
11 | "from scripts.models import rf_classifier"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# Data folder path and Extention of the data files\n",
21 | "base_directory = '../rawdata'\n",
22 | "file_extension = \"*.csv\""
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "### Loading Device Data into Dataframe\n",
30 | "###### Door Bells"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "danmini_doorbell_df = load_device_data_v2(base_directory, file_extension, 'Danmini_Doorbell')\n",
40 | "ennio_doorbell_df = load_device_data_v2(base_directory, file_extension, 'Ennio_Doorbell')"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "###### Thermostat"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "ecobee_thermostat_df = load_device_data_v2(base_directory, file_extension, 'Ecobee_Thermostat')"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "###### Web cam"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "samsung_cam_df = load_device_data_v2(base_directory, file_extension, 'Samsung_SNH_1011_N_Webcam')"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "###### Baby monitor"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "baby_monitor_df = load_device_data_v2(base_directory, file_extension, 'Philips_B120N10_Baby_Monitor')"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "###### Security Cam"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "provision_cam1_df = load_device_data_v2(base_directory, file_extension, 'Provision_PT_737E_Security_Camera')\n",
105 | "provision_cam2_df = load_device_data_v2(base_directory, file_extension, 'Provision_PT_838_Security_Camera')\n",
106 | "simplehome_cam1_df = load_device_data_v2(base_directory, file_extension, 'SimpleHome_XCS7_1002_WHT_Security_Camera')\n",
107 | "simplehome_cam2_df = load_device_data_v2(base_directory, file_extension, 'SimpleHome_XCS7_1003_WHT_Security_Camera')"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "##### Model Training"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "dataframe = {\"Danmini_Doorbell\": danmini_doorbell_df, \n",
124 | " \"Ecobee_Thermostat\": ecobee_thermostat_df,\n",
125 | " \"Ennio_Doorbell\": ennio_doorbell_df,\n",
126 | " \"Philips_B120N10_Baby_Monitor\": baby_monitor_df,\n",
127 | " \"Provision_PT_737E_Security_Camera\": provision_cam1_df,\n",
128 | " \"Provision_PT_838_Security_Camera\": provision_cam2_df,\n",
129 | " \"Samsung_SNH_1011_N_Webcam\": samsung_cam_df,\n",
130 | " \"SimpleHome_XCS7_1002_WHT_Security_Camera\": simplehome_cam1_df,\n",
131 | " \"SimpleHome_XCS7_1003_WHT_Security_Camera\": simplehome_cam2_df\n",
132 | " }"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "for k in dataframe:\n",
142 | " print(\"----------------------xxxxxxx----------------------\")\n",
143 | " print(k)\n",
144 | " print(\"----------------------xxxxxxx----------------------\")\n",
145 | " results = rf_classifier(dataframe[k], k)\n",
146 | " print(results)\n",
147 | " print(\"---------------------xxxxxxx-----------------------\")"
148 | ]
149 | }
150 | ],
151 | "metadata": {
152 | "kernelspec": {
153 | "display_name": "Python 3",
154 | "language": "python",
155 | "name": "python3"
156 | },
157 | "language_info": {
158 | "codemirror_mode": {
159 | "name": "ipython",
160 | "version": 3
161 | },
162 | "file_extension": ".py",
163 | "mimetype": "text/x-python",
164 | "name": "python",
165 | "nbconvert_exporter": "python",
166 | "pygments_lexer": "ipython3",
167 | "version": "3.7.8"
168 | }
169 | },
170 | "nbformat": 4,
171 | "nbformat_minor": 4
172 | }
173 |
--------------------------------------------------------------------------------
/scripts/models.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from sklearn.model_selection import train_test_split
4 | from sklearn.ensemble import RandomForestClassifier
5 | from sklearn.preprocessing import StandardScaler
6 | from sklearn.metrics import (
7 | f1_score, classification_report,
8 | confusion_matrix, roc_curve,
9 | roc_auc_score, accuracy_score,
10 | log_loss)
11 | from sklearn import __version__ as sklearn_version
12 | from sklearn.neighbors import KNeighborsClassifier
13 | from imblearn.under_sampling import NearMiss
14 | from datetime import datetime
15 | import os
16 | import pickle
17 |
18 |
19 |
20 | def rf_classifier(data, device_name, scaling=False):
21 |
22 | # Split some data for validation
23 | validation_data = data.sample(frac=0.30)
24 |
25 | # Removing Validation data from dataframe
26 | data_df = data.drop(validation_data.index)
27 |
28 | # New Dict for storing Results
29 | results = {}
30 |
31 | # X & Y Variables from dataframe
32 | X = data_df.drop(['label', 'device'], axis=1)
33 | y = data_df['label']
34 |
35 | results['original_shape'] = [X.shape, y.shape]
36 |
37 | # Check data needs to be scaled or not
38 | if scaling == False:
39 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=47)
40 | model_name = f'{device_name}_without_scaling_unbalanced_model.pkl'
41 | else:
42 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=47)
43 | scaler = StandardScaler()
44 | scaler.fit(X_train)
45 | X_train = scaler.transform(X_train)
46 | X_test = scaler.transform(X_test)
47 | model_name = f'{device_name}_with_scaling_unbalanced_model.pkl'
48 |
49 | # Define Classifier
50 | clf = RandomForestClassifier()
51 |
52 | # Fit the model classifier into training data
53 | model_res = clf.fit(X_train, y_train)
54 |
55 | # Predict with Test Data
56 | y_pred = model_res.predict(X_test)
57 | y_pred_prob = model_res.predict_proba(X_test)
58 | lr_probs = y_pred_prob[:,1]
59 |
60 | # Accuracy Score
61 | ac = accuracy_score(y_test, y_pred)
62 |
63 | # Calculate F1 Score
64 | f1 = f1_score(y_test, y_pred, average='weighted')
65 |
66 | # Calculate Confusion Matrix, classification Report
67 | cm = confusion_matrix(y_test, y_pred)
68 | cr = classification_report(y_test, y_pred)
69 |
70 |
71 | # Feature Importance
72 | importances = pd.DataFrame({'feature':X.columns,'importance':np.round(clf.feature_importances_,3)})
73 | importances = importances.sort_values('importance',ascending=False).set_index('feature')
74 |
75 |
76 | results['feature_importance'] = [importances.head(20)]
77 | results['Accuracy Test Data'] = ac
78 | results['F1 Score Test Data'] = f1
79 |
80 | # Saving Model
81 |
82 | best_model = clf
83 | best_model.version = 1.0
84 | best_model.pandas_version = pd.__version__
85 | best_model.numpy_version = np.__version__
86 | best_model.sklearn_version = sklearn_version
87 | best_model.build_datetime = datetime.now()
88 |
89 | modelpath = f'models/{device_name}'
90 | if not os.path.exists(modelpath):
91 | os.mkdir(modelpath)
92 | iotmodel_path = os.path.join(modelpath, model_name)
93 | if not os.path.exists(iotmodel_path):
94 | with open(iotmodel_path, 'wb') as f:
95 | pickle.dump(best_model, f)
96 |
97 | f = open(f'models/{device_name}/report.txt', 'w')
98 | f.write(f'''Classification Report on Test Set
99 | \n \n {cr}\n \n
100 | Confusion Matrix on Test Set
101 | \n \n {cm} \n \n''')
102 | f.close()
103 |
104 | validation_data.to_csv(f'{modelpath}/{device_name}_validation_data.csv')
105 | return f'Model trained and saved successfully \n {results}'
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | from glob import glob
4 |
5 |
6 |
7 | def load_device_data_v2(file_path, file_ext, device, labels=3, size=1):
8 | """
9 | This function helps to crate a data frame contain only the device specified.
10 | The directory should be where the unzipped data files are stored. Assumes the file structurce is
11 | device name(folder)
12 | mirai_attacks(folder)
13 | gafgyt_attacks(folder)
14 | benign_traffic.csv
15 | Parameters
16 | ----------
17 | file_path : str
18 | The directory in which the data files are stored.
19 | file_ext : str
20 | Extension of the file.
21 | device: str
22 | Device name (Folder Name).
23 | label : integer
24 | Benign, mirai, gafgyt (bashlite) or 11 Classes.
25 | size : float
26 | All the data or only sample data.
27 |
28 | Returns
29 | -------
30 | device_data : pandas dataframe containg the data
31 | """
32 | try:
33 | # Generate Empty list to hold data
34 | dfs = []
35 |
36 | # Loop Through all the file associated with the data
37 | for path, subdir, files in os.walk(file_path):
38 | for file in glob(os.path.join(path, file_ext)):
39 | # Check the device name and retrive only those files.
40 | if file.split("\\")[1] == device:
41 | # Reading csv file.
42 | data = pd.read_csv(file)
43 | label = file.split('\\')
44 | data['device'] = label[1]
45 | # Check whether only 3 class or 11 class
46 | if labels == 3:
47 | data['label'] = label[2].split('_')[0]
48 |
49 | else:
50 | # Benign data file usually outside of the folder
51 | if len(label) == 3:
52 | data['label'] = label[2].split('_')[0]
53 | else:
54 | data['label'] = label[2].split('_')[0] + '_' \
55 | + label[3].split('.')[0]
56 |
57 | # Check load all the data or only sample of data
58 | if size != 1:
59 | sample_data = data.sample(frac=size)
60 | dfs.append(sample_data)
61 | else:
62 | dfs.append(data)
63 |
64 | device_data = pd.concat(dfs, ignore_index = True)
65 |
66 | return device_data
67 | except Exception as e:
68 | return str(e)
69 |
70 |
71 |
72 | def load_all_data_v2(file_path, file_ext, device, labels=3, size=1):
73 | """
74 | This function helps to crate a data frame contain all the device data.
75 | The directory should be where the unzipped data files are stored. Assumes the file structurce is
76 | device name(folder)
77 | mirai_attacks(folder)
78 | gafgyt_attacks(folder)
79 | benign_traffic.csv
80 | Parameters
81 | ----------
82 | file_path : str
83 | The directory in which the data files are stored.
84 | file_ext : str
85 | Extension of the file.
86 | device: str
87 | Device name (Folder Name).
88 | label : integer
89 | Benign, mirai, gafgyt (bashlite) or 11 Classes.
90 | size : float
91 | All the data or only sample data.
92 |
93 | Returns
94 | -------
95 | device_data : pandas dataframe containg the data
96 | """
97 | try:
98 | # Generate Empty list to hold data
99 | dfs = []
100 |
101 | # Loop Through all the file associated with the data
102 | for path, subdir, files in os.walk(file_path):
103 | for file in glob(os.path.join(path, file_ext)):
104 | # Reading csv file.
105 | data = pd.read_csv(file)
106 | label = file.split('\\')
107 | data['device'] = label[1]
108 | # Check whether only 3 class or 11 class
109 | if labels == 3:
110 | data['label'] = label[2].split('_')[0]
111 |
112 | else:
113 | # Benign data file usually outside of the folder
114 | if len(label) == 3:
115 | data['label'] = label[2].split('_')[0]
116 | else:
117 | data['label'] = label[2].split('_')[0] + '_' \
118 | + label[3].split('.')[0]
119 |
120 | # Check load all the data or only sample of data
121 | if size != 1:
122 | sample_data = data.sample(frac=size)
123 | dfs.append(sample_data)
124 | else:
125 | dfs.append(dat)
126 |
127 | device_data = pd.concat(dfs, ignore_index = True)
128 |
129 | return device_data
130 | except Exception as e:
131 | return str(e)
132 |
133 |
134 | def load_data_labels(PATH, EXT):
135 | """
136 | Creates a data frame consisting of all the .csv-files in a given directory. The directory should
137 | be where the unzipped data files are stored. Assumes the file structurce is
138 | device name
139 | mirai_attacks(folder)
140 | gafgyt_attacks(folder)
141 | benign_traffic.csv
142 | Parameters
143 | ----------
144 | PATH : str
145 | The directory in which the data files are stored.
146 | EXT : str
147 | Extension of the file
148 |
149 | Returns
150 | -------
151 | benign_data : pandas data frame
152 | consisting of all the bengin data.
153 | mirai_data : pandas data frame
154 | consisting of all the mirai data.
155 | gafgyt_data : pandas data frame
156 | consisting of all the gafgyt data.
157 | """
158 | try:
159 | benign_dfs = []
160 | mirai_dfs = []
161 | gafgyt_dfs = []
162 | for path, subdir, files in os.walk(PATH):
163 | for file in glob(os.path.join(path, EXT)):
164 | if 'benign_traffic' in file:
165 | data = pd.read_csv(file)
166 | data['label'] = 'Benign'
167 | data['device'] = file.split('\\')[1]
168 | benign_dfs.append(data)
169 | if 'mirai_attacks' in file:
170 | data = pd.read_csv(file)
171 | data['label'] = 'Mirai_'+file.split('\\')[3].split('.')[0]
172 | data['device'] = file.split('\\')[1]
173 | mirai_dfs.append(data)
174 | if 'gafgyt_attacks' in file:
175 | data = pd.read_csv(file)
176 | data['label'] = 'Gafgyt_'+file.split('\\')[3].split('.')[0]
177 | data['device'] = file.split('\\')[1]
178 | gafgyt_dfs.append(data)
179 |
180 | benign_data = pd.concat(benign_dfs, ignore_index=True)
181 | mirai_data = pd.concat(mirai_dfs, ignore_index=True)
182 | gafgyt_data = pd.concat(gafgyt_dfs, ignore_index=True)
183 |
184 | return benign_data, mirai_data, gafgyt_data
185 | except Exception as e:
186 | return str(e)
187 |
188 |
189 | def load_device_data(PATH, EXT, device):
190 | """
191 | Creates a data frame consisting of individual device data.
192 | The directory should be where the unzipped data files are stored.
193 | Assumes the file structurce is
194 | device name
195 | mirai_attacks(folder)
196 | gafgyt_attacks(folder)
197 | benign_traffic.csv
198 | Parameters
199 | ----------
200 | PATH : str
201 | The directory in which the data files are stored.
202 | EXT : str
203 | Extension of the file
204 | device : str
205 | Device Name
206 |
207 | Returns
208 | -------
209 | device_data : pandas data frame consisting of specific device data with 3 classes.
210 | """
211 | try:
212 | dfs = []
213 | for path, subdir, files in os.walk(PATH):
214 | for file in glob(os.path.join(path, EXT)):
215 | if file.split('\\')[1] == device:
216 | data = pd.read_csv(file)
217 | data['label'] = file.split('\\')[2].split('_')[0]
218 | data['device'] = file.split('\\')[1]
219 | dfs.append(data)
220 |
221 | device_data = pd.concat(dfs, ignore_index=True)
222 |
223 | return device_data
224 |
225 | except Exception as e:
226 | return str(e)
227 |
228 |
229 |
230 | def load_device_data_multi_label(PATH, EXT, device):
231 | """
232 | Creates a data frame consisting of individual device data.
233 | The directory should be where the unzipped data files are stored.
234 | Assumes the file structurce is
235 | device name
236 | mirai_attacks(folder)
237 | gafgyt_attacks(folder)
238 | benign_traffic.csv
239 | Parameters
240 | ----------
241 | PATH : str
242 | The directory in which the data files are stored.
243 | EXT : str
244 | Extension of the file
245 | device : str
246 | Device Name
247 |
248 | Returns
249 | -------
250 | device_data : pandas data frame consisting of specific device
251 | data with 11 different classes.
252 | """
253 | try:
254 | dfs = []
255 | for path, subdir, files in os.walk(PATH):
256 | for file in glob(os.path.join(path, EXT)):
257 | if file.split('\\')[1] == device:
258 | data = pd.read_csv(file)
259 | label = file.split('\\')
260 | if len(label) == 3:
261 | data['label'] = label[2].split('_')[0]
262 | else:
263 | data['label'] = label[2].split('_')[0] + '_' + label[3].split('.')[0]
264 | data['device'] = file.split('\\')[1]
265 | dfs.append(data)
266 |
267 | device_data = pd.concat(dfs, ignore_index=True)
268 |
269 | return device_data
270 |
271 | except Exception as e:
272 | return str(e)
273 |
274 |
275 | def load_all_data(PATH, EXT):
276 | """
277 | Creates a data frame consisting of all the device data.
278 | The directory should be where the unzipped data files are stored.
279 | Assumes the file structurce is
280 | device name
281 | mirai_attacks(folder)
282 | gafgyt_attacks(folder)
283 | benign_traffic.csv
284 | Parameters
285 | ----------
286 | PATH : str
287 | The directory in which the data files are stored.
288 | EXT : str
289 | Extension of the file
290 |
291 | Returns
292 | -------
293 | device_data : pandas data frame contain all the device data with 11 classes
294 | which is device specified.
295 | """
296 | try:
297 | dfs = []
298 | for path, subdir, files in os.walk(PATH):
299 | for file in glob(os.path.join(path, EXT)):
300 | data = pd.read_csv(file)
301 | label = file.split('\\')
302 | if len(label) == 3:
303 | data['label'] = label[1] + '_' + label[2].split('_')[0]
304 | else:
305 | data['label'] = label[1] + '_' + label[2].split('_')[0] + '_' + label[3].split('.')[0]
306 |
307 | data['device'] = file.split('\\')[1]
308 | dfs.append(data)
309 | device_data = pd.concat(dfs, ignore_index = True)
310 |
311 | return device_data
312 | except Exception as e:
313 | return str(e)
314 |
315 |
316 | def load_all_data_class(PATH, EXT):
317 | """
318 | Creates a data frame consisting of all the device data.
319 | The directory should be where the unzipped data files are stored.
320 | Assumes the file structurce is
321 | device name
322 | mirai_attacks(folder)
323 | gafgyt_attacks(folder)
324 | benign_traffic.csv
325 | Parameters
326 | ----------
327 | PATH : str
328 | The directory in which the data files are stored.
329 | EXT : str
330 | Extension of the file
331 |
332 | Returns
333 | -------
334 | device_data : pandas data frame contain all the device data with 3 classes
335 | which is Not device specified.
336 | """
337 | try:
338 | dfs = []
339 | for path, subdir, files in os.walk(PATH):
340 | for file in glob(os.path.join(path, EXT)):
341 | data = pd.read_csv(file)
342 | label = file.split("\\")
343 | if len(label) == 3:
344 | data['label'] = label[2].split('_')[0]
345 | else:
346 | data['label'] = label[2].split('_')[0]
347 | data['device'] = file.split('\\')[1]
348 | sampled_data = data.sample(frac=0.15)
349 | dfs.append(sampled_data)
350 | device_data = pd.concat(dfs, ignore_index = True)
351 | return device_data
352 |
353 | except Exception as e:
354 | return(str(e))
355 |
356 |
357 | def load_all_data_multi_class(PATH, EXT):
358 | """
359 | Creates a data frame consisting of all the device data.
360 | The directory should be where the unzipped data files are stored.
361 | Assumes the file structurce is
362 | device name
363 | mirai_attacks(folder)
364 | gafgyt_attacks(folder)
365 | benign_traffic.csv
366 | Parameters
367 | ----------
368 | PATH : str
369 | The directory in which the data files are stored.
370 | EXT : str
371 | Extension of the file
372 |
373 | Returns
374 | -------
375 | device_data : pandas data frame contain all the device data with 11 classes
376 | which is Not device specified.
377 | """
378 | try:
379 | dfs = []
380 | for path, subdir, files in os.walk(PATH):
381 | for file in glob(os.path.join(path, EXT)):
382 | data = pd.read_csv(file)
383 | label = file.split('\\')
384 | if len(label) == 3:
385 | data['label'] = label[2].split('_')[0]
386 | else:
387 | data['label'] = label[2].split('_')[0] + '_' + label[3].split('.')[0]
388 |
389 | data['device'] = file.split('\\')[1]
390 | sampled_data = data.sample(frac=0.15)
391 | dfs.append(sampled_data)
392 | device_data = pd.concat(dfs, ignore_index = True)
393 |
394 | return device_data
395 | except Exception as e:
396 | return str(e)
--------------------------------------------------------------------------------