├── 2.1 EEG ML classification.ipynb ├── 2.2 EEG DL Classification.ipynb ├── BCI_Competition_IV.ipynb ├── README.md ├── chrononet_keras.ipynb ├── chrononet_pytorch.ipynb ├── cityscape-tutorial.ipynb ├── eeg-conv2d.ipynb ├── eeg_epilepsy.ipynb └── video_classification_end2end.ipynb /2.1 EEG ML classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "9c4a3bc5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from glob import glob\n", 11 | "import os\n", 12 | "import mne\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import pandas as pd" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "id": "2b11214d", 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "28\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "#read all file\n", 34 | "all_files_path=glob('dataverse_files/*.edf')\n", 35 | "print(len(all_files_path))" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "3126e54c", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "'dataverse_files\\\\h01.edf'" 48 | ] 49 | }, 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "all_files_path[0]" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "id": "7bd827d0", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "healthy_file_path=[i for i in all_files_path if 'h' in i.split('\\\\')[1]]\n", 67 | "patient_file_path=[i for i in all_files_path if 's' in i.split('\\\\')[1]]" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "id": "1d84e167", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def read_data(file_path):\n", 78 | " datax=mne.io.read_raw_edf(file_path,preload=True)\n", 79 | " datax.set_eeg_reference()\n", 80 | " datax.filter(l_freq=1,h_freq=45)\n", 81 | " epochs=mne.make_fixed_length_epochs(datax,duration=25,overlap=0)\n", 82 | " epochs=epochs.get_data()\n", 83 | " return epochs #trials,channel,length" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "id": "b1352819", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h01.edf...\n", 97 | "EDF file detected\n", 98 | "Setting channel info structure...\n", 99 | "Creating raw.info structure...\n", 100 | "Reading 0 ... 231249 = 0.000 ... 924.996 secs...\n", 101 | "EEG channel type selected for re-referencing\n", 102 | "Applying average reference.\n", 103 | "Applying a custom EEG reference.\n", 104 | "Filtering raw data in 1 contiguous segment\n", 105 | "Setting up band-pass filter from 1 - 45 Hz\n", 106 | "\n", 107 | "FIR filter parameters\n", 108 | "---------------------\n", 109 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 110 | "- Windowed time-domain design (firwin) method\n", 111 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 112 | "- Lower passband edge: 1.00\n", 113 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 114 | "- Upper passband edge: 45.00 Hz\n", 115 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 116 | "- Filter length: 825 samples (3.300 sec)\n", 117 | "\n", 118 | "Not setting metadata\n", 119 | "Not setting metadata\n", 120 | "37 matching events found\n", 121 | "No baseline correction applied\n", 122 | "0 projection items activated\n", 123 | "Loading data for 37 events and 6250 original time points ...\n", 124 | "0 bad epochs dropped\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "data=read_data(healthy_file_path[0])" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "id": "a48348ea", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "(37, 19, 6250)" 142 | ] 143 | }, 144 | "execution_count": 7, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "data.shape" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "id": "e6aea13c", 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h01.edf...\n", 164 | "EDF file detected\n", 165 | "Setting channel info structure...\n", 166 | "Creating raw.info structure...\n", 167 | "Reading 0 ... 231249 = 0.000 ... 924.996 secs...\n", 168 | "EEG channel type selected for re-referencing\n", 169 | "Applying average reference.\n", 170 | "Applying a custom EEG reference.\n", 171 | "Filtering raw data in 1 contiguous segment\n", 172 | "Setting up band-pass filter from 1 - 45 Hz\n", 173 | "\n", 174 | "FIR filter parameters\n", 175 | "---------------------\n", 176 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 177 | "- Windowed time-domain design (firwin) method\n", 178 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 179 | "- Lower passband edge: 1.00\n", 180 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 181 | "- Upper passband edge: 45.00 Hz\n", 182 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 183 | "- Filter length: 825 samples (3.300 sec)\n", 184 | "\n", 185 | "Not setting metadata\n", 186 | "Not setting metadata\n", 187 | "37 matching events found\n", 188 | "No baseline correction applied\n", 189 | "0 projection items activated\n", 190 | "Loading data for 37 events and 6250 original time points ...\n", 191 | "0 bad epochs dropped\n", 192 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h02.edf...\n", 193 | "EDF file detected\n", 194 | "Setting channel info structure...\n", 195 | "Creating raw.info structure...\n", 196 | "Reading 0 ... 227499 = 0.000 ... 909.996 secs...\n", 197 | "EEG channel type selected for re-referencing\n", 198 | "Applying average reference.\n", 199 | "Applying a custom EEG reference.\n", 200 | "Filtering raw data in 1 contiguous segment\n", 201 | "Setting up band-pass filter from 1 - 45 Hz\n", 202 | "\n", 203 | "FIR filter parameters\n", 204 | "---------------------\n", 205 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 206 | "- Windowed time-domain design (firwin) method\n", 207 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 208 | "- Lower passband edge: 1.00\n", 209 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 210 | "- Upper passband edge: 45.00 Hz\n", 211 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 212 | "- Filter length: 825 samples (3.300 sec)\n", 213 | "\n", 214 | "Not setting metadata\n", 215 | "Not setting metadata\n", 216 | "36 matching events found\n", 217 | "No baseline correction applied\n", 218 | "0 projection items activated\n", 219 | "Loading data for 36 events and 6250 original time points ...\n", 220 | "0 bad epochs dropped\n", 221 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h03.edf...\n", 222 | "EDF file detected\n", 223 | "Setting channel info structure...\n", 224 | "Creating raw.info structure...\n", 225 | "Reading 0 ... 227499 = 0.000 ... 909.996 secs...\n", 226 | "EEG channel type selected for re-referencing\n", 227 | "Applying average reference.\n", 228 | "Applying a custom EEG reference.\n", 229 | "Filtering raw data in 1 contiguous segment\n", 230 | "Setting up band-pass filter from 1 - 45 Hz\n", 231 | "\n", 232 | "FIR filter parameters\n", 233 | "---------------------\n", 234 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 235 | "- Windowed time-domain design (firwin) method\n", 236 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 237 | "- Lower passband edge: 1.00\n", 238 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 239 | "- Upper passband edge: 45.00 Hz\n", 240 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 241 | "- Filter length: 825 samples (3.300 sec)\n", 242 | "\n", 243 | "Not setting metadata\n", 244 | "Not setting metadata\n", 245 | "36 matching events found\n", 246 | "No baseline correction applied\n", 247 | "0 projection items activated\n", 248 | "Loading data for 36 events and 6250 original time points ...\n", 249 | "0 bad epochs dropped\n", 250 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h04.edf...\n", 251 | "EDF file detected\n", 252 | "Setting channel info structure...\n", 253 | "Creating raw.info structure...\n", 254 | "Reading 0 ... 231249 = 0.000 ... 924.996 secs...\n", 255 | "EEG channel type selected for re-referencing\n", 256 | "Applying average reference.\n", 257 | "Applying a custom EEG reference.\n", 258 | "Filtering raw data in 1 contiguous segment\n", 259 | "Setting up band-pass filter from 1 - 45 Hz\n", 260 | "\n", 261 | "FIR filter parameters\n", 262 | "---------------------\n", 263 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 264 | "- Windowed time-domain design (firwin) method\n", 265 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 266 | "- Lower passband edge: 1.00\n", 267 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 268 | "- Upper passband edge: 45.00 Hz\n", 269 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 270 | "- Filter length: 825 samples (3.300 sec)\n", 271 | "\n", 272 | "Not setting metadata\n", 273 | "Not setting metadata\n", 274 | "37 matching events found\n", 275 | "No baseline correction applied\n", 276 | "0 projection items activated\n", 277 | "Loading data for 37 events and 6250 original time points ...\n", 278 | "0 bad epochs dropped\n", 279 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h05.edf...\n", 280 | "EDF file detected\n", 281 | "Setting channel info structure...\n", 282 | "Creating raw.info structure...\n", 283 | "Reading 0 ... 236249 = 0.000 ... 944.996 secs...\n", 284 | "EEG channel type selected for re-referencing\n", 285 | "Applying average reference.\n", 286 | "Applying a custom EEG reference.\n", 287 | "Filtering raw data in 1 contiguous segment\n", 288 | "Setting up band-pass filter from 1 - 45 Hz\n", 289 | "\n", 290 | "FIR filter parameters\n", 291 | "---------------------\n", 292 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 293 | "- Windowed time-domain design (firwin) method\n", 294 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 295 | "- Lower passband edge: 1.00\n", 296 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 297 | "- Upper passband edge: 45.00 Hz\n", 298 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 299 | "- Filter length: 825 samples (3.300 sec)\n", 300 | "\n", 301 | "Not setting metadata\n", 302 | "Not setting metadata\n", 303 | "37 matching events found\n", 304 | "No baseline correction applied\n", 305 | "0 projection items activated\n", 306 | "Loading data for 37 events and 6250 original time points ...\n", 307 | "0 bad epochs dropped\n", 308 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h06.edf...\n", 309 | "EDF file detected\n", 310 | "Setting channel info structure...\n", 311 | "Creating raw.info structure...\n", 312 | "Reading 0 ... 232499 = 0.000 ... 929.996 secs...\n", 313 | "EEG channel type selected for re-referencing\n", 314 | "Applying average reference.\n", 315 | "Applying a custom EEG reference.\n", 316 | "Filtering raw data in 1 contiguous segment\n", 317 | "Setting up band-pass filter from 1 - 45 Hz\n", 318 | "\n", 319 | "FIR filter parameters\n", 320 | "---------------------\n", 321 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 322 | "- Windowed time-domain design (firwin) method\n", 323 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 324 | "- Lower passband edge: 1.00\n", 325 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 326 | "- Upper passband edge: 45.00 Hz\n", 327 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 328 | "- Filter length: 825 samples (3.300 sec)\n", 329 | "\n", 330 | "Not setting metadata\n", 331 | "Not setting metadata\n", 332 | "37 matching events found\n", 333 | "No baseline correction applied\n", 334 | "0 projection items activated\n", 335 | "Loading data for 37 events and 6250 original time points ...\n", 336 | "0 bad epochs dropped\n", 337 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h07.edf...\n", 338 | "EDF file detected\n", 339 | "Setting channel info structure...\n", 340 | "Creating raw.info structure...\n", 341 | "Reading 0 ... 227499 = 0.000 ... 909.996 secs...\n", 342 | "EEG channel type selected for re-referencing\n", 343 | "Applying average reference.\n", 344 | "Applying a custom EEG reference.\n", 345 | "Filtering raw data in 1 contiguous segment\n", 346 | "Setting up band-pass filter from 1 - 45 Hz\n", 347 | "\n", 348 | "FIR filter parameters\n", 349 | "---------------------\n", 350 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 351 | "- Windowed time-domain design (firwin) method\n", 352 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 353 | "- Lower passband edge: 1.00\n", 354 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 355 | "- Upper passband edge: 45.00 Hz\n", 356 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 357 | "- Filter length: 825 samples (3.300 sec)\n", 358 | "\n", 359 | "Not setting metadata\n", 360 | "Not setting metadata\n", 361 | "36 matching events found\n", 362 | "No baseline correction applied\n", 363 | "0 projection items activated\n", 364 | "Loading data for 36 events and 6250 original time points ...\n", 365 | "0 bad epochs dropped\n", 366 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h08.edf...\n", 367 | "EDF file detected\n", 368 | "Setting channel info structure...\n", 369 | "Creating raw.info structure...\n", 370 | "Reading 0 ... 227499 = 0.000 ... 909.996 secs...\n", 371 | "EEG channel type selected for re-referencing\n", 372 | "Applying average reference.\n", 373 | "Applying a custom EEG reference.\n", 374 | "Filtering raw data in 1 contiguous segment\n", 375 | "Setting up band-pass filter from 1 - 45 Hz\n" 376 | ] 377 | }, 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "\n", 383 | "FIR filter parameters\n", 384 | "---------------------\n", 385 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 386 | "- Windowed time-domain design (firwin) method\n", 387 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 388 | "- Lower passband edge: 1.00\n", 389 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 390 | "- Upper passband edge: 45.00 Hz\n", 391 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 392 | "- Filter length: 825 samples (3.300 sec)\n", 393 | "\n", 394 | "Not setting metadata\n", 395 | "Not setting metadata\n", 396 | "36 matching events found\n", 397 | "No baseline correction applied\n", 398 | "0 projection items activated\n", 399 | "Loading data for 36 events and 6250 original time points ...\n", 400 | "0 bad epochs dropped\n", 401 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h09.edf...\n", 402 | "EDF file detected\n", 403 | "Setting channel info structure...\n", 404 | "Creating raw.info structure...\n", 405 | "Reading 0 ... 226249 = 0.000 ... 904.996 secs...\n", 406 | "EEG channel type selected for re-referencing\n", 407 | "Applying average reference.\n", 408 | "Applying a custom EEG reference.\n", 409 | "Filtering raw data in 1 contiguous segment\n", 410 | "Setting up band-pass filter from 1 - 45 Hz\n", 411 | "\n", 412 | "FIR filter parameters\n", 413 | "---------------------\n", 414 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 415 | "- Windowed time-domain design (firwin) method\n", 416 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 417 | "- Lower passband edge: 1.00\n", 418 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 419 | "- Upper passband edge: 45.00 Hz\n", 420 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 421 | "- Filter length: 825 samples (3.300 sec)\n", 422 | "\n", 423 | "Not setting metadata\n", 424 | "Not setting metadata\n", 425 | "36 matching events found\n", 426 | "No baseline correction applied\n", 427 | "0 projection items activated\n", 428 | "Loading data for 36 events and 6250 original time points ...\n", 429 | "0 bad epochs dropped\n", 430 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h10.edf...\n", 431 | "EDF file detected\n", 432 | "Setting channel info structure...\n", 433 | "Creating raw.info structure...\n", 434 | "Reading 0 ... 278749 = 0.000 ... 1114.996 secs...\n", 435 | "EEG channel type selected for re-referencing\n", 436 | "Applying average reference.\n", 437 | "Applying a custom EEG reference.\n", 438 | "Filtering raw data in 1 contiguous segment\n", 439 | "Setting up band-pass filter from 1 - 45 Hz\n", 440 | "\n", 441 | "FIR filter parameters\n", 442 | "---------------------\n", 443 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 444 | "- Windowed time-domain design (firwin) method\n", 445 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 446 | "- Lower passband edge: 1.00\n", 447 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 448 | "- Upper passband edge: 45.00 Hz\n", 449 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 450 | "- Filter length: 825 samples (3.300 sec)\n", 451 | "\n", 452 | "Not setting metadata\n", 453 | "Not setting metadata\n", 454 | "44 matching events found\n", 455 | "No baseline correction applied\n", 456 | "0 projection items activated\n", 457 | "Loading data for 44 events and 6250 original time points ...\n", 458 | "0 bad epochs dropped\n", 459 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h11.edf...\n", 460 | "EDF file detected\n", 461 | "Setting channel info structure...\n", 462 | "Creating raw.info structure...\n", 463 | "Reading 0 ... 228749 = 0.000 ... 914.996 secs...\n", 464 | "EEG channel type selected for re-referencing\n", 465 | "Applying average reference.\n", 466 | "Applying a custom EEG reference.\n", 467 | "Filtering raw data in 1 contiguous segment\n", 468 | "Setting up band-pass filter from 1 - 45 Hz\n", 469 | "\n", 470 | "FIR filter parameters\n", 471 | "---------------------\n", 472 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 473 | "- Windowed time-domain design (firwin) method\n", 474 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 475 | "- Lower passband edge: 1.00\n", 476 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 477 | "- Upper passband edge: 45.00 Hz\n", 478 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 479 | "- Filter length: 825 samples (3.300 sec)\n", 480 | "\n", 481 | "Not setting metadata\n", 482 | "Not setting metadata\n", 483 | "36 matching events found\n", 484 | "No baseline correction applied\n", 485 | "0 projection items activated\n", 486 | "Loading data for 36 events and 6250 original time points ...\n", 487 | "0 bad epochs dropped\n", 488 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h12.edf...\n", 489 | "EDF file detected\n", 490 | "Setting channel info structure...\n", 491 | "Creating raw.info structure...\n", 492 | "Reading 0 ... 224999 = 0.000 ... 899.996 secs...\n", 493 | "EEG channel type selected for re-referencing\n", 494 | "Applying average reference.\n", 495 | "Applying a custom EEG reference.\n", 496 | "Filtering raw data in 1 contiguous segment\n", 497 | "Setting up band-pass filter from 1 - 45 Hz\n", 498 | "\n", 499 | "FIR filter parameters\n", 500 | "---------------------\n", 501 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 502 | "- Windowed time-domain design (firwin) method\n", 503 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 504 | "- Lower passband edge: 1.00\n", 505 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 506 | "- Upper passband edge: 45.00 Hz\n", 507 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 508 | "- Filter length: 825 samples (3.300 sec)\n", 509 | "\n", 510 | "Not setting metadata\n", 511 | "Not setting metadata\n", 512 | "36 matching events found\n", 513 | "No baseline correction applied\n", 514 | "0 projection items activated\n", 515 | "Loading data for 36 events and 6250 original time points ...\n", 516 | "0 bad epochs dropped\n", 517 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h13.edf...\n", 518 | "EDF file detected\n", 519 | "Setting channel info structure...\n", 520 | "Creating raw.info structure...\n", 521 | "Reading 0 ... 241249 = 0.000 ... 964.996 secs...\n", 522 | "EEG channel type selected for re-referencing\n", 523 | "Applying average reference.\n", 524 | "Applying a custom EEG reference.\n", 525 | "Filtering raw data in 1 contiguous segment\n", 526 | "Setting up band-pass filter from 1 - 45 Hz\n", 527 | "\n", 528 | "FIR filter parameters\n", 529 | "---------------------\n", 530 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 531 | "- Windowed time-domain design (firwin) method\n", 532 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 533 | "- Lower passband edge: 1.00\n", 534 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 535 | "- Upper passband edge: 45.00 Hz\n", 536 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 537 | "- Filter length: 825 samples (3.300 sec)\n", 538 | "\n", 539 | "Not setting metadata\n", 540 | "Not setting metadata\n", 541 | "38 matching events found\n", 542 | "No baseline correction applied\n", 543 | "0 projection items activated\n", 544 | "Loading data for 38 events and 6250 original time points ...\n", 545 | "0 bad epochs dropped\n", 546 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\h14.edf...\n", 547 | "EDF file detected\n", 548 | "Setting channel info structure...\n", 549 | "Creating raw.info structure...\n", 550 | "Reading 0 ... 216249 = 0.000 ... 864.996 secs...\n", 551 | "EEG channel type selected for re-referencing\n", 552 | "Applying average reference.\n", 553 | "Applying a custom EEG reference.\n", 554 | "Filtering raw data in 1 contiguous segment\n", 555 | "Setting up band-pass filter from 1 - 45 Hz\n", 556 | "\n", 557 | "FIR filter parameters\n", 558 | "---------------------\n", 559 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 560 | "- Windowed time-domain design (firwin) method\n", 561 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 562 | "- Lower passband edge: 1.00\n", 563 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 564 | "- Upper passband edge: 45.00 Hz\n", 565 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 566 | "- Filter length: 825 samples (3.300 sec)\n", 567 | "\n", 568 | "Not setting metadata\n", 569 | "Not setting metadata\n", 570 | "34 matching events found\n", 571 | "No baseline correction applied\n", 572 | "0 projection items activated\n", 573 | "Loading data for 34 events and 6250 original time points ...\n", 574 | "0 bad epochs dropped\n", 575 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s01.edf...\n", 576 | "EDF file detected\n", 577 | "Setting channel info structure...\n", 578 | "Creating raw.info structure...\n", 579 | "Reading 0 ... 211249 = 0.000 ... 844.996 secs...\n", 580 | "EEG channel type selected for re-referencing\n", 581 | "Applying average reference.\n", 582 | "Applying a custom EEG reference.\n", 583 | "Filtering raw data in 1 contiguous segment\n", 584 | "Setting up band-pass filter from 1 - 45 Hz\n", 585 | "\n", 586 | "FIR filter parameters\n", 587 | "---------------------\n", 588 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 589 | "- Windowed time-domain design (firwin) method\n", 590 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 591 | "- Lower passband edge: 1.00\n", 592 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 593 | "- Upper passband edge: 45.00 Hz\n", 594 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n" 595 | ] 596 | }, 597 | { 598 | "name": "stdout", 599 | "output_type": "stream", 600 | "text": [ 601 | "- Filter length: 825 samples (3.300 sec)\n", 602 | "\n", 603 | "Not setting metadata\n", 604 | "Not setting metadata\n", 605 | "33 matching events found\n", 606 | "No baseline correction applied\n", 607 | "0 projection items activated\n", 608 | "Loading data for 33 events and 6250 original time points ...\n", 609 | "0 bad epochs dropped\n", 610 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s02.edf...\n", 611 | "EDF file detected\n", 612 | "Setting channel info structure...\n", 613 | "Creating raw.info structure...\n", 614 | "Reading 0 ... 286249 = 0.000 ... 1144.996 secs...\n", 615 | "EEG channel type selected for re-referencing\n", 616 | "Applying average reference.\n", 617 | "Applying a custom EEG reference.\n", 618 | "Filtering raw data in 1 contiguous segment\n", 619 | "Setting up band-pass filter from 1 - 45 Hz\n", 620 | "\n", 621 | "FIR filter parameters\n", 622 | "---------------------\n", 623 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 624 | "- Windowed time-domain design (firwin) method\n", 625 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 626 | "- Lower passband edge: 1.00\n", 627 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 628 | "- Upper passband edge: 45.00 Hz\n", 629 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 630 | "- Filter length: 825 samples (3.300 sec)\n", 631 | "\n", 632 | "Not setting metadata\n", 633 | "Not setting metadata\n", 634 | "45 matching events found\n", 635 | "No baseline correction applied\n", 636 | "0 projection items activated\n", 637 | "Loading data for 45 events and 6250 original time points ...\n", 638 | "0 bad epochs dropped\n", 639 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s03.edf...\n", 640 | "EDF file detected\n", 641 | "Setting channel info structure...\n", 642 | "Creating raw.info structure...\n", 643 | "Reading 0 ... 240999 = 0.000 ... 963.996 secs...\n", 644 | "EEG channel type selected for re-referencing\n", 645 | "Applying average reference.\n", 646 | "Applying a custom EEG reference.\n", 647 | "Filtering raw data in 1 contiguous segment\n", 648 | "Setting up band-pass filter from 1 - 45 Hz\n", 649 | "\n", 650 | "FIR filter parameters\n", 651 | "---------------------\n", 652 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 653 | "- Windowed time-domain design (firwin) method\n", 654 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 655 | "- Lower passband edge: 1.00\n", 656 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 657 | "- Upper passband edge: 45.00 Hz\n", 658 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 659 | "- Filter length: 825 samples (3.300 sec)\n", 660 | "\n", 661 | "Not setting metadata\n", 662 | "Not setting metadata\n", 663 | "38 matching events found\n", 664 | "No baseline correction applied\n", 665 | "0 projection items activated\n", 666 | "Loading data for 38 events and 6250 original time points ...\n", 667 | "0 bad epochs dropped\n", 668 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s04.edf...\n", 669 | "EDF file detected\n", 670 | "Setting channel info structure...\n", 671 | "Creating raw.info structure...\n", 672 | "Reading 0 ... 301249 = 0.000 ... 1204.996 secs...\n", 673 | "EEG channel type selected for re-referencing\n", 674 | "Applying average reference.\n", 675 | "Applying a custom EEG reference.\n", 676 | "Filtering raw data in 1 contiguous segment\n", 677 | "Setting up band-pass filter from 1 - 45 Hz\n", 678 | "\n", 679 | "FIR filter parameters\n", 680 | "---------------------\n", 681 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 682 | "- Windowed time-domain design (firwin) method\n", 683 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 684 | "- Lower passband edge: 1.00\n", 685 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 686 | "- Upper passband edge: 45.00 Hz\n", 687 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 688 | "- Filter length: 825 samples (3.300 sec)\n", 689 | "\n", 690 | "Not setting metadata\n", 691 | "Not setting metadata\n", 692 | "48 matching events found\n", 693 | "No baseline correction applied\n", 694 | "0 projection items activated\n", 695 | "Loading data for 48 events and 6250 original time points ...\n", 696 | "0 bad epochs dropped\n", 697 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s05.edf...\n", 698 | "EDF file detected\n", 699 | "Setting channel info structure...\n", 700 | "Creating raw.info structure...\n", 701 | "Reading 0 ... 222499 = 0.000 ... 889.996 secs...\n", 702 | "EEG channel type selected for re-referencing\n", 703 | "Applying average reference.\n", 704 | "Applying a custom EEG reference.\n", 705 | "Filtering raw data in 1 contiguous segment\n", 706 | "Setting up band-pass filter from 1 - 45 Hz\n", 707 | "\n", 708 | "FIR filter parameters\n", 709 | "---------------------\n", 710 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 711 | "- Windowed time-domain design (firwin) method\n", 712 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 713 | "- Lower passband edge: 1.00\n", 714 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 715 | "- Upper passband edge: 45.00 Hz\n", 716 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 717 | "- Filter length: 825 samples (3.300 sec)\n", 718 | "\n", 719 | "Not setting metadata\n", 720 | "Not setting metadata\n", 721 | "35 matching events found\n", 722 | "No baseline correction applied\n", 723 | "0 projection items activated\n", 724 | "Loading data for 35 events and 6250 original time points ...\n", 725 | "0 bad epochs dropped\n", 726 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s06.edf...\n", 727 | "EDF file detected\n", 728 | "Setting channel info structure...\n", 729 | "Creating raw.info structure...\n", 730 | "Reading 0 ... 184999 = 0.000 ... 739.996 secs...\n", 731 | "EEG channel type selected for re-referencing\n", 732 | "Applying average reference.\n", 733 | "Applying a custom EEG reference.\n", 734 | "Filtering raw data in 1 contiguous segment\n", 735 | "Setting up band-pass filter from 1 - 45 Hz\n", 736 | "\n", 737 | "FIR filter parameters\n", 738 | "---------------------\n", 739 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 740 | "- Windowed time-domain design (firwin) method\n", 741 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 742 | "- Lower passband edge: 1.00\n", 743 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 744 | "- Upper passband edge: 45.00 Hz\n", 745 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 746 | "- Filter length: 825 samples (3.300 sec)\n", 747 | "\n", 748 | "Not setting metadata\n", 749 | "Not setting metadata\n", 750 | "29 matching events found\n", 751 | "No baseline correction applied\n", 752 | "0 projection items activated\n", 753 | "Loading data for 29 events and 6250 original time points ...\n", 754 | "0 bad epochs dropped\n", 755 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s07.edf...\n", 756 | "EDF file detected\n", 757 | "Setting channel info structure...\n", 758 | "Creating raw.info structure...\n", 759 | "Reading 0 ... 336499 = 0.000 ... 1345.996 secs...\n", 760 | "EEG channel type selected for re-referencing\n", 761 | "Applying average reference.\n", 762 | "Applying a custom EEG reference.\n", 763 | "Filtering raw data in 1 contiguous segment\n", 764 | "Setting up band-pass filter from 1 - 45 Hz\n", 765 | "\n", 766 | "FIR filter parameters\n", 767 | "---------------------\n", 768 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 769 | "- Windowed time-domain design (firwin) method\n", 770 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 771 | "- Lower passband edge: 1.00\n", 772 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 773 | "- Upper passband edge: 45.00 Hz\n", 774 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 775 | "- Filter length: 825 samples (3.300 sec)\n", 776 | "\n", 777 | "Not setting metadata\n", 778 | "Not setting metadata\n", 779 | "53 matching events found\n", 780 | "No baseline correction applied\n", 781 | "0 projection items activated\n", 782 | "Loading data for 53 events and 6250 original time points ...\n", 783 | "0 bad epochs dropped\n", 784 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s08.edf...\n", 785 | "EDF file detected\n", 786 | "Setting channel info structure...\n", 787 | "Creating raw.info structure...\n", 788 | "Reading 0 ... 227749 = 0.000 ... 910.996 secs...\n", 789 | "EEG channel type selected for re-referencing\n", 790 | "Applying average reference.\n", 791 | "Applying a custom EEG reference.\n", 792 | "Filtering raw data in 1 contiguous segment\n", 793 | "Setting up band-pass filter from 1 - 45 Hz\n", 794 | "\n", 795 | "FIR filter parameters\n", 796 | "---------------------\n", 797 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 798 | "- Windowed time-domain design (firwin) method\n", 799 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 800 | "- Lower passband edge: 1.00\n", 801 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 802 | "- Upper passband edge: 45.00 Hz\n", 803 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 804 | "- Filter length: 825 samples (3.300 sec)\n", 805 | "\n", 806 | "Not setting metadata\n", 807 | "Not setting metadata\n", 808 | "36 matching events found\n", 809 | "No baseline correction applied\n", 810 | "0 projection items activated\n", 811 | "Loading data for 36 events and 6250 original time points ...\n", 812 | "0 bad epochs dropped\n", 813 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s09.edf...\n", 814 | "EDF file detected\n", 815 | "Setting channel info structure...\n", 816 | "Creating raw.info structure...\n" 817 | ] 818 | }, 819 | { 820 | "name": "stdout", 821 | "output_type": "stream", 822 | "text": [ 823 | "Reading 0 ... 296249 = 0.000 ... 1184.996 secs...\n", 824 | "EEG channel type selected for re-referencing\n", 825 | "Applying average reference.\n", 826 | "Applying a custom EEG reference.\n", 827 | "Filtering raw data in 1 contiguous segment\n", 828 | "Setting up band-pass filter from 1 - 45 Hz\n", 829 | "\n", 830 | "FIR filter parameters\n", 831 | "---------------------\n", 832 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 833 | "- Windowed time-domain design (firwin) method\n", 834 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 835 | "- Lower passband edge: 1.00\n", 836 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 837 | "- Upper passband edge: 45.00 Hz\n", 838 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 839 | "- Filter length: 825 samples (3.300 sec)\n", 840 | "\n", 841 | "Not setting metadata\n", 842 | "Not setting metadata\n", 843 | "47 matching events found\n", 844 | "No baseline correction applied\n", 845 | "0 projection items activated\n", 846 | "Loading data for 47 events and 6250 original time points ...\n", 847 | "0 bad epochs dropped\n", 848 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s10.edf...\n", 849 | "EDF file detected\n", 850 | "Setting channel info structure...\n", 851 | "Creating raw.info structure...\n", 852 | "Reading 0 ... 212499 = 0.000 ... 849.996 secs...\n", 853 | "EEG channel type selected for re-referencing\n", 854 | "Applying average reference.\n", 855 | "Applying a custom EEG reference.\n", 856 | "Filtering raw data in 1 contiguous segment\n", 857 | "Setting up band-pass filter from 1 - 45 Hz\n", 858 | "\n", 859 | "FIR filter parameters\n", 860 | "---------------------\n", 861 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 862 | "- Windowed time-domain design (firwin) method\n", 863 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 864 | "- Lower passband edge: 1.00\n", 865 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 866 | "- Upper passband edge: 45.00 Hz\n", 867 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 868 | "- Filter length: 825 samples (3.300 sec)\n", 869 | "\n", 870 | "Not setting metadata\n", 871 | "Not setting metadata\n", 872 | "34 matching events found\n", 873 | "No baseline correction applied\n", 874 | "0 projection items activated\n", 875 | "Loading data for 34 events and 6250 original time points ...\n", 876 | "0 bad epochs dropped\n", 877 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s11.edf...\n", 878 | "EDF file detected\n", 879 | "Setting channel info structure...\n", 880 | "Creating raw.info structure...\n", 881 | "Reading 0 ... 339999 = 0.000 ... 1359.996 secs...\n", 882 | "EEG channel type selected for re-referencing\n", 883 | "Applying average reference.\n", 884 | "Applying a custom EEG reference.\n", 885 | "Filtering raw data in 1 contiguous segment\n", 886 | "Setting up band-pass filter from 1 - 45 Hz\n", 887 | "\n", 888 | "FIR filter parameters\n", 889 | "---------------------\n", 890 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 891 | "- Windowed time-domain design (firwin) method\n", 892 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 893 | "- Lower passband edge: 1.00\n", 894 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 895 | "- Upper passband edge: 45.00 Hz\n", 896 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 897 | "- Filter length: 825 samples (3.300 sec)\n", 898 | "\n", 899 | "Not setting metadata\n", 900 | "Not setting metadata\n", 901 | "54 matching events found\n", 902 | "No baseline correction applied\n", 903 | "0 projection items activated\n", 904 | "Loading data for 54 events and 6250 original time points ...\n", 905 | "0 bad epochs dropped\n", 906 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s12.edf...\n", 907 | "EDF file detected\n", 908 | "Setting channel info structure...\n", 909 | "Creating raw.info structure...\n", 910 | "Reading 0 ... 271749 = 0.000 ... 1086.996 secs...\n", 911 | "EEG channel type selected for re-referencing\n", 912 | "Applying average reference.\n", 913 | "Applying a custom EEG reference.\n", 914 | "Filtering raw data in 1 contiguous segment\n", 915 | "Setting up band-pass filter from 1 - 45 Hz\n", 916 | "\n", 917 | "FIR filter parameters\n", 918 | "---------------------\n", 919 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 920 | "- Windowed time-domain design (firwin) method\n", 921 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 922 | "- Lower passband edge: 1.00\n", 923 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 924 | "- Upper passband edge: 45.00 Hz\n", 925 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 926 | "- Filter length: 825 samples (3.300 sec)\n", 927 | "\n", 928 | "Not setting metadata\n", 929 | "Not setting metadata\n", 930 | "43 matching events found\n", 931 | "No baseline correction applied\n", 932 | "0 projection items activated\n", 933 | "Loading data for 43 events and 6250 original time points ...\n", 934 | "0 bad epochs dropped\n", 935 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s13.edf...\n", 936 | "EDF file detected\n", 937 | "Setting channel info structure...\n", 938 | "Creating raw.info structure...\n", 939 | "Reading 0 ... 283749 = 0.000 ... 1134.996 secs...\n", 940 | "EEG channel type selected for re-referencing\n", 941 | "Applying average reference.\n", 942 | "Applying a custom EEG reference.\n", 943 | "Filtering raw data in 1 contiguous segment\n", 944 | "Setting up band-pass filter from 1 - 45 Hz\n", 945 | "\n", 946 | "FIR filter parameters\n", 947 | "---------------------\n", 948 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 949 | "- Windowed time-domain design (firwin) method\n", 950 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 951 | "- Lower passband edge: 1.00\n", 952 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 953 | "- Upper passband edge: 45.00 Hz\n", 954 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 955 | "- Filter length: 825 samples (3.300 sec)\n", 956 | "\n", 957 | "Not setting metadata\n", 958 | "Not setting metadata\n", 959 | "45 matching events found\n", 960 | "No baseline correction applied\n", 961 | "0 projection items activated\n", 962 | "Loading data for 45 events and 6250 original time points ...\n", 963 | "0 bad epochs dropped\n", 964 | "Extracting EDF parameters from D:\\complete_yt\\EEG_classification\\dataverse_files\\s14.edf...\n", 965 | "EDF file detected\n", 966 | "Setting channel info structure...\n", 967 | "Creating raw.info structure...\n", 968 | "Reading 0 ... 542499 = 0.000 ... 2169.996 secs...\n", 969 | "EEG channel type selected for re-referencing\n", 970 | "Applying average reference.\n", 971 | "Applying a custom EEG reference.\n", 972 | "Filtering raw data in 1 contiguous segment\n", 973 | "Setting up band-pass filter from 1 - 45 Hz\n", 974 | "\n", 975 | "FIR filter parameters\n", 976 | "---------------------\n", 977 | "Designing a one-pass, zero-phase, non-causal bandpass filter:\n", 978 | "- Windowed time-domain design (firwin) method\n", 979 | "- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", 980 | "- Lower passband edge: 1.00\n", 981 | "- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)\n", 982 | "- Upper passband edge: 45.00 Hz\n", 983 | "- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)\n", 984 | "- Filter length: 825 samples (3.300 sec)\n", 985 | "\n", 986 | "Not setting metadata\n", 987 | "Not setting metadata\n", 988 | "86 matching events found\n", 989 | "No baseline correction applied\n", 990 | "0 projection items activated\n", 991 | "Loading data for 86 events and 6250 original time points ...\n", 992 | "0 bad epochs dropped\n" 993 | ] 994 | } 995 | ], 996 | "source": [ 997 | "control_epochs_array=[read_data(subject) for subject in healthy_file_path]\n", 998 | "patients_epochs_array=[read_data(subject) for subject in patient_file_path]" 999 | ] 1000 | }, 1001 | { 1002 | "cell_type": "code", 1003 | "execution_count": 9, 1004 | "id": "1110efbc", 1005 | "metadata": {}, 1006 | "outputs": [ 1007 | { 1008 | "name": "stdout", 1009 | "output_type": "stream", 1010 | "text": [ 1011 | "14 14\n" 1012 | ] 1013 | } 1014 | ], 1015 | "source": [ 1016 | "control_epochs_labels=[len(i)*[0] for i in control_epochs_array]\n", 1017 | "patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]\n", 1018 | "print(len(control_epochs_labels),len(patients_epochs_labels))" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": 10, 1024 | "id": "f4993102", 1025 | "metadata": {}, 1026 | "outputs": [ 1027 | { 1028 | "name": "stdout", 1029 | "output_type": "stream", 1030 | "text": [ 1031 | "28 28\n" 1032 | ] 1033 | } 1034 | ], 1035 | "source": [ 1036 | "data_list=control_epochs_array+patients_epochs_array\n", 1037 | "label_list=control_epochs_labels+patients_epochs_labels\n", 1038 | "print(len(data_list),len(label_list))" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": 11, 1044 | "id": "4ea68f45", 1045 | "metadata": {}, 1046 | "outputs": [], 1047 | "source": [ 1048 | "groups_list=[[i]*len(j) for i, j in enumerate(data_list)]" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": 13, 1054 | "id": "b6f64fbb", 1055 | "metadata": {}, 1056 | "outputs": [ 1057 | { 1058 | "name": "stdout", 1059 | "output_type": "stream", 1060 | "text": [ 1061 | "(1142, 19, 6250) (1142,) (1142,)\n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "data_array=np.vstack(data_list)\n", 1067 | "label_array=np.hstack(label_list)\n", 1068 | "group_array=np.hstack(groups_list)\n", 1069 | "print(data_array.shape,label_array.shape,group_array.shape)" 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "code", 1074 | "execution_count": 14, 1075 | "id": "8de8a58b", 1076 | "metadata": {}, 1077 | "outputs": [], 1078 | "source": [ 1079 | "from scipy import stats\n", 1080 | "def mean(data):\n", 1081 | " return np.mean(data,axis=-1)\n", 1082 | " \n", 1083 | "def std(data):\n", 1084 | " return np.std(data,axis=-1)\n", 1085 | "\n", 1086 | "def ptp(data):\n", 1087 | " return np.ptp(data,axis=-1)\n", 1088 | "\n", 1089 | "def var(data):\n", 1090 | " return np.var(data,axis=-1)\n", 1091 | "\n", 1092 | "def minim(data):\n", 1093 | " return np.min(data,axis=-1)\n", 1094 | "\n", 1095 | "\n", 1096 | "def maxim(data):\n", 1097 | " return np.max(data,axis=-1)\n", 1098 | "\n", 1099 | "def argminim(data):\n", 1100 | " return np.argmin(data,axis=-1)\n", 1101 | "\n", 1102 | "\n", 1103 | "def argmaxim(data):\n", 1104 | " return np.argmax(data,axis=-1)\n", 1105 | "\n", 1106 | "def mean_square(data):\n", 1107 | " return np.mean(data**2,axis=-1)\n", 1108 | "\n", 1109 | "def rms(data): #root mean square\n", 1110 | " return np.sqrt(np.mean(data**2,axis=-1)) \n", 1111 | "\n", 1112 | "def abs_diffs_signal(data):\n", 1113 | " return np.sum(np.abs(np.diff(data,axis=-1)),axis=-1)\n", 1114 | "\n", 1115 | "\n", 1116 | "def skewness(data):\n", 1117 | " return stats.skew(data,axis=-1)\n", 1118 | "\n", 1119 | "def kurtosis(data):\n", 1120 | " return stats.kurtosis(data,axis=-1)\n", 1121 | "\n", 1122 | "def concatenate_features(data):\n", 1123 | " return np.concatenate((mean(data),std(data),ptp(data),var(data),minim(data),maxim(data),argminim(data),argmaxim(data),\n", 1124 | " mean_square(data),rms(data),abs_diffs_signal(data),\n", 1125 | " skewness(data),kurtosis(data)),axis=-1)" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "code", 1130 | "execution_count": 26, 1131 | "id": "23fdd5a4", 1132 | "metadata": {}, 1133 | "outputs": [ 1134 | { 1135 | "name": "stderr", 1136 | "output_type": "stream", 1137 | "text": [ 1138 | ":3: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n", 1139 | "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n", 1140 | " for data in tqdm_notebook(data_array):\n" 1141 | ] 1142 | }, 1143 | { 1144 | "data": { 1145 | "application/vnd.jupyter.widget-view+json": { 1146 | "model_id": "7bb96952d6a345f58783302547022d7a", 1147 | "version_major": 2, 1148 | "version_minor": 0 1149 | }, 1150 | "text/plain": [ 1151 | " 0%| | 0/1142 [00:00" 403 | ] 404 | }, 405 | "execution_count": null, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "model.fit(train_features,train_labels,epochs=10,batch_size=128,validation_data=(val_features,val_labels))" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": { 418 | "colab": { 419 | "background_save": true 420 | }, 421 | "id": "LSvEAYOeV2AQ" 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "" 426 | ] 427 | } 428 | ], 429 | "metadata": { 430 | "accelerator": "GPU", 431 | "colab": { 432 | "collapsed_sections": [], 433 | "name": "chrononet-keras_x.ipynb", 434 | "provenance": [] 435 | }, 436 | "kernelspec": { 437 | "display_name": "Python 3", 438 | "name": "python3" 439 | }, 440 | "language_info": { 441 | "name": "python" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 0 446 | } -------------------------------------------------------------------------------- /chrononet_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "chrononet-pytorch.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "d906d8b3258a4897bd5301d4feaa91c7": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "model_module_version": "1.5.0", 23 | "state": { 24 | "_view_name": "HBoxView", 25 | "_dom_classes": [], 26 | "_model_name": "HBoxModel", 27 | "_view_module": "@jupyter-widgets/controls", 28 | "_model_module_version": "1.5.0", 29 | "_view_count": null, 30 | "_view_module_version": "1.5.0", 31 | "box_style": "", 32 | "layout": "IPY_MODEL_ddb353b0e36747839bebbc72689b29a8", 33 | "_model_module": "@jupyter-widgets/controls", 34 | "children": [ 35 | "IPY_MODEL_6f0e8c4ca555468fb2d14b929434fc49", 36 | "IPY_MODEL_7c19031e4b1a4306ae282a6cdf4a0935", 37 | "IPY_MODEL_960da51a7dc342c4ae478cfa551de087" 38 | ] 39 | } 40 | }, 41 | "ddb353b0e36747839bebbc72689b29a8": { 42 | "model_module": "@jupyter-widgets/base", 43 | "model_name": "LayoutModel", 44 | "model_module_version": "1.2.0", 45 | "state": { 46 | "_view_name": "LayoutView", 47 | "grid_template_rows": null, 48 | "right": null, 49 | "justify_content": null, 50 | "_view_module": "@jupyter-widgets/base", 51 | "overflow": null, 52 | "_model_module_version": "1.2.0", 53 | "_view_count": null, 54 | "flex_flow": "row wrap", 55 | "width": "100%", 56 | "min_width": null, 57 | "border": null, 58 | "align_items": null, 59 | "bottom": null, 60 | "_model_module": "@jupyter-widgets/base", 61 | "top": null, 62 | "grid_column": null, 63 | "overflow_y": null, 64 | "overflow_x": null, 65 | "grid_auto_flow": null, 66 | "grid_area": null, 67 | "grid_template_columns": null, 68 | "flex": null, 69 | "_model_name": "LayoutModel", 70 | "justify_items": null, 71 | "grid_row": null, 72 | "max_height": null, 73 | "align_content": null, 74 | "visibility": null, 75 | "align_self": null, 76 | "height": null, 77 | "min_height": null, 78 | "padding": null, 79 | "grid_auto_rows": null, 80 | "grid_gap": null, 81 | "max_width": null, 82 | "order": null, 83 | "_view_module_version": "1.2.0", 84 | "grid_template_areas": null, 85 | "object_position": null, 86 | "object_fit": null, 87 | "grid_auto_columns": null, 88 | "margin": null, 89 | "display": "inline-flex", 90 | "left": null 91 | } 92 | }, 93 | "6f0e8c4ca555468fb2d14b929434fc49": { 94 | "model_module": "@jupyter-widgets/controls", 95 | "model_name": "HTMLModel", 96 | "model_module_version": "1.5.0", 97 | "state": { 98 | "_view_name": "HTMLView", 99 | "style": "IPY_MODEL_c26031cf9401466dad958a13dd606dc8", 100 | "_dom_classes": [], 101 | "description": "", 102 | "_model_name": "HTMLModel", 103 | "placeholder": "​", 104 | "_view_module": "@jupyter-widgets/controls", 105 | "_model_module_version": "1.5.0", 106 | "value": "Validation sanity check: 100%", 107 | "_view_count": null, 108 | "_view_module_version": "1.5.0", 109 | "description_tooltip": null, 110 | "_model_module": "@jupyter-widgets/controls", 111 | "layout": "IPY_MODEL_c3ac49e08f2243cf8d0fd5beb631a10f" 112 | } 113 | }, 114 | "7c19031e4b1a4306ae282a6cdf4a0935": { 115 | "model_module": "@jupyter-widgets/controls", 116 | "model_name": "FloatProgressModel", 117 | "model_module_version": "1.5.0", 118 | "state": { 119 | "_view_name": "ProgressView", 120 | "style": "IPY_MODEL_874dce2b24c6419b8a8f32ab264b4daa", 121 | "_dom_classes": [], 122 | "description": "", 123 | "_model_name": "FloatProgressModel", 124 | "bar_style": "", 125 | "max": 2, 126 | "_view_module": "@jupyter-widgets/controls", 127 | "_model_module_version": "1.5.0", 128 | "value": 2, 129 | "_view_count": null, 130 | "_view_module_version": "1.5.0", 131 | "orientation": "horizontal", 132 | "min": 0, 133 | "description_tooltip": null, 134 | "_model_module": "@jupyter-widgets/controls", 135 | "layout": "IPY_MODEL_dbf216bbe04a45ba89dc71d9e0632f6e" 136 | } 137 | }, 138 | "960da51a7dc342c4ae478cfa551de087": { 139 | "model_module": "@jupyter-widgets/controls", 140 | "model_name": "HTMLModel", 141 | "model_module_version": "1.5.0", 142 | "state": { 143 | "_view_name": "HTMLView", 144 | "style": "IPY_MODEL_330e6b856caf4754a302be2f7fd66099", 145 | "_dom_classes": [], 146 | "description": "", 147 | "_model_name": "HTMLModel", 148 | "placeholder": "​", 149 | "_view_module": "@jupyter-widgets/controls", 150 | "_model_module_version": "1.5.0", 151 | "value": " 2/2 [00:00<00:00, 11.97it/s]", 152 | "_view_count": null, 153 | "_view_module_version": "1.5.0", 154 | "description_tooltip": null, 155 | "_model_module": "@jupyter-widgets/controls", 156 | "layout": "IPY_MODEL_ee28e600a91944cab7a85d06b78add4c" 157 | } 158 | }, 159 | "c26031cf9401466dad958a13dd606dc8": { 160 | "model_module": "@jupyter-widgets/controls", 161 | "model_name": "DescriptionStyleModel", 162 | "model_module_version": "1.5.0", 163 | "state": { 164 | "_view_name": "StyleView", 165 | "_model_name": "DescriptionStyleModel", 166 | "description_width": "", 167 | "_view_module": "@jupyter-widgets/base", 168 | "_model_module_version": "1.5.0", 169 | "_view_count": null, 170 | "_view_module_version": "1.2.0", 171 | "_model_module": "@jupyter-widgets/controls" 172 | } 173 | }, 174 | "c3ac49e08f2243cf8d0fd5beb631a10f": { 175 | "model_module": "@jupyter-widgets/base", 176 | "model_name": "LayoutModel", 177 | "model_module_version": "1.2.0", 178 | "state": { 179 | "_view_name": "LayoutView", 180 | "grid_template_rows": null, 181 | "right": null, 182 | "justify_content": null, 183 | "_view_module": "@jupyter-widgets/base", 184 | "overflow": null, 185 | "_model_module_version": "1.2.0", 186 | "_view_count": null, 187 | "flex_flow": null, 188 | "width": null, 189 | "min_width": null, 190 | "border": null, 191 | "align_items": null, 192 | "bottom": null, 193 | "_model_module": "@jupyter-widgets/base", 194 | "top": null, 195 | "grid_column": null, 196 | "overflow_y": null, 197 | "overflow_x": null, 198 | "grid_auto_flow": null, 199 | "grid_area": null, 200 | "grid_template_columns": null, 201 | "flex": null, 202 | "_model_name": "LayoutModel", 203 | "justify_items": null, 204 | "grid_row": null, 205 | "max_height": null, 206 | "align_content": null, 207 | "visibility": null, 208 | "align_self": null, 209 | "height": null, 210 | "min_height": null, 211 | "padding": null, 212 | "grid_auto_rows": null, 213 | "grid_gap": null, 214 | "max_width": null, 215 | "order": null, 216 | "_view_module_version": "1.2.0", 217 | "grid_template_areas": null, 218 | "object_position": null, 219 | "object_fit": null, 220 | "grid_auto_columns": null, 221 | "margin": null, 222 | "display": null, 223 | "left": null 224 | } 225 | }, 226 | "874dce2b24c6419b8a8f32ab264b4daa": { 227 | "model_module": "@jupyter-widgets/controls", 228 | "model_name": "ProgressStyleModel", 229 | "model_module_version": "1.5.0", 230 | "state": { 231 | "_view_name": "StyleView", 232 | "_model_name": "ProgressStyleModel", 233 | "description_width": "", 234 | "_view_module": "@jupyter-widgets/base", 235 | "_model_module_version": "1.5.0", 236 | "_view_count": null, 237 | "_view_module_version": "1.2.0", 238 | "bar_color": null, 239 | "_model_module": "@jupyter-widgets/controls" 240 | } 241 | }, 242 | "dbf216bbe04a45ba89dc71d9e0632f6e": { 243 | "model_module": "@jupyter-widgets/base", 244 | "model_name": "LayoutModel", 245 | "model_module_version": "1.2.0", 246 | "state": { 247 | "_view_name": "LayoutView", 248 | "grid_template_rows": null, 249 | "right": null, 250 | "justify_content": null, 251 | "_view_module": "@jupyter-widgets/base", 252 | "overflow": null, 253 | "_model_module_version": "1.2.0", 254 | "_view_count": null, 255 | "flex_flow": null, 256 | "width": null, 257 | "min_width": null, 258 | "border": null, 259 | "align_items": null, 260 | "bottom": null, 261 | "_model_module": "@jupyter-widgets/base", 262 | "top": null, 263 | "grid_column": null, 264 | "overflow_y": null, 265 | "overflow_x": null, 266 | "grid_auto_flow": null, 267 | "grid_area": null, 268 | "grid_template_columns": null, 269 | "flex": "2", 270 | "_model_name": "LayoutModel", 271 | "justify_items": null, 272 | "grid_row": null, 273 | "max_height": null, 274 | "align_content": null, 275 | "visibility": null, 276 | "align_self": null, 277 | "height": null, 278 | "min_height": null, 279 | "padding": null, 280 | "grid_auto_rows": null, 281 | "grid_gap": null, 282 | "max_width": null, 283 | "order": null, 284 | "_view_module_version": "1.2.0", 285 | "grid_template_areas": null, 286 | "object_position": null, 287 | "object_fit": null, 288 | "grid_auto_columns": null, 289 | "margin": null, 290 | "display": null, 291 | "left": null 292 | } 293 | }, 294 | "330e6b856caf4754a302be2f7fd66099": { 295 | "model_module": "@jupyter-widgets/controls", 296 | "model_name": "DescriptionStyleModel", 297 | "model_module_version": "1.5.0", 298 | "state": { 299 | "_view_name": "StyleView", 300 | "_model_name": "DescriptionStyleModel", 301 | "description_width": "", 302 | "_view_module": "@jupyter-widgets/base", 303 | "_model_module_version": "1.5.0", 304 | "_view_count": null, 305 | "_view_module_version": "1.2.0", 306 | "_model_module": "@jupyter-widgets/controls" 307 | } 308 | }, 309 | "ee28e600a91944cab7a85d06b78add4c": { 310 | "model_module": "@jupyter-widgets/base", 311 | "model_name": "LayoutModel", 312 | "model_module_version": "1.2.0", 313 | "state": { 314 | "_view_name": "LayoutView", 315 | "grid_template_rows": null, 316 | "right": null, 317 | "justify_content": null, 318 | "_view_module": "@jupyter-widgets/base", 319 | "overflow": null, 320 | "_model_module_version": "1.2.0", 321 | "_view_count": null, 322 | "flex_flow": null, 323 | "width": null, 324 | "min_width": null, 325 | "border": null, 326 | "align_items": null, 327 | "bottom": null, 328 | "_model_module": "@jupyter-widgets/base", 329 | "top": null, 330 | "grid_column": null, 331 | "overflow_y": null, 332 | "overflow_x": null, 333 | "grid_auto_flow": null, 334 | "grid_area": null, 335 | "grid_template_columns": null, 336 | "flex": null, 337 | "_model_name": "LayoutModel", 338 | "justify_items": null, 339 | "grid_row": null, 340 | "max_height": null, 341 | "align_content": null, 342 | "visibility": null, 343 | "align_self": null, 344 | "height": null, 345 | "min_height": null, 346 | "padding": null, 347 | "grid_auto_rows": null, 348 | "grid_gap": null, 349 | "max_width": null, 350 | "order": null, 351 | "_view_module_version": "1.2.0", 352 | "grid_template_areas": null, 353 | "object_position": null, 354 | "object_fit": null, 355 | "grid_auto_columns": null, 356 | "margin": null, 357 | "display": null, 358 | "left": null 359 | } 360 | }, 361 | "c08f1e54817a4cb399d13f70cc631556": { 362 | "model_module": "@jupyter-widgets/controls", 363 | "model_name": "HBoxModel", 364 | "model_module_version": "1.5.0", 365 | "state": { 366 | "_view_name": "HBoxView", 367 | "_dom_classes": [], 368 | "_model_name": "HBoxModel", 369 | "_view_module": "@jupyter-widgets/controls", 370 | "_model_module_version": "1.5.0", 371 | "_view_count": null, 372 | "_view_module_version": "1.5.0", 373 | "box_style": "", 374 | "layout": "IPY_MODEL_29db7d286e214531afde8496c7eac53b", 375 | "_model_module": "@jupyter-widgets/controls", 376 | "children": [ 377 | "IPY_MODEL_da67d06f25b341938d96611e7ca9021b", 378 | "IPY_MODEL_2e774d5dbdf24f70a7c6fcf361111f69", 379 | "IPY_MODEL_ac2821869fb44bc8afa959c9a9f80bd3" 380 | ] 381 | } 382 | }, 383 | "29db7d286e214531afde8496c7eac53b": { 384 | "model_module": "@jupyter-widgets/base", 385 | "model_name": "LayoutModel", 386 | "model_module_version": "1.2.0", 387 | "state": { 388 | "_view_name": "LayoutView", 389 | "grid_template_rows": null, 390 | "right": null, 391 | "justify_content": null, 392 | "_view_module": "@jupyter-widgets/base", 393 | "overflow": null, 394 | "_model_module_version": "1.2.0", 395 | "_view_count": null, 396 | "flex_flow": "row wrap", 397 | "width": "100%", 398 | "min_width": null, 399 | "border": null, 400 | "align_items": null, 401 | "bottom": null, 402 | "_model_module": "@jupyter-widgets/base", 403 | "top": null, 404 | "grid_column": null, 405 | "overflow_y": null, 406 | "overflow_x": null, 407 | "grid_auto_flow": null, 408 | "grid_area": null, 409 | "grid_template_columns": null, 410 | "flex": null, 411 | "_model_name": "LayoutModel", 412 | "justify_items": null, 413 | "grid_row": null, 414 | "max_height": null, 415 | "align_content": null, 416 | "visibility": null, 417 | "align_self": null, 418 | "height": null, 419 | "min_height": null, 420 | "padding": null, 421 | "grid_auto_rows": null, 422 | "grid_gap": null, 423 | "max_width": null, 424 | "order": null, 425 | "_view_module_version": "1.2.0", 426 | "grid_template_areas": null, 427 | "object_position": null, 428 | "object_fit": null, 429 | "grid_auto_columns": null, 430 | "margin": null, 431 | "display": "inline-flex", 432 | "left": null 433 | } 434 | }, 435 | "da67d06f25b341938d96611e7ca9021b": { 436 | "model_module": "@jupyter-widgets/controls", 437 | "model_name": "HTMLModel", 438 | "model_module_version": "1.5.0", 439 | "state": { 440 | "_view_name": "HTMLView", 441 | "style": "IPY_MODEL_c66197c3cd064d858215136526c36970", 442 | "_dom_classes": [], 443 | "description": "", 444 | "_model_name": "HTMLModel", 445 | "placeholder": "​", 446 | "_view_module": "@jupyter-widgets/controls", 447 | "_model_module_version": "1.5.0", 448 | "value": "Epoch 0: 100%", 449 | "_view_count": null, 450 | "_view_module_version": "1.5.0", 451 | "description_tooltip": null, 452 | "_model_module": "@jupyter-widgets/controls", 453 | "layout": "IPY_MODEL_a44ca1d1c7e949f1a237253483683f14" 454 | } 455 | }, 456 | "2e774d5dbdf24f70a7c6fcf361111f69": { 457 | "model_module": "@jupyter-widgets/controls", 458 | "model_name": "FloatProgressModel", 459 | "model_module_version": "1.5.0", 460 | "state": { 461 | "_view_name": "ProgressView", 462 | "style": "IPY_MODEL_85b17c14e13a452f98f84c6b75b00915", 463 | "_dom_classes": [], 464 | "description": "", 465 | "_model_name": "FloatProgressModel", 466 | "bar_style": "success", 467 | "max": 36, 468 | "_view_module": "@jupyter-widgets/controls", 469 | "_model_module_version": "1.5.0", 470 | "value": 36, 471 | "_view_count": null, 472 | "_view_module_version": "1.5.0", 473 | "orientation": "horizontal", 474 | "min": 0, 475 | "description_tooltip": null, 476 | "_model_module": "@jupyter-widgets/controls", 477 | "layout": "IPY_MODEL_0326747e2fce41819917749f6f95b184" 478 | } 479 | }, 480 | "ac2821869fb44bc8afa959c9a9f80bd3": { 481 | "model_module": "@jupyter-widgets/controls", 482 | "model_name": "HTMLModel", 483 | "model_module_version": "1.5.0", 484 | "state": { 485 | "_view_name": "HTMLView", 486 | "style": "IPY_MODEL_516a6e31fa704f9a87f32fd131a6ad2b", 487 | "_dom_classes": [], 488 | "description": "", 489 | "_model_name": "HTMLModel", 490 | "placeholder": "​", 491 | "_view_module": "@jupyter-widgets/controls", 492 | "_model_module_version": "1.5.0", 493 | "value": " 36/36 [00:02<00:00, 12.98it/s, loss=0.668, v_num=2]", 494 | "_view_count": null, 495 | "_view_module_version": "1.5.0", 496 | "description_tooltip": null, 497 | "_model_module": "@jupyter-widgets/controls", 498 | "layout": "IPY_MODEL_6f3262e846fe491daa0dc4c091989888" 499 | } 500 | }, 501 | "c66197c3cd064d858215136526c36970": { 502 | "model_module": "@jupyter-widgets/controls", 503 | "model_name": "DescriptionStyleModel", 504 | "model_module_version": "1.5.0", 505 | "state": { 506 | "_view_name": "StyleView", 507 | "_model_name": "DescriptionStyleModel", 508 | "description_width": "", 509 | "_view_module": "@jupyter-widgets/base", 510 | "_model_module_version": "1.5.0", 511 | "_view_count": null, 512 | "_view_module_version": "1.2.0", 513 | "_model_module": "@jupyter-widgets/controls" 514 | } 515 | }, 516 | "a44ca1d1c7e949f1a237253483683f14": { 517 | "model_module": "@jupyter-widgets/base", 518 | "model_name": "LayoutModel", 519 | "model_module_version": "1.2.0", 520 | "state": { 521 | "_view_name": "LayoutView", 522 | "grid_template_rows": null, 523 | "right": null, 524 | "justify_content": null, 525 | "_view_module": "@jupyter-widgets/base", 526 | "overflow": null, 527 | "_model_module_version": "1.2.0", 528 | "_view_count": null, 529 | "flex_flow": null, 530 | "width": null, 531 | "min_width": null, 532 | "border": null, 533 | "align_items": null, 534 | "bottom": null, 535 | "_model_module": "@jupyter-widgets/base", 536 | "top": null, 537 | "grid_column": null, 538 | "overflow_y": null, 539 | "overflow_x": null, 540 | "grid_auto_flow": null, 541 | "grid_area": null, 542 | "grid_template_columns": null, 543 | "flex": null, 544 | "_model_name": "LayoutModel", 545 | "justify_items": null, 546 | "grid_row": null, 547 | "max_height": null, 548 | "align_content": null, 549 | "visibility": null, 550 | "align_self": null, 551 | "height": null, 552 | "min_height": null, 553 | "padding": null, 554 | "grid_auto_rows": null, 555 | "grid_gap": null, 556 | "max_width": null, 557 | "order": null, 558 | "_view_module_version": "1.2.0", 559 | "grid_template_areas": null, 560 | "object_position": null, 561 | "object_fit": null, 562 | "grid_auto_columns": null, 563 | "margin": null, 564 | "display": null, 565 | "left": null 566 | } 567 | }, 568 | "85b17c14e13a452f98f84c6b75b00915": { 569 | "model_module": "@jupyter-widgets/controls", 570 | "model_name": "ProgressStyleModel", 571 | "model_module_version": "1.5.0", 572 | "state": { 573 | "_view_name": "StyleView", 574 | "_model_name": "ProgressStyleModel", 575 | "description_width": "", 576 | "_view_module": "@jupyter-widgets/base", 577 | "_model_module_version": "1.5.0", 578 | "_view_count": null, 579 | "_view_module_version": "1.2.0", 580 | "bar_color": null, 581 | "_model_module": "@jupyter-widgets/controls" 582 | } 583 | }, 584 | "0326747e2fce41819917749f6f95b184": { 585 | "model_module": "@jupyter-widgets/base", 586 | "model_name": "LayoutModel", 587 | "model_module_version": "1.2.0", 588 | "state": { 589 | "_view_name": "LayoutView", 590 | "grid_template_rows": null, 591 | "right": null, 592 | "justify_content": null, 593 | "_view_module": "@jupyter-widgets/base", 594 | "overflow": null, 595 | "_model_module_version": "1.2.0", 596 | "_view_count": null, 597 | "flex_flow": null, 598 | "width": null, 599 | "min_width": null, 600 | "border": null, 601 | "align_items": null, 602 | "bottom": null, 603 | "_model_module": "@jupyter-widgets/base", 604 | "top": null, 605 | "grid_column": null, 606 | "overflow_y": null, 607 | "overflow_x": null, 608 | "grid_auto_flow": null, 609 | "grid_area": null, 610 | "grid_template_columns": null, 611 | "flex": "2", 612 | "_model_name": "LayoutModel", 613 | "justify_items": null, 614 | "grid_row": null, 615 | "max_height": null, 616 | "align_content": null, 617 | "visibility": null, 618 | "align_self": null, 619 | "height": null, 620 | "min_height": null, 621 | "padding": null, 622 | "grid_auto_rows": null, 623 | "grid_gap": null, 624 | "max_width": null, 625 | "order": null, 626 | "_view_module_version": "1.2.0", 627 | "grid_template_areas": null, 628 | "object_position": null, 629 | "object_fit": null, 630 | "grid_auto_columns": null, 631 | "margin": null, 632 | "display": null, 633 | "left": null 634 | } 635 | }, 636 | "516a6e31fa704f9a87f32fd131a6ad2b": { 637 | "model_module": "@jupyter-widgets/controls", 638 | "model_name": "DescriptionStyleModel", 639 | "model_module_version": "1.5.0", 640 | "state": { 641 | "_view_name": "StyleView", 642 | "_model_name": "DescriptionStyleModel", 643 | "description_width": "", 644 | "_view_module": "@jupyter-widgets/base", 645 | "_model_module_version": "1.5.0", 646 | "_view_count": null, 647 | "_view_module_version": "1.2.0", 648 | "_model_module": "@jupyter-widgets/controls" 649 | } 650 | }, 651 | "6f3262e846fe491daa0dc4c091989888": { 652 | "model_module": "@jupyter-widgets/base", 653 | "model_name": "LayoutModel", 654 | "model_module_version": "1.2.0", 655 | "state": { 656 | "_view_name": "LayoutView", 657 | "grid_template_rows": null, 658 | "right": null, 659 | "justify_content": null, 660 | "_view_module": "@jupyter-widgets/base", 661 | "overflow": null, 662 | "_model_module_version": "1.2.0", 663 | "_view_count": null, 664 | "flex_flow": null, 665 | "width": null, 666 | "min_width": null, 667 | "border": null, 668 | "align_items": null, 669 | "bottom": null, 670 | "_model_module": "@jupyter-widgets/base", 671 | "top": null, 672 | "grid_column": null, 673 | "overflow_y": null, 674 | "overflow_x": null, 675 | "grid_auto_flow": null, 676 | "grid_area": null, 677 | "grid_template_columns": null, 678 | "flex": null, 679 | "_model_name": "LayoutModel", 680 | "justify_items": null, 681 | "grid_row": null, 682 | "max_height": null, 683 | "align_content": null, 684 | "visibility": null, 685 | "align_self": null, 686 | "height": null, 687 | "min_height": null, 688 | "padding": null, 689 | "grid_auto_rows": null, 690 | "grid_gap": null, 691 | "max_width": null, 692 | "order": null, 693 | "_view_module_version": "1.2.0", 694 | "grid_template_areas": null, 695 | "object_position": null, 696 | "object_fit": null, 697 | "grid_auto_columns": null, 698 | "margin": null, 699 | "display": null, 700 | "left": null 701 | } 702 | }, 703 | "d0192776eb604b769b232b368d4eab27": { 704 | "model_module": "@jupyter-widgets/controls", 705 | "model_name": "HBoxModel", 706 | "model_module_version": "1.5.0", 707 | "state": { 708 | "_view_name": "HBoxView", 709 | "_dom_classes": [], 710 | "_model_name": "HBoxModel", 711 | "_view_module": "@jupyter-widgets/controls", 712 | "_model_module_version": "1.5.0", 713 | "_view_count": null, 714 | "_view_module_version": "1.5.0", 715 | "box_style": "", 716 | "layout": "IPY_MODEL_c903b21f5d4745388d9b0eca0300ca4c", 717 | "_model_module": "@jupyter-widgets/controls", 718 | "children": [ 719 | "IPY_MODEL_ff7261f767004966a8d6516f695b6196", 720 | "IPY_MODEL_93f7cff642bd46fcb860b57856b5ffa5", 721 | "IPY_MODEL_2aaaeda7133e412fb5ce83f22b86776d" 722 | ] 723 | } 724 | }, 725 | "c903b21f5d4745388d9b0eca0300ca4c": { 726 | "model_module": "@jupyter-widgets/base", 727 | "model_name": "LayoutModel", 728 | "model_module_version": "1.2.0", 729 | "state": { 730 | "_view_name": "LayoutView", 731 | "grid_template_rows": null, 732 | "right": null, 733 | "justify_content": null, 734 | "_view_module": "@jupyter-widgets/base", 735 | "overflow": null, 736 | "_model_module_version": "1.2.0", 737 | "_view_count": null, 738 | "flex_flow": "row wrap", 739 | "width": "100%", 740 | "min_width": null, 741 | "border": null, 742 | "align_items": null, 743 | "bottom": null, 744 | "_model_module": "@jupyter-widgets/base", 745 | "top": null, 746 | "grid_column": null, 747 | "overflow_y": null, 748 | "overflow_x": null, 749 | "grid_auto_flow": null, 750 | "grid_area": null, 751 | "grid_template_columns": null, 752 | "flex": null, 753 | "_model_name": "LayoutModel", 754 | "justify_items": null, 755 | "grid_row": null, 756 | "max_height": null, 757 | "align_content": null, 758 | "visibility": null, 759 | "align_self": null, 760 | "height": null, 761 | "min_height": null, 762 | "padding": null, 763 | "grid_auto_rows": null, 764 | "grid_gap": null, 765 | "max_width": null, 766 | "order": null, 767 | "_view_module_version": "1.2.0", 768 | "grid_template_areas": null, 769 | "object_position": null, 770 | "object_fit": null, 771 | "grid_auto_columns": null, 772 | "margin": null, 773 | "display": "inline-flex", 774 | "left": null 775 | } 776 | }, 777 | "ff7261f767004966a8d6516f695b6196": { 778 | "model_module": "@jupyter-widgets/controls", 779 | "model_name": "HTMLModel", 780 | "model_module_version": "1.5.0", 781 | "state": { 782 | "_view_name": "HTMLView", 783 | "style": "IPY_MODEL_74b70ff5e7ce40bd85a7acb38741df9c", 784 | "_dom_classes": [], 785 | "description": "", 786 | "_model_name": "HTMLModel", 787 | "placeholder": "​", 788 | "_view_module": "@jupyter-widgets/controls", 789 | "_model_module_version": "1.5.0", 790 | "value": "Validating: 100%", 791 | "_view_count": null, 792 | "_view_module_version": "1.5.0", 793 | "description_tooltip": null, 794 | "_model_module": "@jupyter-widgets/controls", 795 | "layout": "IPY_MODEL_a442f21cc52b4c4ab80f47906230a20c" 796 | } 797 | }, 798 | "93f7cff642bd46fcb860b57856b5ffa5": { 799 | "model_module": "@jupyter-widgets/controls", 800 | "model_name": "FloatProgressModel", 801 | "model_module_version": "1.5.0", 802 | "state": { 803 | "_view_name": "ProgressView", 804 | "style": "IPY_MODEL_6d3233c23d394ee3898a9fbcf74da4b4", 805 | "_dom_classes": [], 806 | "description": "", 807 | "_model_name": "FloatProgressModel", 808 | "bar_style": "", 809 | "max": 8, 810 | "_view_module": "@jupyter-widgets/controls", 811 | "_model_module_version": "1.5.0", 812 | "value": 8, 813 | "_view_count": null, 814 | "_view_module_version": "1.5.0", 815 | "orientation": "horizontal", 816 | "min": 0, 817 | "description_tooltip": null, 818 | "_model_module": "@jupyter-widgets/controls", 819 | "layout": "IPY_MODEL_be81a64971d547e1aefd6829024ca619" 820 | } 821 | }, 822 | "2aaaeda7133e412fb5ce83f22b86776d": { 823 | "model_module": "@jupyter-widgets/controls", 824 | "model_name": "HTMLModel", 825 | "model_module_version": "1.5.0", 826 | "state": { 827 | "_view_name": "HTMLView", 828 | "style": "IPY_MODEL_7348ab1840f0481781d597b41d390f95", 829 | "_dom_classes": [], 830 | "description": "", 831 | "_model_name": "HTMLModel", 832 | "placeholder": "​", 833 | "_view_module": "@jupyter-widgets/controls", 834 | "_model_module_version": "1.5.0", 835 | "value": " 8/8 [00:00<00:00, 25.63it/s]", 836 | "_view_count": null, 837 | "_view_module_version": "1.5.0", 838 | "description_tooltip": null, 839 | "_model_module": "@jupyter-widgets/controls", 840 | "layout": "IPY_MODEL_9e96dcb1fd064488b720be58dcc4a1a3" 841 | } 842 | }, 843 | "74b70ff5e7ce40bd85a7acb38741df9c": { 844 | "model_module": "@jupyter-widgets/controls", 845 | "model_name": "DescriptionStyleModel", 846 | "model_module_version": "1.5.0", 847 | "state": { 848 | "_view_name": "StyleView", 849 | "_model_name": "DescriptionStyleModel", 850 | "description_width": "", 851 | "_view_module": "@jupyter-widgets/base", 852 | "_model_module_version": "1.5.0", 853 | "_view_count": null, 854 | "_view_module_version": "1.2.0", 855 | "_model_module": "@jupyter-widgets/controls" 856 | } 857 | }, 858 | "a442f21cc52b4c4ab80f47906230a20c": { 859 | "model_module": "@jupyter-widgets/base", 860 | "model_name": "LayoutModel", 861 | "model_module_version": "1.2.0", 862 | "state": { 863 | "_view_name": "LayoutView", 864 | "grid_template_rows": null, 865 | "right": null, 866 | "justify_content": null, 867 | "_view_module": "@jupyter-widgets/base", 868 | "overflow": null, 869 | "_model_module_version": "1.2.0", 870 | "_view_count": null, 871 | "flex_flow": null, 872 | "width": null, 873 | "min_width": null, 874 | "border": null, 875 | "align_items": null, 876 | "bottom": null, 877 | "_model_module": "@jupyter-widgets/base", 878 | "top": null, 879 | "grid_column": null, 880 | "overflow_y": null, 881 | "overflow_x": null, 882 | "grid_auto_flow": null, 883 | "grid_area": null, 884 | "grid_template_columns": null, 885 | "flex": null, 886 | "_model_name": "LayoutModel", 887 | "justify_items": null, 888 | "grid_row": null, 889 | "max_height": null, 890 | "align_content": null, 891 | "visibility": null, 892 | "align_self": null, 893 | "height": null, 894 | "min_height": null, 895 | "padding": null, 896 | "grid_auto_rows": null, 897 | "grid_gap": null, 898 | "max_width": null, 899 | "order": null, 900 | "_view_module_version": "1.2.0", 901 | "grid_template_areas": null, 902 | "object_position": null, 903 | "object_fit": null, 904 | "grid_auto_columns": null, 905 | "margin": null, 906 | "display": null, 907 | "left": null 908 | } 909 | }, 910 | "6d3233c23d394ee3898a9fbcf74da4b4": { 911 | "model_module": "@jupyter-widgets/controls", 912 | "model_name": "ProgressStyleModel", 913 | "model_module_version": "1.5.0", 914 | "state": { 915 | "_view_name": "StyleView", 916 | "_model_name": "ProgressStyleModel", 917 | "description_width": "", 918 | "_view_module": "@jupyter-widgets/base", 919 | "_model_module_version": "1.5.0", 920 | "_view_count": null, 921 | "_view_module_version": "1.2.0", 922 | "bar_color": null, 923 | "_model_module": "@jupyter-widgets/controls" 924 | } 925 | }, 926 | "be81a64971d547e1aefd6829024ca619": { 927 | "model_module": "@jupyter-widgets/base", 928 | "model_name": "LayoutModel", 929 | "model_module_version": "1.2.0", 930 | "state": { 931 | "_view_name": "LayoutView", 932 | "grid_template_rows": null, 933 | "right": null, 934 | "justify_content": null, 935 | "_view_module": "@jupyter-widgets/base", 936 | "overflow": null, 937 | "_model_module_version": "1.2.0", 938 | "_view_count": null, 939 | "flex_flow": null, 940 | "width": null, 941 | "min_width": null, 942 | "border": null, 943 | "align_items": null, 944 | "bottom": null, 945 | "_model_module": "@jupyter-widgets/base", 946 | "top": null, 947 | "grid_column": null, 948 | "overflow_y": null, 949 | "overflow_x": null, 950 | "grid_auto_flow": null, 951 | "grid_area": null, 952 | "grid_template_columns": null, 953 | "flex": "2", 954 | "_model_name": "LayoutModel", 955 | "justify_items": null, 956 | "grid_row": null, 957 | "max_height": null, 958 | "align_content": null, 959 | "visibility": null, 960 | "align_self": null, 961 | "height": null, 962 | "min_height": null, 963 | "padding": null, 964 | "grid_auto_rows": null, 965 | "grid_gap": null, 966 | "max_width": null, 967 | "order": null, 968 | "_view_module_version": "1.2.0", 969 | "grid_template_areas": null, 970 | "object_position": null, 971 | "object_fit": null, 972 | "grid_auto_columns": null, 973 | "margin": null, 974 | "display": null, 975 | "left": null 976 | } 977 | }, 978 | "7348ab1840f0481781d597b41d390f95": { 979 | "model_module": "@jupyter-widgets/controls", 980 | "model_name": "DescriptionStyleModel", 981 | "model_module_version": "1.5.0", 982 | "state": { 983 | "_view_name": "StyleView", 984 | "_model_name": "DescriptionStyleModel", 985 | "description_width": "", 986 | "_view_module": "@jupyter-widgets/base", 987 | "_model_module_version": "1.5.0", 988 | "_view_count": null, 989 | "_view_module_version": "1.2.0", 990 | "_model_module": "@jupyter-widgets/controls" 991 | } 992 | }, 993 | "9e96dcb1fd064488b720be58dcc4a1a3": { 994 | "model_module": "@jupyter-widgets/base", 995 | "model_name": "LayoutModel", 996 | "model_module_version": "1.2.0", 997 | "state": { 998 | "_view_name": "LayoutView", 999 | "grid_template_rows": null, 1000 | "right": null, 1001 | "justify_content": null, 1002 | "_view_module": "@jupyter-widgets/base", 1003 | "overflow": null, 1004 | "_model_module_version": "1.2.0", 1005 | "_view_count": null, 1006 | "flex_flow": null, 1007 | "width": null, 1008 | "min_width": null, 1009 | "border": null, 1010 | "align_items": null, 1011 | "bottom": null, 1012 | "_model_module": "@jupyter-widgets/base", 1013 | "top": null, 1014 | "grid_column": null, 1015 | "overflow_y": null, 1016 | "overflow_x": null, 1017 | "grid_auto_flow": null, 1018 | "grid_area": null, 1019 | "grid_template_columns": null, 1020 | "flex": null, 1021 | "_model_name": "LayoutModel", 1022 | "justify_items": null, 1023 | "grid_row": null, 1024 | "max_height": null, 1025 | "align_content": null, 1026 | "visibility": null, 1027 | "align_self": null, 1028 | "height": null, 1029 | "min_height": null, 1030 | "padding": null, 1031 | "grid_auto_rows": null, 1032 | "grid_gap": null, 1033 | "max_width": null, 1034 | "order": null, 1035 | "_view_module_version": "1.2.0", 1036 | "grid_template_areas": null, 1037 | "object_position": null, 1038 | "object_fit": null, 1039 | "grid_auto_columns": null, 1040 | "margin": null, 1041 | "display": null, 1042 | "left": null 1043 | } 1044 | } 1045 | } 1046 | } 1047 | }, 1048 | "cells": [ 1049 | { 1050 | "cell_type": "code", 1051 | "source": [ 1052 | "%%capture\n", 1053 | "!pip install mne\n", 1054 | "!pip install pytorch-lightning" 1055 | ], 1056 | "metadata": { 1057 | "id": "TJvtghAiSX6A" 1058 | }, 1059 | "execution_count": null, 1060 | "outputs": [] 1061 | }, 1062 | { 1063 | "cell_type": "code", 1064 | "source": [ 1065 | "%%capture\n", 1066 | "!wget https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/fshy54ypyh-1.zip -O data.zip\n", 1067 | "!unzip data.zip\n" 1068 | ], 1069 | "metadata": { 1070 | "id": "LzFg29vfSewH" 1071 | }, 1072 | "execution_count": null, 1073 | "outputs": [] 1074 | }, 1075 | { 1076 | "cell_type": "code", 1077 | "source": [ 1078 | "from glob import glob\n", 1079 | "import scipy.io\n", 1080 | "import torch.nn as nn\n", 1081 | "import torch\n", 1082 | "import numpy as np\n", 1083 | "import mne" 1084 | ], 1085 | "metadata": { 1086 | "id": "epkRi81yS_y3" 1087 | }, 1088 | "execution_count": null, 1089 | "outputs": [] 1090 | }, 1091 | { 1092 | "cell_type": "code", 1093 | "source": [ 1094 | "input=torch.randn(3,22,15000)\n", 1095 | "input.shape" 1096 | ], 1097 | "metadata": { 1098 | "colab": { 1099 | "base_uri": "https://localhost:8080/" 1100 | }, 1101 | "id": "RtVMJnue8r4A", 1102 | "outputId": "a26ebbe1-e8e8-4017-abd3-dd2083a137da" 1103 | }, 1104 | "execution_count": null, 1105 | "outputs": [ 1106 | { 1107 | "output_type": "execute_result", 1108 | "data": { 1109 | "text/plain": [ 1110 | "torch.Size([3, 22, 15000])" 1111 | ] 1112 | }, 1113 | "metadata": {}, 1114 | "execution_count": 4 1115 | } 1116 | ] 1117 | }, 1118 | { 1119 | "cell_type": "code", 1120 | "source": [ 1121 | "class Block(nn.Module):\n", 1122 | " def __init__(self,inplace):\n", 1123 | " super().__init__()\n", 1124 | " self.conv1=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=2,stride=2,padding=0)\n", 1125 | " self.conv2=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=4,stride=2,padding=1)\n", 1126 | " self.conv3=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=8,stride=2,padding=3)\n", 1127 | " self.relu=nn.ReLU()\n", 1128 | "\n", 1129 | " def forward(self,x):\n", 1130 | " x1=self.relu(self.conv1(x))\n", 1131 | " x2=self.relu(self.conv2(x))\n", 1132 | " x3=self.relu(self.conv3(x))\n", 1133 | " x=torch.cat([x1,x3,x3],dim=1)\n", 1134 | " return x" 1135 | ], 1136 | "metadata": { 1137 | "id": "wtzMgY-M_FZL" 1138 | }, 1139 | "execution_count": null, 1140 | "outputs": [] 1141 | }, 1142 | { 1143 | "cell_type": "code", 1144 | "source": [ 1145 | "class ChronoNet(nn.Module):\n", 1146 | " def __init__(self,channel):\n", 1147 | " super().__init__()\n", 1148 | " self.block1=Block(channel)\n", 1149 | " self.block2=Block(96)\n", 1150 | " self.block3=Block(96)\n", 1151 | " self.gru1=nn.GRU(input_size=96,hidden_size=32,batch_first=True)\n", 1152 | " self.gru2=nn.GRU(input_size=32,hidden_size=32,batch_first=True)\n", 1153 | " self.gru3=nn.GRU(input_size=64,hidden_size=32,batch_first=True)\n", 1154 | " self.gru4=nn.GRU(input_size=96,hidden_size=32,batch_first=True)\n", 1155 | " self.gru_linear=nn.Linear(64,1)\n", 1156 | " self.flatten=nn.Flatten()\n", 1157 | " self.fc1=nn.Linear(32,1)\n", 1158 | " self.relu=nn.ReLU()\n", 1159 | " def forward(self,x):\n", 1160 | " x=self.block1(x)\n", 1161 | " x=self.block2(x)\n", 1162 | " x=self.block3(x)\n", 1163 | " x=x.permute(0,2,1)\n", 1164 | " gru_out1,_=self.gru1(x)\n", 1165 | " gru_out2,_=self.gru2(gru_out1)\n", 1166 | " gru_out=torch.cat([gru_out1,gru_out2],dim=2)\n", 1167 | " gru_out3,_=self.gru3(gru_out)\n", 1168 | " gru_out=torch.cat([gru_out1,gru_out2,gru_out3],dim=2)\n", 1169 | " #print('gru_out',gru_out.shape)\n", 1170 | " linear_out=self.relu(self.gru_linear(gru_out.permute(0,2,1)))\n", 1171 | " gru_out4,_=self.gru4(linear_out.permute(0,2,1))\n", 1172 | " x=self.flatten(gru_out4)\n", 1173 | " x=self.fc1(x)\n", 1174 | " return x" 1175 | ], 1176 | "metadata": { 1177 | "id": "yhLl-gVt4gtv" 1178 | }, 1179 | "execution_count": null, 1180 | "outputs": [] 1181 | }, 1182 | { 1183 | "cell_type": "code", 1184 | "source": [ 1185 | "input=torch.randn(3,14,512)\n", 1186 | "input.shape\n", 1187 | "model=ChronoNet(14)\n", 1188 | "out=model(input)\n", 1189 | "out.shape" 1190 | ], 1191 | "metadata": { 1192 | "colab": { 1193 | "base_uri": "https://localhost:8080/" 1194 | }, 1195 | "id": "0KLyQbZh4gwp", 1196 | "outputId": "f68eef69-eb6e-4173-803a-2377a1d7644c" 1197 | }, 1198 | "execution_count": null, 1199 | "outputs": [ 1200 | { 1201 | "output_type": "execute_result", 1202 | "data": { 1203 | "text/plain": [ 1204 | "torch.Size([3, 1])" 1205 | ] 1206 | }, 1207 | "metadata": {}, 1208 | "execution_count": 9 1209 | } 1210 | ] 1211 | }, 1212 | { 1213 | "cell_type": "code", 1214 | "source": [ 1215 | "IDD_data_path='/content/Data/CleanData/CleanData_TDC/Rest'\n", 1216 | "TDC_data_path='/content/Data/Data/CleanData/Data/Data/CleanData/CleanData_IDD/Rest'\n", 1217 | "!rm '/content/Data/Data/CleanData/Data/Data/CleanData/CleanData_IDD/Rest/NDS001_Rest_CD(1).mat'" 1218 | ], 1219 | "metadata": { 1220 | "id": "Z4cFUiGnBxB7" 1221 | }, 1222 | "execution_count": null, 1223 | "outputs": [] 1224 | }, 1225 | { 1226 | "cell_type": "code", 1227 | "source": [ 1228 | "def convertmat2mne(data):\n", 1229 | " ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']\n", 1230 | " ch_types = ['eeg'] * 14\n", 1231 | " sampling_freq=128\n", 1232 | " info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)\n", 1233 | " info.set_montage('standard_1020')\n", 1234 | " data=mne.io.RawArray(data, info)\n", 1235 | " data.set_eeg_reference()\n", 1236 | " data.filter(l_freq=1,h_freq=30)\n", 1237 | " epochs=mne.make_fixed_length_epochs(data,duration=4,overlap=0)\n", 1238 | " return epochs.get_data()" 1239 | ], 1240 | "metadata": { 1241 | "id": "FTESfB9jEylt" 1242 | }, 1243 | "execution_count": null, 1244 | "outputs": [] 1245 | }, 1246 | { 1247 | "cell_type": "code", 1248 | "source": [ 1249 | "%%capture\n", 1250 | "idd_subject=[]\n", 1251 | "for idd in glob(IDD_data_path+'/*.mat'):\n", 1252 | " data=scipy.io.loadmat(idd)['clean_data']\n", 1253 | " data=convertmat2mne(data)\n", 1254 | " idd_subject.append(data)" 1255 | ], 1256 | "metadata": { 1257 | "id": "9hbUxNpSS9U8" 1258 | }, 1259 | "execution_count": null, 1260 | "outputs": [] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "source": [ 1265 | "%%capture\n", 1266 | "tdc_subject=[]\n", 1267 | "for tdc in glob(TDC_data_path+'/*.mat'):\n", 1268 | " data=scipy.io.loadmat(tdc)['clean_data']\n", 1269 | " data=convertmat2mne(data)\n", 1270 | " tdc_subject.append(data)\n", 1271 | " " 1272 | ], 1273 | "metadata": { 1274 | "id": "9U1fOB2QF2U3" 1275 | }, 1276 | "execution_count": null, 1277 | "outputs": [] 1278 | }, 1279 | { 1280 | "cell_type": "code", 1281 | "source": [ 1282 | "len(idd_subject),len(tdc_subject)" 1283 | ], 1284 | "metadata": { 1285 | "colab": { 1286 | "base_uri": "https://localhost:8080/" 1287 | }, 1288 | "id": "IRVuAHySGnJT", 1289 | "outputId": "cc8c19c8-41a8-4d1f-8b2a-59cf6a59b750" 1290 | }, 1291 | "execution_count": null, 1292 | "outputs": [ 1293 | { 1294 | "output_type": "execute_result", 1295 | "data": { 1296 | "text/plain": [ 1297 | "(7, 7)" 1298 | ] 1299 | }, 1300 | "metadata": {}, 1301 | "execution_count": 14 1302 | } 1303 | ] 1304 | }, 1305 | { 1306 | "cell_type": "code", 1307 | "source": [ 1308 | "control_epochs_labels=[len(i)*[0] for i in tdc_subject]\n", 1309 | "patients_epochs_labels=[len(i)*[1] for i in idd_subject]\n", 1310 | "print(len(control_epochs_labels),len(patients_epochs_labels))" 1311 | ], 1312 | "metadata": { 1313 | "colab": { 1314 | "base_uri": "https://localhost:8080/" 1315 | }, 1316 | "id": "lc5-ca_RGus7", 1317 | "outputId": "0bcb6cd4-00c8-4c80-ac1d-a866e71766e6" 1318 | }, 1319 | "execution_count": null, 1320 | "outputs": [ 1321 | { 1322 | "output_type": "stream", 1323 | "name": "stdout", 1324 | "text": [ 1325 | "7 7\n" 1326 | ] 1327 | } 1328 | ] 1329 | }, 1330 | { 1331 | "cell_type": "code", 1332 | "source": [ 1333 | "data_list=tdc_subject+idd_subject\n", 1334 | "label_list=control_epochs_labels+patients_epochs_labels\n", 1335 | "groups_list=[[i]*len(j) for i, j in enumerate(data_list)]\n", 1336 | "print(len(data_list),len(label_list),len(groups_list))\n" 1337 | ], 1338 | "metadata": { 1339 | "colab": { 1340 | "base_uri": "https://localhost:8080/" 1341 | }, 1342 | "id": "gIz9MHzAG-HC", 1343 | "outputId": "0b14f19e-d7b0-4264-eeda-cfc15096f017" 1344 | }, 1345 | "execution_count": null, 1346 | "outputs": [ 1347 | { 1348 | "output_type": "stream", 1349 | "name": "stdout", 1350 | "text": [ 1351 | "14 14 14\n" 1352 | ] 1353 | } 1354 | ] 1355 | }, 1356 | { 1357 | "cell_type": "code", 1358 | "source": [ 1359 | "from sklearn.model_selection import GroupKFold,LeaveOneGroupOut\n", 1360 | "from sklearn.preprocessing import StandardScaler\n", 1361 | "gkf=GroupKFold()\n", 1362 | "from sklearn.base import TransformerMixin,BaseEstimator\n", 1363 | "from sklearn.preprocessing import StandardScaler\n", 1364 | "#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix\n", 1365 | "class StandardScaler3D(BaseEstimator,TransformerMixin):\n", 1366 | " #batch, sequence, channels\n", 1367 | " def __init__(self):\n", 1368 | " self.scaler = StandardScaler()\n", 1369 | "\n", 1370 | " def fit(self,X,y=None):\n", 1371 | " self.scaler.fit(X.reshape(-1, X.shape[2]))\n", 1372 | " return self\n", 1373 | "\n", 1374 | " def transform(self,X):\n", 1375 | " return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)" 1376 | ], 1377 | "metadata": { 1378 | "id": "rodPubFyMbwH" 1379 | }, 1380 | "execution_count": null, 1381 | "outputs": [] 1382 | }, 1383 | { 1384 | "cell_type": "code", 1385 | "source": [ 1386 | "import numpy as np\n", 1387 | "data_array=np.concatenate(data_list)\n", 1388 | "label_array=np.concatenate(label_list)\n", 1389 | "group_array=np.concatenate(groups_list)\n", 1390 | "data_array=np.moveaxis(data_array,1,2)\n", 1391 | "\n", 1392 | "print(data_array.shape,label_array.shape,group_array.shape)" 1393 | ], 1394 | "metadata": { 1395 | "colab": { 1396 | "base_uri": "https://localhost:8080/" 1397 | }, 1398 | "id": "k_9aWxq9MlFg", 1399 | "outputId": "f05ceda8-0c1e-440a-8f41-d8f2e33feae6" 1400 | }, 1401 | "execution_count": null, 1402 | "outputs": [ 1403 | { 1404 | "output_type": "stream", 1405 | "name": "stdout", 1406 | "text": [ 1407 | "(420, 512, 14) (420,) (420,)\n" 1408 | ] 1409 | } 1410 | ] 1411 | }, 1412 | { 1413 | "cell_type": "code", 1414 | "source": [ 1415 | "accuracy=[]\n", 1416 | "for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):\n", 1417 | " train_features,train_labels=data_array[train_index],label_array[train_index]\n", 1418 | " val_features,val_labels=data_array[val_index],label_array[val_index]\n", 1419 | " scaler=StandardScaler3D()\n", 1420 | " train_features=scaler.fit_transform(train_features)\n", 1421 | " val_features=scaler.transform(val_features)\n", 1422 | " train_features=np.moveaxis(train_features,1,2)\n", 1423 | " val_features=np.moveaxis(val_features,1,2)\n", 1424 | "\n", 1425 | " break" 1426 | ], 1427 | "metadata": { 1428 | "id": "qOq1dRqOMcFd" 1429 | }, 1430 | "execution_count": null, 1431 | "outputs": [] 1432 | }, 1433 | { 1434 | "cell_type": "code", 1435 | "source": [ 1436 | "train_features = torch.Tensor(train_features)\n", 1437 | "val_features = torch.Tensor(val_features)\n", 1438 | "train_labels = torch.Tensor(train_labels)\n", 1439 | "val_labels = torch.Tensor(val_labels)" 1440 | ], 1441 | "metadata": { 1442 | "id": "HDyijz8nbaK2" 1443 | }, 1444 | "execution_count": null, 1445 | "outputs": [] 1446 | }, 1447 | { 1448 | "cell_type": "code", 1449 | "source": [ 1450 | "len(val_features),len(val_labels)" 1451 | ], 1452 | "metadata": { 1453 | "colab": { 1454 | "base_uri": "https://localhost:8080/" 1455 | }, 1456 | "id": "adRqbzk5bF_X", 1457 | "outputId": "09df6384-e35f-4cef-9b10-0e4ec615fa90" 1458 | }, 1459 | "execution_count": null, 1460 | "outputs": [ 1461 | { 1462 | "output_type": "execute_result", 1463 | "data": { 1464 | "text/plain": [ 1465 | "(90, 90)" 1466 | ] 1467 | }, 1468 | "metadata": {}, 1469 | "execution_count": 22 1470 | } 1471 | ] 1472 | }, 1473 | { 1474 | "cell_type": "code", 1475 | "source": [ 1476 | "train_features.shape" 1477 | ], 1478 | "metadata": { 1479 | "colab": { 1480 | "base_uri": "https://localhost:8080/" 1481 | }, 1482 | "id": "qMsWqCDAMhGJ", 1483 | "outputId": "55bde110-fd35-48a3-d6d6-b26c30444d98" 1484 | }, 1485 | "execution_count": null, 1486 | "outputs": [ 1487 | { 1488 | "output_type": "execute_result", 1489 | "data": { 1490 | "text/plain": [ 1491 | "torch.Size([330, 14, 512])" 1492 | ] 1493 | }, 1494 | "metadata": {}, 1495 | "execution_count": 23 1496 | } 1497 | ] 1498 | }, 1499 | { 1500 | "cell_type": "code", 1501 | "source": [ 1502 | "from pytorch_lightning import LightningModule,Trainer\n", 1503 | "import torchmetrics\n", 1504 | "from torch.utils.data import TensorDataset,DataLoader" 1505 | ], 1506 | "metadata": { 1507 | "id": "xNpn_33LMvvT" 1508 | }, 1509 | "execution_count": null, 1510 | "outputs": [] 1511 | }, 1512 | { 1513 | "cell_type": "code", 1514 | "source": [ 1515 | "class ChronoModel(LightningModule):\n", 1516 | " def __init__(self):\n", 1517 | " super(ChronoModel,self).__init__()\n", 1518 | " self.model=ChronoNet(14)\n", 1519 | " self.lr=1e-3\n", 1520 | " self.bs=12\n", 1521 | " self.worker=2\n", 1522 | " self.acc=torchmetrics.Accuracy()\n", 1523 | " self.creterion=nn.BCEWithLogitsLoss()\n", 1524 | "\n", 1525 | " def forward(self,x):\n", 1526 | " x=self.model(x)\n", 1527 | " return x\n", 1528 | "\n", 1529 | " def configure_optimizers(self):\n", 1530 | " return torch.optim.Adam(self.parameters(),lr=self.lr)\n", 1531 | "\n", 1532 | " def train_dataloader(self):\n", 1533 | " dataset=TensorDataset(train_features,train_labels)\n", 1534 | " dataloader=DataLoader(dataset,batch_size=self.bs,num_workers=self.worker,shuffle=True)\n", 1535 | " return dataloader\n", 1536 | "\n", 1537 | " def training_step(self,batch,batch_idx):\n", 1538 | " signal,label=batch\n", 1539 | " out=self(signal.float())\n", 1540 | " loss=self.creterion(out.flatten(),label.float().flatten())\n", 1541 | " acc=self.acc(out.flatten(),label.long().flatten())\n", 1542 | " return {'loss':loss,'acc':acc}\n", 1543 | "\n", 1544 | " def trained_epoch_end(self,outputs):\n", 1545 | " acc=torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)\n", 1546 | " loss=torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)\n", 1547 | " print('train acc loss',acc,loss)\n", 1548 | "\n", 1549 | " def val_dataloader(self):\n", 1550 | " dataset=TensorDataset(val_features,val_labels)\n", 1551 | " dataloader=DataLoader(dataset,batch_size=self.bs,num_workers=self.worker,shuffle=True)\n", 1552 | " return dataloader\n", 1553 | "\n", 1554 | " def validation_step(self,batch,batch_idx):\n", 1555 | " signal,label=batch\n", 1556 | " out=self(signal.float())\n", 1557 | " loss=self.creterion(out.flatten(),label.float().flatten())\n", 1558 | " acc=self.acc(out.flatten(),label.long().flatten())\n", 1559 | " return {'loss':loss,'acc':acc}\n", 1560 | "\n", 1561 | " def validation_epoch_end(self,outputs):\n", 1562 | " acc=torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)\n", 1563 | " loss=torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)\n", 1564 | " print('val acc loss',acc,loss)\n", 1565 | " \n", 1566 | "\n", 1567 | "\n", 1568 | "\n", 1569 | "\n", 1570 | "\n", 1571 | "\n", 1572 | "\n", 1573 | "\n", 1574 | "\n" 1575 | ], 1576 | "metadata": { 1577 | "id": "W8I9ymGnalea" 1578 | }, 1579 | "execution_count": null, 1580 | "outputs": [] 1581 | }, 1582 | { 1583 | "cell_type": "code", 1584 | "source": [ 1585 | "model=ChronoModel()\n" 1586 | ], 1587 | "metadata": { 1588 | "id": "Cv6BN2NKMeds" 1589 | }, 1590 | "execution_count": null, 1591 | "outputs": [] 1592 | }, 1593 | { 1594 | "cell_type": "code", 1595 | "source": [ 1596 | "trainer=Trainer(max_epochs=1)" 1597 | ], 1598 | "metadata": { 1599 | "colab": { 1600 | "base_uri": "https://localhost:8080/" 1601 | }, 1602 | "id": "FcKQ8b1fMjkV", 1603 | "outputId": "b1597491-3c5f-4a13-8c4d-f34b7fe7fb24" 1604 | }, 1605 | "execution_count": null, 1606 | "outputs": [ 1607 | { 1608 | "output_type": "stream", 1609 | "name": "stderr", 1610 | "text": [ 1611 | "GPU available: False, used: False\n", 1612 | "TPU available: False, using: 0 TPU cores\n", 1613 | "IPU available: False, using: 0 IPUs\n" 1614 | ] 1615 | } 1616 | ] 1617 | }, 1618 | { 1619 | "cell_type": "code", 1620 | "source": [ 1621 | "trainer.fit(model)" 1622 | ], 1623 | "metadata": { 1624 | "colab": { 1625 | "base_uri": "https://localhost:8080/", 1626 | "height": 364, 1627 | "referenced_widgets": [ 1628 | "d906d8b3258a4897bd5301d4feaa91c7", 1629 | "ddb353b0e36747839bebbc72689b29a8", 1630 | "6f0e8c4ca555468fb2d14b929434fc49", 1631 | "7c19031e4b1a4306ae282a6cdf4a0935", 1632 | "960da51a7dc342c4ae478cfa551de087", 1633 | "c26031cf9401466dad958a13dd606dc8", 1634 | "c3ac49e08f2243cf8d0fd5beb631a10f", 1635 | "874dce2b24c6419b8a8f32ab264b4daa", 1636 | "dbf216bbe04a45ba89dc71d9e0632f6e", 1637 | "330e6b856caf4754a302be2f7fd66099", 1638 | "ee28e600a91944cab7a85d06b78add4c", 1639 | "c08f1e54817a4cb399d13f70cc631556", 1640 | "29db7d286e214531afde8496c7eac53b", 1641 | "da67d06f25b341938d96611e7ca9021b", 1642 | "2e774d5dbdf24f70a7c6fcf361111f69", 1643 | "ac2821869fb44bc8afa959c9a9f80bd3", 1644 | "c66197c3cd064d858215136526c36970", 1645 | "a44ca1d1c7e949f1a237253483683f14", 1646 | "85b17c14e13a452f98f84c6b75b00915", 1647 | "0326747e2fce41819917749f6f95b184", 1648 | "516a6e31fa704f9a87f32fd131a6ad2b", 1649 | "6f3262e846fe491daa0dc4c091989888", 1650 | "d0192776eb604b769b232b368d4eab27", 1651 | "c903b21f5d4745388d9b0eca0300ca4c", 1652 | "ff7261f767004966a8d6516f695b6196", 1653 | "93f7cff642bd46fcb860b57856b5ffa5", 1654 | "2aaaeda7133e412fb5ce83f22b86776d", 1655 | "74b70ff5e7ce40bd85a7acb38741df9c", 1656 | "a442f21cc52b4c4ab80f47906230a20c", 1657 | "6d3233c23d394ee3898a9fbcf74da4b4", 1658 | "be81a64971d547e1aefd6829024ca619", 1659 | "7348ab1840f0481781d597b41d390f95", 1660 | "9e96dcb1fd064488b720be58dcc4a1a3" 1661 | ] 1662 | }, 1663 | "id": "ZKqNUh_1MtMf", 1664 | "outputId": "ef88a5d6-ee10-4b41-cda3-5094447ed1fb" 1665 | }, 1666 | "execution_count": null, 1667 | "outputs": [ 1668 | { 1669 | "output_type": "stream", 1670 | "name": "stderr", 1671 | "text": [ 1672 | "\n", 1673 | " | Name | Type | Params\n", 1674 | "------------------------------------------------\n", 1675 | "0 | model | ChronoNet | 133 K \n", 1676 | "1 | acc | Accuracy | 0 \n", 1677 | "2 | creterion | BCEWithLogitsLoss | 0 \n", 1678 | "------------------------------------------------\n", 1679 | "133 K Trainable params\n", 1680 | "0 Non-trainable params\n", 1681 | "133 K Total params\n", 1682 | "0.534 Total estimated model params size (MB)\n" 1683 | ] 1684 | }, 1685 | { 1686 | "output_type": "display_data", 1687 | "data": { 1688 | "application/vnd.jupyter.widget-view+json": { 1689 | "model_id": "d906d8b3258a4897bd5301d4feaa91c7", 1690 | "version_minor": 0, 1691 | "version_major": 2 1692 | }, 1693 | "text/plain": [ 1694 | "Validation sanity check: 0it [00:00, ?it/s]" 1695 | ] 1696 | }, 1697 | "metadata": {} 1698 | }, 1699 | { 1700 | "output_type": "stream", 1701 | "name": "stderr", 1702 | "text": [ 1703 | "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:658: UserWarning: Your `val_dataloader` has `shuffle=True`, it is strongly recommended that you turn this off for val/test/predict dataloaders.\n", 1704 | " category=UserWarning,\n", 1705 | "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:429: UserWarning: The number of training samples (28) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", 1706 | " f\"The number of training samples ({self.num_training_batches}) is smaller than the logging interval\"\n" 1707 | ] 1708 | }, 1709 | { 1710 | "output_type": "stream", 1711 | "name": "stdout", 1712 | "text": [ 1713 | "val acc loss 0.25 0.7\n" 1714 | ] 1715 | }, 1716 | { 1717 | "output_type": "display_data", 1718 | "data": { 1719 | "application/vnd.jupyter.widget-view+json": { 1720 | "model_id": "c08f1e54817a4cb399d13f70cc631556", 1721 | "version_minor": 0, 1722 | "version_major": 2 1723 | }, 1724 | "text/plain": [ 1725 | "Training: 0it [00:00, ?it/s]" 1726 | ] 1727 | }, 1728 | "metadata": {} 1729 | }, 1730 | { 1731 | "output_type": "display_data", 1732 | "data": { 1733 | "application/vnd.jupyter.widget-view+json": { 1734 | "model_id": "d0192776eb604b769b232b368d4eab27", 1735 | "version_minor": 0, 1736 | "version_major": 2 1737 | }, 1738 | "text/plain": [ 1739 | "Validating: 0it [00:00, ?it/s]" 1740 | ] 1741 | }, 1742 | "metadata": {} 1743 | }, 1744 | { 1745 | "output_type": "stream", 1746 | "name": "stdout", 1747 | "text": [ 1748 | "val acc loss 0.32 0.66\n" 1749 | ] 1750 | } 1751 | ] 1752 | }, 1753 | { 1754 | "cell_type": "code", 1755 | "source": [ 1756 | "" 1757 | ], 1758 | "metadata": { 1759 | "id": "ThvwI1CdMvkt" 1760 | }, 1761 | "execution_count": null, 1762 | "outputs": [] 1763 | } 1764 | ] 1765 | } -------------------------------------------------------------------------------- /eeg-conv2d.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"%%capture\n!pip install ssqueezepy\n!pip install timm\n!pip install pytorch-lightning","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2022-05-16T08:11:19.959525Z","iopub.execute_input":"2022-05-16T08:11:19.960224Z","iopub.status.idle":"2022-05-16T08:11:49.676142Z","shell.execute_reply.started":"2022-05-16T08:11:19.960125Z","shell.execute_reply":"2022-05-16T08:11:49.674917Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"main_path='../input/eeg-data-distance-learning-environment'","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:12:38.720944Z","iopub.execute_input":"2022-05-16T08:12:38.721776Z","iopub.status.idle":"2022-05-16T08:12:38.726422Z","shell.execute_reply.started":"2022-05-16T08:12:38.721721Z","shell.execute_reply":"2022-05-16T08:12:38.725477Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import pandas as pd\nimport os\ndf=pd.read_csv(os.path.join(main_path,'EEG_data.csv'))\ndf.head()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:12:51.51902Z","iopub.execute_input":"2022-05-16T08:12:51.519317Z","iopub.status.idle":"2022-05-16T08:12:52.978393Z","shell.execute_reply.started":"2022-05-16T08:12:51.519268Z","shell.execute_reply":"2022-05-16T08:12:52.977704Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"#we need first 16 channels to get raw data 14 channels\n\ncols_remove=df.columns.tolist()[16:-1]\ndf=df.loc[:, ~df.columns.isin(cols_remove)]\ndf.columns = df.columns.str.strip('EEG.')\ndf.head()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:12:58.804973Z","iopub.execute_input":"2022-05-16T08:12:58.805291Z","iopub.status.idle":"2022-05-16T08:12:58.836616Z","shell.execute_reply.started":"2022-05-16T08:12:58.805242Z","shell.execute_reply":"2022-05-16T08:12:58.835548Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"df['subject_understood'].unique()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:12:59.91882Z","iopub.execute_input":"2022-05-16T08:12:59.91948Z","iopub.status.idle":"2022-05-16T08:12:59.930952Z","shell.execute_reply.started":"2022-05-16T08:12:59.919442Z","shell.execute_reply":"2022-05-16T08:12:59.93018Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"#now i need to reshape the data, into subjects,trials,channels,length\n#for that first i will create groups based on subjects\ngroups=df.groupby(['subject_id','video_id'])\ngrp_keys=list(groups.groups.keys())\nprint(grp_keys)\n","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:13:02.184624Z","iopub.execute_input":"2022-05-16T08:13:02.184891Z","iopub.status.idle":"2022-05-16T08:13:02.464364Z","shell.execute_reply.started":"2022-05-16T08:13:02.184864Z","shell.execute_reply":"2022-05-16T08:13:02.463598Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"grpno=grp_keys[0]\ngrp1=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)\nlabel=grp1['subject_understood']\nsubject_id=grpno[0]\ngrp1=grp1.drop('subject_understood',axis=1)\ngrp1.head()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:14:27.785874Z","iopub.execute_input":"2022-05-16T08:14:27.786642Z","iopub.status.idle":"2022-05-16T08:14:27.816566Z","shell.execute_reply.started":"2022-05-16T08:14:27.786603Z","shell.execute_reply":"2022-05-16T08:14:27.815822Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import mne\ndef convertDF2MNE(sub):\n info = mne.create_info(list(sub.columns), ch_types=['eeg'] * len(sub.columns), sfreq=128)\n info.set_montage('standard_1020')\n data=mne.io.RawArray(sub.T, info)\n data.set_eeg_reference()\n #data.filter(l_freq=1,h_freq=30)\n epochs=mne.make_fixed_length_epochs(data,duration=3,overlap=2)\n return epochs.get_data()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:17:18.883824Z","iopub.execute_input":"2022-05-16T08:17:18.884085Z","iopub.status.idle":"2022-05-16T08:17:20.147933Z","shell.execute_reply.started":"2022-05-16T08:17:18.884057Z","shell.execute_reply":"2022-05-16T08:17:20.147208Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test=convertDF2MNE(grp1)\ntest.shape","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:17:27.001782Z","iopub.execute_input":"2022-05-16T08:17:27.002095Z","iopub.status.idle":"2022-05-16T08:17:27.691287Z","shell.execute_reply.started":"2022-05-16T08:17:27.002056Z","shell.execute_reply":"2022-05-16T08:17:27.690352Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"128*3","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:17:42.632414Z","iopub.execute_input":"2022-05-16T08:17:42.63309Z","iopub.status.idle":"2022-05-16T08:17:42.637848Z","shell.execute_reply.started":"2022-05-16T08:17:42.63305Z","shell.execute_reply":"2022-05-16T08:17:42.637137Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"!mkdir scaleogram","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:34:00.692099Z","iopub.execute_input":"2022-05-16T08:34:00.692667Z","iopub.status.idle":"2022-05-16T08:34:01.402097Z","shell.execute_reply.started":"2022-05-16T08:34:00.692628Z","shell.execute_reply":"2022-05-16T08:34:01.401135Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from glob import glob\nimport scipy.io\nimport torch.nn as nn\nimport torch\nimport numpy as np\nimport mne\nfrom ssqueezepy import cwt\nfrom ssqueezepy.visuals import plot, imshow\nimport os\nimport re\nimport pandas as pd","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:34:25.652106Z","iopub.execute_input":"2022-05-16T08:34:25.652434Z","iopub.status.idle":"2022-05-16T08:34:29.424392Z","shell.execute_reply.started":"2022-05-16T08:34:25.652397Z","shell.execute_reply":"2022-05-16T08:34:29.423642Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test[0][0].shape","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:35:43.458182Z","iopub.execute_input":"2022-05-16T08:35:43.458484Z","iopub.status.idle":"2022-05-16T08:35:43.464306Z","shell.execute_reply.started":"2022-05-16T08:35:43.458456Z","shell.execute_reply":"2022-05-16T08:35:43.463531Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"Wx, scales = cwt(test[0], 'morlet')\nWx.shape","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:36:34.126933Z","iopub.execute_input":"2022-05-16T08:36:34.127488Z","iopub.status.idle":"2022-05-16T08:36:34.205525Z","shell.execute_reply.started":"2022-05-16T08:36:34.127446Z","shell.execute_reply":"2022-05-16T08:36:34.204741Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"imshow(Wx[0])","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:37:16.656726Z","iopub.execute_input":"2022-05-16T08:37:16.656991Z","iopub.status.idle":"2022-05-16T08:37:16.853774Z","shell.execute_reply.started":"2022-05-16T08:37:16.656964Z","shell.execute_reply":"2022-05-16T08:37:16.853129Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"%%capture\ngrpnos,labels,paths=[],[],[]\nfor i,grpno in enumerate(grp_keys):\n grp=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)\n label=int(grp['subject_understood'].unique())\n subject_id=grpno[0]\n grp=grp.drop('subject_understood',axis=1)\n data=convertDF2MNE(grp)#(trials, channels, length)\n for c,x in enumerate(data):#loop trials\n Wx, scales = cwt(x, 'morlet')\n Wx=np.abs(Wx)\n path=os.path.join('./scaleogram',f'subvideo_{grpno}/',)\n os.makedirs(path,exist_ok=True)\n path=path+f'trial_{c}.npy'\n np.save(path,Wx)\n \n grpnos.append(i)\n labels.append(label)\n paths.append(path)","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:40:02.693826Z","iopub.execute_input":"2022-05-16T08:40:02.694126Z","iopub.status.idle":"2022-05-16T08:40:39.452262Z","shell.execute_reply.started":"2022-05-16T08:40:02.694094Z","shell.execute_reply":"2022-05-16T08:40:39.451397Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"Wx, scales = cwt(x, 'morlet')\nimshow(Wx[0])","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:41:13.391668Z","iopub.execute_input":"2022-05-16T08:41:13.392143Z","iopub.status.idle":"2022-05-16T08:41:13.645695Z","shell.execute_reply.started":"2022-05-16T08:41:13.392104Z","shell.execute_reply":"2022-05-16T08:41:13.645029Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"df_scale=pd.DataFrame(zip(paths,labels,grpnos),columns=['path','label','group'])\ndf_scale.head()","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:41:18.806165Z","iopub.execute_input":"2022-05-16T08:41:18.806745Z","iopub.status.idle":"2022-05-16T08:41:18.817645Z","shell.execute_reply.started":"2022-05-16T08:41:18.806708Z","shell.execute_reply":"2022-05-16T08:41:18.816854Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import numpy as np\nfrom pytorch_lightning import seed_everything, LightningModule, Trainer\nfrom sklearn.utils import class_weight\nimport torch.nn as nn\nimport torch\nfrom torch.utils.data.dataloader import DataLoader\nfrom pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor\nfrom torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau,CosineAnnealingWarmRestarts,OneCycleLR,CosineAnnealingLR\nimport torchvision\nfrom sklearn.metrics import classification_report,f1_score,accuracy_score,roc_curve,auc,roc_auc_score\nimport matplotlib.pyplot as plt\nimport pandas as pd\nimport numpy as np\nfrom glob import glob\nfrom PIL import Image\nimport cv2\nfrom torch.utils.data import DataLoader, Dataset,ConcatDataset\nimport torchmetrics\nimport timm\nimport random","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:44:55.856487Z","iopub.execute_input":"2022-05-16T08:44:55.857217Z","iopub.status.idle":"2022-05-16T08:45:02.271095Z","shell.execute_reply.started":"2022-05-16T08:44:55.857178Z","shell.execute_reply":"2022-05-16T08:45:02.270334Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"#read data from folders\nclass DataReader(Dataset):\n def __init__(self, dataset,aug=None):\n self.dataset = dataset\n self.aug=aug\n def __getitem__(self, index):\n x=self.dataset.path[index]\n y=self.dataset.label[index]\n x=np.load(x)\n if self.aug:\n if random.uniform(0, 1)>0.5:\n x=np.flip(x,-1)\n if random.uniform(0, 1)>0.5:\n x=np.flip(x,-2)\n # if random.uniform(0, 1)>0.5:\n # c=np.arange(14)\n # np.random.shuffle(c)\n # x=x[c,:,:]\n x=(x - np.min(x)) / (np.max(x) - np.min(x))\n \n return x, y\n \n def __len__(self):\n return len(self.dataset)","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:46:15.725971Z","iopub.execute_input":"2022-05-16T08:46:15.726261Z","iopub.status.idle":"2022-05-16T08:46:15.733754Z","shell.execute_reply.started":"2022-05-16T08:46:15.726229Z","shell.execute_reply":"2022-05-16T08:46:15.733032Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"test_loader=DataLoader(DataReader(df_scale,True), batch_size =8)\ntest_batch=next(iter(test_loader))\ntest_batch[0].shape ,test_batch[1].shape ","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:46:40.046306Z","iopub.execute_input":"2022-05-16T08:46:40.046569Z","iopub.status.idle":"2022-05-16T08:46:40.148911Z","shell.execute_reply.started":"2022-05-16T08:46:40.046542Z","shell.execute_reply":"2022-05-16T08:46:40.14822Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import timm\nclass OurModel(LightningModule):\n def __init__(self,train_split,val_split):\n super(OurModel,self).__init__()\n #architecute\n #lambda resnet\n \n self.train_split=train_split\n self.val_split=val_split\n #########TIMM#################\n model_name='resnest26d'\n self.model = timm.create_model(model_name,pretrained=True)\n self.model.conv1[0]=nn.Conv2d(14, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n \n\n self.fc1=nn.Linear(1000,500)\n self.relu=nn.ReLU()\n self.fc2= nn.Linear(500,250)\n self.fc3= nn.Linear(250,1)\n self.drp=nn.Dropout(0.25)\n #parameters\n self.lr=1e-3\n self.batch_size=16\n self.numworker=2\n self.criterion=nn.BCEWithLogitsLoss()\n self.metrics=torchmetrics.Accuracy()\n\n self.trainloss,self.valloss=[],[]\n self.trainacc,self.valacc=[],[]\n \n self.sub_pred=0\n def forward(self,x):\n x= self.model(x)\n x=self.fc1(x)\n x=self.relu(x)\n x=self.drp(x)\n x=self.fc2(x)\n x=self.relu(x)\n x=self.drp(x)\n x=self.fc3(x)\n return x\n\n def configure_optimizers(self):\n opt=torch.optim.AdamW(params=self.parameters(),lr=self.lr )\n return opt\n \n def train_dataloader(self):\n return DataLoader(DataReader(self.train_split,False), batch_size = self.batch_size, \n num_workers=self.numworker,pin_memory=True,shuffle=True)\n\n def training_step(self,batch,batch_idx):\n image,label=batch\n pred = self(image)\n loss=self.criterion(pred.flatten(),label.float()) #calculate loss\n acc=self.metrics(pred.flatten(),label)#calculate accuracy\n return {'loss':loss,'acc':acc}\n\n def training_epoch_end(self, outputs):\n loss=torch.stack([x[\"loss\"] for x in outputs]).mean().detach().cpu().numpy().round(2)\n acc=torch.stack([x[\"acc\"] for x in outputs]).mean().detach().cpu().numpy().round(2)\n self.trainloss.append(loss)\n self.trainacc.append(acc)\n #print('training acc',acc)\n self.log('train_loss', loss)\n \n def val_dataloader(self):\n ds=DataLoader(DataReader(self.val_split), batch_size = self.batch_size,\n num_workers=self.numworker,pin_memory=True, shuffle=False)\n return ds\n\n def validation_step(self,batch,batch_idx):\n image,label=batch\n pred = self(image)\n loss=self.criterion(pred.flatten(),label.float()) #calculate loss\n acc=self.metrics(pred.flatten(),label)#calculate accuracy\n return {'loss':loss,'acc':acc}\n\n def validation_epoch_end(self, outputs):\n loss=torch.stack([x[\"loss\"] for x in outputs]).mean().detach().cpu().numpy().round(2)\n acc=torch.stack([x[\"acc\"] for x in outputs]).mean().detach().cpu().numpy().round(2)\n self.valloss.append(loss)\n self.valacc.append(acc)\n #print('validation acc',self.current_epoch,acc)\n self.log('val_loss', loss)\n self.log('val_acc', acc)\n \n def test_dataloader(self):\n ds=DataLoader(DataReader(self.val_split), batch_size = self.batch_size,\n num_workers=self.numworker,pin_memory=True, shuffle=False)\n return ds\n def test_step(self,batch,batch_idx):\n image,label=batch\n pred = self(image)\n \n return {'label':label,'pred':pred}\n\n def test_epoch_end(self, outputs):\n\n label=torch.cat([x[\"label\"] for x in outputs])\n pred=torch.cat([x[\"pred\"] for x in outputs])\n acc=self.metrics(pred.flatten(),label)\n pred=pred.detach().cpu().numpy().ravel()\n label=label.detach().cpu().numpy().ravel()\n print('sklearn auc',roc_auc_score(label,pred))\n pred=np.where(pred>0.5,1,0).astype(int)\n print('torch acc',acc)\n print(classification_report(label,pred))\n print('sklearn',accuracy_score(label,pred))\n ","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:50:29.447722Z","iopub.execute_input":"2022-05-16T08:50:29.448004Z","iopub.status.idle":"2022-05-16T08:50:29.472917Z","shell.execute_reply.started":"2022-05-16T08:50:29.447974Z","shell.execute_reply":"2022-05-16T08:50:29.472213Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"from sklearn.model_selection import GroupKFold,LeaveOneGroupOut,StratifiedGroupKFold\ngkf=StratifiedGroupKFold(5)\nresult=[]\nvalacc=[]\nfor train_index, val_index in gkf.split(df_scale.path,df_scale.label, groups=df_scale.group):\n train_df=df_scale.iloc[train_index].reset_index(drop=True)\n val_df=df_scale.iloc[val_index].reset_index(drop=True)\n\n\n lr_monitor = LearningRateMonitor(logging_interval='epoch')\n gpu=-1 if torch.cuda.is_available() else 0\n gpup=16 if torch.cuda.is_available() else 32\n model=OurModel(train_df,val_df)\n trainer = Trainer(max_epochs=20, auto_lr_find=True, auto_scale_batch_size=True,\n deterministic=True,\n gpus=gpu,precision=gpup,\n accumulate_grad_batches=2,\n enable_progress_bar = True,\n num_sanity_val_steps=0,\n callbacks=[lr_monitor],\n \n )\n trainer.fit(model)\n res=trainer.validate(model)\n result.append(res)\n valacc.append(model.valacc)\n trainer.test(model)\n ","metadata":{"execution":{"iopub.status.busy":"2022-05-16T08:50:48.023825Z","iopub.execute_input":"2022-05-16T08:50:48.024059Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"model.batch_size","metadata":{"execution":{"iopub.status.busy":"2022-05-05T00:51:07.10088Z","iopub.status.idle":"2022-05-05T00:51:07.101493Z","shell.execute_reply.started":"2022-05-05T00:51:07.101235Z","shell.execute_reply":"2022-05-05T00:51:07.101262Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"plt.plot(model.trainacc,label='train')\nplt.plot(model.valacc,label='val')\nplt.legend()","metadata":{"execution":{"iopub.status.busy":"2022-05-05T00:43:40.498498Z","iopub.status.idle":"2022-05-05T00:43:40.4991Z","shell.execute_reply.started":"2022-05-05T00:43:40.498857Z","shell.execute_reply":"2022-05-05T00:43:40.498885Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"val_df.label.unique(),val_df.group.unique()","metadata":{"execution":{"iopub.status.busy":"2022-05-05T00:43:40.500239Z","iopub.status.idle":"2022-05-05T00:43:40.500838Z","shell.execute_reply.started":"2022-05-05T00:43:40.500564Z","shell.execute_reply":"2022-05-05T00:43:40.5006Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /eeg_epilepsy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "eeg-epilepsy.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyOshNakVVOZ6KUZLG1Tupsc", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "# Install MNE package" 35 | ], 36 | "metadata": { 37 | "id": "JEUKe9LBPPAC" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "source": [ 43 | "!pip install mne -q" 44 | ], 45 | "metadata": { 46 | "id": "lvebAo_jXlcE" 47 | }, 48 | "execution_count": 1, 49 | "outputs": [] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "source": [ 54 | "# Download data" 55 | ], 56 | "metadata": { 57 | "id": "rl3lgqAxPRtl" 58 | } 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "source": [ 63 | "* Read the [paper](https://www.biorxiv.org/content/10.1101/324954v1 ) to understand data.\n", 64 | "* Datasets were collected using EMOTIVE with the 128 Hz sampling frequency. \n", 65 | "* There are two differnt dataset collected in Guinea Bissau (97 subjects) and Nigeria (112 subjects). \n", 66 | "* Here we are dealing with data collected in Guinea Bissau" 67 | ], 68 | "metadata": { 69 | "id": "O3AwG_JlRVZ-" 70 | } 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": { 76 | "colab": { 77 | "base_uri": "https://localhost:8080/" 78 | }, 79 | "id": "0sw2_4m7UlCq", 80 | "outputId": "eb24ef6e-70e1-4a8e-e13b-4a574705d628" 81 | }, 82 | "outputs": [ 83 | { 84 | "output_type": "stream", 85 | "name": "stdout", 86 | "text": [ 87 | "--2022-08-05 14:19:11-- https://zenodo.org/record/1252141/files/EEGs_Guinea-Bissau.zip\n", 88 | "Resolving zenodo.org (zenodo.org)... 137.138.76.77\n", 89 | "Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.\n", 90 | "HTTP request sent, awaiting response... 200 OK\n", 91 | "Length: 153973086 (147M) [application/octet-stream]\n", 92 | "Saving to: ‘EEGs_Guinea-Bissau.zip.1’\n", 93 | "\n", 94 | "EEGs_Guinea-Bissau. 100%[===================>] 146.84M 6.85MB/s in 81s \n", 95 | "\n", 96 | "2022-08-05 14:20:37 (1.81 MB/s) - ‘EEGs_Guinea-Bissau.zip.1’ saved [153973086/153973086]\n", 97 | "\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "!wget https://zenodo.org/record/1252141/files/EEGs_Guinea-Bissau.zip" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "source": [ 108 | "# Unzip the data" 109 | ], 110 | "metadata": { 111 | "id": "0dTIQ9mbPThA" 112 | } 113 | }, 114 | { 115 | "cell_type": "code", 116 | "source": [ 117 | "#unzip the files\n", 118 | "from zipfile import ZipFile \n", 119 | "data = ZipFile('EEGs_Guinea-Bissau.zip')\n", 120 | "data.extractall()" 121 | ], 122 | "metadata": { 123 | "id": "EREN4IjUUnmm" 124 | }, 125 | "execution_count": 3, 126 | "outputs": [] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "source": [ 131 | "# Read data" 132 | ], 133 | "metadata": { 134 | "id": "pZ3JhqMAPWFe" 135 | } 136 | }, 137 | { 138 | "cell_type": "code", 139 | "source": [ 140 | "import pandas as pd\n", 141 | "import numpy as np\n", 142 | "from matplotlib import pyplot as plt" 143 | ], 144 | "metadata": { 145 | "id": "CYqLPWjrU1pj" 146 | }, 147 | "execution_count": 4, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "source": [ 153 | "meta_df=pd.read_csv('https://zenodo.org/record/1252141/files/metadata_guineabissau.csv')\n", 154 | "meta_df.head()" 155 | ], 156 | "metadata": { 157 | "colab": { 158 | "base_uri": "https://localhost:8080/", 159 | "height": 206 160 | }, 161 | "id": "wsBVGx60VDbY", 162 | "outputId": "bf6ff63b-75cb-4c90-835e-e0fbb09772d4" 163 | }, 164 | "execution_count": 5, 165 | "outputs": [ 166 | { 167 | "output_type": "execute_result", 168 | "data": { 169 | "text/plain": [ 170 | " subject.id Group Eyes.condition \\\n", 171 | "0 1 Epilepsy closed-3min-then-open-2min \n", 172 | "1 2 Control open-3min-then-closed-2min \n", 173 | "2 3 Epilepsy closed-3min-then-open-2min \n", 174 | "3 4 Epilepsy closed-3min-then-open-2min \n", 175 | "4 5 Control closed-3min-then-open-2min \n", 176 | "\n", 177 | " Remarks recordedPeriod \\\n", 178 | "0 by 45s reposition electrodes 301 \n", 179 | "1 NaN 309 \n", 180 | "2 NaN 309 \n", 181 | "3 Green lights not shown, but good EEG traces 299 \n", 182 | "4 NaN 302 \n", 183 | "\n", 184 | " startTime \n", 185 | "0 27/5/2020 14:33 \n", 186 | "1 26/5/2020 22:44 \n", 187 | "2 27/5/2020 14:26 \n", 188 | "3 27/5/2020 15:23 \n", 189 | "4 23/5/2020 19:09 " 190 | ], 191 | "text/html": [ 192 | "\n", 193 | "
\n", 194 | "
\n", 195 | "
\n", 196 | "\n", 209 | "\n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | "
subject.idGroupEyes.conditionRemarksrecordedPeriodstartTime
01Epilepsyclosed-3min-then-open-2minby 45s reposition electrodes30127/5/2020 14:33
12Controlopen-3min-then-closed-2minNaN30926/5/2020 22:44
23Epilepsyclosed-3min-then-open-2minNaN30927/5/2020 14:26
34Epilepsyclosed-3min-then-open-2minGreen lights not shown, but good EEG traces29927/5/2020 15:23
45Controlclosed-3min-then-open-2minNaN30223/5/2020 19:09
\n", 269 | "
\n", 270 | " \n", 280 | " \n", 281 | " \n", 318 | "\n", 319 | " \n", 343 | "
\n", 344 | "
\n", 345 | " " 346 | ] 347 | }, 348 | "metadata": {}, 349 | "execution_count": 5 350 | } 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "source": [ 356 | "#now i need to seprate Epilepsy vs Control subjects\n", 357 | "EP_sub=meta_df['subject.id'][meta_df['Group']=='Epilepsy']\n", 358 | "CT_sub=meta_df['subject.id'][meta_df['Group']=='Control']\n", 359 | "\n" 360 | ], 361 | "metadata": { 362 | "id": "pRRIxnflVJAT" 363 | }, 364 | "execution_count": 6, 365 | "outputs": [] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "source": [ 370 | "#read csv files\n", 371 | "Epilepsy=[pd.read_csv('EEGs_Guinea-Bissau/signal-{}.csv.gz'.format(i), compression='gzip') for i in EP_sub]\n", 372 | "Control=[pd.read_csv('EEGs_Guinea-Bissau/signal-{}.csv.gz'.format(i), compression='gzip') for i in CT_sub]" 373 | ], 374 | "metadata": { 375 | "id": "0HkJWxjtVgF9" 376 | }, 377 | "execution_count": 7, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "source": [ 383 | "Epilepsy[0].head()" 384 | ], 385 | "metadata": { 386 | "colab": { 387 | "base_uri": "https://localhost:8080/", 388 | "height": 317 389 | }, 390 | "id": "C06Vw5PaVwS4", 391 | "outputId": "25abf640-e779-4ab2-ec87-2147fabb7eae" 392 | }, 393 | "execution_count": 8, 394 | "outputs": [ 395 | { 396 | "output_type": "execute_result", 397 | "data": { 398 | "text/plain": [ 399 | " Unnamed: 0 AF3 AF4 F3 F4 \\\n", 400 | "0 1 4426.153846 3994.871795 4408.205128 3847.692308 \n", 401 | "1 2 4420.512821 3986.666667 4394.358974 3836.923077 \n", 402 | "2 3 4413.846154 3986.153846 4386.666667 3831.794872 \n", 403 | "3 4 4407.692308 3984.615385 4384.102564 3832.820513 \n", 404 | "4 5 4407.179487 3978.974359 4382.564103 3832.307692 \n", 405 | "\n", 406 | " F7 F8 FC5 FC6 O1 ... \\\n", 407 | "0 4690.256410 3895.897436 4702.051282 3914.871795 4049.743590 ... \n", 408 | "1 4678.461538 3886.666667 4696.410256 3910.769231 4054.358974 ... \n", 409 | "2 4654.871795 3881.025641 4690.769231 3908.205128 4066.666667 ... \n", 410 | "3 4644.615385 3883.076923 4686.153846 3910.256410 4063.076923 ... \n", 411 | "4 4647.692308 3878.974359 4685.641026 3903.076923 4057.948718 ... \n", 412 | "\n", 413 | " CQ_F3 CQ_P7 CQ_P8 CQ_F4 CQ_AF3 CQ_FC5 CQ_O1 CQ_T8 CQ_F8 CQ_DRL \n", 414 | "0 4 4 4 4 4 4 4 4 4 4 \n", 415 | "1 4 4 4 4 4 4 4 4 4 4 \n", 416 | "2 4 4 4 4 4 4 4 4 4 4 \n", 417 | "3 4 4 4 4 4 4 4 4 4 4 \n", 418 | "4 4 4 4 4 4 4 4 4 4 4 \n", 419 | "\n", 420 | "[5 rows x 36 columns]" 421 | ], 422 | "text/html": [ 423 | "\n", 424 | "
\n", 425 | "
\n", 426 | "
\n", 427 | "\n", 440 | "\n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | "
Unnamed: 0AF3AF4F3F4F7F8FC5FC6O1...CQ_F3CQ_P7CQ_P8CQ_F4CQ_AF3CQ_FC5CQ_O1CQ_T8CQ_F8CQ_DRL
014426.1538463994.8717954408.2051283847.6923084690.2564103895.8974364702.0512823914.8717954049.743590...4444444444
124420.5128213986.6666674394.3589743836.9230774678.4615383886.6666674696.4102563910.7692314054.358974...4444444444
234413.8461543986.1538464386.6666673831.7948724654.8717953881.0256414690.7692313908.2051284066.666667...4444444444
344407.6923083984.6153854384.1025643832.8205134644.6153853883.0769234686.1538463910.2564104063.076923...4444444444
454407.1794873978.9743594382.5641033832.3076924647.6923083878.9743594685.6410263903.0769234057.948718...4444444444
\n", 590 | "

5 rows × 36 columns

\n", 591 | "
\n", 592 | " \n", 602 | " \n", 603 | " \n", 640 | "\n", 641 | " \n", 665 | "
\n", 666 | "
\n", 667 | " " 668 | ] 669 | }, 670 | "metadata": {}, 671 | "execution_count": 8 672 | } 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "source": [ 678 | "#remove non eeg channels\n", 679 | "Epilepsy=[i.iloc[:,1:15] for i in Epilepsy]\n", 680 | "Control=[i.iloc[:,1:15] for i in Control]" 681 | ], 682 | "metadata": { 683 | "id": "aq6cOHyVWJ0-" 684 | }, 685 | "execution_count": 9, 686 | "outputs": [] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "source": [ 691 | "# Convert to MNE object" 692 | ], 693 | "metadata": { 694 | "id": "QOcnawOcPYt7" 695 | } 696 | }, 697 | { 698 | "cell_type": "code", 699 | "source": [ 700 | "import mne\n", 701 | "def convertDF2MNE(sub):\n", 702 | " info = mne.create_info(list(sub.columns), ch_types=['eeg'] * len(sub.columns), sfreq=128)\n", 703 | " info.set_montage('standard_1020')\n", 704 | " data=mne.io.RawArray(sub.T, info)\n", 705 | " data.set_eeg_reference()\n", 706 | " data.filter(l_freq=0.1,h_freq=45)\n", 707 | " epochs=mne.make_fixed_length_epochs(data,duration=5,overlap=1)\n", 708 | " epochs=epochs.drop_bad()\n", 709 | " \n", 710 | " return epochs" 711 | ], 712 | "metadata": { 713 | "id": "nNnQ8nZVWVkF" 714 | }, 715 | "execution_count": 10, 716 | "outputs": [] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "source": [ 721 | "%%capture\n", 722 | "#Convert each dataframe to mne object\n", 723 | "Epilepsy=[convertDF2MNE(i) for i in Epilepsy]\n", 724 | "Control=[convertDF2MNE(i) for i in Control]" 725 | ], 726 | "metadata": { 727 | "id": "1a-tFpz6XaW3" 728 | }, 729 | "execution_count": 11, 730 | "outputs": [] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "source": [ 735 | "%%capture\n", 736 | "#concatenate the epochs\n", 737 | "Epilepsy_epochs=mne.concatenate_epochs(Epilepsy)\n", 738 | "Control_epochs=mne.concatenate_epochs(Control)\n" 739 | ], 740 | "metadata": { 741 | "id": "isk-LSSLBEhj" 742 | }, 743 | "execution_count": 12, 744 | "outputs": [] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "source": [ 749 | "# Create labels and groups" 750 | ], 751 | "metadata": { 752 | "id": "4B4R_ltbPcuT" 753 | } 754 | }, 755 | { 756 | "cell_type": "code", 757 | "source": [ 758 | "Epilepsy_group=np.concatenate([[i]*len(Epilepsy[i]) for i in range(len(Epilepsy))])#create a list of list where each sub list corresponds to subject_no\n", 759 | "Control_group=np.concatenate([[i]*len(Control[i]) for i in range(len(Control))])#create a list of list where each sub list corresponds to subject_no\n", 760 | "\n", 761 | "Epilepsy_label=np.concatenate([[0]*len(Epilepsy[i]) for i in range(len(Epilepsy))])\n", 762 | "Control_label=np.concatenate([[1]*len(Control[i]) for i in range(len(Control))])" 763 | ], 764 | "metadata": { 765 | "id": "FDzPSY159zq5" 766 | }, 767 | "execution_count": 13, 768 | "outputs": [] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "source": [ 773 | "Epilepsy_group.shape,Control_group.shape,Epilepsy_label.shape,Control_label.shape" 774 | ], 775 | "metadata": { 776 | "colab": { 777 | "base_uri": "https://localhost:8080/" 778 | }, 779 | "id": "tMDKoBnBAOqH", 780 | "outputId": "633e31e3-60f4-40b2-a2f2-38e38b7b36fd" 781 | }, 782 | "execution_count": 14, 783 | "outputs": [ 784 | { 785 | "output_type": "execute_result", 786 | "data": { 787 | "text/plain": [ 788 | "((3995,), (3461,), (3995,), (3461,))" 789 | ] 790 | }, 791 | "metadata": {}, 792 | "execution_count": 14 793 | } 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "source": [ 799 | "#combine data\n", 800 | "data=mne.concatenate_epochs([Epilepsy_epochs,Control_epochs])\n", 801 | "group=np.concatenate((Epilepsy_group,Control_group))\n", 802 | "label=np.concatenate((Epilepsy_label,Control_label))\n", 803 | "print(len(data),len(group),len(label))" 804 | ], 805 | "metadata": { 806 | "colab": { 807 | "base_uri": "https://localhost:8080/" 808 | }, 809 | "id": "0qpTA318CONT", 810 | "outputId": "3cf4ae86-cbaf-4fab-b663-920e12558fe2" 811 | }, 812 | "execution_count": 15, 813 | "outputs": [ 814 | { 815 | "output_type": "stream", 816 | "name": "stdout", 817 | "text": [ 818 | "Not setting metadata\n", 819 | "7456 matching events found\n", 820 | "No baseline correction applied\n", 821 | "7456 7456 7456\n" 822 | ] 823 | } 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "source": [ 829 | "# Feature Extraction - Power spectral density\n", 830 | "The power spectral density of a signal is a measure of how much power the signal has at each different frequency. The power spectral density (power spectrum) reflects the ‘frequency content’ of the signal or the distribution of signal power over frequency. \n" 831 | ], 832 | "metadata": { 833 | "id": "SnjeFhVltiRR" 834 | } 835 | }, 836 | { 837 | "cell_type": "code", 838 | "source": [ 839 | "# source: https://mne.tools/stable/auto_tutorials/clinical/60_sleep.html#sphx-glr-auto-tutorials-clinical-60-sleep-py\n", 840 | "from mne.time_frequency import psd_welch\n", 841 | "def eeg_power_band(epochs):\n", 842 | " \"\"\"EEG relative power band feature extraction.\n", 843 | "\n", 844 | " This function takes an ``mne.Epochs`` object and creates EEG features based\n", 845 | " on relative power in specific frequency bands that are compatible with\n", 846 | " scikit-learn.\n", 847 | "\n", 848 | " Parameters\n", 849 | " ----------\n", 850 | " epochs : Epochs\n", 851 | " The data.\n", 852 | "\n", 853 | " Returns\n", 854 | " -------\n", 855 | " X : numpy array of shape [n_samples, 5]\n", 856 | " Transformed data.\n", 857 | " \"\"\"\n", 858 | " # specific frequency bands\n", 859 | " FREQ_BANDS = {\"delta\": [0.5, 4.5],\n", 860 | " \"theta\": [4.5, 8.5],\n", 861 | " \"alpha\": [8.5, 11.5],\n", 862 | " \"sigma\": [11.5, 15.5],\n", 863 | " \"beta\": [15.5, 30],\n", 864 | " \"gamma\": [30, 45],\n", 865 | " }\n", 866 | "\n", 867 | " psds, freqs = psd_welch(epochs, picks='eeg', fmin=0.5, fmax=45)# Compute the PSD using the Welch method\n", 868 | " psds /= np.sum(psds, axis=-1, keepdims=True) # Normalize the PSDs\n", 869 | "\n", 870 | " X = []#For each frequency band, compute the mean PSD in that band\n", 871 | " for fmin, fmax in FREQ_BANDS.values():\n", 872 | " psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)# Compute the mean PSD in each frequency band.\n", 873 | " X.append(psds_band)\n", 874 | "\n", 875 | " return np.concatenate(X, axis=1)#Concatenate the mean PSDs for each band into a single feature vector" 876 | ], 877 | "metadata": { 878 | "id": "dhK4vRC4n-4c" 879 | }, 880 | "execution_count": 16, 881 | "outputs": [] 882 | }, 883 | { 884 | "cell_type": "markdown", 885 | "source": [ 886 | "# Classification 5-fold" 887 | ], 888 | "metadata": { 889 | "id": "0bAKzjyKPh7u" 890 | } 891 | }, 892 | { 893 | "cell_type": "code", 894 | "source": [ 895 | "from sklearn.ensemble import RandomForestClassifier\n", 896 | "from sklearn.model_selection import cross_val_score" 897 | ], 898 | "metadata": { 899 | "id": "dtbouUjgrDbi" 900 | }, 901 | "execution_count": 17, 902 | "outputs": [] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "source": [ 907 | "%%capture\n", 908 | "features=[]\n", 909 | "for d in range(len(data)):#get features from each epoch and save in a list\n", 910 | " features.append(eeg_power_band(data[d]))" 911 | ], 912 | "metadata": { 913 | "id": "8PsOWqLaKn6K" 914 | }, 915 | "execution_count": 18, 916 | "outputs": [] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "source": [ 921 | "#convert list to array\n", 922 | "features=np.concatenate(features)\n", 923 | "features.shape" 924 | ], 925 | "metadata": { 926 | "colab": { 927 | "base_uri": "https://localhost:8080/" 928 | }, 929 | "id": "m59hZOOWM01E", 930 | "outputId": "27ed2e47-b66f-473c-962f-99bce3f478a3" 931 | }, 932 | "execution_count": 19, 933 | "outputs": [ 934 | { 935 | "output_type": "execute_result", 936 | "data": { 937 | "text/plain": [ 938 | "(7456, 84)" 939 | ] 940 | }, 941 | "metadata": {}, 942 | "execution_count": 19 943 | } 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "source": [ 949 | "#do 5 fold cross validation\n", 950 | "clf=RandomForestClassifier()\n", 951 | "accuracies=cross_val_score(clf, features,label,groups=group,cv=5)\n", 952 | "print('Five fold accuracies',accuracies)\n", 953 | "print('Average accuracy',np.mean(accuracies))" 954 | ], 955 | "metadata": { 956 | "id": "UwOOeONDranq", 957 | "colab": { 958 | "base_uri": "https://localhost:8080/" 959 | }, 960 | "outputId": "193f3bf9-2fa1-4bd6-91fd-1b24e13f317f" 961 | }, 962 | "execution_count": 20, 963 | "outputs": [ 964 | { 965 | "output_type": "stream", 966 | "name": "stdout", 967 | "text": [ 968 | "Five fold accuracies [0.71380697 0.69953052 0.63782696 0.73105298 0.69014085]\n", 969 | "Average accuracy 0.6944716556712932\n" 970 | ] 971 | } 972 | ] 973 | }, 974 | { 975 | "cell_type": "markdown", 976 | "source": [ 977 | "# Tips to improve it further\n", 978 | "1. Try different classifier\n", 979 | "2. try tuning the parameters of classifiers\n", 980 | "3. Try feature elimination " 981 | ], 982 | "metadata": { 983 | "id": "eHVFF_cmPBMu" 984 | } 985 | }, 986 | { 987 | "cell_type": "code", 988 | "source": [ 989 | "" 990 | ], 991 | "metadata": { 992 | "id": "aRS5y0zTRm06" 993 | }, 994 | "execution_count": 20, 995 | "outputs": [] 996 | } 997 | ] 998 | } -------------------------------------------------------------------------------- /video_classification_end2end.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 17, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.insert(0, \"../pytorchvideo\")\n", 11 | "from pytorchvideo.data import LabeledVideoDataset,Kinetics, make_clip_sampler\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "#https://www.kaggle.com/datasets/mohamedmustafa/real-life-violence-situations-dataset" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": { 27 | "execution": { 28 | "iopub.execute_input": "2022-09-15T20:18:04.995677Z", 29 | "iopub.status.busy": "2022-09-15T20:18:04.995237Z", 30 | "iopub.status.idle": "2022-09-15T20:18:06.789586Z", 31 | "shell.execute_reply": "2022-09-15T20:18:06.788360Z", 32 | "shell.execute_reply.started": "2022-09-15T20:18:04.995639Z" 33 | } 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import numpy as np\n", 38 | "from pytorch_lightning import seed_everything, LightningModule, Trainer\n", 39 | "import torch.nn as nn\n", 40 | "import torch\n", 41 | "from torch.utils.data.dataloader import DataLoader\n", 42 | "from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor\n", 43 | "from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau,CosineAnnealingWarmRestarts,OneCycleLR,CosineAnnealingLR\n", 44 | "import torchvision\n", 45 | "import pandas as pd\n", 46 | "import numpy as np\n", 47 | "from glob import glob\n", 48 | "from PIL import Image\n", 49 | "import cv2\n", 50 | "import os\n", 51 | "from torch.utils.data import DataLoader, Dataset,ConcatDataset,default_collate\n", 52 | "from sklearn.model_selection import KFold,GroupShuffleSplit,GroupKFold,LeaveOneGroupOut\n", 53 | "from torchmetrics import MeanAbsoluteError\n", 54 | "from sklearn.utils import shuffle\n", 55 | "import shutil\n", 56 | "from sklearn.model_selection import train_test_split\n", 57 | "from torchaudio import transforms as TA\n", 58 | "from sklearn.metrics import classification_report\n", 59 | "import torchmetrics" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": { 66 | "execution": { 67 | "iopub.execute_input": "2022-09-15T20:18:06.794066Z", 68 | "iopub.status.busy": "2022-09-15T20:18:06.792689Z", 69 | "iopub.status.idle": "2022-09-15T20:18:06.803825Z", 70 | "shell.execute_reply": "2022-09-15T20:18:06.802108Z", 71 | "shell.execute_reply.started": "2022-09-15T20:18:06.794021Z" 72 | } 73 | }, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "'1.11.0+cu113'" 79 | ] 80 | }, 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "import torch\n", 88 | "torch.__version__" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "non=glob('NonViolence/*')\n", 98 | "vio=glob('Violence/*')\n", 99 | "label=[0]*len(non)+[1]*len(vio)\n", 100 | "df=pd.DataFrame(zip(non+vio,label),columns=['file','label'])" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "train_df,val_df=train_test_split(df,test_size=0.25)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 7, 115 | "metadata": { 116 | "execution": { 117 | "iopub.execute_input": "2022-09-15T20:18:06.872186Z", 118 | "iopub.status.busy": "2022-09-15T20:18:06.871763Z", 119 | "iopub.status.idle": "2022-09-15T20:18:06.889484Z", 120 | "shell.execute_reply": "2022-09-15T20:18:06.888400Z", 121 | "shell.execute_reply.started": "2022-09-15T20:18:06.872133Z" 122 | } 123 | }, 124 | "outputs": [ 125 | { 126 | "name": "stderr", 127 | "output_type": "stream", 128 | "text": [ 129 | "/home/talha/venv/lib/python3.8/site-packages/torchvision/transforms/_functional_video.py:6: UserWarning: The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in 0.14. Please use the 'torchvision.transforms.functional' module instead.\n", 130 | " warnings.warn(\n", 131 | "/home/talha/venv/lib/python3.8/site-packages/torchvision/transforms/_transforms_video.py:25: UserWarning: The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in 0.14. Please use the 'torchvision.transforms' module instead.\n", 132 | " warnings.warn(\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "from pytorchvideo.data import LabeledVideoDataset,Kinetics, make_clip_sampler\n", 138 | "\n", 139 | "from pytorchvideo.transforms import (\n", 140 | " ApplyTransformToKey,\n", 141 | " Normalize,\n", 142 | " RandomShortSideScale,\n", 143 | "# RemoveKey,\n", 144 | "# ShortSideScale,\n", 145 | " UniformTemporalSubsample,\n", 146 | " Permute\n", 147 | ")\n", 148 | "\n", 149 | "from torchvision.transforms import (\n", 150 | " Compose,\n", 151 | " Lambda,\n", 152 | " RandomCrop,\n", 153 | " RandomHorizontalFlip,\n", 154 | " Resize\n", 155 | ")\n", 156 | "\n", 157 | "from torchvision.transforms._transforms_video import (\n", 158 | " CenterCropVideo,\n", 159 | " NormalizeVideo,\n", 160 | ")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 8, 166 | "metadata": { 167 | "execution": { 168 | "iopub.execute_input": "2022-09-15T20:18:06.892088Z", 169 | "iopub.status.busy": "2022-09-15T20:18:06.891199Z", 170 | "iopub.status.idle": "2022-09-15T20:18:06.897433Z", 171 | "shell.execute_reply": "2022-09-15T20:18:06.896290Z", 172 | "shell.execute_reply.started": "2022-09-15T20:18:06.892051Z" 173 | } 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "#tuneable params\n", 178 | "num_video_samples=20\n", 179 | "video_duration=2\n", 180 | "model_name='efficient_x3d_xs'\n", 181 | "batch_size=8\n", 182 | "scheduler='cosine'\n", 183 | "clipmode='random'\n", 184 | "img_size=224" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": { 191 | "execution": { 192 | "iopub.execute_input": "2022-09-15T20:18:06.901622Z", 193 | "iopub.status.busy": "2022-09-15T20:18:06.900753Z", 194 | "iopub.status.idle": "2022-09-15T20:18:06.910916Z", 195 | "shell.execute_reply": "2022-09-15T20:18:06.909562Z", 196 | "shell.execute_reply.started": "2022-09-15T20:18:06.901576Z" 197 | } 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler,labeled_video_dataset\n", 202 | "from torchvision.transforms import ColorJitter,RandomAdjustSharpness,RandomAutocontrast\n", 203 | "video_transform = Compose(\n", 204 | " [\n", 205 | " ApplyTransformToKey(\n", 206 | " key=\"video\",\n", 207 | " transform=Compose(\n", 208 | " [\n", 209 | " UniformTemporalSubsample(num_video_samples),\n", 210 | " Lambda(lambda x: x / 255.0),\n", 211 | " Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),\n", 212 | " #Determines the shorter spatial dim of the video (i.e. width or height) and scales it to the given size\n", 213 | " RandomShortSideScale(min_size=img_size+16, max_size=img_size+32),\n", 214 | " CenterCropVideo(img_size),\n", 215 | " RandomHorizontalFlip(p=0.5),\n", 216 | " ]\n", 217 | " ),\n", 218 | " ),\n", 219 | " ]\n", 220 | " )\n", 221 | "\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": { 228 | "execution": { 229 | "iopub.execute_input": "2022-09-15T20:18:06.924208Z", 230 | "iopub.status.busy": "2022-09-15T20:18:06.923758Z", 231 | "iopub.status.idle": "2022-09-15T20:18:06.958966Z", 232 | "shell.execute_reply": "2022-09-15T20:18:06.955103Z", 233 | "shell.execute_reply.started": "2022-09-15T20:18:06.924171Z" 234 | } 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "train_dataset=labeled_video_dataset(val_df,\n", 239 | " clip_sampler=make_clip_sampler(clipmode, video_duration),\\\n", 240 | " transform=video_transform, decode_audio=False\n", 241 | " )\n", 242 | " \n", 243 | "train_loader=DataLoader(train_dataset,batch_size=4,\n", 244 | " num_workers=0,\n", 245 | " pin_memory=True)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 11, 251 | "metadata": { 252 | "execution": { 253 | "iopub.execute_input": "2022-09-15T20:18:06.971415Z", 254 | "iopub.status.busy": "2022-09-15T20:18:06.970920Z", 255 | "iopub.status.idle": "2022-09-15T20:18:12.962845Z", 256 | "shell.execute_reply": "2022-09-15T20:18:12.961746Z", 257 | "shell.execute_reply.started": "2022-09-15T20:18:06.971365Z" 258 | } 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "batch=next(iter(train_loader))" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 12, 268 | "metadata": { 269 | "execution": { 270 | "iopub.execute_input": "2022-09-15T20:18:12.964898Z", 271 | "iopub.status.busy": "2022-09-15T20:18:12.964470Z", 272 | "iopub.status.idle": "2022-09-15T20:18:12.972822Z", 273 | "shell.execute_reply": "2022-09-15T20:18:12.971638Z", 274 | "shell.execute_reply.started": "2022-09-15T20:18:12.964831Z" 275 | } 276 | }, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "(torch.Size([4, 3, 20, 224, 224]), torch.Size([4, 1]))" 282 | ] 283 | }, 284 | "execution_count": 12, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "batch['video'].shape,batch['label'].shape" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 13, 296 | "metadata": { 297 | "execution": { 298 | "iopub.execute_input": "2022-09-15T20:18:12.975981Z", 299 | "iopub.status.busy": "2022-09-15T20:18:12.975072Z", 300 | "iopub.status.idle": "2022-09-15T20:18:13.003089Z", 301 | "shell.execute_reply": "2022-09-15T20:18:13.001856Z", 302 | "shell.execute_reply.started": "2022-09-15T20:18:12.975944Z" 303 | } 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "from time import sleep\n", 308 | "import torchvision.models as models\n", 309 | "import timm\n", 310 | "class OurModel(LightningModule):\n", 311 | " def __init__(self):\n", 312 | " super(OurModel,self).__init__()\n", 313 | "\n", 314 | " self.scheduler=scheduler\n", 315 | "\n", 316 | " \n", 317 | " self.video_model = torch.hub.load('facebookresearch/pytorchvideo', model_name, pretrained=True)\n", 318 | " self.video_model.projection.model=nn.Linear(in_features=2048, out_features=1000, bias=True)\n", 319 | " \n", 320 | " \n", 321 | " self.relu=nn.ReLU()\n", 322 | " self.linear=nn.Linear(1000,1)\n", 323 | " \n", 324 | " self.lr=1e-3\n", 325 | " self.batch_size=batch_size\n", 326 | " self.numworker=6\n", 327 | " \n", 328 | " self.metric = torchmetrics.Accuracy()\n", 329 | " self.criterion=nn.BCEWithLogitsLoss()\n", 330 | " \n", 331 | " def forward(self,video):\n", 332 | " x=self.video_model(video)\n", 333 | " x=self.relu(x)\n", 334 | " x=self.linear(x)\n", 335 | " return x\n", 336 | "\n", 337 | " def configure_optimizers(self):\n", 338 | " opt=torch.optim.AdamW(params=self.parameters(),lr=self.lr )\n", 339 | " if self.scheduler=='cosine':\n", 340 | " scheduler=CosineAnnealingLR(opt,T_max=10, eta_min=1e-6, last_epoch=-1)\n", 341 | " return {'optimizer': opt,'lr_scheduler':scheduler}\n", 342 | " elif self.scheduler=='reduce':\n", 343 | " scheduler=ReduceLROnPlateau(opt,mode='min', factor=0.5, patience=5)\n", 344 | " return {'optimizer': opt,'lr_scheduler':scheduler,'monitor':'val_loss'}\n", 345 | " elif self.scheduler=='warm':\n", 346 | " scheduler=CosineAnnealingWarmRestarts(opt,T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)\n", 347 | " return {'optimizer': opt,'lr_scheduler':scheduler}\n", 348 | " elif self.scheduler=='cycle':\n", 349 | " opt=torch.optim.AdamW(params=self.parameters(),lr=1e-6 )\n", 350 | " scheduler=OneCycleLR(opt,max_lr=1e-2,epochs=15,steps_per_epoch=len(self.train_df)//self.batch_size//4)\n", 351 | " lr_scheduler = {'scheduler': scheduler, 'interval': 'step'}\n", 352 | " return {'optimizer': opt, 'lr_scheduler': lr_scheduler}\n", 353 | " elif self.scheduler=='lambda':\n", 354 | " lambda1 = lambda epoch: 0.9 ** epoch\n", 355 | " scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)\n", 356 | " return {'optimizer': opt, 'lr_scheduler': scheduler}\n", 357 | " elif self.scheduler=='constant':\n", 358 | " return opt\n", 359 | " \n", 360 | " def train_dataloader(self):\n", 361 | " dataset=labeled_video_dataset(train_df,\n", 362 | " clip_sampler=make_clip_sampler(clipmode, video_duration),\\\n", 363 | " transform=video_transform, decode_audio=False)\n", 364 | " \n", 365 | " loader=DataLoader(dataset,batch_size=self.batch_size,\n", 366 | " num_workers=self.numworker,\n", 367 | " pin_memory=True)\n", 368 | " return loader\n", 369 | "\n", 370 | " def training_step(self,batch,batch_idx):\n", 371 | " video,label=batch['video'],batch['label']\n", 372 | "# label=label.ravel().to(torch.int64)\n", 373 | " out = self(video)\n", 374 | " loss=self.criterion(out,label)\n", 375 | " metric=self.metric(out,label.to(torch.int64))\n", 376 | " return {'loss':loss,'metric':metric.detach()}\n", 377 | "\n", 378 | " def training_epoch_end(self, outputs):\n", 379 | " loss=torch.stack([x[\"loss\"] for x in outputs]).mean().cpu().numpy().round(2)\n", 380 | " metric=torch.stack([x[\"metric\"] for x in outputs]).mean().cpu().numpy().round(2)\n", 381 | " self.log('train_loss', loss,batch_size=self.batch_size)\n", 382 | " self.log('train_metric', metric,batch_size=self.batch_size)\n", 383 | " print('training loss ',self.current_epoch,loss,metric)\n", 384 | " \n", 385 | " def val_dataloader(self):\n", 386 | " dataset=labeled_video_dataset(val_df,\n", 387 | " clip_sampler=make_clip_sampler(clipmode, video_duration),\\\n", 388 | " transform=video_transform, decode_audio=False)\n", 389 | " \n", 390 | " loader=DataLoader(dataset,batch_size=self.batch_size,\n", 391 | " num_workers=self.numworker,\n", 392 | " pin_memory=True)\n", 393 | " return loader\n", 394 | " \n", 395 | " def validation_step(self,batch,batch_idx):\n", 396 | " video,label=batch['video'],batch['label']\n", 397 | " out = self(video)\n", 398 | " loss=self.criterion(out,label)\n", 399 | " metric=self.metric(out,label.to(torch.int64))\n", 400 | " return {'loss':loss,'metric':metric.detach()}\n", 401 | "\n", 402 | " def validation_epoch_end(self, outputs):\n", 403 | " loss=torch.stack([x[\"loss\"] for x in outputs]).mean().cpu().numpy().round(2)\n", 404 | " metric=torch.stack([x[\"metric\"] for x in outputs]).mean().cpu().numpy().round(2)\n", 405 | " print('validation loss ',self.current_epoch,loss,metric)\n", 406 | " self.log('val_loss', loss,batch_size=self.batch_size)\n", 407 | " self.log('val_metric',metric,batch_size=self.batch_size)\n", 408 | " \n", 409 | " def test_dataloader(self):\n", 410 | " dataset=labeled_video_dataset(val_df,\n", 411 | " clip_sampler=make_clip_sampler(clipmode, video_duration),\\\n", 412 | " transform=video_transform, decode_audio=False)\n", 413 | " \n", 414 | " loader=DataLoader(dataset,batch_size=self.batch_size,\n", 415 | " num_workers=self.numworker,\n", 416 | " pin_memory=True)\n", 417 | " return loader\n", 418 | "\n", 419 | " \n", 420 | " def test_step(self, batch, batch_idx):\n", 421 | " video,label=batch['video'],batch['label']\n", 422 | " out = self(video)\n", 423 | " return { 'label': label.detach(), 'pred': out.detach()}\n", 424 | "\n", 425 | " def test_epoch_end(self, outputs):\n", 426 | " label = torch.cat([x['label'] for x in outputs]).cpu().numpy()\n", 427 | " pred = torch.cat([x['pred'] for x in outputs]).cpu().numpy()\n", 428 | " pred=np.where(pred>0.5,1,0)\n", 429 | " print(classification_report(label, pred))\n" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "execution": { 437 | "iopub.status.busy": "2022-09-15T20:18:13.004769Z" 438 | } 439 | }, 440 | "outputs": [], 441 | "source": [] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 14, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='checkpoints',\n", 450 | " filename='file',save_last=True)\n", 451 | "lr_monitor = LearningRateMonitor(logging_interval='epoch')\n" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "import wandb\n", 461 | "wandb.login()\n", 462 | "from pytorch_lightning.loggers import WandbLogger\n", 463 | "wandb_logger = WandbLogger(project=\"violence\")" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 15, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stderr", 473 | "output_type": "stream", 474 | "text": [ 475 | "Using cache found in /home/talha/.cache/torch/hub/facebookresearch_pytorchvideo_main\n", 476 | "Global seed set to 0\n", 477 | "Using 16bit native Automatic Mixed Precision (AMP)\n", 478 | "GPU available: True, used: True\n", 479 | "TPU available: False, using: 0 TPU cores\n", 480 | "IPU available: False, using: 0 IPUs\n", 481 | "HPU available: False, using: 0 HPUs\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "model=OurModel()\n", 487 | "seed_everything(0)\n", 488 | "trainer = Trainer(max_epochs=30, \n", 489 | "# deterministic=True,\n", 490 | " accelerator='gpu', devices=-1,\n", 491 | " precision=16,\n", 492 | " accumulate_grad_batches=2,\n", 493 | " enable_progress_bar = False,\n", 494 | " num_sanity_val_steps=0,\n", 495 | " callbacks=[lr_monitor,checkpoint_callback],\n", 496 | "# limit_train_batches=5,\n", 497 | "# limit_val_batches=1,\n", 498 | "# logger=wandb_logger\n", 499 | "\n", 500 | " )" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 16, 506 | "metadata": {}, 507 | "outputs": [ 508 | { 509 | "name": "stderr", 510 | "output_type": "stream", 511 | "text": [ 512 | "Missing logger folder: /media/talha/data/image/classification/video_classification/Real Life Violence Dataset/lightning_logs\n", 513 | "/home/talha/venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:611: UserWarning: Checkpoint directory /media/talha/data/image/classification/video_classification/Real Life Violence Dataset/checkpoints exists and is not empty.\n", 514 | " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", 515 | "Restoring states from the checkpoint path at checkpoints/last.ckpt\n", 516 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 517 | "\n", 518 | " | Name | Type | Params\n", 519 | "--------------------------------------------------\n", 520 | "0 | video_model | EfficientX3d | 5.0 M \n", 521 | "1 | relu | ReLU | 0 \n", 522 | "2 | linear | Linear | 1.0 K \n", 523 | "3 | metric | Accuracy | 0 \n", 524 | "4 | criterion | BCEWithLogitsLoss | 0 \n", 525 | "--------------------------------------------------\n", 526 | "5.0 M Trainable params\n", 527 | "0 Non-trainable params\n", 528 | "5.0 M Total params\n", 529 | "10.049 Total estimated model params size (MB)\n", 530 | "Restored all states from the checkpoint file at checkpoints/last.ckpt\n", 531 | "2022-09-26 15:57:11.460252: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n" 532 | ] 533 | } 534 | ], 535 | "source": [ 536 | "trainer.fit(model,\n", 537 | "# ckpt_path='checkpoints/last.ckpt'\n", 538 | " )" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "val_res=trainer.validate(model)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": null, 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "trainer.test(model)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "metadata": {}, 563 | "outputs": [], 564 | "source": [ 565 | "wandb_logger.experiment.save('notebook.ipynb')" 566 | ] 567 | } 568 | ], 569 | "metadata": { 570 | "kernelspec": { 571 | "display_name": "Python 3 (ipykernel)", 572 | "language": "python", 573 | "name": "python3" 574 | }, 575 | "language_info": { 576 | "codemirror_mode": { 577 | "name": "ipython", 578 | "version": 3 579 | }, 580 | "file_extension": ".py", 581 | "mimetype": "text/x-python", 582 | "name": "python", 583 | "nbconvert_exporter": "python", 584 | "pygments_lexer": "ipython3", 585 | "version": "3.8.10" 586 | } 587 | }, 588 | "nbformat": 4, 589 | "nbformat_minor": 4 590 | } 591 | --------------------------------------------------------------------------------