├── 0.1RF_CWRU_EvaluationFramework.ipynb
├── ANN_CWRU_EvaluationFramework.ipynb
├── CWRU_EvaluationFramework_DE.ipynb
├── KFold_RF_CWRU_EvaluationFramework.ipynb
├── KFold_SVM_CWRU_EvaluationFramework.ipynb
├── KNN_CWRU_EvaluationFramework.ipynb
├── README.md
├── RF_CWRU_EvaluationFramework.ipynb
├── SVM_CWRU_EvaluationFramework.ipynb
├── cwru_evaluation_by_severity.ipynb
└── cwru_segmentation.ipynb
/CWRU_EvaluationFramework_DE.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "CWRU-EvaluationFramework-bkp.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true,
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "bSSOMru17Z6c",
33 | "colab_type": "code",
34 | "colab": {}
35 | },
36 | "source": [
37 | "debug = False"
38 | ],
39 | "execution_count": 0,
40 | "outputs": []
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {
45 | "id": "pMQoq6dvStey",
46 | "colab_type": "text"
47 | },
48 | "source": [
49 | "# CWRU files.\n",
50 | "\n",
51 | "Associate each Matlab file name to a bearing condition in a Python dictionary.\n",
52 | "The dictionary keys identify the conditions.\n",
53 | "\n",
54 | "There are only four normal conditions, with loads of 0, 1, 2 and 3 hp.\n",
55 | "All conditions end with an underscore character followed by an algarism representing the load applied during the acquisitions.\n",
56 | "The remaining conditions follow the pattern:\n",
57 | "\n",
58 | "\n",
59 | "* First two characters represent the bearing location, i.e. drive end (DE) and fan end (FE).\n",
60 | "* The following two characters represent the failure location in the bearing, i.e. ball (BA), Inner Race (IR) and Outer Race (OR).\n",
61 | "* The next three algarisms indicate the severity of the failure, where 007 stands for 0.007 inches and 0021 for 0.021 inches.\n",
62 | "* For Outer Race failures, the character @ is followed by a number that indicates different load zones. "
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "metadata": {
68 | "id": "K6mp2QrP1lmR",
69 | "colab_type": "code",
70 | "colab": {}
71 | },
72 | "source": [
73 | "def cwru_12khz():\n",
74 | " '''\n",
75 | " Retuns a dictionary with the names of all Matlab files read in 12kHz located in\n",
76 | " http://csegroups.case.edu/sites/default/files/bearingdatacenter/files/Datafiles/.\n",
77 | " The dictionary keys represent the bearing condition.\n",
78 | " '''\n",
79 | " matlab_files_name = {}\n",
80 | " # Normal\n",
81 | " matlab_files_name[\"Normal_0\"] = \"97.mat\"\n",
82 | " matlab_files_name[\"Normal_1\"] = \"98.mat\"\n",
83 | " matlab_files_name[\"Normal_2\"] = \"99.mat\"\n",
84 | " matlab_files_name[\"Normal_3\"] = \"100.mat\"\n",
85 | " # DE Inner Race 0.007 inches\n",
86 | " matlab_files_name[\"DEIR.007_0\"] = \"105.mat\"\n",
87 | " matlab_files_name[\"DEIR.007_1\"] = \"106.mat\"\n",
88 | " matlab_files_name[\"DEIR.007_2\"] = \"107.mat\"\n",
89 | " matlab_files_name[\"DEIR.007_3\"] = \"108.mat\"\n",
90 | " # DE Ball 0.007 inches\n",
91 | " matlab_files_name[\"DEB.007_0\"] = \"118.mat\"\n",
92 | " matlab_files_name[\"DEB.007_1\"] = \"119.mat\"\n",
93 | " matlab_files_name[\"DEB.007_2\"] = \"120.mat\"\n",
94 | " matlab_files_name[\"DEB.007_3\"] = \"121.mat\"\n",
95 | " # DE Outer race 0.007 inches centered @6:00\n",
96 | " matlab_files_name[\"DEOR.007@6_0\"] = \"130.mat\"\n",
97 | " matlab_files_name[\"DEOR.007@6_1\"] = \"131.mat\"\n",
98 | " matlab_files_name[\"DEOR.007@6_2\"] = \"132.mat\"\n",
99 | " matlab_files_name[\"DEOR.007@6_3\"] = \"133.mat\"\n",
100 | " # DE Outer race 0.007 inches centered @3:00\n",
101 | " matlab_files_name[\"DEOR.007@3_0\"] = \"144.mat\"\n",
102 | " matlab_files_name[\"DEOR.007@3_1\"] = \"145.mat\"\n",
103 | " matlab_files_name[\"DEOR.007@3_2\"] = \"146.mat\"\n",
104 | " matlab_files_name[\"DEOR.007@3_3\"] = \"147.mat\"\n",
105 | " # DE Outer race 0.007 inches centered @12:00\n",
106 | " matlab_files_name[\"DEOR.007@12_0\"] = \"156.mat\"\n",
107 | " matlab_files_name[\"DEOR.007@12_1\"] = \"158.mat\"\n",
108 | " matlab_files_name[\"DEOR.007@12_2\"] = \"159.mat\"\n",
109 | " matlab_files_name[\"DEOR.007@12_3\"] = \"160.mat\"\n",
110 | " # DE Inner Race 0.014 inches\n",
111 | " matlab_files_name[\"DEIR.014_0\"] = \"169.mat\"\n",
112 | " matlab_files_name[\"DEIR.014_1\"] = \"170.mat\"\n",
113 | " matlab_files_name[\"DEIR.014_2\"] = \"171.mat\"\n",
114 | " matlab_files_name[\"DEIR.014_3\"] = \"172.mat\"\n",
115 | " # DE Ball 0.014 inches\n",
116 | " matlab_files_name[\"DEB.014_0\"] = \"185.mat\"\n",
117 | " matlab_files_name[\"DEB.014_1\"] = \"186.mat\"\n",
118 | " matlab_files_name[\"DEB.014_2\"] = \"187.mat\"\n",
119 | " matlab_files_name[\"DEB.014_3\"] = \"188.mat\"\n",
120 | " # DE Outer race 0.014 inches centered @6:00\n",
121 | " matlab_files_name[\"DEOR.014@6_0\"] = \"197.mat\"\n",
122 | " matlab_files_name[\"DEOR.014@6_1\"] = \"198.mat\"\n",
123 | " matlab_files_name[\"DEOR.014@6_2\"] = \"199.mat\"\n",
124 | " matlab_files_name[\"DEOR.014@6_3\"] = \"200.mat\"\n",
125 | " # DE Ball 0.021 inches\n",
126 | " matlab_files_name[\"DEB.021_0\"] = \"222.mat\"\n",
127 | " matlab_files_name[\"DEB.021_1\"] = \"223.mat\"\n",
128 | " matlab_files_name[\"DEB.021_2\"] = \"224.mat\"\n",
129 | " matlab_files_name[\"DEB.021_3\"] = \"225.mat\"\n",
130 | " # FE Inner Race 0.021 inches\n",
131 | " matlab_files_name[\"FEIR.021_0\"] = \"270.mat\"\n",
132 | " matlab_files_name[\"FEIR.021_1\"] = \"271.mat\"\n",
133 | " matlab_files_name[\"FEIR.021_2\"] = \"272.mat\"\n",
134 | " matlab_files_name[\"FEIR.021_3\"] = \"273.mat\"\n",
135 | " # FE Inner Race 0.014 inches\n",
136 | " matlab_files_name[\"FEIR.014_0\"] = \"274.mat\"\n",
137 | " matlab_files_name[\"FEIR.014_1\"] = \"275.mat\"\n",
138 | " matlab_files_name[\"FEIR.014_2\"] = \"276.mat\"\n",
139 | " matlab_files_name[\"FEIR.014_3\"] = \"277.mat\"\n",
140 | " # FE Ball 0.007 inches\n",
141 | " matlab_files_name[\"FEB.007_0\"] = \"282.mat\"\n",
142 | " matlab_files_name[\"FEB.007_1\"] = \"283.mat\"\n",
143 | " matlab_files_name[\"FEB.007_2\"] = \"284.mat\"\n",
144 | " matlab_files_name[\"FEB.007_3\"] = \"285.mat\"\n",
145 | " # DE Inner Race 0.021 inches\n",
146 | " matlab_files_name[\"DEIR.021_0\"] = \"209.mat\"\n",
147 | " matlab_files_name[\"DEIR.021_1\"] = \"210.mat\"\n",
148 | " matlab_files_name[\"DEIR.021_2\"] = \"211.mat\"\n",
149 | " matlab_files_name[\"DEIR.021_3\"] = \"212.mat\"\n",
150 | " # DE Outer race 0.021 inches centered @6:00\n",
151 | " matlab_files_name[\"DEOR.021@6_0\"] = \"234.mat\"\n",
152 | " matlab_files_name[\"DEOR.021@6_1\"] = \"235.mat\"\n",
153 | " matlab_files_name[\"DEOR.021@6_2\"] = \"236.mat\"\n",
154 | " matlab_files_name[\"DEOR.021@6_3\"] = \"237.mat\"\n",
155 | " # DE Outer race 0.021 inches centered @3:00\n",
156 | " matlab_files_name[\"DEOR.021@3_0\"] = \"246.mat\"\n",
157 | " matlab_files_name[\"DEOR.021@3_1\"] = \"247.mat\"\n",
158 | " matlab_files_name[\"DEOR.021@3_2\"] = \"248.mat\"\n",
159 | " matlab_files_name[\"DEOR.021@3_3\"] = \"249.mat\"\n",
160 | " # DE Outer race 0.021 inches centered @12:00\n",
161 | " matlab_files_name[\"DEOR.021@12_0\"] = \"258.mat\"\n",
162 | " matlab_files_name[\"DEOR.021@12_1\"] = \"259.mat\"\n",
163 | " matlab_files_name[\"DEOR.021@12_2\"] = \"260.mat\"\n",
164 | " matlab_files_name[\"DEOR.021@12_3\"] = \"261.mat\"\n",
165 | " # FE Inner Race 0.007 inches\n",
166 | " matlab_files_name[\"FEIR.007_0\"] = \"278.mat\"\n",
167 | " matlab_files_name[\"FEIR.007_1\"] = \"279.mat\"\n",
168 | " matlab_files_name[\"FEIR.007_2\"] = \"280.mat\"\n",
169 | " matlab_files_name[\"FEIR.007_3\"] = \"281.mat\"\n",
170 | " # FE Ball 0.014 inches\n",
171 | " matlab_files_name[\"FEB.014_0\"] = \"286.mat\"\n",
172 | " matlab_files_name[\"FEB.014_1\"] = \"287.mat\"\n",
173 | " matlab_files_name[\"FEB.014_2\"] = \"288.mat\"\n",
174 | " matlab_files_name[\"FEB.014_3\"] = \"289.mat\"\n",
175 | " # FE Ball 0.021 inches\n",
176 | " matlab_files_name[\"FEB.021_0\"] = \"290.mat\"\n",
177 | " matlab_files_name[\"FEB.021_1\"] = \"291.mat\"\n",
178 | " matlab_files_name[\"FEB.021_2\"] = \"292.mat\"\n",
179 | " matlab_files_name[\"FEB.021_3\"] = \"293.mat\"\n",
180 | " # FE Outer race 0.007 inches centered @6:00\n",
181 | " matlab_files_name[\"FEOR.007@6_0\"] = \"294.mat\"\n",
182 | " matlab_files_name[\"FEOR.007@6_1\"] = \"295.mat\"\n",
183 | " matlab_files_name[\"FEOR.007@6_2\"] = \"296.mat\"\n",
184 | " matlab_files_name[\"FEOR.007@6_3\"] = \"297.mat\"\n",
185 | " # FE Outer race 0.007 inches centered @3:00\n",
186 | " matlab_files_name[\"FEOR.007@3_0\"] = \"298.mat\"\n",
187 | " matlab_files_name[\"FEOR.007@3_1\"] = \"299.mat\"\n",
188 | " matlab_files_name[\"FEOR.007@3_2\"] = \"300.mat\"\n",
189 | " matlab_files_name[\"FEOR.007@3_3\"] = \"301.mat\"\n",
190 | " # FE Outer race 0.007 inches centered @12:00\n",
191 | " matlab_files_name[\"FEOR.007@12_0\"] = \"302.mat\"\n",
192 | " matlab_files_name[\"FEOR.007@12_1\"] = \"305.mat\"\n",
193 | " matlab_files_name[\"FEOR.007@12_2\"] = \"306.mat\"\n",
194 | " matlab_files_name[\"FEOR.007@12_3\"] = \"307.mat\"\n",
195 | " # FE Outer race 0.014 inches centered @3:00\n",
196 | " matlab_files_name[\"FEOR.014@3_0\"] = \"310.mat\"\n",
197 | " matlab_files_name[\"FEOR.014@3_1\"] = \"309.mat\"\n",
198 | " matlab_files_name[\"FEOR.014@3_2\"] = \"311.mat\"\n",
199 | " matlab_files_name[\"FEOR.014@3_3\"] = \"312.mat\"\n",
200 | " # FE Outer race 0.014 inches centered @6:00\n",
201 | " matlab_files_name[\"FEOR.014@6_0\"] = \"313.mat\"\n",
202 | " # FE Outer race 0.021 inches centered @6:00\n",
203 | " matlab_files_name[\"FEOR.021@6_0\"] = \"315.mat\"\n",
204 | " # FE Outer race 0.021 inches centered @3:00\n",
205 | " matlab_files_name[\"FEOR.021@3_1\"] = \"316.mat\"\n",
206 | " matlab_files_name[\"FEOR.021@3_2\"] = \"317.mat\"\n",
207 | " matlab_files_name[\"FEOR.021@3_3\"] = \"318.mat\"\n",
208 | " # DE Inner Race 0.028 inches\n",
209 | " matlab_files_name[\"DEIR.028_0\"] = \"3001.mat\"\n",
210 | " matlab_files_name[\"DEIR.028_1\"] = \"3002.mat\"\n",
211 | " matlab_files_name[\"DEIR.028_2\"] = \"3003.mat\"\n",
212 | " matlab_files_name[\"DEIR.028_3\"] = \"3004.mat\"\n",
213 | " # DE Ball 0.028 inches\n",
214 | " matlab_files_name[\"DEB.028_0\"] = \"3005.mat\"\n",
215 | " matlab_files_name[\"DEB.028_1\"] = \"3006.mat\"\n",
216 | " matlab_files_name[\"DEB.028_2\"] = \"3007.mat\"\n",
217 | " matlab_files_name[\"DEB.028_3\"] = \"3008.mat\"\n",
218 | " return matlab_files_name"
219 | ],
220 | "execution_count": 0,
221 | "outputs": []
222 | },
223 | {
224 | "cell_type": "markdown",
225 | "metadata": {
226 | "id": "K9y9byVeSz_u",
227 | "colab_type": "text"
228 | },
229 | "source": [
230 | "##Download Matlab files"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "metadata": {
236 | "id": "wPSGH1401-W2",
237 | "colab_type": "code",
238 | "colab": {}
239 | },
240 | "source": [
241 | "import urllib.request\n",
242 | "import os.path\n",
243 | "\n",
244 | "def download_cwrufiles(matlab_files_name):\n",
245 | " '''\n",
246 | " Downloads the Matlab files in the dictionary matlab_files_name.\n",
247 | " '''\n",
248 | " url=\"http://csegroups.case.edu/sites/default/files/bearingdatacenter/files/Datafiles/\"\n",
249 | " n = len(matlab_files_name)\n",
250 | " for i,key in enumerate(matlab_files_name):\n",
251 | " file_name = matlab_files_name[key]\n",
252 | " if not os.path.exists(file_name):\n",
253 | " urllib.request.urlretrieve(url+file_name, file_name)\n",
254 | " print(\"{}/{}\\t{}\\t{}\".format(i+1, n, key, file_name))\n"
255 | ],
256 | "execution_count": 0,
257 | "outputs": []
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "metadata": {
262 | "id": "FRijKbOjS-JZ",
263 | "colab_type": "text"
264 | },
265 | "source": [
266 | "##Extract data from Matlab files"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "metadata": {
272 | "id": "BbpFkSI12CUe",
273 | "colab_type": "code",
274 | "colab": {}
275 | },
276 | "source": [
277 | "import scipy.io\n",
278 | "import numpy as np\n",
279 | "\n",
280 | "def get_tensors_from_matlab(matlab_files_name):\n",
281 | " '''\n",
282 | " Extracts the acquisitions of each Matlab file in the dictionary matlab_files_name.\n",
283 | " '''\n",
284 | " acquisitions = {}\n",
285 | " for key in matlab_files_name:\n",
286 | " file_name = matlab_files_name[key]\n",
287 | " matlab_file = scipy.io.loadmat(file_name)\n",
288 | " for position in ['DE','FE', 'BA']:\n",
289 | " keys = [key for key in matlab_file if key.endswith(position+\"_time\")]\n",
290 | " if len(keys)>0:\n",
291 | " array_key = keys[0]\n",
292 | " acquisitions[key+position.lower()] = matlab_file[array_key].reshape(1,-1)[0]\n",
293 | " return acquisitions\n"
294 | ],
295 | "execution_count": 0,
296 | "outputs": []
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {
301 | "id": "x-2lcukz5Nyk",
302 | "colab_type": "text"
303 | },
304 | "source": [
305 | "##Downloading pickle file\n",
306 | "Following, some auxiliary functions to download a pickle file in a google drive account.\n",
307 | "The pickle file already has the acquisitions propertly extracted.\n",
308 | "Therefore, these functions might speed up the whole process."
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "metadata": {
314 | "id": "ZJkpaFxn1xtR",
315 | "colab_type": "code",
316 | "colab": {}
317 | },
318 | "source": [
319 | "import requests\n",
320 | "\n",
321 | "def download_file_from_google_drive(id, destination):\n",
322 | " URL = \"https://docs.google.com/uc?export=download\"\n",
323 | " session = requests.Session()\n",
324 | " response = session.get(URL, params = { 'id' : id }, stream = True)\n",
325 | " token = get_confirm_token(response)\n",
326 | " if token:\n",
327 | " params = { 'id' : id, 'confirm' : token }\n",
328 | " response = session.get(URL, params = params, stream = True)\n",
329 | " save_response_content(response, destination) \n",
330 | "\n",
331 | "def get_confirm_token(response):\n",
332 | " for key, value in response.cookies.items():\n",
333 | " if key.startswith('download_warning'):\n",
334 | " return value\n",
335 | " return None\n",
336 | "\n",
337 | "def save_response_content(response, destination):\n",
338 | " CHUNK_SIZE = 32768\n",
339 | " with open(destination, \"wb\") as f:\n",
340 | " for chunk in response.iter_content(CHUNK_SIZE):\n",
341 | " if chunk: # filter out keep-alive new chunks\n",
342 | " f.write(chunk)\n",
343 | "\n",
344 | "file_id = \"1qJezMiROz9NAYafPUDPh9BFkxYF4nOi2\"\n",
345 | "destination = 'cwru.pickle'\n",
346 | "\n",
347 | "try:\n",
348 | " download_file_from_google_drive(file_id, destination)\n",
349 | "except:\n",
350 | " print(\"Download failed!\")"
351 | ],
352 | "execution_count": 0,
353 | "outputs": []
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {
358 | "id": "oEzRboCpgtlx",
359 | "colab_type": "text"
360 | },
361 | "source": [
362 | "##Save/Load data\n",
363 | "If the cwru pickle file is already download, it will not be downloaded again, and the dictionary with the acquisitions will be loaded.\n",
364 | "Otherwise, the desired files are downloaded and the acquisitions are extrated."
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "metadata": {
370 | "id": "Z1m5Q3OUbvqa",
371 | "colab_type": "code",
372 | "colab": {}
373 | },
374 | "source": [
375 | "import pickle\n",
376 | "import os\n",
377 | "\n",
378 | "pickle_file = 'cwru.pickle'\n",
379 | "if os.path.isfile(pickle_file):\n",
380 | " with open(pickle_file, 'rb') as handle:\n",
381 | " acquisitions = pickle.load(handle)\n",
382 | "else:\n",
383 | " matlab_files_name = cwru_12khz()\n",
384 | " download_cwrufiles(matlab_files_name)\n",
385 | " acquisitions = get_tensors_from_matlab(matlab_files_name)\n",
386 | " with open(pickle_file, 'wb') as handle:\n",
387 | " pickle.dump(acquisitions, handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
388 | ],
389 | "execution_count": 0,
390 | "outputs": []
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "metadata": {
395 | "id": "AT7hgDnzTcNP",
396 | "colab_type": "text"
397 | },
398 | "source": [
399 | "##Segment data"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "metadata": {
405 | "id": "7BKfioJFzAKA",
406 | "colab_type": "code",
407 | "outputId": "4151c14e-c533-48ac-9e68-1f72c376d868",
408 | "colab": {
409 | "base_uri": "https://localhost:8080/",
410 | "height": 35
411 | }
412 | },
413 | "source": [
414 | "import numpy as np\n",
415 | "def cwru_segmentation(acquisitions, sample_size=512, max_samples=None):\n",
416 | " '''\n",
417 | " Segments the acquisitions.\n",
418 | " sample_size is the size of each segment.\n",
419 | " max_samples is used for debug purpouses and \n",
420 | " reduces the number of samples from each acquisition.\n",
421 | " '''\n",
422 | " origin = []\n",
423 | " data = np.empty((0,sample_size,1))\n",
424 | " n = len(acquisitions)\n",
425 | " for i,key in enumerate(acquisitions):\n",
426 | " acquisition_size = len(acquisitions[key])\n",
427 | " n_samples = acquisition_size//sample_size\n",
428 | " if max_samples is not None and max_samples > 0 and n_samples > max_samples:\n",
429 | " n_samples = max_samples\n",
430 | " print('\\r{}/{} --- {}:\\t{}'.format(i+1, n, key, n_samples), end='')\n",
431 | " origin.extend([key for _ in range(n_samples)])\n",
432 | " data = np.concatenate((data,\n",
433 | " acquisitions[key][:(n_samples*sample_size)].reshape(\n",
434 | " (n_samples,sample_size,1))))\n",
435 | " return data,origin\n",
436 | "\n",
437 | "if True: #not debug:\n",
438 | " signal_data,signal_origin = cwru_segmentation(acquisitions, 512)\n",
439 | "else:\n",
440 | " signal_data,signal_origin = cwru_segmentation(acquisitions, 512, 2) #debug mode\n",
441 | "\n",
442 | "signal_data.shape\n"
443 | ],
444 | "execution_count": 240,
445 | "outputs": [
446 | {
447 | "output_type": "stream",
448 | "text": [
449 | "307/307 --- DEB.028_3de:\t236"
450 | ],
451 | "name": "stdout"
452 | },
453 | {
454 | "output_type": "execute_result",
455 | "data": {
456 | "text/plain": [
457 | "(77527, 512, 1)"
458 | ]
459 | },
460 | "metadata": {
461 | "tags": []
462 | },
463 | "execution_count": 240
464 | }
465 | ]
466 | },
467 | {
468 | "cell_type": "markdown",
469 | "metadata": {
470 | "id": "fpVAzW73-buw",
471 | "colab_type": "text"
472 | },
473 | "source": [
474 | "#Experiments"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "metadata": {
480 | "id": "NLeArq-uThHW",
481 | "colab_type": "text"
482 | },
483 | "source": [
484 | "##Feature Extraction Models"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "metadata": {
490 | "id": "WuSNj6YIEhu0",
491 | "colab_type": "code",
492 | "colab": {}
493 | },
494 | "source": [
495 | "from sklearn.base import TransformerMixin"
496 | ],
497 | "execution_count": 0,
498 | "outputs": []
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "metadata": {
503 | "id": "Mm95T4CsDxaN",
504 | "colab_type": "text"
505 | },
506 | "source": [
507 | "###Statistical functions"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "metadata": {
513 | "id": "vWCPUON8D1A8",
514 | "colab_type": "code",
515 | "colab": {}
516 | },
517 | "source": [
518 | "import numpy as np\n",
519 | "import scipy.stats as stats\n",
520 | "\n",
521 | "def rms(x):\n",
522 | " '''\n",
523 | " root mean square\n",
524 | " '''\n",
525 | " x = np.array(x)\n",
526 | " return np.sqrt(np.mean(np.square(x)))\n",
527 | "\n",
528 | "def sra(x):\n",
529 | " '''\n",
530 | " square root amplitude\n",
531 | " '''\n",
532 | " x = np.array(x)\n",
533 | " return np.mean(np.sqrt(np.absolute(x)))**2\n",
534 | "\n",
535 | "def ppv(x):\n",
536 | " '''\n",
537 | " peak to peak value\n",
538 | " '''\n",
539 | " x = np.array(x)\n",
540 | " return np.max(x)-np.min(x)\n",
541 | "\n",
542 | "def cf(x):\n",
543 | " '''\n",
544 | " crest factor\n",
545 | " '''\n",
546 | " x = np.array(x)\n",
547 | " return np.max(np.absolute(x))/rms(x)\n",
548 | "\n",
549 | "def ifa(x):\n",
550 | " '''\n",
551 | " impact factor\n",
552 | " '''\n",
553 | " x = np.array(x)\n",
554 | " return np.max(np.absolute(x))/np.mean(np.absolute(x))\n",
555 | "\n",
556 | "def mf(x):\n",
557 | " '''\n",
558 | " margin factor\n",
559 | " '''\n",
560 | " x = np.array(x)\n",
561 | " return np.max(np.absolute(x))/sra(x)\n",
562 | "\n",
563 | "def sf(x):\n",
564 | " '''\n",
565 | " shape factor\n",
566 | " '''\n",
567 | " x = np.array(x)\n",
568 | " return rms(x)/np.mean(np.absolute(x))\n",
569 | "\n",
570 | "def kf(x):\n",
571 | " '''\n",
572 | " kurtosis factor\n",
573 | " '''\n",
574 | " x = np.array(x)\n",
575 | " return stats.kurtosis(x)/(np.mean(x**2)**2)\n",
576 | "\n"
577 | ],
578 | "execution_count": 0,
579 | "outputs": []
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "metadata": {
584 | "id": "njMb9HtUEBrI",
585 | "colab_type": "text"
586 | },
587 | "source": [
588 | "### Statistical Features from Time Domain"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "metadata": {
594 | "id": "oSN2_c28D_Zr",
595 | "colab_type": "code",
596 | "colab": {}
597 | },
598 | "source": [
599 | "class StatisticalTime(TransformerMixin):\n",
600 | " '''\n",
601 | " Extracts statistical features from the time domain.\n",
602 | " '''\n",
603 | " def fit(self, X, y=None):\n",
604 | " return self\n",
605 | " def transform(self, X, y=None):\n",
606 | " return np.array([\n",
607 | " [\n",
608 | " rms(x), # root mean square\n",
609 | " sra(x), # square root amplitude\n",
610 | " stats.kurtosis(x), # kurtosis\n",
611 | " stats.skew(x), # skewness\n",
612 | " ppv(x), # peak to peak value\n",
613 | " cf(x), # crest factor\n",
614 | " ifa(x), # impact factor\n",
615 | " mf(x), # margin factor\n",
616 | " sf(x), # shape factor\n",
617 | " kf(x), # kurtosis factor\n",
618 | " ] for x in X[:,:,0]\n",
619 | " ])\n"
620 | ],
621 | "execution_count": 0,
622 | "outputs": []
623 | },
624 | {
625 | "cell_type": "markdown",
626 | "metadata": {
627 | "id": "dXDWD3JZEnep",
628 | "colab_type": "text"
629 | },
630 | "source": [
631 | "### Statistical Features from Frequency Domain"
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "metadata": {
637 | "id": "Sj3XTpVTEvAp",
638 | "colab_type": "code",
639 | "colab": {}
640 | },
641 | "source": [
642 | "class StatisticalFrequency(TransformerMixin):\n",
643 | " '''\n",
644 | " Extracts statistical features from the frequency domain.\n",
645 | " '''\n",
646 | " def fit(self, X, y=None):\n",
647 | " return self\n",
648 | " def transform(self, X, y=None):\n",
649 | " sig = []\n",
650 | " for x in X[:,:,0]:\n",
651 | " fx = np.absolute(np.fft.fft(x)) # transform x from time to frequency domain\n",
652 | " fc = np.mean(fx) # frequency center\n",
653 | " sig.append([\n",
654 | " fc, # frequency center\n",
655 | " rms(fx), # RMS from the frequency domain\n",
656 | " rms(fx-fc), # Root Variance Frequency\n",
657 | " ])\n",
658 | " return np.array(sig)\n"
659 | ],
660 | "execution_count": 0,
661 | "outputs": []
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "metadata": {
666 | "id": "c0YBmzTb6ARb",
667 | "colab_type": "text"
668 | },
669 | "source": [
670 | "###Statistical Features"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "metadata": {
676 | "id": "kep4ubkR6DR0",
677 | "colab_type": "code",
678 | "colab": {}
679 | },
680 | "source": [
681 | "class Statistical(TransformerMixin):\n",
682 | " '''\n",
683 | " Extracts statistical features from both time and frequency domain.\n",
684 | " '''\n",
685 | " def fit(self, X, y=None):\n",
686 | " return self\n",
687 | " def transform(self, X, y=None):\n",
688 | " st = StatisticalTime()\n",
689 | " stfeats = st.transform(X)\n",
690 | " sf = StatisticalFrequency()\n",
691 | " sffeats = sf.transform(X)\n",
692 | " return np.concatenate((stfeats,sffeats),axis=1)"
693 | ],
694 | "execution_count": 0,
695 | "outputs": []
696 | },
697 | {
698 | "cell_type": "markdown",
699 | "metadata": {
700 | "id": "ZuiVsHNzFORr",
701 | "colab_type": "text"
702 | },
703 | "source": [
704 | "###Wavelet Package Features"
705 | ]
706 | },
707 | {
708 | "cell_type": "code",
709 | "metadata": {
710 | "id": "oPd92xtJhaH3",
711 | "colab_type": "code",
712 | "cellView": "code",
713 | "colab": {}
714 | },
715 | "source": [
716 | "import numpy as np\n",
717 | "import pywt\n",
718 | "\n",
719 | "class WaveletPackage(TransformerMixin):\n",
720 | " '''\n",
721 | " Extracts Wavelet Package features.\n",
722 | " The features are calculated by the energy of the recomposed signal\n",
723 | " of the leaf nodes coeficients.\n",
724 | " '''\n",
725 | " def fit(self, X, y=None):\n",
726 | " return self\n",
727 | " def transform(self, X, y=None):\n",
728 | " def Energy(coeffs, k):\n",
729 | " return np.sqrt(np.sum(np.array(coeffs[-k]) ** 2)) / len(coeffs[-k])\n",
730 | " def getEnergy(wp):\n",
731 | " coefs = np.asarray([n.data for n in wp.get_leaf_nodes(True)])\n",
732 | " return np.asarray([Energy(coefs,i) for i in range(2**wp.maxlevel)])\n",
733 | " return np.array([getEnergy(pywt.WaveletPacket(data=x, wavelet='db4', \n",
734 | " mode='symmetric', maxlevel=4)\n",
735 | " ) for x in X[:,:,0]])"
736 | ],
737 | "execution_count": 0,
738 | "outputs": []
739 | },
740 | {
741 | "cell_type": "markdown",
742 | "metadata": {
743 | "id": "i_sonHkjFYbB",
744 | "colab_type": "text"
745 | },
746 | "source": [
747 | "###Heterogeneus Features"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "metadata": {
753 | "id": "IZsZhuVfFZsQ",
754 | "colab_type": "code",
755 | "colab": {}
756 | },
757 | "source": [
758 | "class Heterogeneous(TransformerMixin):\n",
759 | " '''\n",
760 | " Mixes Statistical and Wavelet Package features.\n",
761 | " '''\n",
762 | " def fit(self, X, y=None):\n",
763 | " return self\n",
764 | " def transform(self, X, y=None):\n",
765 | " st = StatisticalTime()\n",
766 | " stfeats = st.transform(X)\n",
767 | " sf = StatisticalFrequency()\n",
768 | " sffeats = sf.transform(X)\n",
769 | " wp = WaveletPackage()\n",
770 | " wpfeats = wp.transform(X)\n",
771 | " return np.concatenate((stfeats,sffeats,wpfeats),axis=1)\n"
772 | ],
773 | "execution_count": 0,
774 | "outputs": []
775 | },
776 | {
777 | "cell_type": "markdown",
778 | "metadata": {
779 | "id": "_jCN8XZ5dOF3",
780 | "colab_type": "text"
781 | },
782 | "source": [
783 | "## Clean dataset functions\n",
784 | "The functions below help to select samples from acquisitions and form groups according to these acquisitions, using regular expressions."
785 | ]
786 | },
787 | {
788 | "cell_type": "code",
789 | "metadata": {
790 | "id": "eOOP9H2c3AaZ",
791 | "colab_type": "code",
792 | "colab": {}
793 | },
794 | "source": [
795 | "import re\n",
796 | "import numpy as np\n",
797 | "\n",
798 | "def select_samples(regex, X, y):\n",
799 | " '''\n",
800 | " Selects samples wich has some regex pattern in its name.\n",
801 | " '''\n",
802 | " mask = [re.search(regex,label) is not None for label in y]\n",
803 | " return X[mask],y[mask]\n",
804 | "\n",
805 | "def join_labels(regex, y):\n",
806 | " '''\n",
807 | " Excludes some regex patterns from the labels, \n",
808 | " making some samples to have the same label.\n",
809 | " '''\n",
810 | " return np.array([re.sub(regex, '', label) for label in y])\n",
811 | "\n",
812 | "def get_groups(regex, y):\n",
813 | " '''\n",
814 | " Generates a list of groups of samples with \n",
815 | " the same regex patten in its label.\n",
816 | " '''\n",
817 | " groups = list(range(len(y)))\n",
818 | " for i,label in enumerate(y):\n",
819 | " match = re.search(regex,label)\n",
820 | " groups[i] = match.group(0) if match else None\n",
821 | " return groups"
822 | ],
823 | "execution_count": 0,
824 | "outputs": []
825 | },
826 | {
827 | "cell_type": "markdown",
828 | "metadata": {
829 | "id": "yFA-8l02RplD",
830 | "colab_type": "text"
831 | },
832 | "source": [
833 | "##Selecting samples"
834 | ]
835 | },
836 | {
837 | "cell_type": "code",
838 | "metadata": {
839 | "id": "r89dYOJm8gzW",
840 | "colab_type": "code",
841 | "outputId": "0804d5c3-9c0e-4d9a-8901-5bc54a6b3a53",
842 | "colab": {
843 | "base_uri": "https://localhost:8080/",
844 | "height": 55
845 | }
846 | },
847 | "source": [
848 | "#samples = '^(DE).*(de)|^(FE).*(fe)|(Normal).*' #DE from de, FE from fe and Normal\n",
849 | "samples = '^(DE).*(de)|(Normal).*(de)' #Only acquisitions from de with failures in DE\n",
850 | "X,y = select_samples(samples, signal_data, np.array(signal_origin))\n",
851 | "print(len(set(y)),set(y))"
852 | ],
853 | "execution_count": 249,
854 | "outputs": [
855 | {
856 | "output_type": "stream",
857 | "text": [
858 | "64 {'DEIR.021_3de', 'DEIR.021_1de', 'DEB.007_2de', 'DEIR.021_0de', 'DEOR.021@12_0de', 'DEIR.014_2de', 'Normal_3de', 'DEIR.007_2de', 'DEOR.021@6_0de', 'Normal_1de', 'DEB.014_2de', 'DEOR.021@6_2de', 'DEOR.021@6_3de', 'DEB.028_0de', 'DEOR.021@3_3de', 'DEB.028_3de', 'DEIR.007_1de', 'Normal_0de', 'DEOR.014@6_3de', 'DEOR.014@6_1de', 'DEIR.021_2de', 'DEOR.021@3_1de', 'DEOR.021@3_2de', 'DEB.028_1de', 'DEB.021_1de', 'DEB.021_3de', 'DEOR.007@12_3de', 'DEOR.014@6_2de', 'DEOR.021@12_1de', 'DEB.021_2de', 'DEIR.014_1de', 'DEOR.014@6_0de', 'DEIR.028_1de', 'DEIR.014_3de', 'DEIR.028_2de', 'DEB.007_1de', 'DEB.028_2de', 'DEOR.007@12_2de', 'DEB.014_1de', 'Normal_2de', 'DEOR.021@12_3de', 'DEOR.007@6_2de', 'DEB.007_0de', 'DEIR.028_3de', 'DEOR.007@3_1de', 'DEOR.007@6_3de', 'DEOR.021@6_1de', 'DEOR.007@12_1de', 'DEB.014_3de', 'DEB.021_0de', 'DEOR.007@6_1de', 'DEOR.007@6_0de', 'DEOR.007@3_2de', 'DEIR.014_0de', 'DEB.007_3de', 'DEOR.021@3_0de', 'DEIR.007_3de', 'DEOR.007@3_0de', 'DEOR.021@12_2de', 'DEIR.028_0de', 'DEIR.007_0de', 'DEOR.007@12_0de', 'DEOR.007@3_3de', 'DEB.014_0de'}\n"
859 | ],
860 | "name": "stdout"
861 | }
862 | ]
863 | },
864 | {
865 | "cell_type": "markdown",
866 | "metadata": {
867 | "id": "dWYfuxcxFjt8",
868 | "colab_type": "text"
869 | },
870 | "source": [
871 | "## Customized Splitter"
872 | ]
873 | },
874 | {
875 | "cell_type": "code",
876 | "metadata": {
877 | "id": "xRdfG-uzhPm4",
878 | "colab_type": "code",
879 | "colab": {}
880 | },
881 | "source": [
882 | "from sklearn.model_selection import KFold\n",
883 | "from sklearn.utils import shuffle\n",
884 | "from sklearn.utils.validation import check_array\n",
885 | "import numpy as np\n",
886 | "\n",
887 | "class GroupShuffleKFold(KFold):\n",
888 | " '''\n",
889 | " Neither GroupShuffleSplit nor GroupKFold are good splitters for this case.\n",
890 | " A custom splitter must be made.\n",
891 | " '''\n",
892 | " def __init__(self, n_splits=4, shuffle=True, random_state=None):\n",
893 | " super().__init__(n_splits, shuffle=shuffle, random_state=random_state)\n",
894 | " def get_n_splits(self, X, y, groups=None):\n",
895 | " return self.n_splits\n",
896 | " def _iter_test_indices(self, X=None, y=None, groups=None):\n",
897 | " if groups is None:\n",
898 | " raise ValueError(\"The 'groups' parameter should not be None.\")\n",
899 | " groups = check_array(groups, ensure_2d=False, dtype=None)\n",
900 | " unique_groups, groups = np.unique(groups, return_inverse=True)\n",
901 | " n_groups = len(unique_groups)\n",
902 | " if self.n_splits > n_groups:\n",
903 | " raise ValueError(\"Cannot have number of splits n_splits=%d greater\"\n",
904 | " \" than the number of groups: %d.\"\n",
905 | " % (self.n_splits, n_groups))\n",
906 | " # Distribute groups\n",
907 | " indices = np.arange(n_groups)\n",
908 | " if self.shuffle:\n",
909 | " for i in range(n_groups//self.n_splits):\n",
910 | " if self.random_state is None:\n",
911 | " indices[self.n_splits*i:self.n_splits*(i+1)] = shuffle(\n",
912 | " indices[self.n_splits*i:self.n_splits*(i+1)])\n",
913 | " else:\n",
914 | " indices[self.n_splits*i:self.n_splits*(i+1)] = shuffle(\n",
915 | " indices[self.n_splits*i:self.n_splits*(i+1)],\n",
916 | " random_state=self.random_state+i)\n",
917 | " #print(unique_groups[indices]) #Debug purpose\n",
918 | " # Total weight of each fold\n",
919 | " n_samples_per_fold = np.zeros(self.n_splits)\n",
920 | " # Mapping from group index to fold index\n",
921 | " group_to_fold = np.zeros(len(unique_groups))\n",
922 | " # Distribute samples \n",
923 | " for group_index in indices:\n",
924 | " group_to_fold[indices[group_index]] = group_index%(self.n_splits)\n",
925 | " indices = group_to_fold[groups]\n",
926 | " for f in range(self.n_splits):\n",
927 | " yield np.where(indices == f)[0]"
928 | ],
929 | "execution_count": 0,
930 | "outputs": []
931 | },
932 | {
933 | "cell_type": "markdown",
934 | "metadata": {
935 | "id": "5pRQQK0Mhm1_",
936 | "colab_type": "text"
937 | },
938 | "source": [
939 | "##Experimenter definition"
940 | ]
941 | },
942 | {
943 | "cell_type": "code",
944 | "metadata": {
945 | "id": "GE4TTG1-hmH7",
946 | "colab_type": "code",
947 | "colab": {}
948 | },
949 | "source": [
950 | "from sklearn.model_selection import cross_validate, KFold, PredefinedSplit\n",
951 | "\n",
952 | "def experimenter(model, X, y, \n",
953 | " groups=None, \n",
954 | " scoring=None,\n",
955 | " cv=KFold(4, True), \n",
956 | " verbose=0):\n",
957 | " '''\n",
958 | " Performs a experiment with some estimator (model) and validation.\n",
959 | " It works like a cross_validate function from sklearn, however, \n",
960 | " when a estimator has an internal validation with groups, \n",
961 | " it maintains the groups from the external validation.\n",
962 | " '''\n",
963 | " if hasattr(model,'cv') or (hasattr(model,'steps') and any(['gs' in step[0] for step in model.steps])):\n",
964 | " scores = {}\n",
965 | " lstval = list(validation.split(X,y,groups))\n",
966 | " for tr,te in lstval:\n",
967 | " if groups is not None:\n",
968 | " innercv = list(GroupShuffleKFold(validation.n_splits-1).split(X[tr],y[tr],np.array(groups)[tr]))\n",
969 | " else:\n",
970 | " innercv = list(KFold(validation.n_splits-1, True).split(X[tr],y[tr]))\n",
971 | " if hasattr(model,'cv'):\n",
972 | " model.cv = innercv\n",
973 | " else:\n",
974 | " for step in model.steps:\n",
975 | " if 'gs' in step[0]:\n",
976 | " step[1].cv = innercv\n",
977 | " test_fold = np.zeros((len(y),), dtype=int)\n",
978 | " test_fold[tr] = -1\n",
979 | " score = cross_validate(model, X, y, groups, scoring, PredefinedSplit(test_fold), verbose=verbose)\n",
980 | " for k in score.keys():\n",
981 | " if k not in scores:\n",
982 | " scores[k] = []\n",
983 | " scores[k].extend(score[k])\n",
984 | " return scores\n",
985 | " return cross_validate(model, X, y, groups, scoring, cv, verbose=verbose)"
986 | ],
987 | "execution_count": 0,
988 | "outputs": []
989 | },
990 | {
991 | "cell_type": "markdown",
992 | "metadata": {
993 | "id": "e48W6KkIhesw",
994 | "colab_type": "text"
995 | },
996 | "source": [
997 | "##Experiment setup"
998 | ]
999 | },
1000 | {
1001 | "cell_type": "code",
1002 | "metadata": {
1003 | "id": "K9dImVH3hh_Y",
1004 | "colab_type": "code",
1005 | "colab": {}
1006 | },
1007 | "source": [
1008 | "from collections import namedtuple\n",
1009 | "\n",
1010 | "ExperimentSetup = namedtuple('ExperimentSetup', 'groups, splitter_name')\n",
1011 | "\n",
1012 | "validations = {\n",
1013 | " # Validation usually seen in publications with CWRU bearing dataset.\n",
1014 | " \"Usual K-Fold\": ExperimentSetup(groups = None, splitter_name = 'KFold'), \n",
1015 | " # Samples from the same original Matlab file cannot be in the \n",
1016 | " # trainning folds and the test fold simultaneously.\n",
1017 | " \"By Acquisition\": ExperimentSetup(groups = join_labels('(de)|(fe)|(ba)',y), \n",
1018 | " splitter_name = 'GroupShuffleKFold'),\n",
1019 | " # Samples with the same severity cannot be in the trainning folds and\n",
1020 | " # the test folds simultaneously.\n",
1021 | " \"By Severity\": ExperimentSetup(groups = get_groups('(\\.\\d{3})|(Normal_\\d)',y),\n",
1022 | " splitter_name = 'GroupShuffleKFold'),\n",
1023 | "}\n",
1024 | "if debug:\n",
1025 | " validations = {\n",
1026 | " \"By Severity\": ExperimentSetup(groups = get_groups('(\\.\\d{3})|(Normal_\\d)',y),\n",
1027 | " splitter_name = 'GroupShuffleKFold')}"
1028 | ],
1029 | "execution_count": 0,
1030 | "outputs": []
1031 | },
1032 | {
1033 | "cell_type": "markdown",
1034 | "metadata": {
1035 | "id": "E9yn44VAoRFo",
1036 | "colab_type": "text"
1037 | },
1038 | "source": [
1039 | "##Common Variables"
1040 | ]
1041 | },
1042 | {
1043 | "cell_type": "code",
1044 | "metadata": {
1045 | "id": "ogpdDde4oTsS",
1046 | "colab_type": "code",
1047 | "colab": {}
1048 | },
1049 | "source": [
1050 | "# Only four conditions are considered: Normal, Ball, Inner Race and Outer Race.\n",
1051 | "selected_y = join_labels('_\\d|@\\d{1,3}|(de)|(fe)|\\.\\d{3}|(DE)|(FE)',y)\n",
1052 | "rounds = 4 if not debug else 1\n",
1053 | "verbose = 0\n",
1054 | "random_state = 42\n",
1055 | "scoring = ['accuracy', 'f1_macro']#, 'precision_macro', 'recall_macro']"
1056 | ],
1057 | "execution_count": 0,
1058 | "outputs": []
1059 | },
1060 | {
1061 | "cell_type": "markdown",
1062 | "metadata": {
1063 | "id": "mq30RtWYToeu",
1064 | "colab_type": "text"
1065 | },
1066 | "source": [
1067 | "##Classification Models Definition"
1068 | ]
1069 | },
1070 | {
1071 | "cell_type": "code",
1072 | "metadata": {
1073 | "id": "IstS2gTeY7pg",
1074 | "colab_type": "code",
1075 | "colab": {}
1076 | },
1077 | "source": [
1078 | "import warnings\n",
1079 | "warnings.filterwarnings('ignore')"
1080 | ],
1081 | "execution_count": 0,
1082 | "outputs": []
1083 | },
1084 | {
1085 | "cell_type": "markdown",
1086 | "metadata": {
1087 | "id": "1bnyL67EUcxV",
1088 | "colab_type": "text"
1089 | },
1090 | "source": [
1091 | "###K-NN"
1092 | ]
1093 | },
1094 | {
1095 | "cell_type": "code",
1096 | "metadata": {
1097 | "id": "F--sjKZRUh5G",
1098 | "colab_type": "code",
1099 | "outputId": "def285ca-d979-4e52-92f1-e8fdb22638d9",
1100 | "colab": {
1101 | "base_uri": "https://localhost:8080/",
1102 | "height": 208
1103 | }
1104 | },
1105 | "source": [
1106 | "from sklearn.neighbors import KNeighborsClassifier\n",
1107 | "from sklearn.pipeline import Pipeline\n",
1108 | "from sklearn.preprocessing import StandardScaler\n",
1109 | "\n",
1110 | "knn = Pipeline([\n",
1111 | " ('FeatureExtraction', WaveletPackage()),\n",
1112 | " ('scaler', StandardScaler()),\n",
1113 | " ('knn', KNeighborsClassifier()),\n",
1114 | " ])\n",
1115 | "knn"
1116 | ],
1117 | "execution_count": 255,
1118 | "outputs": [
1119 | {
1120 | "output_type": "execute_result",
1121 | "data": {
1122 | "text/plain": [
1123 | "Pipeline(memory=None,\n",
1124 | " steps=[('FeatureExtraction',\n",
1125 | " <__main__.WaveletPackage object at 0x7f923e3a0358>),\n",
1126 | " ('scaler',\n",
1127 | " StandardScaler(copy=True, with_mean=True, with_std=True)),\n",
1128 | " ('knn',\n",
1129 | " KNeighborsClassifier(algorithm='auto', leaf_size=30,\n",
1130 | " metric='minkowski', metric_params=None,\n",
1131 | " n_jobs=None, n_neighbors=5, p=2,\n",
1132 | " weights='uniform'))],\n",
1133 | " verbose=False)"
1134 | ]
1135 | },
1136 | "metadata": {
1137 | "tags": []
1138 | },
1139 | "execution_count": 255
1140 | }
1141 | ]
1142 | },
1143 | {
1144 | "cell_type": "markdown",
1145 | "metadata": {
1146 | "id": "d7DfMTS_ujeE",
1147 | "colab_type": "text"
1148 | },
1149 | "source": [
1150 | "###SVM with GridSearchCV"
1151 | ]
1152 | },
1153 | {
1154 | "cell_type": "code",
1155 | "metadata": {
1156 | "id": "6v_wXxiDupvF",
1157 | "colab_type": "code",
1158 | "outputId": "243d3878-22f8-4b39-8055-9521cc3c528a",
1159 | "colab": {
1160 | "base_uri": "https://localhost:8080/",
1161 | "height": 399
1162 | }
1163 | },
1164 | "source": [
1165 | "from sklearn.svm import SVC\n",
1166 | "from sklearn.pipeline import Pipeline\n",
1167 | "from sklearn.preprocessing import StandardScaler\n",
1168 | "from sklearn.model_selection import GridSearchCV\n",
1169 | "\n",
1170 | "parameters = {'C':[10**c for c in range(5)]}\n",
1171 | "svm = Pipeline([\n",
1172 | " ('FeatureExtraction', WaveletPackage()),\n",
1173 | " ('scaler', StandardScaler()),\n",
1174 | " ('svc_gs', GridSearchCV(SVC(), parameters)),\n",
1175 | " ])\n",
1176 | "svm"
1177 | ],
1178 | "execution_count": 256,
1179 | "outputs": [
1180 | {
1181 | "output_type": "execute_result",
1182 | "data": {
1183 | "text/plain": [
1184 | "Pipeline(memory=None,\n",
1185 | " steps=[('FeatureExtraction',\n",
1186 | " <__main__.WaveletPackage object at 0x7f923e3a0a20>),\n",
1187 | " ('scaler',\n",
1188 | " StandardScaler(copy=True, with_mean=True, with_std=True)),\n",
1189 | " ('svc_gs',\n",
1190 | " GridSearchCV(cv=None, error_score=nan,\n",
1191 | " estimator=SVC(C=1.0, break_ties=False,\n",
1192 | " cache_size=200, class_weight=None,\n",
1193 | " coef0=0.0,\n",
1194 | " decision_function_shape='ovr',\n",
1195 | " degree=3, gamma='scale',\n",
1196 | " kernel='rbf', max_iter=-1,\n",
1197 | " probability=False,\n",
1198 | " random_state=None, shrinking=True,\n",
1199 | " tol=0.001, verbose=False),\n",
1200 | " iid='deprecated', n_jobs=None,\n",
1201 | " param_grid={'C': [1, 10, 100, 1000, 10000]},\n",
1202 | " pre_dispatch='2*n_jobs', refit=True,\n",
1203 | " return_train_score=False, scoring=None,\n",
1204 | " verbose=0))],\n",
1205 | " verbose=False)"
1206 | ]
1207 | },
1208 | "metadata": {
1209 | "tags": []
1210 | },
1211 | "execution_count": 256
1212 | }
1213 | ]
1214 | },
1215 | {
1216 | "cell_type": "markdown",
1217 | "metadata": {
1218 | "id": "yU36xsi4JGZv",
1219 | "colab_type": "text"
1220 | },
1221 | "source": [
1222 | "###Random Forest"
1223 | ]
1224 | },
1225 | {
1226 | "cell_type": "code",
1227 | "metadata": {
1228 | "id": "GXABo6HpJJY_",
1229 | "colab_type": "code",
1230 | "outputId": "99a5847b-dcc1-4dac-c745-16ae8148ac5e",
1231 | "colab": {
1232 | "base_uri": "https://localhost:8080/",
1233 | "height": 329
1234 | }
1235 | },
1236 | "source": [
1237 | "from sklearn.ensemble import RandomForestClassifier\n",
1238 | "from sklearn.pipeline import Pipeline\n",
1239 | "from sklearn.preprocessing import StandardScaler\n",
1240 | "\n",
1241 | "rf = Pipeline([\n",
1242 | " ('FeatureExtraction', WaveletPackage()),\n",
1243 | " ('scaler', StandardScaler()),\n",
1244 | " ('rf', RandomForestClassifier()),\n",
1245 | " ])\n",
1246 | "rf"
1247 | ],
1248 | "execution_count": 257,
1249 | "outputs": [
1250 | {
1251 | "output_type": "execute_result",
1252 | "data": {
1253 | "text/plain": [
1254 | "Pipeline(memory=None,\n",
1255 | " steps=[('FeatureExtraction',\n",
1256 | " <__main__.WaveletPackage object at 0x7f923e3a06a0>),\n",
1257 | " ('scaler',\n",
1258 | " StandardScaler(copy=True, with_mean=True, with_std=True)),\n",
1259 | " ('rf',\n",
1260 | " RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,\n",
1261 | " class_weight=None, criterion='gini',\n",
1262 | " max_depth=None, max_features='auto',\n",
1263 | " max_leaf_nodes=None, max_samples=None,\n",
1264 | " min_impurity_decrease=0.0,\n",
1265 | " min_impurity_split=None,\n",
1266 | " min_samples_leaf=1, min_samples_split=2,\n",
1267 | " min_weight_fraction_leaf=0.0,\n",
1268 | " n_estimators=100, n_jobs=None,\n",
1269 | " oob_score=False, random_state=None,\n",
1270 | " verbose=0, warm_start=False))],\n",
1271 | " verbose=False)"
1272 | ]
1273 | },
1274 | "metadata": {
1275 | "tags": []
1276 | },
1277 | "execution_count": 257
1278 | }
1279 | ]
1280 | },
1281 | {
1282 | "cell_type": "markdown",
1283 | "metadata": {
1284 | "id": "REy6ykvWSbJc",
1285 | "colab_type": "text"
1286 | },
1287 | "source": [
1288 | "###Convolutional Neural Network"
1289 | ]
1290 | },
1291 | {
1292 | "cell_type": "code",
1293 | "metadata": {
1294 | "id": "8YpHSjvNcEx5",
1295 | "colab_type": "code",
1296 | "colab": {}
1297 | },
1298 | "source": [
1299 | "%tensorflow_version 2.x"
1300 | ],
1301 | "execution_count": 0,
1302 | "outputs": []
1303 | },
1304 | {
1305 | "cell_type": "markdown",
1306 | "metadata": {
1307 | "id": "FF1hCxJ9b5h0",
1308 | "colab_type": "text"
1309 | },
1310 | "source": [
1311 | "####F1-score macro averaged implemented for Keras"
1312 | ]
1313 | },
1314 | {
1315 | "cell_type": "code",
1316 | "metadata": {
1317 | "id": "LCJErrQIcIZ0",
1318 | "colab_type": "code",
1319 | "colab": {}
1320 | },
1321 | "source": [
1322 | "from tensorflow.keras import backend as K\n",
1323 | "def f1_score_macro(y_true,y_pred):\n",
1324 | " def recall(y_true, y_pred):\n",
1325 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
1326 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n",
1327 | " recall = true_positives / (possible_positives + K.epsilon())\n",
1328 | " return recall\n",
1329 | " def precision(y_true, y_pred):\n",
1330 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
1331 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n",
1332 | " precision = true_positives / (predicted_positives + K.epsilon())\n",
1333 | " return precision\n",
1334 | " precision = precision(y_true, y_pred)\n",
1335 | " recall = recall(y_true, y_pred)\n",
1336 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))"
1337 | ],
1338 | "execution_count": 0,
1339 | "outputs": []
1340 | },
1341 | {
1342 | "cell_type": "markdown",
1343 | "metadata": {
1344 | "id": "PeLXSU7-cLfk",
1345 | "colab_type": "text"
1346 | },
1347 | "source": [
1348 | "####ANN wrapped in a scikit-learn estimator."
1349 | ]
1350 | },
1351 | {
1352 | "cell_type": "code",
1353 | "metadata": {
1354 | "id": "ny7otiW6Siz_",
1355 | "colab_type": "code",
1356 | "colab": {}
1357 | },
1358 | "source": [
1359 | "from tensorflow.keras import layers\n",
1360 | "from tensorflow.keras.models import Sequential\n",
1361 | "from tensorflow.keras.utils import to_categorical\n",
1362 | "from sklearn.base import BaseEstimator, ClassifierMixin\n",
1363 | "import numpy as np\n",
1364 | "from tensorflow.keras.callbacks import EarlyStopping,ReduceLROnPlateau\n",
1365 | "\n",
1366 | "class ANN(BaseEstimator, ClassifierMixin):\n",
1367 | " def __init__(self, \n",
1368 | " dense_layer_sizes=[64], \n",
1369 | " kernel_size=32, \n",
1370 | " filters=32, \n",
1371 | " n_conv_layers=2,\n",
1372 | " pool_size=8,\n",
1373 | " dropout=0.25,\n",
1374 | " epochs=50,\n",
1375 | " validation_split=0.05,\n",
1376 | " optimizer='sgd'#'nadam'#'rmsprop'#\n",
1377 | " ):\n",
1378 | " self.dense_layer_sizes = dense_layer_sizes\n",
1379 | " self.kernel_size = kernel_size\n",
1380 | " self.filters = filters\n",
1381 | " self.n_conv_layers = n_conv_layers\n",
1382 | " self.pool_size = pool_size\n",
1383 | " self.dropout = dropout\n",
1384 | " self.epochs = epochs\n",
1385 | " self.validation_split = validation_split\n",
1386 | " self.optimizer = optimizer\n",
1387 | " \n",
1388 | " def fit(self, X, y=None):\n",
1389 | " dense_layer_sizes = self.dense_layer_sizes\n",
1390 | " kernel_size = self.kernel_size\n",
1391 | " filters = self.filters\n",
1392 | " n_conv_layers = self.n_conv_layers\n",
1393 | " pool_size = self.pool_size\n",
1394 | " dropout = self.dropout\n",
1395 | " epochs = self.epochs\n",
1396 | " optimizer = self.optimizer\n",
1397 | " validation_split = self.validation_split\n",
1398 | "\n",
1399 | " self.labels, ids = np.unique(y, return_inverse=True)\n",
1400 | " y_cat = to_categorical(ids)\n",
1401 | " num_classes = y_cat.shape[1]\n",
1402 | " \n",
1403 | " self.model = Sequential()\n",
1404 | " self.model.add(layers.InputLayer(input_shape=(X.shape[1],X.shape[-1])))\n",
1405 | " for _ in range(n_conv_layers):\n",
1406 | " self.model.add(layers.Conv1D(filters, kernel_size))#, padding='valid'))\n",
1407 | " self.model.add(layers.Activation('relu'))\n",
1408 | " if pool_size>1:\n",
1409 | " self.model.add(layers.MaxPooling1D(pool_size=pool_size))\n",
1410 | " #self.model.add(layers.Dropout(0.25))\n",
1411 | " self.model.add(layers.Flatten())\n",
1412 | " for layer_size in dense_layer_sizes:\n",
1413 | " self.model.add(layers.Dense(layer_size))\n",
1414 | " self.model.add(layers.Activation('relu'))\n",
1415 | " if dropout>0 and dropout<1:\n",
1416 | " self.model.add(layers.Dropout(dropout))\n",
1417 | " self.model.add(layers.Dense(num_classes))\n",
1418 | " self.model.add(layers.Activation('softmax'))\n",
1419 | " self.model.compile(loss='categorical_crossentropy',\n",
1420 | " optimizer=optimizer,\n",
1421 | " metrics=[f1_score_macro])\n",
1422 | " if validation_split>0 and validation_split<1:\n",
1423 | " prop = int(1/validation_split)\n",
1424 | " mask = np.array([i%prop==0 for i in range(len(y))])\n",
1425 | " self.history = self.model.fit(X[~mask], y_cat[~mask], epochs=epochs, \n",
1426 | " validation_data=(X[mask],y_cat[mask]),\n",
1427 | " callbacks=[EarlyStopping(patience=3), ReduceLROnPlateau()],\n",
1428 | " verbose=False\n",
1429 | " ) \n",
1430 | " else:\n",
1431 | " self.history = self.model.fit(X, y_cat, epochs=epochs, verbose=False) \n",
1432 | " \n",
1433 | " def predict_proba(self, X, y=None):\n",
1434 | " return self.model.predict(X)\n",
1435 | "\n",
1436 | " def predict(self, X, y=None):\n",
1437 | " predictions = self.model.predict(X)\n",
1438 | " return self.labels[np.argmax(predictions,axis=1)]"
1439 | ],
1440 | "execution_count": 0,
1441 | "outputs": []
1442 | },
1443 | {
1444 | "cell_type": "markdown",
1445 | "metadata": {
1446 | "id": "UL5Zv2-xcgEG",
1447 | "colab_type": "text"
1448 | },
1449 | "source": [
1450 | "####ANN instantiation"
1451 | ]
1452 | },
1453 | {
1454 | "cell_type": "code",
1455 | "metadata": {
1456 | "id": "MmQe7jeRcb6F",
1457 | "colab_type": "code",
1458 | "outputId": "bf06238e-5bf9-44b9-af3a-a963b4108d5b",
1459 | "colab": {
1460 | "base_uri": "https://localhost:8080/",
1461 | "height": 52
1462 | }
1463 | },
1464 | "source": [
1465 | "ann = ANN()\n",
1466 | "ann"
1467 | ],
1468 | "execution_count": 261,
1469 | "outputs": [
1470 | {
1471 | "output_type": "execute_result",
1472 | "data": {
1473 | "text/plain": [
1474 | "ANN(dense_layer_sizes=[64], dropout=0.25, epochs=50, filters=32, kernel_size=32,\n",
1475 | " n_conv_layers=2, optimizer='sgd', pool_size=8, validation_split=0.05)"
1476 | ]
1477 | },
1478 | "metadata": {
1479 | "tags": []
1480 | },
1481 | "execution_count": 261
1482 | }
1483 | ]
1484 | },
1485 | {
1486 | "cell_type": "markdown",
1487 | "metadata": {
1488 | "id": "4SQenKxoSkME",
1489 | "colab_type": "text"
1490 | },
1491 | "source": [
1492 | "###List of Estimators"
1493 | ]
1494 | },
1495 | {
1496 | "cell_type": "code",
1497 | "metadata": {
1498 | "id": "Vb6AGKFJJqvL",
1499 | "colab_type": "code",
1500 | "colab": {}
1501 | },
1502 | "source": [
1503 | "clfs = [\n",
1504 | " ('KNN - KNeighborsClassifier', knn),\n",
1505 | " ('SVM - SVC with GridSearchCV', svm),\n",
1506 | " ('RF - RandomForestClassifier', rf),\n",
1507 | " ('ANN - Convolutional Layers', ann),\n",
1508 | " ]\n",
1509 | "if debug:\n",
1510 | " clfs = [('ANN - Convolutional Layers', ann)]"
1511 | ],
1512 | "execution_count": 0,
1513 | "outputs": []
1514 | },
1515 | {
1516 | "cell_type": "markdown",
1517 | "metadata": {
1518 | "colab_type": "text",
1519 | "id": "dEl_vSYaq-s2"
1520 | },
1521 | "source": [
1522 | "##Performing Experiments"
1523 | ]
1524 | },
1525 | {
1526 | "cell_type": "code",
1527 | "metadata": {
1528 | "id": "CH4LVC3Zj3jC",
1529 | "colab_type": "code",
1530 | "outputId": "4e93054f-9629-4185-e15f-e890a55ac4d3",
1531 | "colab": {
1532 | "base_uri": "https://localhost:8080/",
1533 | "height": 1000
1534 | }
1535 | },
1536 | "source": [
1537 | "import numpy as np\n",
1538 | "\n",
1539 | "scores = {}\n",
1540 | "trtime = {}\n",
1541 | "tetime = {}\n",
1542 | "# Number of repetitions\n",
1543 | "for r in range(rounds):\n",
1544 | " round_str = \"Round {}\".format(r+1)\n",
1545 | " print(\"@\"*(len(round_str)+8),'\\n@@@',round_str,'@@@\\n'+\"@\"*(len(round_str)+8))\n",
1546 | " # Estimators\n",
1547 | " for clf_name, estimator in clfs:\n",
1548 | " if clf_name not in scores:\n",
1549 | " scores[clf_name] = {}\n",
1550 | " trtime[clf_name] = {}\n",
1551 | " tetime[clf_name] = {}\n",
1552 | " print(\"*\"*(len(clf_name)+8),'\\n***',clf_name,'***\\n'+\"*\"*(len(clf_name)+8))\n",
1553 | " # Validation forms\n",
1554 | " for val_name in validations.keys():\n",
1555 | " print(\"#\"*(len(val_name)+8),'\\n###',val_name,'###\\n'+\"#\"*(len(val_name)+8))\n",
1556 | " groups = validations[val_name].groups\n",
1557 | " if val_name not in scores[clf_name]:\n",
1558 | " scores[clf_name][val_name] = {}\n",
1559 | " validation = eval(validations[val_name].splitter_name\n",
1560 | " +'(4,shuffle=True,random_state='\n",
1561 | " +str(random_state+r)+')')\n",
1562 | " score = experimenter(estimator, X, selected_y, groups, \n",
1563 | " scoring, validation, verbose)\n",
1564 | " for metric,s in score.items():\n",
1565 | " print(metric, ' \\t', s)\n",
1566 | " if metric not in scores[clf_name][val_name]:\n",
1567 | " scores[clf_name][val_name][metric] = []\n",
1568 | " scores[clf_name][val_name][metric].append(s)"
1569 | ],
1570 | "execution_count": 263,
1571 | "outputs": [
1572 | {
1573 | "output_type": "stream",
1574 | "text": [
1575 | "@@@@@@@@@@@@@@@ \n",
1576 | "@@@ Round 1 @@@\n",
1577 | "@@@@@@@@@@@@@@@\n",
1578 | "********************************** \n",
1579 | "*** KNN - KNeighborsClassifier ***\n",
1580 | "**********************************\n",
1581 | "#################### \n",
1582 | "### Usual K-Fold ###\n",
1583 | "####################\n",
1584 | "fit_time \t [7.3245039 7.77022791 7.45543742 7.41583681]\n",
1585 | "score_time \t [2.92294478 2.85009789 2.95585728 3.00658798]\n",
1586 | "test_accuracy \t [0.98132969 0.98383424 0.98269581 0.98474152]\n",
1587 | "test_f1_macro \t [0.98271227 0.98476938 0.9835539 0.98528629]\n",
1588 | "###################### \n",
1589 | "### By Acquisition ###\n",
1590 | "######################\n",
1591 | "fit_time \t [8.43344092 8.36202717 7.17392468 7.22651124]\n",
1592 | "score_time \t [2.93316627 2.96407676 2.92852592 2.90356755]\n",
1593 | "test_accuracy \t [0.96135744 0.95139814 0.97252382 0.93371758]\n",
1594 | "test_f1_macro \t [0.96483569 0.95292091 0.97159579 0.92957374]\n",
1595 | "################### \n",
1596 | "### By Severity ###\n",
1597 | "###################\n",
1598 | "fit_time \t [7.58721757 7.70569801 6.42010689 7.44451952]\n",
1599 | "score_time \t [2.27661967 2.0538609 3.97748995 3.55925274]\n",
1600 | "test_accuracy \t [0.44491302 0.33368607 0.59228747 0.3497038 ]\n",
1601 | "test_f1_macro \t [0.37886062 0.25052798 0.62872652 0.46349296]\n",
1602 | "*********************************** \n",
1603 | "*** SVM - SVC with GridSearchCV ***\n",
1604 | "***********************************\n",
1605 | "#################### \n",
1606 | "### Usual K-Fold ###\n",
1607 | "####################\n",
1608 | "fit_time \t [19.994742155075073, 19.85019087791443, 20.457500219345093, 19.90899920463562]\n",
1609 | "score_time \t [2.5596060752868652, 2.4824955463409424, 2.6774775981903076, 2.5800788402557373]\n",
1610 | "test_accuracy \t [0.9842896174863388, 0.9867941712204007, 0.9877049180327869, 0.9847415167387839]\n",
1611 | "test_f1_macro \t [0.9854838394407729, 0.9876760467628855, 0.9883849486399726, 0.9853136020490524]\n",
1612 | "###################### \n",
1613 | "### By Acquisition ###\n",
1614 | "######################\n",
1615 | "fit_time \t [19.594305992126465, 19.326322317123413, 18.958588123321533, 19.277965307235718]\n",
1616 | "score_time \t [2.4327797889709473, 2.51678466796875, 2.71537446975708, 2.6660337448120117]\n",
1617 | "test_accuracy \t [0.9734951696804558, 0.9829116733244563, 0.9780633724795037, 0.97251163821769]\n",
1618 | "test_f1_macro \t [0.976858416946694, 0.9835580577060097, 0.9776931023587536, 0.9721119164504329]\n",
1619 | "################### \n",
1620 | "### By Severity ###\n",
1621 | "###################\n",
1622 | "fit_time \t [12.453691959381104, 16.324783086776733, 11.726444959640503, 12.092777729034424]\n",
1623 | "score_time \t [2.2399401664733887, 1.6088571548461914, 3.6347174644470215, 3.1187779903411865]\n",
1624 | "test_accuracy \t [0.5071164997364259, 0.3333333333333333, 0.5177914110429448, 0.43626982610357345]\n",
1625 | "test_f1_macro \t [0.4353144040274436, 0.25, 0.5967175448345157, 0.49589598650693223]\n",
1626 | "*********************************** \n",
1627 | "*** RF - RandomForestClassifier ***\n",
1628 | "***********************************\n",
1629 | "#################### \n",
1630 | "### Usual K-Fold ###\n",
1631 | "####################\n",
1632 | "fit_time \t [10.50702095 10.39010954 10.5258615 10.69215393]\n",
1633 | "score_time \t [2.36112642 2.35864496 2.3171196 2.60548401]\n",
1634 | "test_accuracy \t [0.98588342 0.98884335 0.98907104 0.98610795]\n",
1635 | "test_f1_macro \t [0.98700832 0.9896653 0.98960034 0.9868356 ]\n",
1636 | "###################### \n",
1637 | "### By Acquisition ###\n",
1638 | "######################\n",
1639 | "fit_time \t [11.06583214 10.70239806 10.54217863 10.48058772]\n",
1640 | "score_time \t [2.33086109 2.6757555 2.57460976 2.50441241]\n",
1641 | "test_accuracy \t [0.97077037 0.9873502 0.97052958 0.93482598]\n",
1642 | "test_f1_macro \t [0.97489487 0.98788971 0.97031567 0.93293216]\n",
1643 | "################### \n",
1644 | "### By Severity ###\n",
1645 | "###################\n",
1646 | "fit_time \t [11.59657001 11.91939855 9.70096159 9.6287353 ]\n",
1647 | "score_time \t [2.06921029 1.62476182 3.09937906 2.99023056]\n",
1648 | "test_accuracy \t [0.43199789 0.33544974 0.62997371 0.51213453]\n",
1649 | "test_f1_macro \t [0.38000346 0.25315126 0.64502238 0.57749991]\n",
1650 | "********************************** \n",
1651 | "*** ANN - Convolutional Layers ***\n",
1652 | "**********************************\n",
1653 | "#################### \n",
1654 | "### Usual K-Fold ###\n",
1655 | "####################\n",
1656 | "fit_time \t [24.15791512 38.80671716 57.07147145 37.02953219]\n",
1657 | "score_time \t [0.28071547 0.29240489 0.24401665 0.28297114]\n",
1658 | "test_accuracy \t [0.93351548 0.99430783 0.99681239 0.99590071]\n",
1659 | "test_f1_macro \t [0.94137472 0.99478775 0.99695729 0.99626914]\n",
1660 | "###################### \n",
1661 | "### By Acquisition ###\n",
1662 | "######################\n",
1663 | "fit_time \t [18.5324192 35.4609468 21.69827628 37.27756262]\n",
1664 | "score_time \t [0.25750256 0.28676701 0.25203919 0.29622126]\n",
1665 | "test_accuracy \t [0.9799356 0.96981802 0.98736982 0.99157615]\n",
1666 | "test_f1_macro \t [0.98284078 0.97219341 0.98804686 0.99082701]\n",
1667 | "################### \n",
1668 | "### By Severity ###\n",
1669 | "###################\n",
1670 | "fit_time \t [44.0867362 74.66411805 25.14348674 31.85565972]\n",
1671 | "score_time \t [0.24089551 0.19461489 0.34162545 0.29567337]\n",
1672 | "test_accuracy \t [0.66631523 0.38659612 0.80736196 0.56086375]\n",
1673 | "test_f1_macro \t [0.58519461 0.31894977 0.84225609 0.59022724]\n",
1674 | "@@@@@@@@@@@@@@@ \n",
1675 | "@@@ Round 2 @@@\n",
1676 | "@@@@@@@@@@@@@@@\n",
1677 | "********************************** \n",
1678 | "*** KNN - KNeighborsClassifier ***\n",
1679 | "**********************************\n",
1680 | "#################### \n",
1681 | "### Usual K-Fold ###\n",
1682 | "####################\n",
1683 | "fit_time \t [6.74547529 7.33138275 6.9889462 7.11015272]\n",
1684 | "score_time \t [2.7566607 2.67893839 2.88710332 2.7776835 ]\n",
1685 | "test_accuracy \t [0.98816029 0.97996357 0.98178506 0.98223639]\n",
1686 | "test_f1_macro \t [0.98891536 0.98106058 0.98281653 0.98321669]\n",
1687 | "###################### \n",
1688 | "### By Acquisition ###\n",
1689 | "######################\n",
1690 | "fit_time \t [7.60175681 6.84476066 7.37081504 7.13797212]\n",
1691 | "score_time \t [2.79434514 2.95569777 2.59766102 2.75889134]\n",
1692 | "test_accuracy \t [0.97159343 0.93984462 0.94951745 0.9641196 ]\n",
1693 | "test_f1_macro \t [0.97055871 0.93494208 0.95597614 0.96388501]\n",
1694 | "################### \n",
1695 | "### By Severity ###\n",
1696 | "###################\n",
1697 | "fit_time \t [6.38719606 7.26009154 8.30787253 6.39034128]\n",
1698 | "score_time \t [3.86299062 2.43005657 1.70710516 3.53779507]\n",
1699 | "test_accuracy \t [0.4035057 0.44491302 0.33368607 0.55455762]\n",
1700 | "test_f1_macro \t [0.46342948 0.37886062 0.25052798 0.62811191]\n",
1701 | "*********************************** \n",
1702 | "*** SVM - SVC with GridSearchCV ***\n",
1703 | "***********************************\n",
1704 | "#################### \n",
1705 | "### Usual K-Fold ###\n",
1706 | "####################\n",
1707 | "fit_time \t [19.724173069000244, 19.745147943496704, 19.873310327529907, 19.93139123916626]\n",
1708 | "score_time \t [2.5572855472564697, 2.4192516803741455, 2.5142109394073486, 2.5450191497802734]\n",
1709 | "test_accuracy \t [0.9888433515482696, 0.9847449908925319, 0.9842896174863388, 0.9886130721931223]\n",
1710 | "test_f1_macro \t [0.9894045941109493, 0.9857217535870164, 0.9853516908688675, 0.9893416987274131]\n",
1711 | "###################### \n",
1712 | "### By Acquisition ###\n",
1713 | "######################\n",
1714 | "fit_time \t [18.95182466506958, 18.980645656585693, 18.83344292640686, 19.036104679107666]\n",
1715 | "score_time \t [2.414243459701538, 2.6076626777648926, 2.332960367202759, 2.543945074081421]\n",
1716 | "test_accuracy \t [0.9731469152241455, 0.9718091009988902, 0.9819351645632269, 0.9718715393133998]\n",
1717 | "test_f1_macro \t [0.973038296824626, 0.9716621059136256, 0.9842450077949008, 0.9727993254399118]\n",
1718 | "################### \n",
1719 | "### By Severity ###\n",
1720 | "###################\n",
1721 | "fit_time \t [11.033088445663452, 11.847302198410034, 16.41715431213379, 12.386778116226196]\n",
1722 | "score_time \t [3.145768404006958, 2.163776397705078, 1.8239169120788574, 3.433562994003296]\n",
1723 | "test_accuracy \t [0.48255915863277826, 0.5071164997364259, 0.3333333333333333, 0.4750621058666157]\n",
1724 | "test_f1_macro \t [0.49597490715437265, 0.4353144040274436, 0.25, 0.5966747860549317]\n",
1725 | "*********************************** \n",
1726 | "*** RF - RandomForestClassifier ***\n",
1727 | "***********************************\n",
1728 | "#################### \n",
1729 | "### Usual K-Fold ###\n",
1730 | "####################\n",
1731 | "fit_time \t [10.44487453 10.59957457 10.42241383 10.69838881]\n",
1732 | "score_time \t [2.35236216 2.39362836 2.48229885 2.45215535]\n",
1733 | "test_accuracy \t [0.98861566 0.9879326 0.98770492 0.98656343]\n",
1734 | "test_f1_macro \t [0.98905118 0.98857401 0.9885108 0.98764181]\n",
1735 | "###################### \n",
1736 | "### By Acquisition ###\n",
1737 | "######################\n",
1738 | "fit_time \t [10.59389687 10.46544337 10.41703629 10.5358088 ]\n",
1739 | "score_time \t [2.51146913 2.39405012 2.18060732 2.73209858]\n",
1740 | "test_accuracy \t [0.97492233 0.94605993 0.98663697 0.96256921]\n",
1741 | "test_f1_macro \t [0.97484002 0.94414127 0.98853825 0.96386836]\n",
1742 | "################### \n",
1743 | "### By Severity ###\n",
1744 | "###################\n",
1745 | "fit_time \t [ 9.03128004 10.79873133 11.67754769 9.37620449]\n",
1746 | "score_time \t [3.00855851 2.03731585 1.45516348 2.93091846]\n",
1747 | "test_accuracy \t [0.55740578 0.43384291 0.33368607 0.59889165]\n",
1748 | "test_f1_macro \t [0.5833592 0.38408874 0.25052798 0.64727801]\n",
1749 | "********************************** \n",
1750 | "*** ANN - Convolutional Layers ***\n",
1751 | "**********************************\n",
1752 | "#################### \n",
1753 | "### Usual K-Fold ###\n",
1754 | "####################\n",
1755 | "fit_time \t [18.51238251 38.71945739 34.63079643 15.25693583]\n",
1756 | "score_time \t [0.26096463 0.27576256 0.26916194 0.24006557]\n",
1757 | "test_accuracy \t [0.98998179 0.99408015 0.99544627 0.98588021]\n",
1758 | "test_f1_macro \t [0.99040384 0.99462804 0.99576567 0.9860312 ]\n",
1759 | "###################### \n",
1760 | "### By Acquisition ###\n",
1761 | "######################\n",
1762 | "fit_time \t [60.96851325 28.83005214 48.82531905 22.15704131]\n",
1763 | "score_time \t [0.24486303 0.29043531 0.25695372 0.305897 ]\n",
1764 | "test_accuracy \t [0.98956946 0.97602664 0.98737936 0.96943522]\n",
1765 | "test_f1_macro \t [0.99003041 0.97660823 0.98929215 0.97055929]\n",
1766 | "################### \n",
1767 | "### By Severity ###\n",
1768 | "###################\n",
1769 | "fit_time \t [45.4839859 47.39291573 44.47006106 42.55098677]\n",
1770 | "score_time \t [0.31366587 0.28757834 0.19485497 0.28546786]\n",
1771 | "test_accuracy \t [0.5964943 0.71402214 0.40070547 0.79629276]\n",
1772 | "test_f1_macro \t [0.63582133 0.638561 0.33414097 0.83572249]\n",
1773 | "@@@@@@@@@@@@@@@ \n",
1774 | "@@@ Round 3 @@@\n",
1775 | "@@@@@@@@@@@@@@@\n",
1776 | "********************************** \n",
1777 | "*** KNN - KNeighborsClassifier ***\n",
1778 | "**********************************\n",
1779 | "#################### \n",
1780 | "### Usual K-Fold ###\n",
1781 | "####################\n",
1782 | "fit_time \t [6.78158879 7.24524999 7.16766906 6.97739434]\n",
1783 | "score_time \t [2.80046415 2.78269243 2.78030491 2.89765549]\n",
1784 | "test_accuracy \t [0.98542805 0.98201275 0.98064663 0.98747438]\n",
1785 | "test_f1_macro \t [0.98671034 0.98245946 0.98173025 0.98822911]\n",
1786 | "###################### \n",
1787 | "### By Acquisition ###\n",
1788 | "######################\n",
1789 | "fit_time \t [7.31945705 6.84883904 7.04814601 7.12005329]\n",
1790 | "score_time \t [2.73347282 2.85485291 2.54963541 2.60869932]\n",
1791 | "test_accuracy \t [0.91494559 0.95856415 0.97327394 0.97050998]\n",
1792 | "test_f1_macro \t [0.91379771 0.96057775 0.97510474 0.96973573]\n",
1793 | "################### \n",
1794 | "### By Severity ###\n",
1795 | "###################\n",
1796 | "fit_time \t [7.95710611 6.05790138 7.73230648 6.2113421 ]\n",
1797 | "score_time \t [1.48542976 3.87246108 2.30525112 3.90099192]\n",
1798 | "test_accuracy \t [0.20160609 0.40319186 0.44491302 0.59228747]\n",
1799 | "test_f1_macro \t [0.25052798 0.46323317 0.37886062 0.62872652]\n",
1800 | "*********************************** \n",
1801 | "*** SVM - SVC with GridSearchCV ***\n",
1802 | "***********************************\n",
1803 | "#################### \n",
1804 | "### Usual K-Fold ###\n",
1805 | "####################\n",
1806 | "fit_time \t [19.78970241546631, 20.034347772598267, 18.8914213180542, 19.659440517425537]\n",
1807 | "score_time \t [2.5892982482910156, 2.7418341636657715, 2.492760181427002, 2.530878782272339]\n",
1808 | "test_accuracy \t [0.985655737704918, 0.9854280510018215, 0.9831511839708561, 0.989296287861535]\n",
1809 | "test_f1_macro \t [0.9866728190171588, 0.9862340473534927, 0.9843512312779796, 0.9900168204996117]\n",
1810 | "###################### \n",
1811 | "### By Acquisition ###\n",
1812 | "######################\n",
1813 | "fit_time \t [19.65701723098755, 17.885798692703247, 19.174787044525146, 18.710833072662354]\n",
1814 | "score_time \t [2.474280595779419, 2.664071559906006, 2.2238361835479736, 2.491318702697754]\n",
1815 | "test_accuracy \t [0.9715745058849656, 0.95989364059384, 0.9755011135857461, 0.9809312638580931]\n",
1816 | "test_f1_macro \t [0.9705238821018404, 0.9612368138045917, 0.9776007988451276, 0.9808835530107685]\n",
1817 | "################### \n",
1818 | "### By Severity ###\n",
1819 | "###################\n",
1820 | "fit_time \t [16.430733919143677, 10.869746685028076, 11.728529214859009, 11.905654907226562]\n",
1821 | "score_time \t [1.3907830715179443, 3.1103897094726562, 2.5582427978515625, 3.69344425201416]\n",
1822 | "test_accuracy \t [0.20118343195266272, 0.48228691687127323, 0.5071164997364259, 0.5177914110429448]\n",
1823 | "test_f1_macro \t [0.25, 0.49593706609267424, 0.4353144040274436, 0.5967175448345157]\n",
1824 | "*********************************** \n",
1825 | "*** RF - RandomForestClassifier ***\n",
1826 | "***********************************\n",
1827 | "#################### \n",
1828 | "### Usual K-Fold ###\n",
1829 | "####################\n",
1830 | "fit_time \t [10.53787589 10.2915628 10.51413083 10.52403307]\n",
1831 | "score_time \t [2.34554863 2.51538396 2.38000703 2.48478985]\n",
1832 | "test_accuracy \t [0.98884335 0.98770492 0.98315118 0.98952403]\n",
1833 | "test_f1_macro \t [0.98951368 0.98830651 0.98414733 0.9902686 ]\n",
1834 | "###################### \n",
1835 | "### By Acquisition ###\n",
1836 | "######################\n",
1837 | "fit_time \t [10.13828254 10.56666279 10.40775037 10.29485846]\n",
1838 | "score_time \t [2.49524665 2.41962385 2.23708034 2.41174793]\n",
1839 | "test_accuracy \t [0.94603598 0.95413251 0.98267756 0.96430155]\n",
1840 | "test_f1_macro \t [0.94438752 0.95593309 0.9850004 0.96383986]\n",
1841 | "################### \n",
1842 | "### By Severity ###\n",
1843 | "###################\n",
1844 | "fit_time \t [12.50928974 8.8536787 10.62946248 9.24698877]\n",
1845 | "score_time \t [1.23483443 3.08676791 2.0546658 3.11678696]\n",
1846 | "test_accuracy \t [0.20245139 0.53998597 0.42698998 0.6312007 ]\n",
1847 | "test_f1_macro \t [0.25158061 0.56366944 0.37801829 0.64596422]\n",
1848 | "********************************** \n",
1849 | "*** ANN - Convolutional Layers ***\n",
1850 | "**********************************\n",
1851 | "#################### \n",
1852 | "### Usual K-Fold ###\n",
1853 | "####################\n",
1854 | "fit_time \t [28.18626499 28.47523022 12.85678458 32.59321523]\n",
1855 | "score_time \t [0.25498891 0.2511642 0.33200359 0.25324249]\n",
1856 | "test_accuracy \t [0.99248634 0.99385246 0.93715847 0.99316784]\n",
1857 | "test_f1_macro \t [0.9932205 0.99409611 0.93403695 0.99367418]\n",
1858 | "###################### \n",
1859 | "### By Acquisition ###\n",
1860 | "######################\n",
1861 | "fit_time \t [27.29787254 46.76934981 26.48193836 39.39669919]\n",
1862 | "score_time \t [0.2738874 0.2835393 0.24724627 0.27070951]\n",
1863 | "test_accuracy \t [0.9862314 0.95922889 0.95421925 0.98514412]\n",
1864 | "test_f1_macro \t [0.98634371 0.96265069 0.96120691 0.98445983]\n",
1865 | "################### \n",
1866 | "### By Severity ###\n",
1867 | "###################\n",
1868 | "fit_time \t [37.14484739 22.21058679 34.79163575 28.47509933]\n",
1869 | "score_time \t [0.63821244 0.32159853 0.25547314 0.30236888]\n",
1870 | "test_accuracy \t [0.33431953 0.49508944 0.68502899 0.82979842]\n",
1871 | "test_f1_macro \t [0.37509929 0.53362145 0.60254297 0.86029085]\n",
1872 | "@@@@@@@@@@@@@@@ \n",
1873 | "@@@ Round 4 @@@\n",
1874 | "@@@@@@@@@@@@@@@\n",
1875 | "********************************** \n",
1876 | "*** KNN - KNeighborsClassifier ***\n",
1877 | "**********************************\n",
1878 | "#################### \n",
1879 | "### Usual K-Fold ###\n",
1880 | "####################\n",
1881 | "fit_time \t [7.33682466 7.82628226 8.24803853 8.0868969 ]\n",
1882 | "score_time \t [3.09780622 3.13119817 3.04925036 3.29486322]\n",
1883 | "test_accuracy \t [0.98315118 0.98019126 0.98497268 0.98428604]\n",
1884 | "test_f1_macro \t [0.9842249 0.98173285 0.98593065 0.98452977]\n",
1885 | "###################### \n",
1886 | "### By Acquisition ###\n",
1887 | "######################\n",
1888 | "fit_time \t [8.42992783 8.95111179 7.82080698 7.99395847]\n",
1889 | "score_time \t [3.36267686 3.11097527 3.34500456 3.29506826]\n",
1890 | "test_accuracy \t [0.97317668 0.96682347 0.91905079 0.95363798]\n",
1891 | "test_f1_macro \t [0.97240717 0.96991399 0.91785393 0.95397016]\n",
1892 | "################### \n",
1893 | "### By Severity ###\n",
1894 | "###################\n",
1895 | "fit_time \t [6.89547205 9.04102612 7.35568476 9.10485649]\n",
1896 | "score_time \t [4.11352658 2.37339211 4.37783432 2.04778314]\n",
1897 | "test_accuracy \t [0.59154683 0.36631579 0.4035057 0.33368607]\n",
1898 | "test_f1_macro \t [0.62843038 0.3757482 0.46342948 0.25052798]\n",
1899 | "*********************************** \n",
1900 | "*** SVM - SVC with GridSearchCV ***\n",
1901 | "***********************************\n",
1902 | "#################### \n",
1903 | "### Usual K-Fold ###\n",
1904 | "####################\n",
1905 | "fit_time \t [23.347135543823242, 22.43061351776123, 22.15123987197876, 22.907995223999023]\n",
1906 | "score_time \t [2.816004514694214, 2.8565492630004883, 2.770206928253174, 3.04343318939209]\n",
1907 | "test_accuracy \t [0.9877049180327869, 0.9861111111111112, 0.9865664845173042, 0.9849692552949214]\n",
1908 | "test_f1_macro \t [0.9887125445867451, 0.9872882515313055, 0.987242023020228, 0.9855826041520989]\n",
1909 | "###################### \n",
1910 | "### By Acquisition ###\n",
1911 | "######################\n",
1912 | "fit_time \t [20.628316402435303, 20.755019187927246, 22.277571201324463, 21.142418384552002]\n",
1913 | "score_time \t [2.8379127979278564, 2.792255401611328, 3.4283406734466553, 2.702456474304199]\n",
1914 | "test_accuracy \t [0.9778319663045888, 0.9796979450359, 0.9773785761809713, 0.9660603371783496]\n",
1915 | "test_f1_macro \t [0.9779191989920553, 0.9825195830106137, 0.9761171817465767, 0.967098736717328]\n",
1916 | "################### \n",
1917 | "### By Severity ###\n",
1918 | "###################\n",
1919 | "fit_time \t [13.05809497833252, 13.786409854888916, 12.301506042480469, 18.482519388198853]\n",
1920 | "score_time \t [4.19962477684021, 2.222336530685425, 3.6370270252227783, 1.8126635551452637]\n",
1921 | "test_accuracy \t [0.5175377060680463, 0.4378947368421053, 0.48255915863277826, 0.3333333333333333]\n",
1922 | "test_f1_macro \t [0.5967175448345157, 0.4342974114125963, 0.49597490715437265, 0.25]\n",
1923 | "*********************************** \n",
1924 | "*** RF - RandomForestClassifier ***\n",
1925 | "***********************************\n",
1926 | "#################### \n",
1927 | "### Usual K-Fold ###\n",
1928 | "####################\n",
1929 | "fit_time \t [12.2456286 12.25842261 11.98340058 12.06807566]\n",
1930 | "score_time \t [2.94996428 2.70877385 2.85489178 2.9848876 ]\n",
1931 | "test_accuracy \t [0.98952641 0.98611111 0.98611111 0.98929629]\n",
1932 | "test_f1_macro \t [0.99028957 0.98709689 0.98699236 0.98962434]\n",
1933 | "###################### \n",
1934 | "### By Acquisition ###\n",
1935 | "######################\n",
1936 | "fit_time \t [12.61375833 12.56573796 11.90711355 11.94932938]\n",
1937 | "score_time \t [2.99247122 2.58260298 2.7625227 3.24548745]\n",
1938 | "test_accuracy \t [0.95012192 0.97301312 0.95076514 0.96983141]\n",
1939 | "test_f1_macro \t [0.95053768 0.97694082 0.94878309 0.97110356]\n",
1940 | "################### \n",
1941 | "### By Severity ###\n",
1942 | "###################\n",
1943 | "fit_time \t [10.4012599 12.43954158 9.98400068 13.11084819]\n",
1944 | "score_time \t [3.6818769 2.0685308 3.49031782 1.81863332]\n",
1945 | "test_accuracy \t [0.62627148 0.34345865 0.55530237 0.3340388 ]\n",
1946 | "test_f1_macro \t [0.64359854 0.37516027 0.5891077 0.25105485]\n",
1947 | "********************************** \n",
1948 | "*** ANN - Convolutional Layers ***\n",
1949 | "**********************************\n",
1950 | "#################### \n",
1951 | "### Usual K-Fold ###\n",
1952 | "####################\n",
1953 | "fit_time \t [20.75646663 19.90564919 19.69050145 19.89894247]\n",
1954 | "score_time \t [0.2843616 0.28579712 0.28508973 0.28671575]\n",
1955 | "test_accuracy \t [0.99066485 0.98816029 0.98019126 0.9740378 ]\n",
1956 | "test_f1_macro \t [0.99133974 0.98931809 0.98222924 0.97482592]\n",
1957 | "###################### \n",
1958 | "### By Acquisition ###\n",
1959 | "######################\n",
1960 | "fit_time \t [32.8947382 56.9355247 27.15734267 25.78575087]\n",
1961 | "score_time \t [0.29065132 0.25645137 0.2865026 0.28651834]\n",
1962 | "test_accuracy \t [0.98293061 0.99083932 0.98935462 0.9784827 ]\n",
1963 | "test_f1_macro \t [0.98399107 0.99227822 0.98875738 0.97939376]\n",
1964 | "################### \n",
1965 | "### By Severity ###\n",
1966 | "###################\n",
1967 | "fit_time \t [24.2693305 20.64741468 52.46021819 34.91594076]\n",
1968 | "score_time \t [0.38094687 0.2325995 0.36206555 0.21522617]\n",
1969 | "test_accuracy \t [0.82427219 0.60240602 0.5363716 0.37707231]\n",
1970 | "test_f1_macro \t [0.85629104 0.57016226 0.57061501 0.30805243]\n"
1971 | ],
1972 | "name": "stdout"
1973 | }
1974 | ]
1975 | },
1976 | {
1977 | "cell_type": "markdown",
1978 | "metadata": {
1979 | "id": "QJ-qe0MIhM-z",
1980 | "colab_type": "text"
1981 | },
1982 | "source": [
1983 | "##Save results"
1984 | ]
1985 | },
1986 | {
1987 | "cell_type": "code",
1988 | "metadata": {
1989 | "id": "qrp8uvOonKpd",
1990 | "colab_type": "code",
1991 | "outputId": "9846470e-ddf5-4480-84ed-695d505a461f",
1992 | "colab": {
1993 | "base_uri": "https://localhost:8080/",
1994 | "height": 121
1995 | }
1996 | },
1997 | "source": [
1998 | "clf = {}\n",
1999 | "val = {}\n",
2000 | "src = {}\n",
2001 | "for c, clf_name in enumerate(scores.keys()):\n",
2002 | " if c not in clf:\n",
2003 | " clf[c] = clf_name\n",
2004 | " for v, val_name in enumerate(scores[clf_name].keys()):\n",
2005 | " if v not in val:\n",
2006 | " val[v] = val_name\n",
2007 | " for s, scr_name in enumerate(scores[clf_name][val_name].keys()):\n",
2008 | " scores[clf_name][val_name][scr_name] = np.array(scores[clf_name][val_name][scr_name])\n",
2009 | " if s not in src:\n",
2010 | " src[s] = scr_name\n",
2011 | " np.savetxt('{}-{}-{}.txt'.format(clf_name,val_name,scr_name), \n",
2012 | " scores[clf_name][val_name][scr_name], delimiter=',')\n",
2013 | "clf, val, src"
2014 | ],
2015 | "execution_count": 264,
2016 | "outputs": [
2017 | {
2018 | "output_type": "execute_result",
2019 | "data": {
2020 | "text/plain": [
2021 | "({0: 'KNN - KNeighborsClassifier',\n",
2022 | " 1: 'SVM - SVC with GridSearchCV',\n",
2023 | " 2: 'RF - RandomForestClassifier',\n",
2024 | " 3: 'ANN - Convolutional Layers'},\n",
2025 | " {0: 'Usual K-Fold', 1: 'By Acquisition', 2: 'By Severity'},\n",
2026 | " {0: 'fit_time', 1: 'score_time', 2: 'test_accuracy', 3: 'test_f1_macro'})"
2027 | ]
2028 | },
2029 | "metadata": {
2030 | "tags": []
2031 | },
2032 | "execution_count": 264
2033 | }
2034 | ]
2035 | },
2036 | {
2037 | "cell_type": "markdown",
2038 | "metadata": {
2039 | "id": "rkb8XMN-Ht58",
2040 | "colab_type": "text"
2041 | },
2042 | "source": [
2043 | "##Average & Standard Deviation"
2044 | ]
2045 | },
2046 | {
2047 | "cell_type": "code",
2048 | "metadata": {
2049 | "id": "hJxdjboqtuNb",
2050 | "colab_type": "code",
2051 | "outputId": "def599b3-d563-45ce-fba6-652d417aac23",
2052 | "colab": {
2053 | "base_uri": "https://localhost:8080/",
2054 | "height": 503
2055 | }
2056 | },
2057 | "source": [
2058 | "c,v,s = len(clf),len(val),len(src)\n",
2059 | "for i in range(s):\n",
2060 | " print(src[i])\n",
2061 | " for k in range(v):\n",
2062 | " print('\\t'+val[k]+' ', end='')\n",
2063 | " print()\n",
2064 | " for j in range(c):\n",
2065 | " print(clf[j].split('-')[0], end='\\t')\n",
2066 | " for k in range(v):\n",
2067 | " print(\"{0:.3f} ({1:.3f})\".format(\n",
2068 | " scores[clf[j]][val[k]][src[i]].mean(),\n",
2069 | " scores[clf[j]][val[k]][src[i]].std()), end='\\t')\n",
2070 | " print()\n",
2071 | " print()"
2072 | ],
2073 | "execution_count": 265,
2074 | "outputs": [
2075 | {
2076 | "output_type": "stream",
2077 | "text": [
2078 | "fit_time\n",
2079 | "\tUsual K-Fold \tBy Acquisition \tBy Severity \n",
2080 | "KNN \t7.363 (0.420)\t7.605 (0.629)\t7.366 (0.917)\t\n",
2081 | "SVM \t20.544 (1.305)\t19.574 (1.065)\t13.303 (2.234)\t\n",
2082 | "RF \t10.919 (0.713)\t10.953 (0.795)\t10.682 (1.327)\t\n",
2083 | "ANN \t27.909 (11.104)\t34.779 (12.357)\t38.160 (13.346)\t\n",
2084 | "\n",
2085 | "score_time\n",
2086 | "\tUsual K-Fold \tBy Acquisition \tBy Severity \n",
2087 | "KNN \t2.917 (0.159)\t2.919 (0.247)\t2.993 (0.952)\t\n",
2088 | "SVM \t2.636 (0.162)\t2.615 (0.265)\t2.737 (0.839)\t\n",
2089 | "RF \t2.534 (0.216)\t2.566 (0.265)\t2.485 (0.744)\t\n",
2090 | "ANN \t0.274 (0.022)\t0.274 (0.019)\t0.304 (0.102)\t\n",
2091 | "\n",
2092 | "test_accuracy\n",
2093 | "\tUsual K-Fold \tBy Acquisition \tBy Severity \n",
2094 | "KNN \t0.983 (0.002)\t0.955 (0.018)\t0.425 (0.108)\t\n",
2095 | "SVM \t0.986 (0.002)\t0.975 (0.006)\t0.442 (0.090)\t\n",
2096 | "RF \t0.988 (0.002)\t0.964 (0.015)\t0.468 (0.128)\t\n",
2097 | "ANN \t0.983 (0.019)\t0.980 (0.011)\t0.601 (0.164)\t\n",
2098 | "\n",
2099 | "test_f1_macro\n",
2100 | "\tUsual K-Fold \tBy Acquisition \tBy Severity \n",
2101 | "KNN \t0.984 (0.002)\t0.955 (0.019)\t0.430 (0.137)\t\n",
2102 | "SVM \t0.987 (0.002)\t0.975 (0.006)\t0.444 (0.126)\t\n",
2103 | "RF \t0.988 (0.002)\t0.965 (0.016)\t0.464 (0.157)\t\n",
2104 | "ANN \t0.984 (0.019)\t0.981 (0.010)\t0.591 (0.184)\t\n",
2105 | "\n"
2106 | ],
2107 | "name": "stdout"
2108 | }
2109 | ]
2110 | },
2111 | {
2112 | "cell_type": "markdown",
2113 | "metadata": {
2114 | "id": "hKh0QQncJpzF",
2115 | "colab_type": "text"
2116 | },
2117 | "source": [
2118 | "##Average of diferences"
2119 | ]
2120 | },
2121 | {
2122 | "cell_type": "code",
2123 | "metadata": {
2124 | "id": "4Pejb-yt_7aW",
2125 | "colab_type": "code",
2126 | "outputId": "9f59da4c-8620-426a-eb02-e6f7f918bfb3",
2127 | "colab": {
2128 | "base_uri": "https://localhost:8080/",
2129 | "height": 1000
2130 | }
2131 | },
2132 | "source": [
2133 | "c,v,s = len(clf),len(val),len(src)\n",
2134 | "compclf = np.zeros((s,v,c,c))\n",
2135 | "for i in range(s):\n",
2136 | " print('*'*3, src[i], '*'*3)\n",
2137 | " for j in range(v):\n",
2138 | " print(val[j])\n",
2139 | " for k in range(c):\n",
2140 | " print(' '+clf[k].split('-')[0],end=' ')\n",
2141 | " print()\n",
2142 | " for k in range(c):\n",
2143 | " for l in range(k):\n",
2144 | " diff = scores[clf[k]][val[j]][src[i]]-scores[clf[l]][val[j]][src[i]]\n",
2145 | " compclf[i][j][l][k] = diff.mean()\n",
2146 | " compclf[i][j][k][l] = diff.std()\n",
2147 | " print(compclf[i][j])"
2148 | ],
2149 | "execution_count": 266,
2150 | "outputs": [
2151 | {
2152 | "output_type": "stream",
2153 | "text": [
2154 | "*** fit_time ***\n",
2155 | "Usual K-Fold\n",
2156 | " KNN SVM RF ANN \n",
2157 | "[[ 0. 13.18034023 3.55572626 20.54602221]\n",
2158 | " [ 1.0788785 0. -9.62461397 7.36568198]\n",
2159 | " [ 0.53641173 0.66331547 0. 16.99029595]\n",
2160 | " [11.12176592 11.50298509 11.41326661 0. ]]\n",
2161 | "By Acquisition\n",
2162 | " KNN SVM RF ANN \n",
2163 | "[[ 0. 11.96921574 3.34769788 27.17411487]\n",
2164 | " [ 0.83696112 0. -8.62151785 15.20489913]\n",
2165 | " [ 0.49760762 0.64829297 0. 23.82641698]\n",
2166 | " [12.28457824 12.6105018 12.31768712 0. ]]\n",
2167 | "By Severity\n",
2168 | " KNN SVM RF ANN \n",
2169 | "[[ 0. 5.93665481 3.31536001 30.79401779]\n",
2170 | " [ 1.66633661 0. -2.6212948 24.85736299]\n",
2171 | " [ 0.62381463 1.31881945 0. 27.47865778]\n",
2172 | " [13.28103369 12.88261266 13.14789378 0. ]]\n",
2173 | "*** score_time ***\n",
2174 | "Usual K-Fold\n",
2175 | " KNN SVM RF ANN \n",
2176 | "[[ 0. -0.28085738 -0.38269013 -2.64316766]\n",
2177 | " [ 0.0881125 0. -0.10183275 -2.36231028]\n",
2178 | " [ 0.12266843 0.12437923 0. -2.26047753]\n",
2179 | " [ 0.15559155 0.16260064 0.21143513 0. ]]\n",
2180 | "By Acquisition\n",
2181 | " KNN SVM RF ANN \n",
2182 | "[[ 0. -0.30325378 -0.35285439 -2.64438325]\n",
2183 | " [ 0.16228367 0. -0.0496006 -2.34112947]\n",
2184 | " [ 0.16754303 0.25355818 0. -2.29152887]\n",
2185 | " [ 0.24034233 0.25769257 0.25571127 0. ]]\n",
2186 | "By Severity\n",
2187 | " KNN SVM RF ANN \n",
2188 | "[[ 0. -0.25550072 -0.50711805 -2.68868583]\n",
2189 | " [ 0.29657587 0. -0.25161733 -2.43318512]\n",
2190 | " [ 0.25009033 0.20421526 0. -2.18156779]\n",
2191 | " [ 0.94661453 0.83494238 0.74165047 0. ]]\n",
2192 | "*** test_accuracy ***\n",
2193 | "Usual K-Fold\n",
2194 | " KNN SVM RF ANN \n",
2195 | "[[ 0. 0.00287467 0.00425508 0.00017091]\n",
2196 | " [ 0.00194861 0. 0.00138041 -0.00270376]\n",
2197 | " [ 0.0021324 0.0016528 0. -0.00408417]\n",
2198 | " [ 0.01851219 0.01831394 0.01805251 0. ]]\n",
2199 | "By Acquisition\n",
2200 | " KNN SVM RF ANN \n",
2201 | "[[ 0. 0.02003495 0.00940557 0.02521791]\n",
2202 | " [ 0.01846429 0. -0.01062938 0.00518295]\n",
2203 | " [ 0.01649239 0.01373783 0. 0.01581234]\n",
2204 | " [ 0.02379938 0.01027753 0.02076568 0. ]]\n",
2205 | "By Severity\n",
2206 | " KNN SVM RF ANN \n",
2207 | "[[0. 0.01737374 0.0436735 0.17616904]\n",
2208 | " [0.06176932 0. 0.02629977 0.1587953 ]\n",
2209 | " [0.06545448 0.07410789 0. 0.13249554]\n",
2210 | " [0.07426634 0.10014654 0.10242866 0. ]]\n",
2211 | "*** test_f1_macro ***\n",
2212 | "Usual K-Fold\n",
2213 | " KNN SVM RF ANN \n",
2214 | "[[ 0.00000000e+00 2.80626148e-03 4.07801969e-03 6.75036343e-05]\n",
2215 | " [ 1.89606647e-03 0.00000000e+00 1.27175821e-03 -2.73875784e-03]\n",
2216 | " [ 2.06854552e-03 1.51490424e-03 0.00000000e+00 -4.01051605e-03]\n",
2217 | " [ 1.79745829e-02 1.78024950e-02 1.75102946e-02 0.00000000e+00]]\n",
2218 | "By Acquisition\n",
2219 | " KNN SVM RF ANN \n",
2220 | "[[ 0. 0.02051354 0.00976857 0.0263644 ]\n",
2221 | " [ 0.01859959 0. -0.01074498 0.00585086]\n",
2222 | " [ 0.01549432 0.01419552 0. 0.01659584]\n",
2223 | " [ 0.02352099 0.00933817 0.02048926 0. ]]\n",
2224 | "By Severity\n",
2225 | " KNN SVM RF ANN \n",
2226 | "[[0. 0.01430178 0.0335664 0.1609704 ]\n",
2227 | " [0.03353973 0. 0.01926462 0.14666862]\n",
2228 | " [0.04771923 0.05220928 0. 0.127404 ]\n",
2229 | " [0.0662515 0.07356697 0.09191026 0. ]]\n"
2230 | ],
2231 | "name": "stdout"
2232 | }
2233 | ]
2234 | }
2235 | ]
2236 | }
--------------------------------------------------------------------------------
/KFold_RF_CWRU_EvaluationFramework.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "KFold RF CWRU-EvaluationFramework.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true,
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "llMoY5AQGXQY",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "The code is separated into four sections.\n",
37 | "\n",
38 | "* In section \"CWRU dataset\" the CWRU Matlab files are downloaded, the acquisitions are extracted from the Matlab files into Numpy arrays. Then, the data is segmented and the samples are selected. The samples are selected by their labels using regular expressions. \n",
39 | "\n",
40 | "* Section \"Experimenter\" defines the splitters mentioned in this work, i.e. GroupShuffleKFold and BySeverityKFold, and the setup of each experiment. The samples are grouped by their original labels also using regular expressions. For instance, to group samples by load, the regular expression '_\\d' may be used. A list of evaluation methods is defined in this section. \n",
41 | "\n",
42 | "* Section \"Classification Models\" defines the estimators and their feature extraction methods. To instantiate a classification method with feature extraction, a Pipeline must be made. A list of classification methods is defined in this section. \n",
43 | "\n",
44 | "* Finally, section \"Performing Experiments\" executes the experiments as they were defined in the previous sections, showing and saving their results. It iterates one list of classification methods and one list of evaluation methods in $r$ rounds. In this work, four classification methods were tested by three evaluation methods in four rounds, resulting in $4\\times 3\\times 4 = 48$ experiments. New classification or evaluation methods can be tested by adding them in their respective list.\n",
45 | "\n",
46 | "The code can be executed direct in the Colab environment when few samples and simple classifier methods are tested. For the experiments presented in this work, it must be run in a local GPU, with enough memory and processing capacity.\n",
47 | "The results are presented for each round of each experiment of each classification method. The average and standard deviation of the rounds is also presented, as well, the average of the differences among the classification methods.\n",
48 | "\n",
49 | "New feature extraction methods must receive a 3-D Numpy array and returns a 2-D Numpy array. This is necessary because the same raw dataset is used for methods that need feature extraction of the signal acquisitions, like K-NN, SVM and Random Forest, and convolutional neural networks that deal with 3-D arrays. The experiments presented here used just one channel of each acquisition, but newer experiments may use more channels.\n",
50 | "\n",
51 | "New classifications methods that need feature extraction may be easily added. It is only required a Pipeline with the feature extraction method and the classifier, like those presented in the code. A new neural network base architecture must be wrapped in a scikit-learn estimator, and its definition must be in the method *fit*. It cannot be defined in the method *\\_\\_init\\_\\_* due to the implementation of the Keras library. If the network architecture is defined in the method *\\_\\_init\\_\\_*, it will remember the samples across the folds and the rounds, giving outstanding results, that will be never replicated in a real scenario. It is worth to highlight that some parameters of the network, like kernel size and number of filters, should be selected by a GridSearchCV method to provide fairer results when compared with other methods."
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {
57 | "id": "sm5t8TYBqkDu",
58 | "colab_type": "text"
59 | },
60 | "source": [
61 | "#CWRU database"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "metadata": {
67 | "id": "bSSOMru17Z6c",
68 | "colab_type": "code",
69 | "colab": {}
70 | },
71 | "source": [
72 | "debug = False"
73 | ],
74 | "execution_count": 1,
75 | "outputs": []
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {
80 | "id": "pMQoq6dvStey",
81 | "colab_type": "text"
82 | },
83 | "source": [
84 | "## CWRU files.\n",
85 | "\n",
86 | "Associate each Matlab file name to a bearing condition in a Python dictionary.\n",
87 | "The dictionary keys identify the conditions.\n",
88 | "\n",
89 | "There are only four normal conditions, with loads of 0, 1, 2 and 3 hp.\n",
90 | "All conditions end with an underscore character followed by an algarism representing the load applied during the acquisitions.\n",
91 | "The remaining conditions follow the pattern:\n",
92 | "\n",
93 | "\n",
94 | "* First two characters represent the bearing location, i.e. drive end (DE) and fan end (FE).\n",
95 | "* The following two characters represent the failure location in the bearing, i.e. ball (BA), Inner Race (IR) and Outer Race (OR).\n",
96 | "* The next three algarisms indicate the severity of the failure, where 007 stands for 0.007 inches and 0021 for 0.021 inches.\n",
97 | "* For Outer Race failures, the character @ is followed by a number that indicates different load zones. "
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "metadata": {
103 | "id": "K6mp2QrP1lmR",
104 | "colab_type": "code",
105 | "colab": {}
106 | },
107 | "source": [
108 | "def cwru_12khz():\n",
109 | " '''\n",
110 | " Retuns a dictionary with the names of all Matlab files read in 12kHz located in\n",
111 | " http://csegroups.case.edu/sites/default/files/bearingdatacenter/files/Datafiles/.\n",
112 | " The dictionary keys represent the bearing condition.\n",
113 | " '''\n",
114 | " matlab_files_name = {}\n",
115 | " # Normal\n",
116 | " matlab_files_name[\"Normal_0\"] = \"97.mat\"\n",
117 | " matlab_files_name[\"Normal_1\"] = \"98.mat\"\n",
118 | " matlab_files_name[\"Normal_2\"] = \"99.mat\"\n",
119 | " matlab_files_name[\"Normal_3\"] = \"100.mat\"\n",
120 | " # DE Inner Race 0.007 inches\n",
121 | " matlab_files_name[\"DEIR.007_0\"] = \"105.mat\"\n",
122 | " matlab_files_name[\"DEIR.007_1\"] = \"106.mat\"\n",
123 | " matlab_files_name[\"DEIR.007_2\"] = \"107.mat\"\n",
124 | " matlab_files_name[\"DEIR.007_3\"] = \"108.mat\"\n",
125 | " # DE Ball 0.007 inches\n",
126 | " matlab_files_name[\"DEB.007_0\"] = \"118.mat\"\n",
127 | " matlab_files_name[\"DEB.007_1\"] = \"119.mat\"\n",
128 | " matlab_files_name[\"DEB.007_2\"] = \"120.mat\"\n",
129 | " matlab_files_name[\"DEB.007_3\"] = \"121.mat\"\n",
130 | " # DE Outer race 0.007 inches centered @6:00\n",
131 | " matlab_files_name[\"DEOR.007@6_0\"] = \"130.mat\"\n",
132 | " matlab_files_name[\"DEOR.007@6_1\"] = \"131.mat\"\n",
133 | " matlab_files_name[\"DEOR.007@6_2\"] = \"132.mat\"\n",
134 | " matlab_files_name[\"DEOR.007@6_3\"] = \"133.mat\"\n",
135 | " # DE Outer race 0.007 inches centered @3:00\n",
136 | " matlab_files_name[\"DEOR.007@3_0\"] = \"144.mat\"\n",
137 | " matlab_files_name[\"DEOR.007@3_1\"] = \"145.mat\"\n",
138 | " matlab_files_name[\"DEOR.007@3_2\"] = \"146.mat\"\n",
139 | " matlab_files_name[\"DEOR.007@3_3\"] = \"147.mat\"\n",
140 | " # DE Outer race 0.007 inches centered @12:00\n",
141 | " matlab_files_name[\"DEOR.007@12_0\"] = \"156.mat\"\n",
142 | " matlab_files_name[\"DEOR.007@12_1\"] = \"158.mat\"\n",
143 | " matlab_files_name[\"DEOR.007@12_2\"] = \"159.mat\"\n",
144 | " matlab_files_name[\"DEOR.007@12_3\"] = \"160.mat\"\n",
145 | " # DE Inner Race 0.014 inches\n",
146 | " matlab_files_name[\"DEIR.014_0\"] = \"169.mat\"\n",
147 | " matlab_files_name[\"DEIR.014_1\"] = \"170.mat\"\n",
148 | " matlab_files_name[\"DEIR.014_2\"] = \"171.mat\"\n",
149 | " matlab_files_name[\"DEIR.014_3\"] = \"172.mat\"\n",
150 | " # DE Ball 0.014 inches\n",
151 | " matlab_files_name[\"DEB.014_0\"] = \"185.mat\"\n",
152 | " matlab_files_name[\"DEB.014_1\"] = \"186.mat\"\n",
153 | " matlab_files_name[\"DEB.014_2\"] = \"187.mat\"\n",
154 | " matlab_files_name[\"DEB.014_3\"] = \"188.mat\"\n",
155 | " # DE Outer race 0.014 inches centered @6:00\n",
156 | " matlab_files_name[\"DEOR.014@6_0\"] = \"197.mat\"\n",
157 | " matlab_files_name[\"DEOR.014@6_1\"] = \"198.mat\"\n",
158 | " matlab_files_name[\"DEOR.014@6_2\"] = \"199.mat\"\n",
159 | " matlab_files_name[\"DEOR.014@6_3\"] = \"200.mat\"\n",
160 | " # DE Ball 0.021 inches\n",
161 | " matlab_files_name[\"DEB.021_0\"] = \"222.mat\"\n",
162 | " matlab_files_name[\"DEB.021_1\"] = \"223.mat\"\n",
163 | " matlab_files_name[\"DEB.021_2\"] = \"224.mat\"\n",
164 | " matlab_files_name[\"DEB.021_3\"] = \"225.mat\"\n",
165 | " # FE Inner Race 0.021 inches\n",
166 | " matlab_files_name[\"FEIR.021_0\"] = \"270.mat\"\n",
167 | " matlab_files_name[\"FEIR.021_1\"] = \"271.mat\"\n",
168 | " matlab_files_name[\"FEIR.021_2\"] = \"272.mat\"\n",
169 | " matlab_files_name[\"FEIR.021_3\"] = \"273.mat\"\n",
170 | " # FE Inner Race 0.014 inches\n",
171 | " matlab_files_name[\"FEIR.014_0\"] = \"274.mat\"\n",
172 | " matlab_files_name[\"FEIR.014_1\"] = \"275.mat\"\n",
173 | " matlab_files_name[\"FEIR.014_2\"] = \"276.mat\"\n",
174 | " matlab_files_name[\"FEIR.014_3\"] = \"277.mat\"\n",
175 | " # FE Ball 0.007 inches\n",
176 | " matlab_files_name[\"FEB.007_0\"] = \"282.mat\"\n",
177 | " matlab_files_name[\"FEB.007_1\"] = \"283.mat\"\n",
178 | " matlab_files_name[\"FEB.007_2\"] = \"284.mat\"\n",
179 | " matlab_files_name[\"FEB.007_3\"] = \"285.mat\"\n",
180 | " # DE Inner Race 0.021 inches\n",
181 | " matlab_files_name[\"DEIR.021_0\"] = \"209.mat\"\n",
182 | " matlab_files_name[\"DEIR.021_1\"] = \"210.mat\"\n",
183 | " matlab_files_name[\"DEIR.021_2\"] = \"211.mat\"\n",
184 | " matlab_files_name[\"DEIR.021_3\"] = \"212.mat\"\n",
185 | " # DE Outer race 0.021 inches centered @6:00\n",
186 | " matlab_files_name[\"DEOR.021@6_0\"] = \"234.mat\"\n",
187 | " matlab_files_name[\"DEOR.021@6_1\"] = \"235.mat\"\n",
188 | " matlab_files_name[\"DEOR.021@6_2\"] = \"236.mat\"\n",
189 | " matlab_files_name[\"DEOR.021@6_3\"] = \"237.mat\"\n",
190 | " # DE Outer race 0.021 inches centered @3:00\n",
191 | " matlab_files_name[\"DEOR.021@3_0\"] = \"246.mat\"\n",
192 | " matlab_files_name[\"DEOR.021@3_1\"] = \"247.mat\"\n",
193 | " matlab_files_name[\"DEOR.021@3_2\"] = \"248.mat\"\n",
194 | " matlab_files_name[\"DEOR.021@3_3\"] = \"249.mat\"\n",
195 | " # DE Outer race 0.021 inches centered @12:00\n",
196 | " matlab_files_name[\"DEOR.021@12_0\"] = \"258.mat\"\n",
197 | " matlab_files_name[\"DEOR.021@12_1\"] = \"259.mat\"\n",
198 | " matlab_files_name[\"DEOR.021@12_2\"] = \"260.mat\"\n",
199 | " matlab_files_name[\"DEOR.021@12_3\"] = \"261.mat\"\n",
200 | " # FE Inner Race 0.007 inches\n",
201 | " matlab_files_name[\"FEIR.007_0\"] = \"278.mat\"\n",
202 | " matlab_files_name[\"FEIR.007_1\"] = \"279.mat\"\n",
203 | " matlab_files_name[\"FEIR.007_2\"] = \"280.mat\"\n",
204 | " matlab_files_name[\"FEIR.007_3\"] = \"281.mat\"\n",
205 | " # FE Ball 0.014 inches\n",
206 | " matlab_files_name[\"FEB.014_0\"] = \"286.mat\"\n",
207 | " matlab_files_name[\"FEB.014_1\"] = \"287.mat\"\n",
208 | " matlab_files_name[\"FEB.014_2\"] = \"288.mat\"\n",
209 | " matlab_files_name[\"FEB.014_3\"] = \"289.mat\"\n",
210 | " # FE Ball 0.021 inches\n",
211 | " matlab_files_name[\"FEB.021_0\"] = \"290.mat\"\n",
212 | " matlab_files_name[\"FEB.021_1\"] = \"291.mat\"\n",
213 | " matlab_files_name[\"FEB.021_2\"] = \"292.mat\"\n",
214 | " matlab_files_name[\"FEB.021_3\"] = \"293.mat\"\n",
215 | " # FE Outer race 0.007 inches centered @6:00\n",
216 | " matlab_files_name[\"FEOR.007@6_0\"] = \"294.mat\"\n",
217 | " matlab_files_name[\"FEOR.007@6_1\"] = \"295.mat\"\n",
218 | " matlab_files_name[\"FEOR.007@6_2\"] = \"296.mat\"\n",
219 | " matlab_files_name[\"FEOR.007@6_3\"] = \"297.mat\"\n",
220 | " # FE Outer race 0.007 inches centered @3:00\n",
221 | " matlab_files_name[\"FEOR.007@3_0\"] = \"298.mat\"\n",
222 | " matlab_files_name[\"FEOR.007@3_1\"] = \"299.mat\"\n",
223 | " matlab_files_name[\"FEOR.007@3_2\"] = \"300.mat\"\n",
224 | " matlab_files_name[\"FEOR.007@3_3\"] = \"301.mat\"\n",
225 | " # FE Outer race 0.007 inches centered @12:00\n",
226 | " matlab_files_name[\"FEOR.007@12_0\"] = \"302.mat\"\n",
227 | " matlab_files_name[\"FEOR.007@12_1\"] = \"305.mat\"\n",
228 | " matlab_files_name[\"FEOR.007@12_2\"] = \"306.mat\"\n",
229 | " matlab_files_name[\"FEOR.007@12_3\"] = \"307.mat\"\n",
230 | " # FE Outer race 0.014 inches centered @3:00\n",
231 | " matlab_files_name[\"FEOR.014@3_0\"] = \"310.mat\"\n",
232 | " matlab_files_name[\"FEOR.014@3_1\"] = \"309.mat\"\n",
233 | " matlab_files_name[\"FEOR.014@3_2\"] = \"311.mat\"\n",
234 | " matlab_files_name[\"FEOR.014@3_3\"] = \"312.mat\"\n",
235 | " # FE Outer race 0.014 inches centered @6:00\n",
236 | " matlab_files_name[\"FEOR.014@6_0\"] = \"313.mat\"\n",
237 | " # FE Outer race 0.021 inches centered @6:00\n",
238 | " matlab_files_name[\"FEOR.021@6_0\"] = \"315.mat\"\n",
239 | " # FE Outer race 0.021 inches centered @3:00\n",
240 | " matlab_files_name[\"FEOR.021@3_1\"] = \"316.mat\"\n",
241 | " matlab_files_name[\"FEOR.021@3_2\"] = \"317.mat\"\n",
242 | " matlab_files_name[\"FEOR.021@3_3\"] = \"318.mat\"\n",
243 | " # DE Inner Race 0.028 inches\n",
244 | " matlab_files_name[\"DEIR.028_0\"] = \"3001.mat\"\n",
245 | " matlab_files_name[\"DEIR.028_1\"] = \"3002.mat\"\n",
246 | " matlab_files_name[\"DEIR.028_2\"] = \"3003.mat\"\n",
247 | " matlab_files_name[\"DEIR.028_3\"] = \"3004.mat\"\n",
248 | " # DE Ball 0.028 inches\n",
249 | " matlab_files_name[\"DEB.028_0\"] = \"3005.mat\"\n",
250 | " matlab_files_name[\"DEB.028_1\"] = \"3006.mat\"\n",
251 | " matlab_files_name[\"DEB.028_2\"] = \"3007.mat\"\n",
252 | " matlab_files_name[\"DEB.028_3\"] = \"3008.mat\"\n",
253 | " return matlab_files_name\n",
254 | "\n",
255 | "def files_debug():\n",
256 | " \"\"\"\n",
257 | " Associate each Matlab file name to a bearing condition in a Python dictionary. \n",
258 | " The dictionary keys identify the conditions.\n",
259 | " \n",
260 | " NOTE: Used only for debug.\n",
261 | " \"\"\"\n",
262 | " matlab_files_name = {}\n",
263 | " # Normal\n",
264 | " matlab_files_name[\"Normal_0\"] = \"97.mat\"\n",
265 | " matlab_files_name[\"Normal_1\"] = \"98.mat\"\n",
266 | " matlab_files_name[\"Normal_2\"] = \"99.mat\"\n",
267 | " matlab_files_name[\"Normal_3\"] = \"100.mat\"\n",
268 | " # FE Inner Race 0.007 inches\n",
269 | " matlab_files_name[\"FEIR.007_2\"] = \"280.mat\"\n",
270 | " # DE Outer race 0.014 inches centered @6:00\n",
271 | " matlab_files_name[\"DEOR.014@6_1\"] = \"198.mat\"\n",
272 | " # FE Outer race 0.021 inches centered @6:00\n",
273 | " matlab_files_name[\"FEOR.021@6_0\"] = \"315.mat\"\n",
274 | " # DE Ball 0.028 inches\n",
275 | " matlab_files_name[\"DEB.028_3\"] = \"3008.mat\"\n",
276 | " return matlab_files_name"
277 | ],
278 | "execution_count": 2,
279 | "outputs": []
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {
284 | "id": "K9y9byVeSz_u",
285 | "colab_type": "text"
286 | },
287 | "source": [
288 | "##Download Matlab files\n",
289 | "Downloads the Matlab files in the dictionary matlab_files_name."
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "metadata": {
295 | "id": "wPSGH1401-W2",
296 | "colab_type": "code",
297 | "colab": {}
298 | },
299 | "source": [
300 | "import urllib.request\n",
301 | "import os.path\n",
302 | "\n",
303 | "def download_cwrufiles(matlab_files_name):\n",
304 | " '''\n",
305 | " Downloads the Matlab files in the dictionary matlab_files_name.\n",
306 | " '''\n",
307 | " url=\"http://csegroups.case.edu/sites/default/files/bearingdatacenter/files/Datafiles/\"\n",
308 | " n = len(matlab_files_name)\n",
309 | " for i,key in enumerate(matlab_files_name):\n",
310 | " file_name = matlab_files_name[key]\n",
311 | " if not os.path.exists(file_name):\n",
312 | " urllib.request.urlretrieve(url+file_name, file_name)\n",
313 | " print(\"{}/{}\\t{}\\t{}\".format(i+1, n, key, file_name))\n"
314 | ],
315 | "execution_count": 3,
316 | "outputs": []
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "metadata": {
321 | "id": "FRijKbOjS-JZ",
322 | "colab_type": "text"
323 | },
324 | "source": [
325 | "##Extract data from Matlab files\n",
326 | "Extracts the acquisitions of each Matlab file in the dictionary matlab_files_name."
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "metadata": {
332 | "id": "BbpFkSI12CUe",
333 | "colab_type": "code",
334 | "colab": {}
335 | },
336 | "source": [
337 | "import scipy.io\n",
338 | "import numpy as np\n",
339 | "\n",
340 | "def get_tensors_from_matlab(matlab_files_name):\n",
341 | " '''\n",
342 | " Extracts the acquisitions of each Matlab file in the dictionary matlab_files_name.\n",
343 | " '''\n",
344 | " acquisitions = {}\n",
345 | " for key in matlab_files_name:\n",
346 | " file_name = matlab_files_name[key]\n",
347 | " matlab_file = scipy.io.loadmat(file_name)\n",
348 | " for position in ['DE','FE', 'BA']:\n",
349 | " keys = [key for key in matlab_file if key.endswith(position+\"_time\")]\n",
350 | " if len(keys)>0:\n",
351 | " array_key = keys[0]\n",
352 | " acquisitions[key+position.lower()] = matlab_file[array_key].reshape(1,-1)[0]\n",
353 | " return acquisitions\n"
354 | ],
355 | "execution_count": 4,
356 | "outputs": []
357 | },
358 | {
359 | "cell_type": "markdown",
360 | "metadata": {
361 | "id": "x-2lcukz5Nyk",
362 | "colab_type": "text"
363 | },
364 | "source": [
365 | "##Downloading pickle file\n",
366 | "Following, some auxiliary functions to download a pickle file in a google drive account.\n",
367 | "The pickle file already has the acquisitions propertly extracted.\n",
368 | "Therefore, these functions might speed up the whole process."
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "metadata": {
374 | "id": "ZJkpaFxn1xtR",
375 | "colab_type": "code",
376 | "colab": {}
377 | },
378 | "source": [
379 | "import requests\n",
380 | "import os.path\n",
381 | "\n",
382 | "def download_file_from_google_drive(id, destination):\n",
383 | " URL = \"https://docs.google.com/uc?export=download\"\n",
384 | " session = requests.Session()\n",
385 | " response = session.get(URL, params = { 'id' : id }, stream = True)\n",
386 | " token = get_confirm_token(response)\n",
387 | " if token:\n",
388 | " params = { 'id' : id, 'confirm' : token }\n",
389 | " response = session.get(URL, params = params, stream = True)\n",
390 | " save_response_content(response, destination) \n",
391 | "\n",
392 | "def get_confirm_token(response):\n",
393 | " for key, value in response.cookies.items():\n",
394 | " if key.startswith('download_warning'):\n",
395 | " return value\n",
396 | " return None\n",
397 | "\n",
398 | "def save_response_content(response, destination):\n",
399 | " CHUNK_SIZE = 32768\n",
400 | " with open(destination, \"wb\") as f:\n",
401 | " for chunk in response.iter_content(CHUNK_SIZE):\n",
402 | " if chunk: # filter out keep-alive new chunks\n",
403 | " f.write(chunk)\n",
404 | "\n",
405 | "# https://drive.google.com/file/d/1qJezMiROz9NAYafPUDPh9BFkxYF4nOi2/view?usp=sharing\n",
406 | "file_id = \"1qJezMiROz9NAYafPUDPh9BFkxYF4nOi2\"\n",
407 | "if not debug:\n",
408 | " pickle_file = 'cwru.pickle'\n",
409 | "else:\n",
410 | " pickle_file = 'debug.pickle'\n",
411 | "\n",
412 | "if not os.path.isfile(pickle_file) and not debug:\n",
413 | " try:\n",
414 | " download_file_from_google_drive(file_id, destination)\n",
415 | " except:\n",
416 | " print(\"Download failed!\")"
417 | ],
418 | "execution_count": 5,
419 | "outputs": []
420 | },
421 | {
422 | "cell_type": "markdown",
423 | "metadata": {
424 | "id": "oEzRboCpgtlx",
425 | "colab_type": "text"
426 | },
427 | "source": [
428 | "##Save/Load data\n",
429 | "If the cwru pickle file is already download, it will not be downloaded again, and the dictionary with the acquisitions will be loaded.\n",
430 | "Otherwise, the desired files are downloaded and the acquisitions are extrated."
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "metadata": {
436 | "id": "Z1m5Q3OUbvqa",
437 | "colab_type": "code",
438 | "colab": {}
439 | },
440 | "source": [
441 | "import pickle\n",
442 | "import os\n",
443 | "\n",
444 | "if not debug:\n",
445 | " matlab_files_name = cwru_12khz()\n",
446 | "else:\n",
447 | " matlab_files_name = files_debug()\n",
448 | "\n",
449 | "if os.path.isfile(pickle_file) and not debug:\n",
450 | " with open(pickle_file, 'rb') as handle:\n",
451 | " acquisitions = pickle.load(handle)\n",
452 | "else:\n",
453 | " download_cwrufiles(matlab_files_name)\n",
454 | " acquisitions = get_tensors_from_matlab(matlab_files_name)\n",
455 | " with open(pickle_file, 'wb') as handle:\n",
456 | " pickle.dump(acquisitions, handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
457 | ],
458 | "execution_count": 6,
459 | "outputs": []
460 | },
461 | {
462 | "cell_type": "markdown",
463 | "metadata": {
464 | "id": "AT7hgDnzTcNP",
465 | "colab_type": "text"
466 | },
467 | "source": [
468 | "##Segment data\n",
469 | "Segments the acquisitions.\n",
470 | " sample_size is the size of each segment.\n",
471 | " max_samples is used for debug purpouses and \n",
472 | " reduces the number of samples from each acquisition.\n"
473 | ]
474 | },
475 | {
476 | "cell_type": "code",
477 | "metadata": {
478 | "id": "7BKfioJFzAKA",
479 | "colab_type": "code",
480 | "colab": {
481 | "base_uri": "https://localhost:8080/",
482 | "height": 1000
483 | },
484 | "outputId": "a072ec32-a722-46a0-f908-e80d943e4699"
485 | },
486 | "source": [
487 | "import numpy as np\n",
488 | "def cwru_segmentation(acquisitions, sample_size=512, max_samples=None):\n",
489 | " '''\n",
490 | " Segments the acquisitions.\n",
491 | " sample_size is the size of each segment.\n",
492 | " max_samples is used for debug purpouses and \n",
493 | " reduces the number of samples from each acquisition.\n",
494 | " '''\n",
495 | " origin = []\n",
496 | " data = np.empty((0,sample_size,1))\n",
497 | " n = len(acquisitions)\n",
498 | " for i,key in enumerate(acquisitions):\n",
499 | " acquisition_size = len(acquisitions[key])\n",
500 | " n_samples = acquisition_size//sample_size\n",
501 | " if max_samples is not None and max_samples > 0 and n_samples > max_samples:\n",
502 | " n_samples = max_samples\n",
503 | " print('{}/{} --- {}: {}'.format(i+1, n, key, n_samples))\n",
504 | " origin.extend([key for _ in range(n_samples)])\n",
505 | " data = np.concatenate((data,\n",
506 | " acquisitions[key][:(n_samples*sample_size)].reshape(\n",
507 | " (n_samples,sample_size,1))))\n",
508 | " return data,origin\n",
509 | "\n",
510 | "if not debug:\n",
511 | " signal_data,signal_origin = cwru_segmentation(acquisitions, 512)\n",
512 | "else:\n",
513 | " signal_data,signal_origin = cwru_segmentation(acquisitions, 512, 15)\n",
514 | "signal_data.shape"
515 | ],
516 | "execution_count": 7,
517 | "outputs": [
518 | {
519 | "output_type": "stream",
520 | "text": [
521 | "1/307 --- Normal_0de: 476\n",
522 | "2/307 --- Normal_0fe: 476\n",
523 | "3/307 --- Normal_1de: 945\n",
524 | "4/307 --- Normal_1fe: 945\n",
525 | "5/307 --- Normal_2de: 945\n",
526 | "6/307 --- Normal_2fe: 945\n",
527 | "7/307 --- Normal_3de: 948\n",
528 | "8/307 --- Normal_3fe: 948\n",
529 | "9/307 --- DEIR.007_0de: 236\n",
530 | "10/307 --- DEIR.007_0fe: 236\n",
531 | "11/307 --- DEIR.007_0ba: 236\n",
532 | "12/307 --- DEIR.007_1de: 238\n",
533 | "13/307 --- DEIR.007_1fe: 238\n",
534 | "14/307 --- DEIR.007_1ba: 238\n",
535 | "15/307 --- DEIR.007_2de: 238\n",
536 | "16/307 --- DEIR.007_2fe: 238\n",
537 | "17/307 --- DEIR.007_2ba: 238\n",
538 | "18/307 --- DEIR.007_3de: 240\n",
539 | "19/307 --- DEIR.007_3fe: 240\n",
540 | "20/307 --- DEIR.007_3ba: 240\n",
541 | "21/307 --- DEB.007_0de: 239\n",
542 | "22/307 --- DEB.007_0fe: 239\n",
543 | "23/307 --- DEB.007_0ba: 239\n",
544 | "24/307 --- DEB.007_1de: 237\n",
545 | "25/307 --- DEB.007_1fe: 237\n",
546 | "26/307 --- DEB.007_1ba: 237\n",
547 | "27/307 --- DEB.007_2de: 237\n",
548 | "28/307 --- DEB.007_2fe: 237\n",
549 | "29/307 --- DEB.007_2ba: 237\n",
550 | "30/307 --- DEB.007_3de: 237\n",
551 | "31/307 --- DEB.007_3fe: 237\n",
552 | "32/307 --- DEB.007_3ba: 237\n",
553 | "33/307 --- DEOR.007@6_0de: 238\n",
554 | "34/307 --- DEOR.007@6_0fe: 238\n",
555 | "35/307 --- DEOR.007@6_0ba: 238\n",
556 | "36/307 --- DEOR.007@6_1de: 239\n",
557 | "37/307 --- DEOR.007@6_1fe: 239\n",
558 | "38/307 --- DEOR.007@6_1ba: 239\n",
559 | "39/307 --- DEOR.007@6_2de: 237\n",
560 | "40/307 --- DEOR.007@6_2fe: 237\n",
561 | "41/307 --- DEOR.007@6_2ba: 237\n",
562 | "42/307 --- DEOR.007@6_3de: 239\n",
563 | "43/307 --- DEOR.007@6_3fe: 239\n",
564 | "44/307 --- DEOR.007@6_3ba: 239\n",
565 | "45/307 --- DEOR.007@3_0de: 238\n",
566 | "46/307 --- DEOR.007@3_0fe: 238\n",
567 | "47/307 --- DEOR.007@3_0ba: 238\n",
568 | "48/307 --- DEOR.007@3_1de: 237\n",
569 | "49/307 --- DEOR.007@3_1fe: 237\n",
570 | "50/307 --- DEOR.007@3_1ba: 237\n",
571 | "51/307 --- DEOR.007@3_2de: 237\n",
572 | "52/307 --- DEOR.007@3_2fe: 237\n",
573 | "53/307 --- DEOR.007@3_2ba: 237\n",
574 | "54/307 --- DEOR.007@3_3de: 238\n",
575 | "55/307 --- DEOR.007@3_3fe: 238\n",
576 | "56/307 --- DEOR.007@3_3ba: 238\n",
577 | "57/307 --- DEOR.007@12_0de: 238\n",
578 | "58/307 --- DEOR.007@12_0fe: 238\n",
579 | "59/307 --- DEOR.007@12_0ba: 238\n",
580 | "60/307 --- DEOR.007@12_1de: 238\n",
581 | "61/307 --- DEOR.007@12_1fe: 238\n",
582 | "62/307 --- DEOR.007@12_1ba: 238\n",
583 | "63/307 --- DEOR.007@12_2de: 238\n",
584 | "64/307 --- DEOR.007@12_2fe: 238\n",
585 | "65/307 --- DEOR.007@12_2ba: 238\n",
586 | "66/307 --- DEOR.007@12_3de: 238\n",
587 | "67/307 --- DEOR.007@12_3fe: 238\n",
588 | "68/307 --- DEOR.007@12_3ba: 238\n",
589 | "69/307 --- DEIR.014_0de: 237\n",
590 | "70/307 --- DEIR.014_0fe: 237\n",
591 | "71/307 --- DEIR.014_0ba: 237\n",
592 | "72/307 --- DEIR.014_1de: 237\n",
593 | "73/307 --- DEIR.014_1fe: 237\n",
594 | "74/307 --- DEIR.014_1ba: 237\n",
595 | "75/307 --- DEIR.014_2de: 237\n",
596 | "76/307 --- DEIR.014_2fe: 237\n",
597 | "77/307 --- DEIR.014_2ba: 237\n",
598 | "78/307 --- DEIR.014_3de: 237\n",
599 | "79/307 --- DEIR.014_3fe: 237\n",
600 | "80/307 --- DEIR.014_3ba: 237\n",
601 | "81/307 --- DEB.014_0de: 237\n",
602 | "82/307 --- DEB.014_0fe: 237\n",
603 | "83/307 --- DEB.014_0ba: 237\n",
604 | "84/307 --- DEB.014_1de: 238\n",
605 | "85/307 --- DEB.014_1fe: 238\n",
606 | "86/307 --- DEB.014_1ba: 238\n",
607 | "87/307 --- DEB.014_2de: 238\n",
608 | "88/307 --- DEB.014_2fe: 238\n",
609 | "89/307 --- DEB.014_2ba: 238\n",
610 | "90/307 --- DEB.014_3de: 238\n",
611 | "91/307 --- DEB.014_3fe: 238\n",
612 | "92/307 --- DEB.014_3ba: 238\n",
613 | "93/307 --- DEOR.014@6_0de: 237\n",
614 | "94/307 --- DEOR.014@6_0fe: 237\n",
615 | "95/307 --- DEOR.014@6_0ba: 237\n",
616 | "96/307 --- DEOR.014@6_1de: 238\n",
617 | "97/307 --- DEOR.014@6_1fe: 238\n",
618 | "98/307 --- DEOR.014@6_1ba: 238\n",
619 | "99/307 --- DEOR.014@6_2de: 237\n",
620 | "100/307 --- DEOR.014@6_2fe: 237\n",
621 | "101/307 --- DEOR.014@6_2ba: 237\n",
622 | "102/307 --- DEOR.014@6_3de: 238\n",
623 | "103/307 --- DEOR.014@6_3fe: 238\n",
624 | "104/307 --- DEOR.014@6_3ba: 238\n",
625 | "105/307 --- DEB.021_0de: 238\n",
626 | "106/307 --- DEB.021_0fe: 238\n",
627 | "107/307 --- DEB.021_0ba: 238\n",
628 | "108/307 --- DEB.021_1de: 237\n",
629 | "109/307 --- DEB.021_1fe: 237\n",
630 | "110/307 --- DEB.021_1ba: 237\n",
631 | "111/307 --- DEB.021_2de: 238\n",
632 | "112/307 --- DEB.021_2fe: 238\n",
633 | "113/307 --- DEB.021_2ba: 238\n",
634 | "114/307 --- DEB.021_3de: 238\n",
635 | "115/307 --- DEB.021_3fe: 238\n",
636 | "116/307 --- DEB.021_3ba: 238\n",
637 | "117/307 --- FEIR.021_0de: 236\n",
638 | "118/307 --- FEIR.021_0fe: 236\n",
639 | "119/307 --- FEIR.021_0ba: 236\n",
640 | "120/307 --- FEIR.021_1de: 236\n",
641 | "121/307 --- FEIR.021_1fe: 236\n",
642 | "122/307 --- FEIR.021_1ba: 236\n",
643 | "123/307 --- FEIR.021_2de: 236\n",
644 | "124/307 --- FEIR.021_2fe: 236\n",
645 | "125/307 --- FEIR.021_2ba: 236\n",
646 | "126/307 --- FEIR.021_3de: 236\n",
647 | "127/307 --- FEIR.021_3fe: 236\n",
648 | "128/307 --- FEIR.021_3ba: 236\n",
649 | "129/307 --- FEIR.014_0de: 237\n",
650 | "130/307 --- FEIR.014_0fe: 237\n",
651 | "131/307 --- FEIR.014_0ba: 237\n",
652 | "132/307 --- FEIR.014_1de: 237\n",
653 | "133/307 --- FEIR.014_1fe: 237\n",
654 | "134/307 --- FEIR.014_1ba: 237\n",
655 | "135/307 --- FEIR.014_2de: 237\n",
656 | "136/307 --- FEIR.014_2fe: 237\n",
657 | "137/307 --- FEIR.014_2ba: 237\n",
658 | "138/307 --- FEIR.014_3de: 236\n",
659 | "139/307 --- FEIR.014_3fe: 236\n",
660 | "140/307 --- FEIR.014_3ba: 236\n",
661 | "141/307 --- FEB.007_0de: 236\n",
662 | "142/307 --- FEB.007_0fe: 236\n",
663 | "143/307 --- FEB.007_0ba: 236\n",
664 | "144/307 --- FEB.007_1de: 235\n",
665 | "145/307 --- FEB.007_1fe: 235\n",
666 | "146/307 --- FEB.007_1ba: 235\n",
667 | "147/307 --- FEB.007_2de: 237\n",
668 | "148/307 --- FEB.007_2fe: 237\n",
669 | "149/307 --- FEB.007_2ba: 237\n",
670 | "150/307 --- FEB.007_3de: 236\n",
671 | "151/307 --- FEB.007_3fe: 236\n",
672 | "152/307 --- FEB.007_3ba: 236\n",
673 | "153/307 --- DEIR.021_0de: 238\n",
674 | "154/307 --- DEIR.021_0fe: 238\n",
675 | "155/307 --- DEIR.021_0ba: 238\n",
676 | "156/307 --- DEIR.021_1de: 237\n",
677 | "157/307 --- DEIR.021_1fe: 237\n",
678 | "158/307 --- DEIR.021_1ba: 237\n",
679 | "159/307 --- DEIR.021_2de: 237\n",
680 | "160/307 --- DEIR.021_2fe: 237\n",
681 | "161/307 --- DEIR.021_2ba: 237\n",
682 | "162/307 --- DEIR.021_3de: 238\n",
683 | "163/307 --- DEIR.021_3fe: 238\n",
684 | "164/307 --- DEIR.021_3ba: 238\n",
685 | "165/307 --- DEOR.021@6_0de: 239\n",
686 | "166/307 --- DEOR.021@6_0fe: 239\n",
687 | "167/307 --- DEOR.021@6_0ba: 239\n",
688 | "168/307 --- DEOR.021@6_1de: 238\n",
689 | "169/307 --- DEOR.021@6_1fe: 238\n",
690 | "170/307 --- DEOR.021@6_1ba: 238\n",
691 | "171/307 --- DEOR.021@6_2de: 238\n",
692 | "172/307 --- DEOR.021@6_2fe: 238\n",
693 | "173/307 --- DEOR.021@6_2ba: 238\n",
694 | "174/307 --- DEOR.021@6_3de: 238\n",
695 | "175/307 --- DEOR.021@6_3fe: 238\n",
696 | "176/307 --- DEOR.021@6_3ba: 238\n",
697 | "177/307 --- DEOR.021@3_0de: 237\n",
698 | "178/307 --- DEOR.021@3_0fe: 237\n",
699 | "179/307 --- DEOR.021@3_0ba: 237\n",
700 | "180/307 --- DEOR.021@3_1de: 238\n",
701 | "181/307 --- DEOR.021@3_1fe: 238\n",
702 | "182/307 --- DEOR.021@3_1ba: 238\n",
703 | "183/307 --- DEOR.021@3_2de: 238\n",
704 | "184/307 --- DEOR.021@3_2fe: 238\n",
705 | "185/307 --- DEOR.021@3_2ba: 238\n",
706 | "186/307 --- DEOR.021@3_3de: 238\n",
707 | "187/307 --- DEOR.021@3_3fe: 238\n",
708 | "188/307 --- DEOR.021@3_3ba: 238\n",
709 | "189/307 --- DEOR.021@12_0de: 237\n",
710 | "190/307 --- DEOR.021@12_0fe: 237\n",
711 | "191/307 --- DEOR.021@12_0ba: 237\n",
712 | "192/307 --- DEOR.021@12_1de: 239\n",
713 | "193/307 --- DEOR.021@12_1fe: 239\n",
714 | "194/307 --- DEOR.021@12_1ba: 239\n",
715 | "195/307 --- DEOR.021@12_2de: 239\n",
716 | "196/307 --- DEOR.021@12_2fe: 239\n",
717 | "197/307 --- DEOR.021@12_2ba: 239\n",
718 | "198/307 --- DEOR.021@12_3de: 237\n",
719 | "199/307 --- DEOR.021@12_3fe: 237\n",
720 | "200/307 --- DEOR.021@12_3ba: 237\n",
721 | "201/307 --- FEIR.007_0de: 237\n",
722 | "202/307 --- FEIR.007_0fe: 237\n",
723 | "203/307 --- FEIR.007_0ba: 237\n",
724 | "204/307 --- FEIR.007_1de: 237\n",
725 | "205/307 --- FEIR.007_1fe: 237\n",
726 | "206/307 --- FEIR.007_1ba: 237\n",
727 | "207/307 --- FEIR.007_2de: 237\n",
728 | "208/307 --- FEIR.007_2fe: 237\n",
729 | "209/307 --- FEIR.007_2ba: 237\n",
730 | "210/307 --- FEIR.007_3de: 237\n",
731 | "211/307 --- FEIR.007_3fe: 237\n",
732 | "212/307 --- FEIR.007_3ba: 237\n",
733 | "213/307 --- FEB.014_0de: 238\n",
734 | "214/307 --- FEB.014_0fe: 238\n",
735 | "215/307 --- FEB.014_0ba: 238\n",
736 | "216/307 --- FEB.014_1de: 237\n",
737 | "217/307 --- FEB.014_1fe: 237\n",
738 | "218/307 --- FEB.014_1ba: 237\n",
739 | "219/307 --- FEB.014_2de: 238\n",
740 | "220/307 --- FEB.014_2fe: 238\n",
741 | "221/307 --- FEB.014_2ba: 238\n",
742 | "222/307 --- FEB.014_3de: 236\n",
743 | "223/307 --- FEB.014_3fe: 236\n",
744 | "224/307 --- FEB.014_3ba: 236\n",
745 | "225/307 --- FEB.021_0de: 237\n",
746 | "226/307 --- FEB.021_0fe: 237\n",
747 | "227/307 --- FEB.021_0ba: 237\n",
748 | "228/307 --- FEB.021_1de: 237\n",
749 | "229/307 --- FEB.021_1fe: 237\n",
750 | "230/307 --- FEB.021_1ba: 237\n",
751 | "231/307 --- FEB.021_2de: 237\n",
752 | "232/307 --- FEB.021_2fe: 237\n",
753 | "233/307 --- FEB.021_2ba: 237\n",
754 | "234/307 --- FEB.021_3de: 235\n",
755 | "235/307 --- FEB.021_3fe: 235\n",
756 | "236/307 --- FEB.021_3ba: 235\n",
757 | "237/307 --- FEOR.007@6_0de: 236\n",
758 | "238/307 --- FEOR.007@6_0fe: 236\n",
759 | "239/307 --- FEOR.007@6_0ba: 236\n",
760 | "240/307 --- FEOR.007@6_1de: 237\n",
761 | "241/307 --- FEOR.007@6_1fe: 237\n",
762 | "242/307 --- FEOR.007@6_1ba: 237\n",
763 | "243/307 --- FEOR.007@6_2de: 236\n",
764 | "244/307 --- FEOR.007@6_2fe: 236\n",
765 | "245/307 --- FEOR.007@6_2ba: 236\n",
766 | "246/307 --- FEOR.007@6_3de: 238\n",
767 | "247/307 --- FEOR.007@6_3fe: 238\n",
768 | "248/307 --- FEOR.007@6_3ba: 238\n",
769 | "249/307 --- FEOR.007@3_0de: 236\n",
770 | "250/307 --- FEOR.007@3_0fe: 236\n",
771 | "251/307 --- FEOR.007@3_0ba: 236\n",
772 | "252/307 --- FEOR.007@3_1de: 236\n",
773 | "253/307 --- FEOR.007@3_1fe: 236\n",
774 | "254/307 --- FEOR.007@3_1ba: 236\n",
775 | "255/307 --- FEOR.007@3_2de: 238\n",
776 | "256/307 --- FEOR.007@3_2fe: 238\n",
777 | "257/307 --- FEOR.007@3_2ba: 238\n",
778 | "258/307 --- FEOR.007@3_3de: 237\n",
779 | "259/307 --- FEOR.007@3_3fe: 237\n",
780 | "260/307 --- FEOR.007@3_3ba: 237\n",
781 | "261/307 --- FEOR.007@12_0de: 236\n",
782 | "262/307 --- FEOR.007@12_0fe: 236\n",
783 | "263/307 --- FEOR.007@12_0ba: 236\n",
784 | "264/307 --- FEOR.007@12_1de: 236\n",
785 | "265/307 --- FEOR.007@12_1fe: 236\n",
786 | "266/307 --- FEOR.007@12_1ba: 236\n",
787 | "267/307 --- FEOR.007@12_2de: 237\n",
788 | "268/307 --- FEOR.007@12_2fe: 237\n",
789 | "269/307 --- FEOR.007@12_2ba: 237\n",
790 | "270/307 --- FEOR.007@12_3de: 236\n",
791 | "271/307 --- FEOR.007@12_3fe: 236\n",
792 | "272/307 --- FEOR.007@12_3ba: 236\n",
793 | "273/307 --- FEOR.014@3_0de: 237\n",
794 | "274/307 --- FEOR.014@3_0fe: 237\n",
795 | "275/307 --- FEOR.014@3_0ba: 237\n",
796 | "276/307 --- FEOR.014@3_1de: 236\n",
797 | "277/307 --- FEOR.014@3_1fe: 236\n",
798 | "278/307 --- FEOR.014@3_1ba: 236\n",
799 | "279/307 --- FEOR.014@3_2de: 236\n",
800 | "280/307 --- FEOR.014@3_2fe: 236\n",
801 | "281/307 --- FEOR.014@3_2ba: 236\n",
802 | "282/307 --- FEOR.014@3_3de: 236\n",
803 | "283/307 --- FEOR.014@3_3fe: 236\n",
804 | "284/307 --- FEOR.014@3_3ba: 236\n",
805 | "285/307 --- FEOR.014@6_0de: 236\n",
806 | "286/307 --- FEOR.014@6_0fe: 236\n",
807 | "287/307 --- FEOR.014@6_0ba: 236\n",
808 | "288/307 --- FEOR.021@6_0de: 235\n",
809 | "289/307 --- FEOR.021@6_0fe: 235\n",
810 | "290/307 --- FEOR.021@6_0ba: 235\n",
811 | "291/307 --- FEOR.021@3_1de: 235\n",
812 | "292/307 --- FEOR.021@3_1fe: 235\n",
813 | "293/307 --- FEOR.021@3_1ba: 235\n",
814 | "294/307 --- FEOR.021@3_2de: 235\n",
815 | "295/307 --- FEOR.021@3_2fe: 235\n",
816 | "296/307 --- FEOR.021@3_2ba: 235\n",
817 | "297/307 --- FEOR.021@3_3de: 237\n",
818 | "298/307 --- FEOR.021@3_3fe: 237\n",
819 | "299/307 --- FEOR.021@3_3ba: 237\n",
820 | "300/307 --- DEIR.028_0de: 235\n",
821 | "301/307 --- DEIR.028_1de: 237\n",
822 | "302/307 --- DEIR.028_2de: 237\n",
823 | "303/307 --- DEIR.028_3de: 237\n",
824 | "304/307 --- DEB.028_0de: 235\n",
825 | "305/307 --- DEB.028_1de: 237\n",
826 | "306/307 --- DEB.028_2de: 236\n",
827 | "307/307 --- DEB.028_3de: 236\n"
828 | ],
829 | "name": "stdout"
830 | },
831 | {
832 | "output_type": "execute_result",
833 | "data": {
834 | "text/plain": [
835 | "(77527, 512, 1)"
836 | ]
837 | },
838 | "metadata": {
839 | "tags": []
840 | },
841 | "execution_count": 7
842 | }
843 | ]
844 | },
845 | {
846 | "cell_type": "markdown",
847 | "metadata": {
848 | "id": "_jCN8XZ5dOF3",
849 | "colab_type": "text"
850 | },
851 | "source": [
852 | "## Clean dataset functions\n",
853 | "The functions below help to select samples from acquisitions and form groups according to these acquisitions, using regular expressions."
854 | ]
855 | },
856 | {
857 | "cell_type": "code",
858 | "metadata": {
859 | "id": "eOOP9H2c3AaZ",
860 | "colab_type": "code",
861 | "colab": {}
862 | },
863 | "source": [
864 | "import re\n",
865 | "import numpy as np\n",
866 | "\n",
867 | "def select_samples(regex, X, y):\n",
868 | " '''\n",
869 | " Selects samples wich has some regex pattern in its name.\n",
870 | " '''\n",
871 | " mask = [re.search(regex,label) is not None for label in y]\n",
872 | " return X[mask],y[mask]\n",
873 | "\n",
874 | "def join_labels(regex, y):\n",
875 | " '''\n",
876 | " Excludes some regex patterns from the labels, \n",
877 | " making some samples to have the same label.\n",
878 | " '''\n",
879 | " return np.array([re.sub(regex, '', label) for label in y])\n",
880 | "\n",
881 | "def get_groups(regex, y):\n",
882 | " '''\n",
883 | " Generates a list of groups of samples with \n",
884 | " the same regex patten in its label.\n",
885 | " '''\n",
886 | " groups = list(range(len(y)))\n",
887 | " for i,label in enumerate(y):\n",
888 | " match = re.search(regex,label)\n",
889 | " groups[i] = match.group(0) if match else None\n",
890 | " return groups"
891 | ],
892 | "execution_count": 8,
893 | "outputs": []
894 | },
895 | {
896 | "cell_type": "markdown",
897 | "metadata": {
898 | "id": "yFA-8l02RplD",
899 | "colab_type": "text"
900 | },
901 | "source": [
902 | "##Selecting samples"
903 | ]
904 | },
905 | {
906 | "cell_type": "code",
907 | "metadata": {
908 | "id": "r89dYOJm8gzW",
909 | "colab_type": "code",
910 | "colab": {
911 | "base_uri": "https://localhost:8080/",
912 | "height": 54
913 | },
914 | "outputId": "3834f56c-049c-46cf-9be4-5aa9f9e34908"
915 | },
916 | "source": [
917 | "#DE from 'de', FE from 'fe', Normal from 'de' and 'fe'\n",
918 | "samples = '^(DE).*(de)|^(FE).*(fe)|(Normal).*'\n",
919 | "X,y = select_samples(samples, signal_data, np.array(signal_origin))\n",
920 | "print(len(set(y)),set(y))"
921 | ],
922 | "execution_count": 9,
923 | "outputs": [
924 | {
925 | "output_type": "stream",
926 | "text": [
927 | "113 {'FEIR.021_2fe', 'DEOR.021@12_3de', 'FEOR.007@3_2fe', 'DEB.021_3de', 'DEB.007_2de', 'DEIR.007_2de', 'DEOR.021@6_3de', 'FEOR.014@3_2fe', 'DEB.028_2de', 'FEIR.021_1fe', 'FEB.007_2fe', 'DEIR.007_1de', 'DEOR.014@6_2de', 'DEOR.021@12_1de', 'FEB.007_1fe', 'DEIR.021_2de', 'FEOR.021@3_3fe', 'FEIR.014_0fe', 'FEOR.007@12_3fe', 'FEOR.007@12_1fe', 'DEOR.014@6_0de', 'FEOR.014@6_0fe', 'DEOR.007@6_3de', 'FEOR.007@3_0fe', 'DEOR.007@6_0de', 'DEOR.007@12_1de', 'FEB.014_3fe', 'FEB.014_2fe', 'FEOR.021@3_2fe', 'DEIR.028_3de', 'DEOR.007@6_2de', 'DEIR.021_1de', 'FEOR.007@3_3fe', 'DEB.028_1de', 'DEB.028_3de', 'DEIR.021_0de', 'FEOR.007@3_1fe', 'DEB.007_0de', 'DEB.014_1de', 'FEB.007_3fe', 'Normal_0de', 'FEB.021_2fe', 'FEOR.014@3_1fe', 'FEOR.014@3_0fe', 'DEOR.014@6_1de', 'DEIR.021_3de', 'DEOR.007@3_2de', 'FEB.021_0fe', 'DEOR.007@12_0de', 'DEB.021_0de', 'DEOR.021@3_2de', 'DEIR.007_3de', 'FEOR.007@12_2fe', 'FEIR.014_1fe', 'DEB.021_2de', 'FEOR.007@6_3fe', 'FEB.014_1fe', 'FEB.007_0fe', 'DEOR.014@6_3de', 'FEIR.014_3fe', 'FEIR.007_0fe', 'Normal_2de', 'FEOR.007@6_0fe', 'DEIR.014_0de', 'FEIR.007_3fe', 'DEB.014_2de', 'FEOR.021@3_1fe', 'DEB.028_0de', 'DEOR.021@12_0de', 'FEIR.021_0fe', 'DEOR.021@6_2de', 'DEOR.007@12_3de', 'DEOR.021@3_1de', 'FEOR.007@6_1fe', 'Normal_1fe', 'DEOR.021@3_3de', 'FEOR.007@12_0fe', 'DEIR.014_1de', 'DEOR.021@6_0de', 'FEIR.014_2fe', 'FEB.014_0fe', 'DEIR.007_0de', 'DEIR.028_0de', 'DEIR.028_2de', 'Normal_3fe', 'DEIR.014_3de', 'FEB.021_3fe', 'FEIR.007_1fe', 'FEOR.021@6_0fe', 'FEIR.021_3fe', 'DEIR.028_1de', 'DEB.014_3de', 'DEOR.007@3_3de', 'DEIR.014_2de', 'FEOR.007@6_2fe', 'DEB.007_1de', 'DEB.014_0de', 'FEOR.014@3_3fe', 'DEOR.021@3_0de', 'DEOR.021@12_2de', 'Normal_1de', 'Normal_0fe', 'FEB.021_1fe', 'DEOR.007@3_1de', 'DEOR.007@6_1de', 'Normal_3de', 'Normal_2fe', 'DEOR.007@3_0de', 'DEOR.007@12_2de', 'DEB.021_1de', 'DEOR.021@6_1de', 'DEB.007_3de', 'FEIR.007_2fe'}\n"
928 | ],
929 | "name": "stdout"
930 | }
931 | ]
932 | },
933 | {
934 | "cell_type": "markdown",
935 | "metadata": {
936 | "id": "5pRQQK0Mhm1_",
937 | "colab_type": "text"
938 | },
939 | "source": [
940 | "#Experimenter"
941 | ]
942 | },
943 | {
944 | "cell_type": "code",
945 | "metadata": {
946 | "id": "GE4TTG1-hmH7",
947 | "colab_type": "code",
948 | "colab": {
949 | "base_uri": "https://localhost:8080/",
950 | "height": 71
951 | },
952 | "outputId": "324eda7a-47d9-4366-f6fd-62a1808ae70f"
953 | },
954 | "source": [
955 | "from sklearn.model_selection import cross_validate, KFold, PredefinedSplit\n",
956 | "\n",
957 | "def experimenter(model, X, y, groups=None, scoring=None, cv=KFold(4, True), verbose=0):\n",
958 | " '''\n",
959 | " Performs a experiment with some estimator (model) and validation.\n",
960 | " It works like a cross_validate function from sklearn, however, \n",
961 | " when a estimator has an internal validation with groups, \n",
962 | " it maintains the groups from the external validation.\n",
963 | " '''\n",
964 | " if hasattr(model,'cv') or (hasattr(model,'steps') and any(['gs' in step[0] for step in model.steps])):\n",
965 | " scores = {}\n",
966 | " lstval = list(validation.split(X,y,groups))\n",
967 | " for tr,te in lstval:\n",
968 | " if groups is not None:\n",
969 | " innercv = list(GroupShuffleKFold(validation.n_splits-1).split(X[tr],y[tr],np.array(groups)[tr]))\n",
970 | " else:\n",
971 | " innercv = list(KFold(validation.n_splits-1, True).split(X[tr],y[tr]))\n",
972 | " if hasattr(model,'cv'):\n",
973 | " model.cv = innercv\n",
974 | " else:\n",
975 | " for step in model.steps:\n",
976 | " if 'gs' in step[0]:\n",
977 | " step[1].cv = innercv\n",
978 | " test_fold = np.zeros((len(y),), dtype=int)\n",
979 | " test_fold[tr] = -1\n",
980 | " score = cross_validate(model, X, y, groups, scoring, \n",
981 | " PredefinedSplit(test_fold), verbose=verbose)\n",
982 | " for k in score.keys():\n",
983 | " if k not in scores:\n",
984 | " scores[k] = []\n",
985 | " scores[k].extend(score[k])\n",
986 | " return scores\n",
987 | " return cross_validate(model, X, y, groups, scoring, cv, verbose=verbose)"
988 | ],
989 | "execution_count": 10,
990 | "outputs": [
991 | {
992 | "output_type": "stream",
993 | "text": [
994 | "/home/francisco/Jupyter/.venv/lib/python3.8/site-packages/sklearn/utils/validation.py:68: FutureWarning: Pass shuffle=True as keyword args. From version 0.25 passing these as positional arguments will result in an error\n",
995 | " warnings.warn(\"Pass {} as keyword args. From version 0.25 \"\n"
996 | ],
997 | "name": "stderr"
998 | }
999 | ]
1000 | },
1001 | {
1002 | "cell_type": "markdown",
1003 | "metadata": {
1004 | "id": "dWYfuxcxFjt8",
1005 | "colab_type": "text"
1006 | },
1007 | "source": [
1008 | "## Custom Splitter"
1009 | ]
1010 | },
1011 | {
1012 | "cell_type": "code",
1013 | "metadata": {
1014 | "id": "xRdfG-uzhPm4",
1015 | "colab_type": "code",
1016 | "colab": {}
1017 | },
1018 | "source": [
1019 | "from sklearn.model_selection import KFold\n",
1020 | "from sklearn.utils import shuffle\n",
1021 | "from sklearn.utils.validation import check_array\n",
1022 | "import numpy as np\n",
1023 | "\n",
1024 | "class GroupShuffleKFold(KFold):\n",
1025 | " '''\n",
1026 | " Neither GroupShuffleSplit nor GroupKFold are good splitters for this case.\n",
1027 | " A custom splitter must be made.\n",
1028 | " '''\n",
1029 | " def __init__(self, n_splits=4, shuffle=False, random_state=None):\n",
1030 | " super().__init__(n_splits, shuffle=shuffle, random_state=random_state)\n",
1031 | " def get_n_splits(self, X, y, groups=None):\n",
1032 | " return self.n_splits\n",
1033 | " def _iter_test_indices(self, X=None, y=None, groups=None):\n",
1034 | " if groups is None:\n",
1035 | " raise ValueError(\"The 'groups' parameter should not be None.\")\n",
1036 | " groups = check_array(groups, ensure_2d=False, dtype=None)\n",
1037 | " unique_groups, groups = np.unique(groups, return_inverse=True)\n",
1038 | " n_groups = len(unique_groups)\n",
1039 | " if self.n_splits > n_groups:\n",
1040 | " raise ValueError(\"Cannot have number of splits n_splits=%d greater\"\n",
1041 | " \" than the number of groups: %d.\"\n",
1042 | " % (self.n_splits, n_groups))\n",
1043 | " # Distribute groups\n",
1044 | " indices = np.arange(n_groups)\n",
1045 | " if self.shuffle:\n",
1046 | " for i in range(n_groups//self.n_splits):\n",
1047 | " if self.random_state is None:\n",
1048 | " indices[self.n_splits*i:self.n_splits*(i+1)] = shuffle(\n",
1049 | " indices[self.n_splits*i:self.n_splits*(i+1)])\n",
1050 | " else:\n",
1051 | " indices[self.n_splits*i:self.n_splits*(i+1)] = shuffle(\n",
1052 | " indices[self.n_splits*i:self.n_splits*(i+1)],\n",
1053 | " random_state=self.random_state+i)\n",
1054 | " #print(unique_groups[indices]) #Debug purpose\n",
1055 | " # Total weight of each fold\n",
1056 | " n_samples_per_fold = np.zeros(self.n_splits)\n",
1057 | " # Mapping from group index to fold index\n",
1058 | " group_to_fold = np.zeros(len(unique_groups))\n",
1059 | " # Distribute samples \n",
1060 | " for group_index in indices:\n",
1061 | " group_to_fold[indices[group_index]] = group_index%(self.n_splits)\n",
1062 | " indices = group_to_fold[groups]\n",
1063 | " for f in range(self.n_splits):\n",
1064 | " yield np.where(indices == f)[0]"
1065 | ],
1066 | "execution_count": 11,
1067 | "outputs": []
1068 | },
1069 | {
1070 | "cell_type": "markdown",
1071 | "metadata": {
1072 | "id": "Frmm3FVivQvg",
1073 | "colab_type": "text"
1074 | },
1075 | "source": [
1076 | "## BySeverity Splitter"
1077 | ]
1078 | },
1079 | {
1080 | "cell_type": "code",
1081 | "metadata": {
1082 | "id": "ZFDWy6zqvGTF",
1083 | "colab_type": "code",
1084 | "colab": {}
1085 | },
1086 | "source": [
1087 | "from sklearn.model_selection import KFold\n",
1088 | "from sklearn.utils import shuffle\n",
1089 | "from sklearn.utils.validation import check_array\n",
1090 | "import numpy as np\n",
1091 | "\n",
1092 | "class BySeverityKFold(KFold):\n",
1093 | " '''\n",
1094 | " Splits the CWRU dataset in severities.\n",
1095 | " '''\n",
1096 | " # Compatibility constructor\n",
1097 | " def __init__(self, n_splits=4, shuffle=False, random_state=None):\n",
1098 | " super().__init__(n_splits=4, shuffle=False, random_state=None)\n",
1099 | " self.nround = random_state\n",
1100 | " def _iter_test_indices(self, X=None, y=None, groups=None):\n",
1101 | " if groups is None:\n",
1102 | " raise ValueError(\"The 'groups' parameter should not be None.\")\n",
1103 | " groups = check_array(groups, ensure_2d=False, dtype=None)\n",
1104 | " unique_groups, groups = np.unique(groups, return_inverse=True)\n",
1105 | " n_groups = len(unique_groups)\n",
1106 | " if self.n_splits > n_groups:\n",
1107 | " raise ValueError(\"Cannot have number of splits n_splits=%d greater\"\n",
1108 | " \" than the number of groups: %d.\"\n",
1109 | " % (self.n_splits, n_groups))\n",
1110 | " # Distribute groups\n",
1111 | " indices = np.arange(n_groups)\n",
1112 | " nround = self.nround - random_state\n",
1113 | " for i in range(nround//4):\n",
1114 | " indices[i],indices[i+1] = indices[i+1],indices[i]\n",
1115 | " for i in range(self.n_splits): \n",
1116 | " indices[i+self.n_splits] = (i+nround)%self.n_splits+self.n_splits\n",
1117 | " #print(unique_groups[indices]) #Debug purpose\n",
1118 | " # Total weight of each fold\n",
1119 | " n_samples_per_fold = np.zeros(self.n_splits)\n",
1120 | " # Mapping from group index to fold index\n",
1121 | " group_to_fold = np.zeros(len(unique_groups))\n",
1122 | " # Distribute samples \n",
1123 | " for group_index in indices:\n",
1124 | " group_to_fold[indices[group_index]] = group_index%(self.n_splits)\n",
1125 | " print(group_to_fold)\n",
1126 | " indices = group_to_fold[groups]\n",
1127 | " for f in range(self.n_splits):\n",
1128 | " yield np.where(indices == f)[0]"
1129 | ],
1130 | "execution_count": 12,
1131 | "outputs": []
1132 | },
1133 | {
1134 | "cell_type": "markdown",
1135 | "metadata": {
1136 | "id": "e48W6KkIhesw",
1137 | "colab_type": "text"
1138 | },
1139 | "source": [
1140 | "##Experiment setup"
1141 | ]
1142 | },
1143 | {
1144 | "cell_type": "code",
1145 | "metadata": {
1146 | "id": "K9dImVH3hh_Y",
1147 | "colab_type": "code",
1148 | "colab": {}
1149 | },
1150 | "source": [
1151 | "from collections import namedtuple\n",
1152 | "\n",
1153 | "ExperimentSetup = namedtuple('ExperimentSetup', \n",
1154 | " 'groups, splitter_name, shuffle, rounds')\n",
1155 | "\n",
1156 | "validations = {\n",
1157 | " # # Samples with the same load cannot be in the trainning fold and\n",
1158 | " # # the test folds simultaneously. \n",
1159 | " # \"By Load\": ExperimentSetup(groups = get_groups('_\\d',y), \n",
1160 | " # splitter_name = 'GroupShuffleKFold',\n",
1161 | " # shuffle = False,\n",
1162 | " # rounds=1,\n",
1163 | " # ),\n",
1164 | " # # Samples with the same severity cannot be in the trainning folds and\n",
1165 | " # # the test folds simultaneously.\n",
1166 | " # \"By Severity\": ExperimentSetup(groups = get_groups('(\\.\\d{3})|(Normal_\\d)',y),\n",
1167 | " # splitter_name = 'BySeverityKFold',\n",
1168 | " # shuffle = False,\n",
1169 | " # rounds=8),\n",
1170 | " # Validation usually seen in publications with CWRU bearing dataset.\n",
1171 | " \"Usual K-Fold\": ExperimentSetup(groups = None, \n",
1172 | " splitter_name = 'KFold',\n",
1173 | " shuffle = True,\n",
1174 | " rounds=8), \n",
1175 | "}"
1176 | ],
1177 | "execution_count": 13,
1178 | "outputs": []
1179 | },
1180 | {
1181 | "cell_type": "markdown",
1182 | "metadata": {
1183 | "id": "E9yn44VAoRFo",
1184 | "colab_type": "text"
1185 | },
1186 | "source": [
1187 | "##Common Variables"
1188 | ]
1189 | },
1190 | {
1191 | "cell_type": "code",
1192 | "metadata": {
1193 | "id": "ogpdDde4oTsS",
1194 | "colab_type": "code",
1195 | "colab": {}
1196 | },
1197 | "source": [
1198 | "# Only four conditions are considered: Normal, Ball, Inner Race and Outer Race.\n",
1199 | "selected_y = join_labels('_\\d|@\\d{1,3}|(de)|(fe)|\\.\\d{3}|(DE)|(FE)',y)\n",
1200 | "verbose = 0 #if not debug else 3\n",
1201 | "random_state = 42\n",
1202 | "scoring = ['accuracy', 'f1_macro']#, 'precision_macro', 'recall_macro']"
1203 | ],
1204 | "execution_count": 14,
1205 | "outputs": []
1206 | },
1207 | {
1208 | "cell_type": "markdown",
1209 | "metadata": {
1210 | "id": "mq30RtWYToeu",
1211 | "colab_type": "text"
1212 | },
1213 | "source": [
1214 | "#Classification Models"
1215 | ]
1216 | },
1217 | {
1218 | "cell_type": "code",
1219 | "metadata": {
1220 | "id": "IstS2gTeY7pg",
1221 | "colab_type": "code",
1222 | "colab": {}
1223 | },
1224 | "source": [
1225 | "import warnings\n",
1226 | "warnings.filterwarnings('ignore')"
1227 | ],
1228 | "execution_count": 15,
1229 | "outputs": []
1230 | },
1231 | {
1232 | "cell_type": "markdown",
1233 | "metadata": {
1234 | "id": "NLeArq-uThHW",
1235 | "colab_type": "text"
1236 | },
1237 | "source": [
1238 | "##Feature Extraction Models"
1239 | ]
1240 | },
1241 | {
1242 | "cell_type": "code",
1243 | "metadata": {
1244 | "id": "WuSNj6YIEhu0",
1245 | "colab_type": "code",
1246 | "colab": {}
1247 | },
1248 | "source": [
1249 | "from sklearn.base import TransformerMixin"
1250 | ],
1251 | "execution_count": 16,
1252 | "outputs": []
1253 | },
1254 | {
1255 | "cell_type": "markdown",
1256 | "metadata": {
1257 | "id": "Mm95T4CsDxaN",
1258 | "colab_type": "text"
1259 | },
1260 | "source": [
1261 | "###Statistical functions"
1262 | ]
1263 | },
1264 | {
1265 | "cell_type": "code",
1266 | "metadata": {
1267 | "id": "vWCPUON8D1A8",
1268 | "colab_type": "code",
1269 | "colab": {}
1270 | },
1271 | "source": [
1272 | "import numpy as np\n",
1273 | "import scipy.stats as stats\n",
1274 | "\n",
1275 | "def rms(x):\n",
1276 | " '''\n",
1277 | " root mean square\n",
1278 | " '''\n",
1279 | " x = np.array(x)\n",
1280 | " return np.sqrt(np.mean(np.square(x)))\n",
1281 | "\n",
1282 | "def sra(x):\n",
1283 | " '''\n",
1284 | " square root amplitude\n",
1285 | " '''\n",
1286 | " x = np.array(x)\n",
1287 | " return np.mean(np.sqrt(np.absolute(x)))**2\n",
1288 | "\n",
1289 | "def ppv(x):\n",
1290 | " '''\n",
1291 | " peak to peak value\n",
1292 | " '''\n",
1293 | " x = np.array(x)\n",
1294 | " return np.max(x)-np.min(x)\n",
1295 | "\n",
1296 | "def cf(x):\n",
1297 | " '''\n",
1298 | " crest factor\n",
1299 | " '''\n",
1300 | " x = np.array(x)\n",
1301 | " return np.max(np.absolute(x))/rms(x)\n",
1302 | "\n",
1303 | "def ifa(x):\n",
1304 | " '''\n",
1305 | " impact factor\n",
1306 | " '''\n",
1307 | " x = np.array(x)\n",
1308 | " return np.max(np.absolute(x))/np.mean(np.absolute(x))\n",
1309 | "\n",
1310 | "def mf(x):\n",
1311 | " '''\n",
1312 | " margin factor\n",
1313 | " '''\n",
1314 | " x = np.array(x)\n",
1315 | " return np.max(np.absolute(x))/sra(x)\n",
1316 | "\n",
1317 | "def sf(x):\n",
1318 | " '''\n",
1319 | " shape factor\n",
1320 | " '''\n",
1321 | " x = np.array(x)\n",
1322 | " return rms(x)/np.mean(np.absolute(x))\n",
1323 | "\n",
1324 | "def kf(x):\n",
1325 | " '''\n",
1326 | " kurtosis factor\n",
1327 | " '''\n",
1328 | " x = np.array(x)\n",
1329 | " return stats.kurtosis(x)/(np.mean(x**2)**2)\n",
1330 | "\n"
1331 | ],
1332 | "execution_count": 17,
1333 | "outputs": []
1334 | },
1335 | {
1336 | "cell_type": "markdown",
1337 | "metadata": {
1338 | "id": "njMb9HtUEBrI",
1339 | "colab_type": "text"
1340 | },
1341 | "source": [
1342 | "### Statistical Features from Time Domain"
1343 | ]
1344 | },
1345 | {
1346 | "cell_type": "code",
1347 | "metadata": {
1348 | "id": "oSN2_c28D_Zr",
1349 | "colab_type": "code",
1350 | "colab": {}
1351 | },
1352 | "source": [
1353 | "class StatisticalTime(TransformerMixin):\n",
1354 | " '''\n",
1355 | " Extracts statistical features from the time domain.\n",
1356 | " '''\n",
1357 | " def fit(self, X, y=None):\n",
1358 | " return self\n",
1359 | " def transform(self, X, y=None):\n",
1360 | " return np.array([\n",
1361 | " [\n",
1362 | " rms(x), # root mean square\n",
1363 | " sra(x), # square root amplitude\n",
1364 | " stats.kurtosis(x), # kurtosis\n",
1365 | " stats.skew(x), # skewness\n",
1366 | " ppv(x), # peak to peak value\n",
1367 | " cf(x), # crest factor\n",
1368 | " ifa(x), # impact factor\n",
1369 | " mf(x), # margin factor\n",
1370 | " sf(x), # shape factor\n",
1371 | " kf(x), # kurtosis factor\n",
1372 | " ] for x in X[:,:,0]\n",
1373 | " ])\n"
1374 | ],
1375 | "execution_count": 18,
1376 | "outputs": []
1377 | },
1378 | {
1379 | "cell_type": "markdown",
1380 | "metadata": {
1381 | "id": "dXDWD3JZEnep",
1382 | "colab_type": "text"
1383 | },
1384 | "source": [
1385 | "### Statistical Features from Frequency Domain"
1386 | ]
1387 | },
1388 | {
1389 | "cell_type": "code",
1390 | "metadata": {
1391 | "id": "Sj3XTpVTEvAp",
1392 | "colab_type": "code",
1393 | "colab": {}
1394 | },
1395 | "source": [
1396 | "class StatisticalFrequency(TransformerMixin):\n",
1397 | " '''\n",
1398 | " Extracts statistical features from the frequency domain.\n",
1399 | " '''\n",
1400 | " def fit(self, X, y=None):\n",
1401 | " return self\n",
1402 | " def transform(self, X, y=None):\n",
1403 | " sig = []\n",
1404 | " for x in X[:,:,0]:\n",
1405 | " fx = np.absolute(np.fft.fft(x)) # transform x from time to frequency domain\n",
1406 | " fc = np.mean(fx) # frequency center\n",
1407 | " sig.append([\n",
1408 | " fc, # frequency center\n",
1409 | " rms(fx), # RMS from the frequency domain\n",
1410 | " rms(fx-fc), # Root Variance Frequency\n",
1411 | " ])\n",
1412 | " return np.array(sig)\n"
1413 | ],
1414 | "execution_count": 19,
1415 | "outputs": []
1416 | },
1417 | {
1418 | "cell_type": "markdown",
1419 | "metadata": {
1420 | "id": "c0YBmzTb6ARb",
1421 | "colab_type": "text"
1422 | },
1423 | "source": [
1424 | "###Statistical Features"
1425 | ]
1426 | },
1427 | {
1428 | "cell_type": "code",
1429 | "metadata": {
1430 | "id": "kep4ubkR6DR0",
1431 | "colab_type": "code",
1432 | "colab": {}
1433 | },
1434 | "source": [
1435 | "class Statistical(TransformerMixin):\n",
1436 | " '''\n",
1437 | " Extracts statistical features from both time and frequency domain.\n",
1438 | " '''\n",
1439 | " def fit(self, X, y=None):\n",
1440 | " return self\n",
1441 | " def transform(self, X, y=None):\n",
1442 | " st = StatisticalTime()\n",
1443 | " stfeats = st.transform(X)\n",
1444 | " sf = StatisticalFrequency()\n",
1445 | " sffeats = sf.transform(X)\n",
1446 | " return np.concatenate((stfeats,sffeats),axis=1)"
1447 | ],
1448 | "execution_count": 20,
1449 | "outputs": []
1450 | },
1451 | {
1452 | "cell_type": "markdown",
1453 | "metadata": {
1454 | "id": "ZuiVsHNzFORr",
1455 | "colab_type": "text"
1456 | },
1457 | "source": [
1458 | "###Wavelet Package Features"
1459 | ]
1460 | },
1461 | {
1462 | "cell_type": "code",
1463 | "metadata": {
1464 | "id": "oPd92xtJhaH3",
1465 | "colab_type": "code",
1466 | "cellView": "code",
1467 | "colab": {}
1468 | },
1469 | "source": [
1470 | "import numpy as np\n",
1471 | "import pywt\n",
1472 | "\n",
1473 | "class WaveletPackage(TransformerMixin):\n",
1474 | " '''\n",
1475 | " Extracts Wavelet Package features.\n",
1476 | " The features are calculated by the energy of the recomposed signal\n",
1477 | " of the leaf nodes coeficients.\n",
1478 | " '''\n",
1479 | " def fit(self, X, y=None):\n",
1480 | " return self\n",
1481 | " def transform(self, X, y=None):\n",
1482 | " def Energy(coeffs, k):\n",
1483 | " return np.sqrt(np.sum(np.array(coeffs[-k]) ** 2)) / len(coeffs[-k])\n",
1484 | " def getEnergy(wp):\n",
1485 | " coefs = np.asarray([n.data for n in wp.get_leaf_nodes(True)])\n",
1486 | " return np.asarray([Energy(coefs,i) for i in range(2**wp.maxlevel)])\n",
1487 | " return np.array([getEnergy(pywt.WaveletPacket(data=x, wavelet='db4', \n",
1488 | " mode='symmetric', maxlevel=4)\n",
1489 | " ) for x in X[:,:,0]])"
1490 | ],
1491 | "execution_count": 21,
1492 | "outputs": []
1493 | },
1494 | {
1495 | "cell_type": "markdown",
1496 | "metadata": {
1497 | "id": "i_sonHkjFYbB",
1498 | "colab_type": "text"
1499 | },
1500 | "source": [
1501 | "###Heterogeneus Features"
1502 | ]
1503 | },
1504 | {
1505 | "cell_type": "code",
1506 | "metadata": {
1507 | "id": "IZsZhuVfFZsQ",
1508 | "colab_type": "code",
1509 | "colab": {}
1510 | },
1511 | "source": [
1512 | "class Heterogeneous(TransformerMixin):\n",
1513 | " '''\n",
1514 | " Mixes Statistical and Wavelet Package features.\n",
1515 | " '''\n",
1516 | " def fit(self, X, y=None):\n",
1517 | " return self\n",
1518 | " def transform(self, X, y=None):\n",
1519 | " st = StatisticalTime()\n",
1520 | " stfeats = st.transform(X)\n",
1521 | " sf = StatisticalFrequency()\n",
1522 | " sffeats = sf.transform(X)\n",
1523 | " wp = WaveletPackage()\n",
1524 | " wpfeats = wp.transform(X)\n",
1525 | " return np.concatenate((stfeats,sffeats,wpfeats),axis=1)\n"
1526 | ],
1527 | "execution_count": 22,
1528 | "outputs": []
1529 | },
1530 | {
1531 | "cell_type": "markdown",
1532 | "metadata": {
1533 | "id": "1bnyL67EUcxV",
1534 | "colab_type": "text"
1535 | },
1536 | "source": [
1537 | "##K-NN with Heterogeneous Features"
1538 | ]
1539 | },
1540 | {
1541 | "cell_type": "code",
1542 | "metadata": {
1543 | "id": "F--sjKZRUh5G",
1544 | "colab_type": "code",
1545 | "colab": {
1546 | "base_uri": "https://localhost:8080/",
1547 | "height": 102
1548 | },
1549 | "outputId": "fcb36a08-9664-4d60-a877-2f8a68a32bc6"
1550 | },
1551 | "source": [
1552 | "from sklearn.neighbors import KNeighborsClassifier\n",
1553 | "from sklearn.pipeline import Pipeline\n",
1554 | "from sklearn.preprocessing import StandardScaler\n",
1555 | "from sklearn.model_selection import GridSearchCV\n",
1556 | "\n",
1557 | "knn = Pipeline([\n",
1558 | " ('FeatureExtraction', Heterogeneous()),\n",
1559 | " ('scaler', StandardScaler()),\n",
1560 | " ('knn', KNeighborsClassifier()),\n",
1561 | " ])\n",
1562 | "\n",
1563 | "parameters = {'knn__n_neighbors': list(range(1,16,2))}\n",
1564 | "if not debug:\n",
1565 | " knn = GridSearchCV(knn, parameters, verbose=verbose)\n",
1566 | "else:\n",
1567 | " knn = GridSearchCV(knn, {'knn__n_neighbors': list(range(1,4,2))}, verbose=verbose)\n",
1568 | "knn"
1569 | ],
1570 | "execution_count": 23,
1571 | "outputs": [
1572 | {
1573 | "output_type": "execute_result",
1574 | "data": {
1575 | "text/plain": [
1576 | "GridSearchCV(estimator=Pipeline(steps=[('FeatureExtraction',\n",
1577 | " <__main__.Heterogeneous object at 0x7ffaee193d30>),\n",
1578 | " ('scaler', StandardScaler()),\n",
1579 | " ('knn', KNeighborsClassifier())]),\n",
1580 | " param_grid={'knn__n_neighbors': [1, 3, 5, 7, 9, 11, 13, 15]})"
1581 | ]
1582 | },
1583 | "metadata": {
1584 | "tags": []
1585 | },
1586 | "execution_count": 23
1587 | }
1588 | ]
1589 | },
1590 | {
1591 | "cell_type": "markdown",
1592 | "metadata": {
1593 | "id": "d7DfMTS_ujeE",
1594 | "colab_type": "text"
1595 | },
1596 | "source": [
1597 | "##SVM with Heterogeneous Features"
1598 | ]
1599 | },
1600 | {
1601 | "cell_type": "code",
1602 | "metadata": {
1603 | "id": "6v_wXxiDupvF",
1604 | "colab_type": "code",
1605 | "colab": {
1606 | "base_uri": "https://localhost:8080/",
1607 | "height": 119
1608 | },
1609 | "outputId": "d198d0b0-4b20-4037-940d-c793b9d1c9b5"
1610 | },
1611 | "source": [
1612 | "from sklearn.svm import SVC\n",
1613 | "from sklearn.pipeline import Pipeline\n",
1614 | "from sklearn.preprocessing import StandardScaler\n",
1615 | "from sklearn.model_selection import GridSearchCV\n",
1616 | "\n",
1617 | "svm = Pipeline([\n",
1618 | " ('FeatureExtraction', Heterogeneous()),\n",
1619 | " ('scaler', StandardScaler()),\n",
1620 | " ('svc', SVC()),\n",
1621 | " ])\n",
1622 | "\n",
1623 | "parameters = {\n",
1624 | " 'svc__C': [10**x for x in range(-3,2)],\n",
1625 | " 'svc__gamma': [10**x for x in range(-3,1)],\n",
1626 | " }\n",
1627 | "if not debug:\n",
1628 | " svm = GridSearchCV(svm, parameters, verbose=verbose)\n",
1629 | "svm"
1630 | ],
1631 | "execution_count": 24,
1632 | "outputs": [
1633 | {
1634 | "output_type": "execute_result",
1635 | "data": {
1636 | "text/plain": [
1637 | "GridSearchCV(estimator=Pipeline(steps=[('FeatureExtraction',\n",
1638 | " <__main__.Heterogeneous object at 0x7ffaee403550>),\n",
1639 | " ('scaler', StandardScaler()),\n",
1640 | " ('svc', SVC())]),\n",
1641 | " param_grid={'svc__C': [0.001, 0.01, 0.1, 1, 10],\n",
1642 | " 'svc__gamma': [0.001, 0.01, 0.1, 1]})"
1643 | ]
1644 | },
1645 | "metadata": {
1646 | "tags": []
1647 | },
1648 | "execution_count": 24
1649 | }
1650 | ]
1651 | },
1652 | {
1653 | "cell_type": "markdown",
1654 | "metadata": {
1655 | "id": "yU36xsi4JGZv",
1656 | "colab_type": "text"
1657 | },
1658 | "source": [
1659 | "##Random Forest with Heterogeneous Features"
1660 | ]
1661 | },
1662 | {
1663 | "cell_type": "code",
1664 | "metadata": {
1665 | "id": "GXABo6HpJJY_",
1666 | "colab_type": "code",
1667 | "colab": {
1668 | "base_uri": "https://localhost:8080/",
1669 | "height": 119
1670 | },
1671 | "outputId": "41e9b1ad-c008-4c83-b8b6-a8d9385e9f4a"
1672 | },
1673 | "source": [
1674 | "from sklearn.ensemble import RandomForestClassifier\n",
1675 | "from sklearn.pipeline import Pipeline\n",
1676 | "from sklearn.preprocessing import StandardScaler\n",
1677 | "from sklearn.model_selection import GridSearchCV\n",
1678 | "\n",
1679 | "rf = Pipeline([\n",
1680 | " ('FeatureExtraction', Heterogeneous()),\n",
1681 | " ('scaler', StandardScaler()),\n",
1682 | " ('rf', RandomForestClassifier()),\n",
1683 | " ])\n",
1684 | "\n",
1685 | "parameters = {\n",
1686 | " \"rf__n_estimators\": [10, 20, 50, 100, 200, 500],\n",
1687 | " \"rf__max_features\": [1, 5, 10, 15, 20], #list(range(1,21)),\n",
1688 | " }\n",
1689 | "if not debug:\n",
1690 | " rf = GridSearchCV(rf, parameters, verbose=verbose)\n",
1691 | "rf"
1692 | ],
1693 | "execution_count": 25,
1694 | "outputs": [
1695 | {
1696 | "output_type": "execute_result",
1697 | "data": {
1698 | "text/plain": [
1699 | "GridSearchCV(estimator=Pipeline(steps=[('FeatureExtraction',\n",
1700 | " <__main__.Heterogeneous object at 0x7ffaff7a9910>),\n",
1701 | " ('scaler', StandardScaler()),\n",
1702 | " ('rf', RandomForestClassifier())]),\n",
1703 | " param_grid={'rf__max_features': [1, 5, 10, 15, 20],\n",
1704 | " 'rf__n_estimators': [10, 20, 50, 100, 200, 500]})"
1705 | ]
1706 | },
1707 | "metadata": {
1708 | "tags": []
1709 | },
1710 | "execution_count": 25
1711 | }
1712 | ]
1713 | },
1714 | {
1715 | "cell_type": "markdown",
1716 | "metadata": {
1717 | "id": "REy6ykvWSbJc",
1718 | "colab_type": "text"
1719 | },
1720 | "source": [
1721 | "##Convolutional Neural Network"
1722 | ]
1723 | },
1724 | {
1725 | "cell_type": "code",
1726 | "metadata": {
1727 | "id": "8YpHSjvNcEx5",
1728 | "colab_type": "code",
1729 | "colab": {
1730 | "base_uri": "https://localhost:8080/",
1731 | "height": 34
1732 | },
1733 | "outputId": "8a111264-ff86-4a96-ca5b-3e82add0d389"
1734 | },
1735 | "source": [
1736 | "try:\n",
1737 | " %tensorflow_version 2.x\n",
1738 | "except:\n",
1739 | " print(\"Out of Colab\")"
1740 | ],
1741 | "execution_count": 26,
1742 | "outputs": [
1743 | {
1744 | "output_type": "stream",
1745 | "text": [
1746 | "Out of Colab\n"
1747 | ],
1748 | "name": "stdout"
1749 | }
1750 | ]
1751 | },
1752 | {
1753 | "cell_type": "markdown",
1754 | "metadata": {
1755 | "id": "FF1hCxJ9b5h0",
1756 | "colab_type": "text"
1757 | },
1758 | "source": [
1759 | "###F1-score macro averaged implemented for Keras"
1760 | ]
1761 | },
1762 | {
1763 | "cell_type": "code",
1764 | "metadata": {
1765 | "id": "LCJErrQIcIZ0",
1766 | "colab_type": "code",
1767 | "colab": {}
1768 | },
1769 | "source": [
1770 | "import tensorflow as tf\n",
1771 | "from tensorflow.keras import backend as K\n",
1772 | "\n",
1773 | "def f1_score_macro(y_true,y_pred):\n",
1774 | " def recall(y_true, y_pred):\n",
1775 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
1776 | " possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n",
1777 | " recall = true_positives / (possible_positives + K.epsilon())\n",
1778 | " return recall\n",
1779 | " def precision(y_true, y_pred):\n",
1780 | " true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
1781 | " predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))\n",
1782 | " precision = true_positives / (predicted_positives + K.epsilon())\n",
1783 | " return precision\n",
1784 | " precision = precision(y_true, y_pred)\n",
1785 | " recall = recall(y_true, y_pred)\n",
1786 | " return 2*((precision*recall)/(precision+recall+K.epsilon()))"
1787 | ],
1788 | "execution_count": 27,
1789 | "outputs": []
1790 | },
1791 | {
1792 | "cell_type": "markdown",
1793 | "metadata": {
1794 | "id": "PeLXSU7-cLfk",
1795 | "colab_type": "text"
1796 | },
1797 | "source": [
1798 | "###ANN wrapped in a scikit-learn estimator."
1799 | ]
1800 | },
1801 | {
1802 | "cell_type": "code",
1803 | "metadata": {
1804 | "id": "ny7otiW6Siz_",
1805 | "colab_type": "code",
1806 | "colab": {}
1807 | },
1808 | "source": [
1809 | "from tensorflow.keras import layers\n",
1810 | "from tensorflow.keras.models import Sequential\n",
1811 | "from tensorflow.keras.utils import to_categorical\n",
1812 | "from sklearn.base import BaseEstimator, ClassifierMixin\n",
1813 | "import numpy as np\n",
1814 | "from tensorflow.keras.callbacks import EarlyStopping,ReduceLROnPlateau\n",
1815 | "\n",
1816 | "class ANN(BaseEstimator, ClassifierMixin):\n",
1817 | " def __init__(self, \n",
1818 | " dense_layer_sizes=[64], \n",
1819 | " kernel_size=32, \n",
1820 | " filters=32, \n",
1821 | " n_conv_layers=2,\n",
1822 | " pool_size=8,\n",
1823 | " dropout=0.25,\n",
1824 | " epochs=50,\n",
1825 | " validation_split=0.05,\n",
1826 | " optimizer='sgd'#'nadam'#'rmsprop'#\n",
1827 | " ):\n",
1828 | " self.dense_layer_sizes = dense_layer_sizes\n",
1829 | " self.kernel_size = kernel_size\n",
1830 | " self.filters = filters\n",
1831 | " self.n_conv_layers = n_conv_layers\n",
1832 | " self.pool_size = pool_size\n",
1833 | " self.dropout = dropout\n",
1834 | " self.epochs = epochs\n",
1835 | " self.validation_split = validation_split\n",
1836 | " self.optimizer = optimizer\n",
1837 | " \n",
1838 | " def fit(self, X, y=None):\n",
1839 | " dense_layer_sizes = self.dense_layer_sizes\n",
1840 | " kernel_size = self.kernel_size\n",
1841 | " filters = self.filters\n",
1842 | " n_conv_layers = self.n_conv_layers\n",
1843 | " pool_size = self.pool_size\n",
1844 | " dropout = self.dropout\n",
1845 | " epochs = self.epochs\n",
1846 | " optimizer = self.optimizer\n",
1847 | " validation_split = self.validation_split\n",
1848 | "\n",
1849 | " self.labels, ids = np.unique(y, return_inverse=True)\n",
1850 | " y_cat = to_categorical(ids)\n",
1851 | " num_classes = y_cat.shape[1]\n",
1852 | " \n",
1853 | " self.model = Sequential()\n",
1854 | " self.model.add(layers.InputLayer(input_shape=(X.shape[1],X.shape[-1])))\n",
1855 | " for _ in range(n_conv_layers):\n",
1856 | " self.model.add(layers.Conv1D(filters, kernel_size))#, padding='valid'))\n",
1857 | " self.model.add(layers.Activation('relu'))\n",
1858 | " if pool_size>1:\n",
1859 | " self.model.add(layers.MaxPooling1D(pool_size=pool_size))\n",
1860 | " #self.model.add(layers.Dropout(0.25))\n",
1861 | " self.model.add(layers.Flatten())\n",
1862 | " for layer_size in dense_layer_sizes:\n",
1863 | " self.model.add(layers.Dense(layer_size))\n",
1864 | " self.model.add(layers.Activation('relu'))\n",
1865 | " if dropout>0 and dropout<1:\n",
1866 | " self.model.add(layers.Dropout(dropout))\n",
1867 | " self.model.add(layers.Dense(num_classes))\n",
1868 | " self.model.add(layers.Activation('softmax'))\n",
1869 | " self.model.compile(loss='categorical_crossentropy',\n",
1870 | " optimizer=optimizer,\n",
1871 | " metrics=[f1_score_macro])\n",
1872 | " if validation_split>0 and validation_split<1:\n",
1873 | " prop = int(1/validation_split)\n",
1874 | " mask = np.array([i%prop==0 for i in range(len(y))])\n",
1875 | " self.history = self.model.fit(X[~mask], y_cat[~mask], epochs=epochs, \n",
1876 | " validation_data=(X[mask],y_cat[mask]),\n",
1877 | " callbacks=[EarlyStopping(patience=3), ReduceLROnPlateau()],\n",
1878 | " verbose=False\n",
1879 | " ) \n",
1880 | " else:\n",
1881 | " self.history = self.model.fit(X, y_cat, epochs=epochs, verbose=False) \n",
1882 | " \n",
1883 | " def predict_proba(self, X, y=None):\n",
1884 | " return self.model.predict(X)\n",
1885 | "\n",
1886 | " def predict(self, X, y=None):\n",
1887 | " predictions = self.model.predict(X)\n",
1888 | " return self.labels[np.argmax(predictions,axis=1)]"
1889 | ],
1890 | "execution_count": 28,
1891 | "outputs": []
1892 | },
1893 | {
1894 | "cell_type": "markdown",
1895 | "metadata": {
1896 | "id": "UL5Zv2-xcgEG",
1897 | "colab_type": "text"
1898 | },
1899 | "source": [
1900 | "###ANN instantiation"
1901 | ]
1902 | },
1903 | {
1904 | "cell_type": "code",
1905 | "metadata": {
1906 | "id": "MmQe7jeRcb6F",
1907 | "colab_type": "code",
1908 | "colab": {
1909 | "base_uri": "https://localhost:8080/",
1910 | "height": 68
1911 | },
1912 | "outputId": "84682a4e-d768-439d-ddf4-63853c6f3b15"
1913 | },
1914 | "source": [
1915 | "parameters = {\n",
1916 | " 'filters': [16, 32],\n",
1917 | " 'kernel_size': [16, 32],\n",
1918 | " 'n_conv_layers': [1, 2],\n",
1919 | " #'pool_size': [2, 4, 6, 8],\n",
1920 | " }\n",
1921 | "ann = ANN()\n",
1922 | "if not debug:\n",
1923 | " ann = GridSearchCV(ann, parameters, verbose=verbose)\n",
1924 | "ann"
1925 | ],
1926 | "execution_count": 29,
1927 | "outputs": [
1928 | {
1929 | "output_type": "execute_result",
1930 | "data": {
1931 | "text/plain": [
1932 | "GridSearchCV(estimator=ANN(),\n",
1933 | " param_grid={'filters': [16, 32], 'kernel_size': [16, 32],\n",
1934 | " 'n_conv_layers': [1, 2]})"
1935 | ]
1936 | },
1937 | "metadata": {
1938 | "tags": []
1939 | },
1940 | "execution_count": 29
1941 | }
1942 | ]
1943 | },
1944 | {
1945 | "cell_type": "markdown",
1946 | "metadata": {
1947 | "id": "4SQenKxoSkME",
1948 | "colab_type": "text"
1949 | },
1950 | "source": [
1951 | "##List of Estimators"
1952 | ]
1953 | },
1954 | {
1955 | "cell_type": "code",
1956 | "metadata": {
1957 | "id": "Vb6AGKFJJqvL",
1958 | "colab_type": "code",
1959 | "colab": {}
1960 | },
1961 | "source": [
1962 | "clfs = [\n",
1963 | " # ('KNN - KNeighborsClassifier, Heterogeneous Features', knn),\n",
1964 | " # ('SVM - SVC with Heterogeneous Features', svm),\n",
1965 | " # ('ANN - Artificial Neural Network with Convolutional Layers', ann),\n",
1966 | " ('RF - RandomForestClassifier with Heterogeneous Features', rf),\n",
1967 | " ]\n",
1968 | "if not debug:\n",
1969 | " dirres = 'cwru_rf'\n",
1970 | " # dirres = 'cwru_res'\n",
1971 | "else:\n",
1972 | " dirres = 'debugres'"
1973 | ],
1974 | "execution_count": 30,
1975 | "outputs": []
1976 | },
1977 | {
1978 | "cell_type": "markdown",
1979 | "metadata": {
1980 | "colab_type": "text",
1981 | "id": "dEl_vSYaq-s2"
1982 | },
1983 | "source": [
1984 | "#Performing Experiments"
1985 | ]
1986 | },
1987 | {
1988 | "cell_type": "code",
1989 | "metadata": {
1990 | "id": "CH4LVC3Zj3jC",
1991 | "colab_type": "code",
1992 | "colab": {
1993 | "base_uri": "https://localhost:8080/",
1994 | "height": 1000
1995 | },
1996 | "outputId": "6c9dbb69-d06b-4067-c7ff-3671ed2e8497"
1997 | },
1998 | "source": [
1999 | "import numpy as np\n",
2000 | "\n",
2001 | "scores = {}\n",
2002 | "trtime = {}\n",
2003 | "tetime = {}\n",
2004 | "# Estimators\n",
2005 | "for clf_name, estimator in clfs:\n",
2006 | " if clf_name not in scores:\n",
2007 | " scores[clf_name] = {}\n",
2008 | " trtime[clf_name] = {}\n",
2009 | " tetime[clf_name] = {}\n",
2010 | " print(\"*\"*(len(clf_name)+8),'\\n***',clf_name,'***\\n'+\"*\"*(len(clf_name)+8))\n",
2011 | " # Validation forms\n",
2012 | " for val_name in validations.keys():\n",
2013 | " print(\"#\"*(len(val_name)+8),'\\n###',val_name,'###\\n'+\"#\"*(len(val_name)+8))\n",
2014 | " # Number of repetitions\n",
2015 | " for r in range(validations[val_name].rounds):\n",
2016 | " round_str = \"Round {}\".format(r+1)\n",
2017 | " print(\"@\"*(len(round_str)+8),'\\n@@@',round_str,'@@@\\n'+\"@\"*(len(round_str)+8))\n",
2018 | " groups = validations[val_name].groups\n",
2019 | " if val_name not in scores[clf_name]:\n",
2020 | " scores[clf_name][val_name] = {}\n",
2021 | " validation = eval(validations[val_name].splitter_name\n",
2022 | " +'(4,shuffle='+str(validations[val_name].shuffle)\n",
2023 | " +',random_state='+str(random_state+r)+')')\n",
2024 | " score = experimenter(estimator, X, selected_y, groups, \n",
2025 | " scoring, validation, verbose)\n",
2026 | " for metric,s in score.items():\n",
2027 | " print(metric, ' \\t', s)\n",
2028 | " if metric not in scores[clf_name][val_name]:\n",
2029 | " scores[clf_name][val_name][metric] = []\n",
2030 | " scores[clf_name][val_name][metric].append(s)"
2031 | ],
2032 | "execution_count": 31,
2033 | "outputs": [
2034 | {
2035 | "output_type": "stream",
2036 | "text": [
2037 | "*************************************************************** \n",
2038 | "*** RF - RandomForestClassifier with Heterogeneous Features ***\n",
2039 | "***************************************************************\n",
2040 | "#################### \n",
2041 | "### Usual K-Fold ###\n",
2042 | "####################\n",
2043 | "@@@@@@@@@@@@@@@ \n",
2044 | "@@@ Round 1 @@@\n",
2045 | "@@@@@@@@@@@@@@@\n",
2046 | "fit_time \t [3927.1396749019623, 3926.9082491397858, 3922.4425542354584, 3890.117686986923]\n",
2047 | "score_time \t [9.353840112686157, 9.795709609985352, 9.666334867477417, 9.215996742248535]\n",
2048 | "test_accuracy \t [0.9887070168760309, 0.98997461928934, 0.9866751269035533, 0.9881979695431472]\n",
2049 | "test_f1_macro \t [0.9886136086699508, 0.9901947681580123, 0.986953524688857, 0.9883360840427271]\n",
2050 | "@@@@@@@@@@@@@@@ \n",
2051 | "@@@ Round 2 @@@\n",
2052 | "@@@@@@@@@@@@@@@\n",
2053 | "fit_time \t [3897.857188463211, 3916.1644649505615, 3834.143933534622, 3819.244938135147]\n",
2054 | "score_time \t [9.50651240348816, 9.435830116271973, 9.096308469772339, 9.4651358127594]\n",
2055 | "test_accuracy \t [0.9883263545235377, 0.9871827411167513, 0.9871827411167513, 0.9901015228426396]\n",
2056 | "test_f1_macro \t [0.9883608348253965, 0.9874413624249705, 0.9873425362681483, 0.9899966059721018]\n",
2057 | "@@@@@@@@@@@@@@@ \n",
2058 | "@@@ Round 3 @@@\n",
2059 | "@@@@@@@@@@@@@@@\n",
2060 | "fit_time \t [3814.7653987407684, 3840.2795662879944, 3819.407868862152, 3839.219763278961]\n",
2061 | "score_time \t [9.151491403579712, 9.516995429992676, 9.212405443191528, 9.516690492630005]\n",
2062 | "test_accuracy \t [0.9883263545235377, 0.9875634517766497, 0.9884517766497461, 0.9885786802030457]\n",
2063 | "test_f1_macro \t [0.9883440181460252, 0.9877101771624337, 0.9888442981541188, 0.9886192131365508]\n",
2064 | "@@@@@@@@@@@@@@@ \n",
2065 | "@@@ Round 4 @@@\n",
2066 | "@@@@@@@@@@@@@@@\n",
2067 | "fit_time \t [3808.983441591263, 3969.9939017295837, 4037.284947872162, 3986.3975813388824]\n",
2068 | "score_time \t [9.403703689575195, 9.974312543869019, 9.711014032363892, 9.435908555984497]\n",
2069 | "test_accuracy \t [0.9876919172693821, 0.9902284263959391, 0.9869289340101522, 0.9887055837563452]\n",
2070 | "test_f1_macro \t [0.9876078716592412, 0.990513192151286, 0.987057167184373, 0.9888626833159644]\n",
2071 | "@@@@@@@@@@@@@@@ \n",
2072 | "@@@ Round 5 @@@\n",
2073 | "@@@@@@@@@@@@@@@\n",
2074 | "fit_time \t [3947.0644967556, 3906.2969613075256, 3875.4539771080017, 3888.7648227214813]\n",
2075 | "score_time \t [9.5697500705719, 9.350503921508789, 9.444551229476929, 9.442347764968872]\n",
2076 | "test_accuracy \t [0.9892145666793554, 0.9860406091370558, 0.9902284263959391, 0.9881979695431472]\n",
2077 | "test_f1_macro \t [0.9892790984399402, 0.9862119307489339, 0.9902912120810636, 0.9883950548411021]\n",
2078 | "@@@@@@@@@@@@@@@ \n",
2079 | "@@@ Round 6 @@@\n",
2080 | "@@@@@@@@@@@@@@@\n",
2081 | "fit_time \t [3820.27006816864, 3681.8083748817444, 3659.199379682541, 2798.5435979366302]\n",
2082 | "score_time \t [9.140054702758789, 9.149131536483765, 7.1341307163238525, 6.6482696533203125]\n",
2083 | "test_accuracy \t [0.987565029818551, 0.98997461928934, 0.9869289340101522, 0.9874365482233503]\n",
2084 | "test_f1_macro \t [0.9876976020722907, 0.9898417825129663, 0.9869844556366362, 0.9878141195353268]\n",
2085 | "@@@@@@@@@@@@@@@ \n",
2086 | "@@@ Round 7 @@@\n",
2087 | "@@@@@@@@@@@@@@@\n",
2088 | "fit_time \t [2804.366559743881, 2842.63476061821, 2831.7397742271423, 2773.038544178009]\n",
2089 | "score_time \t [6.896226406097412, 6.899433374404907, 6.891962289810181, 6.762506008148193]\n",
2090 | "test_accuracy \t [0.9884532419743688, 0.9885786802030457, 0.9862944162436548, 0.98997461928934]\n",
2091 | "test_f1_macro \t [0.9886298254910706, 0.9885760908191492, 0.9864342655085306, 0.9898799942990018]\n",
2092 | "@@@@@@@@@@@@@@@ \n",
2093 | "@@@ Round 8 @@@\n",
2094 | "@@@@@@@@@@@@@@@\n",
2095 | "fit_time \t [2797.9537930488586, 2780.3243157863617, 2799.031191110611, 2836.1921730041504]\n",
2096 | "score_time \t [6.8698036670684814, 6.825961589813232, 6.896572113037109, 6.948557138442993]\n",
2097 | "test_accuracy \t [0.987565029818551, 0.9884517766497461, 0.9884517766497461, 0.9887055837563452]\n",
2098 | "test_f1_macro \t [0.9875351708470765, 0.988539246865983, 0.9886221962501263, 0.9888823429350329]\n"
2099 | ],
2100 | "name": "stdout"
2101 | }
2102 | ]
2103 | },
2104 | {
2105 | "cell_type": "markdown",
2106 | "metadata": {
2107 | "id": "QJ-qe0MIhM-z",
2108 | "colab_type": "text"
2109 | },
2110 | "source": [
2111 | "##Save results"
2112 | ]
2113 | },
2114 | {
2115 | "cell_type": "code",
2116 | "metadata": {
2117 | "id": "qrp8uvOonKpd",
2118 | "colab_type": "code",
2119 | "colab": {
2120 | "base_uri": "https://localhost:8080/",
2121 | "height": 629
2122 | },
2123 | "outputId": "e1b8b9af-97da-48ab-c064-6a859ace8d6e"
2124 | },
2125 | "source": [
2126 | "from pathlib import Path\n",
2127 | "\n",
2128 | "clf = {}\n",
2129 | "val = {}\n",
2130 | "src = {}\n",
2131 | "for c, clf_name in enumerate(scores.keys()):\n",
2132 | " if c not in clf:\n",
2133 | " clf[c] = clf_name\n",
2134 | " for v, val_name in enumerate(scores[clf_name].keys()):\n",
2135 | " if v not in val:\n",
2136 | " val[v] = val_name\n",
2137 | " for s, scr_name in enumerate(scores[clf_name][val_name].keys()):\n",
2138 | " scores[clf_name][val_name][scr_name] = np.array(scores[clf_name][val_name][scr_name])\n",
2139 | " if s not in src:\n",
2140 | " src[s] = scr_name\n",
2141 | " Path(dirres).mkdir(parents=True, exist_ok=True)\n",
2142 | " np.savetxt('{}/{}-{}-{}.txt'.format(dirres,clf_name,val_name,scr_name), \n",
2143 | " scores[clf_name][val_name][scr_name], delimiter=',')\n",
2144 | " print('{}/{} - {} - {}\\n'.format(dirres,clf_name.split('-')[0],val_name,scr_name),\n",
2145 | " scores[clf_name][val_name][scr_name])\n"
2146 | ],
2147 | "execution_count": 32,
2148 | "outputs": [
2149 | {
2150 | "output_type": "stream",
2151 | "text": [
2152 | "cwru_rf/RF - Usual K-Fold - fit_time\n",
2153 | " [[3927.1396749 3926.90824914 3922.44255424 3890.11768699]\n",
2154 | " [3897.85718846 3916.16446495 3834.14393353 3819.24493814]\n",
2155 | " [3814.76539874 3840.27956629 3819.40786886 3839.21976328]\n",
2156 | " [3808.98344159 3969.99390173 4037.28494787 3986.39758134]\n",
2157 | " [3947.06449676 3906.29696131 3875.45397711 3888.76482272]\n",
2158 | " [3820.27006817 3681.80837488 3659.19937968 2798.54359794]\n",
2159 | " [2804.36655974 2842.63476062 2831.73977423 2773.03854418]\n",
2160 | " [2797.95379305 2780.32431579 2799.03119111 2836.192173 ]]\n",
2161 | "cwru_rf/RF - Usual K-Fold - score_time\n",
2162 | " [[9.35384011 9.79570961 9.66633487 9.21599674]\n",
2163 | " [9.5065124 9.43583012 9.09630847 9.46513581]\n",
2164 | " [9.1514914 9.51699543 9.21240544 9.51669049]\n",
2165 | " [9.40370369 9.97431254 9.71101403 9.43590856]\n",
2166 | " [9.56975007 9.35050392 9.44455123 9.44234776]\n",
2167 | " [9.1400547 9.14913154 7.13413072 6.64826965]\n",
2168 | " [6.89622641 6.89943337 6.89196229 6.76250601]\n",
2169 | " [6.86980367 6.82596159 6.89657211 6.94855714]]\n",
2170 | "cwru_rf/RF - Usual K-Fold - test_accuracy\n",
2171 | " [[0.98870702 0.98997462 0.98667513 0.98819797]\n",
2172 | " [0.98832635 0.98718274 0.98718274 0.99010152]\n",
2173 | " [0.98832635 0.98756345 0.98845178 0.98857868]\n",
2174 | " [0.98769192 0.99022843 0.98692893 0.98870558]\n",
2175 | " [0.98921457 0.98604061 0.99022843 0.98819797]\n",
2176 | " [0.98756503 0.98997462 0.98692893 0.98743655]\n",
2177 | " [0.98845324 0.98857868 0.98629442 0.98997462]\n",
2178 | " [0.98756503 0.98845178 0.98845178 0.98870558]]\n",
2179 | "cwru_rf/RF - Usual K-Fold - test_f1_macro\n",
2180 | " [[0.98861361 0.99019477 0.98695352 0.98833608]\n",
2181 | " [0.98836083 0.98744136 0.98734254 0.98999661]\n",
2182 | " [0.98834402 0.98771018 0.9888443 0.98861921]\n",
2183 | " [0.98760787 0.99051319 0.98705717 0.98886268]\n",
2184 | " [0.9892791 0.98621193 0.99029121 0.98839505]\n",
2185 | " [0.9876976 0.98984178 0.98698446 0.98781412]\n",
2186 | " [0.98862983 0.98857609 0.98643427 0.98987999]\n",
2187 | " [0.98753517 0.98853925 0.9886222 0.98888234]]\n"
2188 | ],
2189 | "name": "stdout"
2190 | }
2191 | ]
2192 | },
2193 | {
2194 | "cell_type": "markdown",
2195 | "metadata": {
2196 | "id": "rkb8XMN-Ht58",
2197 | "colab_type": "text"
2198 | },
2199 | "source": [
2200 | "##Average & Standard Deviation"
2201 | ]
2202 | },
2203 | {
2204 | "cell_type": "code",
2205 | "metadata": {
2206 | "id": "hJxdjboqtuNb",
2207 | "colab_type": "code",
2208 | "colab": {
2209 | "base_uri": "https://localhost:8080/",
2210 | "height": 289
2211 | },
2212 | "outputId": "38f75cc3-b7c3-48ae-e9e9-9828492b6546"
2213 | },
2214 | "source": [
2215 | "c,v,s = len(clf),len(val),len(src)\n",
2216 | "for i in range(s):\n",
2217 | " print(src[i])\n",
2218 | " for k in range(v):\n",
2219 | " print('\\t'+val[k]+' ', end='')\n",
2220 | " print()\n",
2221 | " for j in range(c):\n",
2222 | " print(clf[j].split('-')[0], end='\\t')\n",
2223 | " for k in range(v):\n",
2224 | " print(\"{0:.3f} ({1:.3f})\".format(\n",
2225 | " scores[clf[j]][val[k]][src[i]].mean(),\n",
2226 | " scores[clf[j]][val[k]][src[i]].std()), end='\\t')\n",
2227 | " print()\n",
2228 | " print()"
2229 | ],
2230 | "execution_count": 33,
2231 | "outputs": [
2232 | {
2233 | "output_type": "stream",
2234 | "text": [
2235 | "fit_time\n",
2236 | "\tUsual K-Fold \n",
2237 | "RF \t3571.657 (483.917)\t\n",
2238 | "\n",
2239 | "score_time\n",
2240 | "\tUsual K-Fold \n",
2241 | "RF \t8.635 (1.201)\t\n",
2242 | "\n",
2243 | "test_accuracy\n",
2244 | "\tUsual K-Fold \n",
2245 | "RF \t0.988 (0.001)\t\n",
2246 | "\n",
2247 | "test_f1_macro\n",
2248 | "\tUsual K-Fold \n",
2249 | "RF \t0.988 (0.001)\t\n",
2250 | "\n"
2251 | ],
2252 | "name": "stdout"
2253 | }
2254 | ]
2255 | },
2256 | {
2257 | "cell_type": "markdown",
2258 | "metadata": {
2259 | "id": "qZG2ZnLQAeJL",
2260 | "colab_type": "text"
2261 | },
2262 | "source": [
2263 | "## Experiment results"
2264 | ]
2265 | },
2266 | {
2267 | "cell_type": "markdown",
2268 | "metadata": {
2269 | "id": "manFfTfh_9Ta",
2270 | "colab_type": "text"
2271 | },
2272 | "source": [
2273 | ""
2274 | ]
2275 | }
2276 | ]
2277 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Signal Processing
2 | Experiments performed to analyse the CNN on the CWRU data set.
3 |
--------------------------------------------------------------------------------
/cwru_segmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "cwru-segmentation.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "toc_visible": true,
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "colab_type": "text",
32 | "id": "O_WyaUYJoAAB"
33 | },
34 | "source": [
35 | "# CWRU files.\n"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {
41 | "id": "BxnouJISoB7q",
42 | "colab_type": "text"
43 | },
44 | "source": [
45 | "Associate each Matlab file name to a bearing condition in a Python dictionary. The dictionary keys identify the conditions.\n",
46 | "\n",
47 | "There are only four normal conditions, with loads of 0, 1, 2 and 3 hp. All conditions end with an underscore character followed by an algarism representing the load applied during the acquisitions. The remaining conditions follow the pattern:\n",
48 | "\n",
49 | "First two characters represent the bearing location, i.e. drive end (DE) and fan end (FE). The following two characters represent the failure location in the bearing, i.e. ball (BA), Inner Race (IR) and Outer Race (OR). The next three algarisms indicate the severity of the failure, where 007 stands for 0.007 inches and 0021 for 0.021 inches. For Outer Race failures, the character @ is followed by a number that indicates different load zones.\n",
50 | "\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "metadata": {
56 | "id": "ZVVoHvQEn8hi",
57 | "colab_type": "code",
58 | "colab": {}
59 | },
60 | "source": [
61 | "debug = False\n",
62 | "# size of each segment\n",
63 | "sample_size = 32768\n",
64 | "if not debug:\n",
65 | " sample_size = 512\n",
66 | "acquisitions = {}\n",
67 | "# Normal\n",
68 | "acquisitions[\"Normal_0\"] = \"97.mat\"\n",
69 | "acquisitions[\"Normal_1\"] = \"98.mat\"\n",
70 | "acquisitions[\"Normal_2\"] = \"99.mat\"\n",
71 | "acquisitions[\"Normal_3\"] = \"100.mat\"\n",
72 | "# DE Inner Race 0.007 inches\n",
73 | "acquisitions[\"DEIR.007_0\"] = \"105.mat\"\n",
74 | "acquisitions[\"DEIR.007_1\"] = \"106.mat\"\n",
75 | "acquisitions[\"DEIR.007_2\"] = \"107.mat\"\n",
76 | "acquisitions[\"DEIR.007_3\"] = \"108.mat\"\n",
77 | "# DE Ball 0.007 inches\n",
78 | "acquisitions[\"DEB.007_0\"] = \"118.mat\"\n",
79 | "acquisitions[\"DEB.007_1\"] = \"119.mat\"\n",
80 | "acquisitions[\"DEB.007_2\"] = \"120.mat\"\n",
81 | "acquisitions[\"DEB.007_3\"] = \"121.mat\"\n",
82 | "# DE Outer race 0.007 inches centered @6:00\n",
83 | "acquisitions[\"DEOR.007@6_0\"] = \"130.mat\"\n",
84 | "acquisitions[\"DEOR.007@6_1\"] = \"131.mat\"\n",
85 | "acquisitions[\"DEOR.007@6_2\"] = \"132.mat\"\n",
86 | "acquisitions[\"DEOR.007@6_3\"] = \"133.mat\"\n",
87 | "# DE Outer race 0.007 inches centered @3:00\n",
88 | "acquisitions[\"DEOR.007@3_0\"] = \"144.mat\"\n",
89 | "acquisitions[\"DEOR.007@3_1\"] = \"145.mat\"\n",
90 | "acquisitions[\"DEOR.007@3_2\"] = \"146.mat\"\n",
91 | "acquisitions[\"DEOR.007@3_3\"] = \"147.mat\"\n",
92 | "# DE Outer race 0.007 inches centered @12:00\n",
93 | "acquisitions[\"DEOR.007@12_0\"] = \"156.mat\"\n",
94 | "acquisitions[\"DEOR.007@12_1\"] = \"158.mat\"\n",
95 | "acquisitions[\"DEOR.007@12_2\"] = \"159.mat\"\n",
96 | "acquisitions[\"DEOR.007@12_3\"] = \"160.mat\"\n",
97 | "# DE Inner Race 0.014 inches\n",
98 | "acquisitions[\"DEIR.014_0\"] = \"169.mat\"\n",
99 | "acquisitions[\"DEIR.014_1\"] = \"170.mat\"\n",
100 | "acquisitions[\"DEIR.014_2\"] = \"171.mat\"\n",
101 | "acquisitions[\"DEIR.014_3\"] = \"172.mat\"\n",
102 | "# DE Ball 0.014 inches\n",
103 | "acquisitions[\"DEB.014_0\"] = \"185.mat\"\n",
104 | "acquisitions[\"DEB.014_1\"] = \"186.mat\"\n",
105 | "acquisitions[\"DEB.014_2\"] = \"187.mat\"\n",
106 | "acquisitions[\"DEB.014_3\"] = \"188.mat\"\n",
107 | "# DE Outer race 0.014 inches centered @6:00\n",
108 | "acquisitions[\"DEOR.014@6_0\"] = \"197.mat\"\n",
109 | "acquisitions[\"DEOR.014@6_1\"] = \"198.mat\"\n",
110 | "acquisitions[\"DEOR.014@6_2\"] = \"199.mat\"\n",
111 | "acquisitions[\"DEOR.014@6_3\"] = \"200.mat\"\n",
112 | "# DE Ball 0.021 inches\n",
113 | "acquisitions[\"DEB.021_0\"] = \"222.mat\"\n",
114 | "acquisitions[\"DEB.021_1\"] = \"223.mat\"\n",
115 | "acquisitions[\"DEB.021_2\"] = \"224.mat\"\n",
116 | "acquisitions[\"DEB.021_3\"] = \"225.mat\"\n",
117 | "# FE Inner Race 0.021 inches\n",
118 | "acquisitions[\"FEIR.021_0\"] = \"270.mat\"\n",
119 | "acquisitions[\"FEIR.021_1\"] = \"271.mat\"\n",
120 | "acquisitions[\"FEIR.021_2\"] = \"272.mat\"\n",
121 | "acquisitions[\"FEIR.021_3\"] = \"273.mat\"\n",
122 | "# FE Inner Race 0.014 inches\n",
123 | "acquisitions[\"FEIR.014_0\"] = \"274.mat\"\n",
124 | "acquisitions[\"FEIR.014_1\"] = \"275.mat\"\n",
125 | "acquisitions[\"FEIR.014_2\"] = \"276.mat\"\n",
126 | "acquisitions[\"FEIR.014_3\"] = \"277.mat\"\n",
127 | "# FE Ball 0.007 inches\n",
128 | "acquisitions[\"FEB.007_0\"] = \"282.mat\"\n",
129 | "acquisitions[\"FEB.007_1\"] = \"283.mat\"\n",
130 | "acquisitions[\"FEB.007_2\"] = \"284.mat\"\n",
131 | "acquisitions[\"FEB.007_3\"] = \"285.mat\"\n",
132 | "# DE Inner Race 0.021 inches\n",
133 | "acquisitions[\"DEIR.021_0\"] = \"209.mat\"\n",
134 | "acquisitions[\"DEIR.021_1\"] = \"210.mat\"\n",
135 | "acquisitions[\"DEIR.021_2\"] = \"211.mat\"\n",
136 | "acquisitions[\"DEIR.021_3\"] = \"212.mat\"\n",
137 | "# DE Outer race 0.021 inches centered @6:00\n",
138 | "acquisitions[\"DEOR.021@6_0\"] = \"234.mat\"\n",
139 | "acquisitions[\"DEOR.021@6_1\"] = \"235.mat\"\n",
140 | "acquisitions[\"DEOR.021@6_2\"] = \"236.mat\"\n",
141 | "acquisitions[\"DEOR.021@6_3\"] = \"237.mat\"\n",
142 | "# DE Outer race 0.021 inches centered @3:00\n",
143 | "acquisitions[\"DEOR.021@3_0\"] = \"246.mat\"\n",
144 | "acquisitions[\"DEOR.021@3_1\"] = \"247.mat\"\n",
145 | "acquisitions[\"DEOR.021@3_2\"] = \"248.mat\"\n",
146 | "acquisitions[\"DEOR.021@3_3\"] = \"249.mat\"\n",
147 | "# DE Outer race 0.021 inches centered @12:00\n",
148 | "acquisitions[\"DEOR.021@12_0\"] = \"258.mat\"\n",
149 | "acquisitions[\"DEOR.021@12_1\"] = \"259.mat\"\n",
150 | "acquisitions[\"DEOR.021@12_2\"] = \"260.mat\"\n",
151 | "acquisitions[\"DEOR.021@12_3\"] = \"261.mat\"\n",
152 | "# FE Inner Race 0.007 inches\n",
153 | "acquisitions[\"FEIR.007_0\"] = \"278.mat\"\n",
154 | "acquisitions[\"FEIR.007_1\"] = \"279.mat\"\n",
155 | "acquisitions[\"FEIR.007_2\"] = \"280.mat\"\n",
156 | "acquisitions[\"FEIR.007_3\"] = \"281.mat\"\n",
157 | "# FE Ball 0.014 inches\n",
158 | "acquisitions[\"FEB.014_0\"] = \"286.mat\"\n",
159 | "acquisitions[\"FEB.014_1\"] = \"287.mat\"\n",
160 | "acquisitions[\"FEB.014_2\"] = \"288.mat\"\n",
161 | "acquisitions[\"FEB.014_3\"] = \"289.mat\"\n",
162 | "# FE Ball 0.021 inches\n",
163 | "acquisitions[\"FEB.021_0\"] = \"290.mat\"\n",
164 | "acquisitions[\"FEB.021_1\"] = \"291.mat\"\n",
165 | "acquisitions[\"FEB.021_2\"] = \"292.mat\"\n",
166 | "acquisitions[\"FEB.021_3\"] = \"293.mat\"\n",
167 | "# FE Outer race 0.007 inches centered @6:00\n",
168 | "acquisitions[\"FEOR.007@6_0\"] = \"294.mat\"\n",
169 | "acquisitions[\"FEOR.007@6_1\"] = \"295.mat\"\n",
170 | "acquisitions[\"FEOR.007@6_2\"] = \"296.mat\"\n",
171 | "acquisitions[\"FEOR.007@6_3\"] = \"297.mat\"\n",
172 | "# FE Outer race 0.007 inches centered @3:00\n",
173 | "acquisitions[\"FEOR.007@3_0\"] = \"298.mat\"\n",
174 | "acquisitions[\"FEOR.007@3_1\"] = \"299.mat\"\n",
175 | "acquisitions[\"FEOR.007@3_2\"] = \"300.mat\"\n",
176 | "acquisitions[\"FEOR.007@3_3\"] = \"301.mat\"\n",
177 | "# FE Outer race 0.007 inches centered @12:00\n",
178 | "acquisitions[\"FEOR.007@12_0\"] = \"302.mat\"\n",
179 | "acquisitions[\"FEOR.007@12_1\"] = \"305.mat\"\n",
180 | "acquisitions[\"FEOR.007@12_2\"] = \"306.mat\"\n",
181 | "acquisitions[\"FEOR.007@12_3\"] = \"307.mat\"\n",
182 | "# FE Outer race 0.014 inches centered @3:00\n",
183 | "acquisitions[\"FEOR.014@3_0\"] = \"310.mat\"\n",
184 | "acquisitions[\"FEOR.014@3_1\"] = \"309.mat\"\n",
185 | "acquisitions[\"FEOR.014@3_2\"] = \"311.mat\"\n",
186 | "acquisitions[\"FEOR.014@3_3\"] = \"312.mat\"\n",
187 | "# FE Outer race 0.014 inches centered @6:00\n",
188 | "acquisitions[\"FEOR.014@6_0\"] = \"313.mat\"\n",
189 | "# FE Outer race 0.021 inches centered @6:00\n",
190 | "acquisitions[\"FEOR.021@6_0\"] = \"315.mat\"\n",
191 | "# FE Outer race 0.021 inches centered @3:00\n",
192 | "acquisitions[\"FEOR.021@3_1\"] = \"316.mat\"\n",
193 | "acquisitions[\"FEOR.021@3_2\"] = \"317.mat\"\n",
194 | "acquisitions[\"FEOR.021@3_3\"] = \"318.mat\"\n",
195 | "# DE Inner Race 0.028 inches\n",
196 | "acquisitions[\"DEIR.028_0\"] = \"3001.mat\"\n",
197 | "acquisitions[\"DEIR.028_1\"] = \"3002.mat\"\n",
198 | "acquisitions[\"DEIR.028_2\"] = \"3003.mat\"\n",
199 | "acquisitions[\"DEIR.028_3\"] = \"3004.mat\"\n",
200 | "# DE Ball 0.028 inches\n",
201 | "acquisitions[\"DEB.028_0\"] = \"3005.mat\"\n",
202 | "acquisitions[\"DEB.028_1\"] = \"3006.mat\"\n",
203 | "acquisitions[\"DEB.028_2\"] = \"3007.mat\"\n",
204 | "acquisitions[\"DEB.028_3\"] = \"3008.mat\""
205 | ],
206 | "execution_count": 0,
207 | "outputs": []
208 | },
209 | {
210 | "cell_type": "markdown",
211 | "metadata": {
212 | "id": "hQu5HBLioPIu",
213 | "colab_type": "text"
214 | },
215 | "source": [
216 | "# Functions definitions\n"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "metadata": {
222 | "id": "YcKzJQ-XoSCg",
223 | "colab_type": "code",
224 | "colab": {}
225 | },
226 | "source": [
227 | "def get_labels_dict(acquisitions, separator='_', detectPosition=True):\n",
228 | " \"\"\"Generate a dictionary linking the labels with values to keep consistence.\"\"\"\n",
229 | " labels_dict = {}\n",
230 | " value = 0\n",
231 | " for key in acquisitions.keys():\n",
232 | " key = key.split('_')[0]\n",
233 | " key = key.split(separator)\n",
234 | " if key[0] == \"Normal\" or detectPosition:\n",
235 | " label = key[0]\n",
236 | " else:\n",
237 | " label = key[0][2:]\n",
238 | " if not label in labels_dict:\n",
239 | " labels_dict[label] = value\n",
240 | " value += 1\n",
241 | " return labels_dict"
242 | ],
243 | "execution_count": 0,
244 | "outputs": []
245 | },
246 | {
247 | "cell_type": "markdown",
248 | "metadata": {
249 | "id": "bL85KQ_noXAV",
250 | "colab_type": "text"
251 | },
252 | "source": [
253 | "Convert Matlab file into tensors.\n"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "metadata": {
259 | "id": "i52OZlbIodym",
260 | "colab_type": "code",
261 | "colab": {}
262 | },
263 | "source": [
264 | "import scipy.io\n",
265 | "import numpy as np\n",
266 | "def acquisition2tensor(file_name, position=None, sample_size=sample_size):\n",
267 | " \"\"\"\n",
268 | " Convert Matlab file into tensors.\n",
269 | " The file is divided in segments of sample_size values.\n",
270 | " \"\"\"\n",
271 | " print(file_name, end=' ')\n",
272 | " matlab_file = scipy.io.loadmat(file_name)\n",
273 | " DE_samples = []\n",
274 | " FE_samples = []\n",
275 | " \n",
276 | " #signal segmentation\n",
277 | " signal_begin = 0\n",
278 | " if position == None:\n",
279 | " DE_time = [key for key in matlab_file if key.endswith(\"DE_time\")][0] #Find the DRIVE END acquisition key name\n",
280 | " FE_time = [key for key in matlab_file if key.endswith(\"FE_time\")][0] #Find the FAN END acquisition key name\n",
281 | " acquisition_size = max(len(matlab_file[DE_time]),len(matlab_file[FE_time]))\n",
282 | " while signal_begin + sample_size < acquisition_size:\n",
283 | " DE_samples.append([item for sublist in matlab_file[DE_time][signal_begin:signal_begin+sample_size] for item in sublist])\n",
284 | " FE_samples.append([item for sublist in matlab_file[FE_time][signal_begin:signal_begin+sample_size] for item in sublist])\n",
285 | " signal_begin += sample_size\n",
286 | " sample_tensor = np.stack([DE_samples,FE_samples],axis=2).astype('float32')\n",
287 | " elif position == 'DE':\n",
288 | " DE_time = [key for key in matlab_file if key.endswith(\"DE_time\")][0] #Find the DRIVE END acquisition key name\n",
289 | " acquisition_size = len(matlab_file[DE_time])\n",
290 | " while signal_begin + sample_size < acquisition_size:\n",
291 | " DE_samples.append([item for sublist in matlab_file[DE_time][signal_begin:signal_begin+sample_size] for item in sublist])\n",
292 | " signal_begin += sample_size\n",
293 | " sample_tensor = np.stack([DE_samples],axis=2).astype('float32')\n",
294 | " elif position == 'FE':\n",
295 | " FE_time = [key for key in matlab_file if key.endswith(\"FE_time\")][0] #Find the FAN END acquisition key name\n",
296 | " acquisition_size = len(matlab_file[FE_time])\n",
297 | " while signal_begin + sample_size < acquisition_size:\n",
298 | " FE_samples.append([item for sublist in matlab_file[FE_time][signal_begin:signal_begin+sample_size] for item in sublist])\n",
299 | " signal_begin += sample_size\n",
300 | " sample_tensor = np.stack([FE_samples],axis=2).astype('float32')\n",
301 | " return sample_tensor"
302 | ],
303 | "execution_count": 0,
304 | "outputs": []
305 | },
306 | {
307 | "cell_type": "markdown",
308 | "metadata": {
309 | "id": "GlnwDCj9olhH",
310 | "colab_type": "text"
311 | },
312 | "source": [
313 | "Extract datasets from acquisitions.\n"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "metadata": {
319 | "id": "-xJlTxjTomWi",
320 | "colab_type": "code",
321 | "colab": {}
322 | },
323 | "source": [
324 | "def concatenate_datasets(xd,yd,xo,yo):\n",
325 | " \"\"\"\n",
326 | " xd: destination patterns tensor\n",
327 | " yd: destination labels tensor\n",
328 | " xo: origin patterns tensor to be concateneted \n",
329 | " yo: origin labels tensor to be concateneted \n",
330 | " \"\"\"\n",
331 | " if xd is None or yd is None:\n",
332 | " xd = xo\n",
333 | " yd = yo\n",
334 | " else:\n",
335 | " xd = np.concatenate((xd,xo))\n",
336 | " yd = np.concatenate((yd,yo))\n",
337 | " return xd,yd\n",
338 | "\n",
339 | "import urllib.request\n",
340 | "import os.path\n",
341 | "\n",
342 | "def acquisitions_from_substr(substr, acquisitions, labels_dict, position=None,\n",
343 | " url=\"http://csegroups.case.edu/sites/default/files/bearingdatacenter/files/Datafiles/\"):\n",
344 | " \"\"\"\n",
345 | " Extract samples from all files with some load.\n",
346 | " \"\"\"\n",
347 | " samples = None\n",
348 | " labels = None\n",
349 | " for key in acquisitions:\n",
350 | " if str(substr) in key:\n",
351 | " file_name = acquisitions[key]\n",
352 | " if not os.path.exists(file_name):\n",
353 | " urllib.request.urlretrieve(url+file_name, file_name)\n",
354 | " if substr[:2] == key[:2] and position == None:\n",
355 | " acquisition_samples = acquisition2tensor(file_name)\n",
356 | " elif position =='DE':\n",
357 | " acquisition_samples = acquisition2tensor(file_name, 'DE')\n",
358 | " elif position =='FE':\n",
359 | " acquisition_samples = acquisition2tensor(file_name, 'FE')\n",
360 | " else:\n",
361 | " acquisition_samples = acquisition2tensor(file_name, key[:2])\n",
362 | " for label in labels_dict.keys():\n",
363 | " if label in key:\n",
364 | " break\n",
365 | " acquisition_labels = np.ones(acquisition_samples.shape[0])*labels_dict[label]\n",
366 | " samples,labels = concatenate_datasets(samples,labels,acquisition_samples,acquisition_labels)\n",
367 | " print(substr)\n",
368 | " return samples,labels"
369 | ],
370 | "execution_count": 0,
371 | "outputs": []
372 | },
373 | {
374 | "cell_type": "markdown",
375 | "metadata": {
376 | "id": "h1bmT_Proq-y",
377 | "colab_type": "text"
378 | },
379 | "source": [
380 | "# Downloading and Matlab files\n"
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {
386 | "id": "LOwEbwZPowH2",
387 | "colab_type": "text"
388 | },
389 | "source": [
390 | "Extract samples.\n"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "metadata": {
396 | "id": "iZG8Cuylorth",
397 | "colab_type": "code",
398 | "outputId": "e5e81450-0bb8-44c8-ce7f-ee3bff174b36",
399 | "colab": {
400 | "base_uri": "https://localhost:8080/",
401 | "height": 258
402 | }
403 | },
404 | "source": [
405 | "labels_dict = get_labels_dict(acquisitions, '.', False)\n",
406 | "print(labels_dict)\n",
407 | "def normal_indenpendent_position_acquisitions(load,acquisitions,labels_dict):\n",
408 | " x,y = None,None\n",
409 | " for position in ['DE','FE']:\n",
410 | " xn,yn = acquisitions_from_substr('Normal_'+str(load),acquisitions,labels_dict,position)\n",
411 | " x,y = concatenate_datasets(x,y,xn,yn)\n",
412 | " return x,y\n",
413 | "\n",
414 | "xn_0,yn_0 = normal_indenpendent_position_acquisitions(0,acquisitions,labels_dict)\n",
415 | "xn_1,yn_1 = normal_indenpendent_position_acquisitions(1,acquisitions,labels_dict)\n",
416 | "xn_2,yn_2 = normal_indenpendent_position_acquisitions(2,acquisitions,labels_dict)\n",
417 | "xn_3,yn_3 = normal_indenpendent_position_acquisitions(3,acquisitions,labels_dict)\n",
418 | "\n",
419 | "x007,y007 = acquisitions_from_substr('007',acquisitions,labels_dict)\n",
420 | "x014,y014 = acquisitions_from_substr('014',acquisitions,labels_dict)\n",
421 | "x021,y021 = acquisitions_from_substr('021',acquisitions,labels_dict)\n",
422 | "x028,y028 = acquisitions_from_substr('028',acquisitions,labels_dict)\n",
423 | "\n",
424 | "severities = ['007','014','021','028']"
425 | ],
426 | "execution_count": 0,
427 | "outputs": [
428 | {
429 | "output_type": "stream",
430 | "text": [
431 | "{'Normal': 0, 'IR': 1, 'B': 2, 'OR': 3}\n",
432 | "97.mat Normal_0\n",
433 | "97.mat Normal_0\n",
434 | "98.mat Normal_1\n",
435 | "98.mat Normal_1\n",
436 | "99.mat Normal_2\n",
437 | "99.mat Normal_2\n",
438 | "100.mat Normal_3\n",
439 | "100.mat Normal_3\n",
440 | "105.mat 106.mat 107.mat 108.mat 118.mat 119.mat 120.mat 121.mat 130.mat 131.mat 132.mat 133.mat 144.mat 145.mat 146.mat 147.mat 156.mat 158.mat 159.mat 160.mat 282.mat 283.mat 284.mat 285.mat 278.mat 279.mat 280.mat 281.mat 294.mat 295.mat 296.mat 297.mat 298.mat 299.mat 300.mat 301.mat 302.mat 305.mat 306.mat 307.mat 007\n",
441 | "169.mat 170.mat 171.mat 172.mat 185.mat 186.mat 187.mat 188.mat 197.mat 198.mat 199.mat 200.mat 274.mat 275.mat 276.mat 277.mat 286.mat 287.mat 288.mat 289.mat 310.mat 309.mat 311.mat 312.mat 313.mat 014\n",
442 | "222.mat 223.mat 224.mat 225.mat 270.mat 271.mat 272.mat 273.mat 209.mat 210.mat 211.mat 212.mat 234.mat 235.mat 236.mat 237.mat 246.mat 247.mat 248.mat 249.mat 258.mat 259.mat 260.mat 261.mat 290.mat 291.mat 292.mat 293.mat 315.mat 316.mat 317.mat 318.mat 021\n",
443 | "3001.mat 3002.mat 3003.mat 3004.mat 3005.mat 3006.mat 3007.mat 3008.mat 028\n"
444 | ],
445 | "name": "stdout"
446 | }
447 | ]
448 | },
449 | {
450 | "cell_type": "markdown",
451 | "metadata": {
452 | "id": "v4iA0dt9pXyK",
453 | "colab_type": "text"
454 | },
455 | "source": [
456 | "Count number of samples.\n",
457 | "\n"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "metadata": {
463 | "id": "CEHJBo9bpZGQ",
464 | "colab_type": "code",
465 | "outputId": "3135c587-a075-4ab3-9f76-98fd99b8706b",
466 | "colab": {
467 | "base_uri": "https://localhost:8080/",
468 | "height": 187
469 | }
470 | },
471 | "source": [
472 | "print(\"Label\", end='\\t')\n",
473 | "for s in severities:\n",
474 | " print(s, end='\\t')\n",
475 | "print(\"total\")\n",
476 | "mat = np.zeros((4,4))\n",
477 | "i = 0\n",
478 | "for label,value in labels_dict.items():\n",
479 | " print(label, end='\\t')\n",
480 | " tsamples = 0\n",
481 | " if label == 'Normal':\n",
482 | " print(4*'\\t'+'...')\n",
483 | " for load in range(4):\n",
484 | " print(' '+str((load+len(severities)+1)%4)+(load)*'\\t', end='\\t')\n",
485 | " mat[i][load] = list(eval('yn_'+str((load+len(severities)+1)%4))).count(value)\n",
486 | " print(int(mat[i][load]))\n",
487 | " else:\n",
488 | " for j,severity in enumerate(severities):\n",
489 | " tmp = eval('y'+str(severity))\n",
490 | " if tmp is not None:\n",
491 | " nsamples = list(tmp).count(value)\n",
492 | " mat[i][j] = nsamples\n",
493 | " print(nsamples, end='\\t')\n",
494 | " tsamples += nsamples\n",
495 | " else:\n",
496 | " print('0', end='\\t')\n",
497 | " print(tsamples)\n",
498 | " i+=1\n",
499 | "total = np.sum(mat,axis=0)\n",
500 | "print(\"Total:\", end='\\t')\n",
501 | "for i in range(len(total)):\n",
502 | " print(int(total[i]), end='\\t')\n",
503 | "print(int(np.sum(total)))"
504 | ],
505 | "execution_count": 0,
506 | "outputs": [
507 | {
508 | "output_type": "stream",
509 | "text": [
510 | "Label\t007\t014\t021\t028\ttotal\n",
511 | "Normal\t\t\t\t\t...\n",
512 | " 1\t1890\n",
513 | " 2\t\t1890\n",
514 | " 3\t\t\t1896\n",
515 | " 0\t\t\t\t952\n",
516 | "IR\t1900\t1895\t1894\t946\t6635\n",
517 | "B\t1894\t1900\t1897\t944\t6635\n",
518 | "OR\t5694\t2131\t3798\t0\t11623\n",
519 | "Total:\t11378\t7816\t9485\t2842\t31521\n"
520 | ],
521 | "name": "stdout"
522 | }
523 | ]
524 | },
525 | {
526 | "cell_type": "markdown",
527 | "metadata": {
528 | "id": "BQutiWuupevT",
529 | "colab_type": "text"
530 | },
531 | "source": [
532 | "#Gerando CSV Files\n"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "metadata": {
538 | "id": "P2p5W3ZvU1MW",
539 | "colab_type": "code",
540 | "colab": {}
541 | },
542 | "source": [
543 | "import os\n",
544 | "\n",
545 | "cond_dict = {v: k for k, v in labels_dict.items()}\n",
546 | "\n",
547 | "def write_csv(severity, samples, labels):\n",
548 | " if not os.path.exists(\"cwru\"):\n",
549 | " os.makedirs(\"cwru\")\n",
550 | " sevdir = 'cwru/'+severity\n",
551 | " if not os.path.exists(sevdir):\n",
552 | " os.makedirs(sevdir)\n",
553 | " for i,value in enumerate(labels):\n",
554 | " condir = sevdir+'/'+cond_dict[value]\n",
555 | " if not os.path.exists(condir):\n",
556 | " os.makedirs(condir)\n",
557 | " sample_name = condir+'/'+str(i).zfill(len(str(labels.shape[0])))+'.csv'\n",
558 | " np.savetxt(sample_name,samples[i],delimiter=',')"
559 | ],
560 | "execution_count": 0,
561 | "outputs": []
562 | },
563 | {
564 | "cell_type": "code",
565 | "metadata": {
566 | "id": "ZAT4AJoGZJ7k",
567 | "colab_type": "code",
568 | "colab": {}
569 | },
570 | "source": [
571 | "def write_dataset(severity, normal_load):\n",
572 | " x,y = concatenate_datasets(eval('x'+severity),\n",
573 | " eval('y'+severity),\n",
574 | " eval('xn_'+str(normal_load)),\n",
575 | " eval('yn_'+str(normal_load)))\n",
576 | " write_csv(severity,x,y)\n",
577 | "\n",
578 | "write_dataset('007',1)\n",
579 | "write_dataset('014',2)\n",
580 | "write_dataset('021',3)\n",
581 | "write_dataset('028',0)"
582 | ],
583 | "execution_count": 0,
584 | "outputs": []
585 | },
586 | {
587 | "cell_type": "code",
588 | "metadata": {
589 | "id": "aNQ5Wfs4X4On",
590 | "colab_type": "code",
591 | "colab": {}
592 | },
593 | "source": [
594 | "import shutil\n",
595 | "output_filename = 'cwru_segmented'\n",
596 | "ext = 'zip'\n",
597 | "shutil.make_archive(output_filename, ext, 'cwru')\n",
598 | "zipfile_name = output_filename+'.'+ext"
599 | ],
600 | "execution_count": 0,
601 | "outputs": []
602 | },
603 | {
604 | "cell_type": "code",
605 | "metadata": {
606 | "id": "Iv3iuhNOTZ5E",
607 | "colab_type": "code",
608 | "colab": {}
609 | },
610 | "source": [
611 | "from google.colab import files\n",
612 | "import time\n",
613 | "while not os.path.exists(zipfile_name):\n",
614 | " time.sleep(1)\n",
615 | "files.download(zipfile_name)"
616 | ],
617 | "execution_count": 0,
618 | "outputs": []
619 | }
620 | ]
621 | }
--------------------------------------------------------------------------------