├── LICENSE ├── README.rst ├── examples ├── CNN_STFT Example.ipynb ├── RCNN.ipynb └── utils.py ├── models ├── __init__.py ├── cnn_stft.py ├── deep_cnn.py ├── lstm.py ├── lstm_stft.py ├── model.py ├── rcnn.py ├── shallow_cnn.py ├── utils.py └── vanilla_rnn.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 The gumpy developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Deep Learning models for Brain Computer Interfaces 2 | ================================================= 3 | 4 | This repository contains deep learning models that can be used to decode EEG and 5 | EEG signals for brain computer interfaces (BCIs). Some of the models depend on 6 | the functionality that is provided by ``gumpy``, a python toolbox which contains 7 | several signal and feature processing routines that are commonly used for BCIs. 8 | 9 | :license: MIT License 10 | :contributions: Please use github (www.github.com/gumpy-bci/gumpy-deeplearning) and see below 11 | :issues: Please use the issue tracker on github (www.github.com/gumpy-bci/gumpy-deeplearning/issues) 12 | 13 | 14 | Documentation 15 | ============= 16 | 17 | You can find additional documentation for gumpy on www.gumpy.org. 18 | 19 | 20 | Contributing 21 | ============ 22 | 23 | If you wish to contribute to gumpy's development clone one of gumpy's repository 24 | from github and start coding, test if everything works as expected, and finally 25 | submit patches or open merge requests. Preferrably in this order. 26 | 27 | Please make sure that you follow PEP8, or have a look at the formatting of 28 | gumpy's code, and include proper documentation both in your commit messages as 29 | well as the source code. We use Google docstrings for formatting, and 30 | auto-generate parts of the documentation with sphinx. 31 | 32 | 33 | gumpy core developers and contributors 34 | ====================================== 35 | * Zied Tayeb 36 | * Nicolai Waniek, www.github.com/rochus 37 | * Nejla Ghaboosi 38 | * Juri Fedjaev 39 | * Leonard Rychly 40 | 41 | 42 | How to cite gumpy 43 | ================= 44 | 45 | Zied Tayeb, Nicolai Waniek, Juri Fedjaev, Nejla Ghaboosi, Leonard Rychly, 46 | Christian Widderich, Christoph Richter, Jonas Braun, Matteo Saveriano, Gordon 47 | Cheng, and Jörg Conradt. "gumpy: A Python Toolbox Suitable for Hybrid 48 | Brain-Computer Interfaces" 49 | 50 | 51 | .. code:: latex 52 | 53 | @Article{gumpy2018, 54 | Title = {gumpy: A Python Toolbox Suitable for Hybrid Brain-Computer Interfaces}, 55 | Author = {Tayeb, Zied and Waniek, Nicolai and Fedjaev, Juri and Ghaboosi, Nejla and Rychly, Leonard and Widderich, Christian and Richter, Christoph and Braun, Jonas and Saveriano, Matteo and Cheng, Gordon and Conradt, Jorg}, 56 | Year = {2018}, 57 | Journal = {} 58 | } 59 | 60 | 61 | Additional References 62 | ===================== 63 | 64 | * www.gumpy.org: gumpy's main website. You can find links to datasets here 65 | * www.github.com/gumpy-bci/gumpy: gumpy's main github repository 66 | * www.github.com/gumpy-bci/gumpy-deeplearning: gumpy's deep learning models for BCI 67 | * https://github.com/gumpy-bci/gumpy-realtime : gumpy's real-time BCI module with several online demos 68 | * https://www.youtube.com/channel/UCdarvfot4Ustk2UCmCp62sw : gumpy's Youtube channel 69 | * https://www.youtube.com/watch?v=M68GeL8PafE 70 | 71 | 72 | License 73 | ======= 74 | 75 | * All code in this repository is published under the MIT License. 76 | For more details see the LICENSE file. 77 | 78 | 79 | -------------------------------------------------------------------------------- /examples/CNN_STFT Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ConvNet Architecture for Decoding EEG MI Data using Spectrogram Representations" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Preparation\n", 15 | "\n", 16 | "In case that gumpy is not installed as a module, we need to specify the path to ``gumpy``. In addition, we wish to configure jupyter notebooks and any backend properly. Note that it may take some time for ``gumpy`` to load due to the number of dependencies" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from __future__ import print_function\n", 26 | "import os; os.environ[\"THEANO_FLAGS\"] = \"device=gpu0\"\n", 27 | "import os.path\n", 28 | "from datetime import datetime\n", 29 | "import sys\n", 30 | "sys.path.append('../../gumpy')\n", 31 | "\n", 32 | "import gumpy\n", 33 | "import numpy as np\n", 34 | "import scipy.io\n", 35 | "import matplotlib.pyplot as plt\n", 36 | "%matplotlib inline" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "\n", 44 | "To use the models provided by `gumpy-deeplearning`, we have to set the path to the models directory and import it. If you installed `gumpy-deeplearning` as a module, this step may not be required." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "sys.path.append('..')\n", 54 | "import models" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Utility functions\n", 62 | "\n", 63 | "The examples for ``gumpy-deeplearning`` ship with a few tiny helper functions. For instance, there's one that tells you the versions of the currently installed keras and kapre. ``keras`` is required in ``gumpy-deeplearning``, while ``kapre`` \n", 64 | "can be used to compute spectrograms.\n", 65 | "\n", 66 | "In addition, the utility functions contain a method ``load_preprocess_data`` to load and preprocess data. Its usage will be shown further below" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "import utils\n", 76 | "utils.print_version_info()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Setup parameters for the model and data\n", 84 | "Before we jump into the processing, we first wish to specify some parameters (e.g. frequencies) that we know from the data." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "DEBUG = True\n", 94 | "CLASS_COUNT = 2\n", 95 | "DROPOUT = 0.2 # dropout rate in float\n", 96 | "\n", 97 | "# parameters for filtering data\n", 98 | "FS = 250\n", 99 | "LOWCUT = 2\n", 100 | "HIGHCUT = 60\n", 101 | "ANTI_DRIFT = 0.5\n", 102 | "CUTOFF = 50.0 # freq to be removed from signal (Hz) for notch filter\n", 103 | "Q = 30.0 # quality factor for notch filter \n", 104 | "W0 = CUTOFF/(FS/2)\n", 105 | "AXIS = 0\n", 106 | "\n", 107 | "#set random seed\n", 108 | "SEED = 42\n", 109 | "KFOLD = 5" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## Load raw data\n", 117 | "\n", 118 | "Before training and testing a model, we need some data. The following code shows how to load a dataset using ``gumpy``." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# specify the location of the GrazB datasets\n", 128 | "data_dir = '../../Data/Graz'\n", 129 | "subject = 'B01'\n", 130 | "\n", 131 | "# initialize the data-structure, but do _not_ load the data yet\n", 132 | "grazb_data = gumpy.data.GrazB(data_dir, subject)\n", 133 | "\n", 134 | "# now that the dataset is setup, we can load the data. This will be handled from within the utils function, \n", 135 | "# which will first load the data and subsequently filter it using a notch and a bandpass filter.\n", 136 | "# the utility function will then return the training data.\n", 137 | "x_train, y_train = utils.load_preprocess_data(grazb_data, True, LOWCUT, HIGHCUT, W0, Q, ANTI_DRIFT, CLASS_COUNT, CUTOFF, AXIS, FS)\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Augment data" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "x_augmented, y_augmented = gumpy.signal.sliding_window(data = x_train[:,:,:],\n", 154 | " labels = y_train[:,:],\n", 155 | " window_sz = 4 * FS,\n", 156 | " n_hop = FS // 10,\n", 157 | " n_start = FS * 1)\n", 158 | "x_subject = x_augmented\n", 159 | "y_subject = y_augmented\n", 160 | "x_subject = np.rollaxis(x_subject, 2, 1)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Run the model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "from sklearn.model_selection import StratifiedKFold\n", 177 | "from models import CNN_STFT\n", 178 | "\n", 179 | "# define KFOLD-fold cross validation test harness\n", 180 | "kfold = StratifiedKFold(n_splits = KFOLD, shuffle = True, random_state = SEED)\n", 181 | "cvscores = []\n", 182 | "ii = 1\n", 183 | "for train, test in kfold.split(x_subject, y_subject[:, 0]):\n", 184 | " print('Run ' + str(ii) + '...')\n", 185 | " # create callbacks\n", 186 | " model_name_str = 'GRAZ_CNN_STFT_3layer_' + \\\n", 187 | " '_run_' + str(ii)\n", 188 | " callbacks_list = model.get_callbacks(model_name_str)\n", 189 | "\n", 190 | " # initialize and create the model\n", 191 | " model = CNN_STFT(model_name_str)\n", 192 | " model.create_model(x_subject.shape[1:], dropout = DROPOUT, print_summary = False)\n", 193 | " \n", 194 | " # fit model. If you specify monitor=True, then the model will create callbacks\n", 195 | " # and write its state to a HDF5 file\n", 196 | " model.fit(x_subject[train], y_subject[train], monitor=True,\n", 197 | " epochs = 100, \n", 198 | " batch_size = 256, \n", 199 | " verbose = 0, \n", 200 | " validation_split = 0.1, callbacks = callbacks_list)\n", 201 | "\n", 202 | " # evaluate the model\n", 203 | " print('Evaluating model on test set...')\n", 204 | " scores = model.evaluate(x_subject[test], y_subject[test], verbose = 0)\n", 205 | " print(\"Result on test set: %s: %.2f%%\" % (model.metrics_names[1], scores[1] * 100))\n", 206 | " cvscores.append(scores[1] * 100)\n", 207 | " ii += 1\n", 208 | " \n", 209 | "# print some evaluation statistics and write results to file\n", 210 | "print(\"%.2f%% (+/- %.2f%%)\" % (np.mean(cvscores), np.std(cvscores)))\n", 211 | "cv_all_subjects = np.asarray(cvscores)\n", 212 | "print('Saving CV values to file....')\n", 213 | "np.savetxt('GRAZ_CV_' + 'CNN_STFT_3layer_' + str(DROPOUT) + 'do'+'.csv', \n", 214 | " cv_all_subjects, delimiter = ',', fmt = '%2.4f')\n", 215 | "print('CV values successfully saved!\\n')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "# Load the trained model " 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "model.save('CNN_STFTmonitoring.h5') # creates a HDF5 file 'my_model.h5'\n", 232 | "model2 = load_model('CNN_STFTmonitoring.h5', \n", 233 | " custom_objects={'Spectrogram': kapre.time_frequency.Spectrogram, \n", 234 | " 'Normalization2D': kapre.utils.Normalization2D})" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "# New predictions " 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# Method 1 for predictions using predict \n", 251 | "y_pred = model2.predict(X_test,batch_size=64,verbose=1)\n", 252 | "Y_pred = np.argmax(y_pred,axis=1)\n", 253 | "Y_test = np.argmax(Y_test,axis=1)\n", 254 | "accuracy = (len(Y_test) - np.count_nonzero(Y_pred - Y_test) + 0.0)/len(Y_test)\n", 255 | "print(accuracy)\n", 256 | "\n", 257 | "\n", 258 | "# Method 1 for predictions using evaluate (only print the accuracy on the test data)\n", 259 | "score, acc = model2.evaluate(X_test, Y_test, batch_size=64)\n", 260 | "print('\\nTest score:', score)\n", 261 | "print('Test accuracy:', acc)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [] 270 | } 271 | ], 272 | "metadata": { 273 | "kernelspec": { 274 | "display_name": "Python 3", 275 | "language": "python", 276 | "name": "python3" 277 | }, 278 | "language_info": { 279 | "codemirror_mode": { 280 | "name": "ipython", 281 | "version": 3 282 | }, 283 | "file_extension": ".py", 284 | "mimetype": "text/x-python", 285 | "name": "python", 286 | "nbconvert_exporter": "python", 287 | "pygments_lexer": "ipython3", 288 | "version": "3.6.6" 289 | } 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 2 293 | } 294 | -------------------------------------------------------------------------------- /examples/RCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# R-CNN Architecture for Decoding EEG MI Data " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Import module " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import print_function\n", 24 | "import os; os.environ[\"THEANO_FLAGS\"] = \"device=gpu0\"\n", 25 | "import os.path\n", 26 | "from datetime import datetime\n", 27 | "import sys\n", 28 | "sys.path.append('../../gumpy')\n", 29 | "\n", 30 | "import gumpy\n", 31 | "import numpy as np\n", 32 | "import scipy.io\n", 33 | "import matplotlib.pyplot as plt" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "\n", 41 | "\n", 42 | "To use the models provided by gumpy-deeplearning, we have to set the path to the models directory and import it. If you installed gumpy-deeplearning as a module, this step may not be required.\n", 43 | "\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "sys.path.append('..')\n", 53 | "import models" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "\n", 61 | "The examples for gumpy-deeplearning ship with a few tiny helper functions. For instance, there's one that tells you the versions of the currently installed keras and kapre. keras is required in gumpy-deeplearning. \n", 62 | "In addition, the utility functions contain a method load_preprocess_data to load and preprocess data. Its usage will be shown further below" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "import utils\n", 72 | "utils.print_version_info()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Setup parameters for the model and data" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 1, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "DEBUG = True\n", 89 | "######\n", 90 | "## the RCNN FLAG because It does not use spectrograms but rather it uses directly the raw signals.\n", 91 | "# so the data have something different compared to the spectrogram so you have to activte the flage.\n", 92 | "######\n", 93 | "RCNN_FLAG = True\n", 94 | "###########\n", 95 | "\n", 96 | "CLASS_COUNT = 2\n", 97 | "DROPOUT = 0.2 # dropout rate in float\n", 98 | "\n", 99 | "# parameters for filtering data\n", 100 | "FS = 250\n", 101 | "LOWCUT = 2\n", 102 | "HIGHCUT = 60\n", 103 | "ANTI_DRIFT = 0.5\n", 104 | "CUTOFF = 50.0 # freq to be removed from signal (Hz) for notch filter\n", 105 | "Q = 30.0 # quality factor for notch filter\n", 106 | "W0 = CUTOFF/(FS/2)\n", 107 | "AXIS = 0\n", 108 | "\n", 109 | "#set random seed\n", 110 | "SEED = 42\n", 111 | "KFOLD = 5" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# Load raw data " 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "data_dir = '../../grazdata'\n", 128 | "subject = 'B01'\n", 129 | "\n", 130 | "# initialize the data-structure, but do _not_ load the data yet\n", 131 | "grazb_data = gumpy.data.GrazB(data_dir, subject)\n", 132 | "\n", 133 | "# now that the dataset is setup, we can load the data. This will be handled from within the utils function,\n", 134 | "# which will first load the data and subsequently filter it using a notch and a bandpass filter.\n", 135 | "# the utility function will then return the training data.\n", 136 | "x_train, y_train = utils.load_preprocess_data(grazb_data, True, LOWCUT, HIGHCUT, W0, Q, ANTI_DRIFT, CLASS_COUNT, CUTOFF, AXIS, FS)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "# Augment data " 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "x_augmented, y_augmented = gumpy.signal.sliding_window(data = x_train[:,:,:],\n", 153 | " labels = y_train[:,:],\n", 154 | " window_sz = 4 * FS,\n", 155 | " n_hop = FS // 10,\n", 156 | " n_start = FS * 1)\n", 157 | "x_subject = x_augmented\n", 158 | "y_subject = y_augmented\n", 159 | "x_subject = np.rollaxis(x_subject, 2, 1)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "# Run the model " 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "from sklearn.model_selection import StratifiedKFold\n", 176 | "from models import RCNN\n", 177 | "\n", 178 | "# define KFOLD-fold cross validation test harness\n", 179 | "kfold = StratifiedKFold(n_splits=KFOLD, shuffle=True, random_state=SEED)\n", 180 | "cvscores = []\n", 181 | "ii = 1\n", 182 | "for train, test in kfold.split(x_subject, y_subject[:, 0]):\n", 183 | " print('Run ' + str(ii) + '...')\n", 184 | " # create callbacks\n", 185 | " model_name_str = 'GRAZ_CNN_STFT_3layer_' + \\\n", 186 | " '_run_' + str(ii)\n", 187 | " callbacks_list = model.get_callbacks(model_name_str)\n", 188 | " #print(x_subject.shape)\n", 189 | " #print(train)\n", 190 | " # initialize and create the model\n", 191 | " model = RCNN(model_name_str)\n", 192 | " model.create_model(x_subject.shape[1:], print_summary=False, class_count = CLASS_COUNT)\n", 193 | "\n", 194 | " # fit model. If you specify monitor=True, then the model will create callbacks\n", 195 | " # and write its state to a HDF5 file\n", 196 | "\n", 197 | " if (RCNN_FLAG == True):\n", 198 | " x_subject = np.rollaxis(x_subject, 2, 1)\n", 199 | " x_subject = x_subject[:, np.newaxis, :, :]\n", 200 | " #print(x_subject.shape)\n", 201 | " model.fit(x_subject[train], y_subject[train], monitor=True,\n", 202 | " epochs=100,\n", 203 | " batch_size=256,\n", 204 | " verbose=0,\n", 205 | " validation_split=0.1,callbacks = callbacks_list)\n", 206 | "\n", 207 | " # evaluate the model\n", 208 | " print('Evaluating model on test set...')\n", 209 | " scores = model.evaluate(x_subject[test], y_subject[test], verbose=0)\n", 210 | " print(\"Result on test set: %s: %.2f%%\" % (model.metrics_names[1], scores[1] * 100))\n", 211 | " cvscores.append(scores[1] * 100)\n", 212 | " ii += 1\n", 213 | "\n", 214 | "# print some evaluation statistics and write results to file\n", 215 | "print(\"%.2f%% (+/- %.2f%%)\" % (np.mean(cvscores), np.std(cvscores)))\n", 216 | "cv_all_subjects = np.asarray(cvscores)\n", 217 | "print('Saving CV values to file....')\n", 218 | "np.savetxt('GRAZ_CV_' + 'CNN_STFT_3layer_' + str(DROPOUT) + 'do' + '.csv',\n", 219 | " cv_all_subjects, delimiter=',', fmt='%2.4f')\n", 220 | "print('CV values successfully saved!\\n')" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.5.2" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 2 245 | } 246 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import gumpy 2 | import numpy as np 3 | from datetime import datetime 4 | import kapre 5 | import keras 6 | import keras.utils as ku 7 | 8 | 9 | def load_preprocess_data(data, debug, lowcut, highcut, w0, Q, anti_drift, class_count, cutoff, axis, fs): 10 | """Load and preprocess data. 11 | 12 | The routine loads data with the use of gumpy's Dataset objects, and 13 | subsequently applies some post-processing filters to improve the data. 14 | """ 15 | # TODO: improve documentation 16 | 17 | data_loaded = data.load() 18 | 19 | if debug: 20 | print('Band-pass filtering the data in frequency range from %.1f Hz to %.1f Hz... ' 21 | %(lowcut, highcut)) 22 | 23 | data_notch_filtered = gumpy.signal.notch(data_loaded.raw_data, cutoff, axis) 24 | data_hp_filtered = gumpy.signal.butter_highpass(data_notch_filtered, anti_drift, axis) 25 | data_bp_filtered = gumpy.signal.butter_bandpass(data_hp_filtered, lowcut, highcut, axis) 26 | 27 | # Split data into classes. 28 | # TODO: as soon as gumpy.utils.extract_trails2 is merged with the 29 | # regular extract_trails, change here accordingly! 30 | class1_mat, class2_mat = gumpy.utils.extract_trials2(data_bp_filtered, data_loaded.trials, 31 | data_loaded.labels, data_loaded.trial_total, 32 | fs, nbClasses = 2) 33 | 34 | # concatenate data for training and create labels 35 | x_train = np.concatenate((class1_mat, class2_mat)) 36 | labels_c1 = np.zeros((class1_mat.shape[0], )) 37 | labels_c2 = np.ones((class2_mat.shape[0], )) 38 | y_train = np.concatenate((labels_c1, labels_c2)) 39 | 40 | # for categorical crossentropy 41 | y_train = ku.to_categorical(y_train) 42 | 43 | print("Data loaded and processed successfully!") 44 | return x_train, y_train 45 | 46 | 47 | def print_version_info(): 48 | now = datetime.now() 49 | 50 | print('%s/%s/%s' % (now.year, now.month, now.day)) 51 | print('Keras version: {}'.format(keras.__version__)) 52 | if keras.backend._BACKEND == 'tensorflow': 53 | import tensorflow 54 | print('Keras backend: {}: {}'.format(keras.backend._backend, tensorflow.__version__)) 55 | else: 56 | import theano 57 | print('Keras backend: {}: {}'.format(keras.backend._backend, theano.__version__)) 58 | print('Keras image dim ordering: {}'.format(keras.backend.image_dim_ordering())) 59 | print('Kapre version: {}'.format(kapre.__version__)) 60 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # import models for convenience here 2 | 3 | from .cnn_stft import CNN_STFT 4 | from .deep_cnn import Deep_CNN 5 | from .lstm import LSTM 6 | from .lstm_stft import LSTM_STFT 7 | from .shallow_cnn import Shallow_CNN 8 | from .vanilla_rnn import Vanilla_RNN 9 | from .rcnn import RCNN 10 | -------------------------------------------------------------------------------- /models/cnn_stft.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import Dense, Activation, Flatten 5 | from keras.layers import BatchNormalization, Dropout, Conv2D, MaxPooling2D 6 | 7 | import kapre 8 | from kapre.utils import Normalization2D 9 | from kapre.time_frequency import Spectrogram 10 | 11 | 12 | class CNN_STFT(KerasModel): 13 | 14 | def create_model(self, input_shape, dropout=0.5, print_summary=False): 15 | 16 | # basis of the CNN_STFT is a Sequential network 17 | model = Sequential() 18 | 19 | # spectrogram creation using STFT 20 | model.add(Spectrogram(n_dft = 128, n_hop = 16, input_shape = input_shape, 21 | return_decibel_spectrogram = False, power_spectrogram = 2.0, 22 | trainable_kernel = False, name = 'static_stft')) 23 | model.add(Normalization2D(str_axis = 'freq')) 24 | 25 | # Conv Block 1 26 | model.add(Conv2D(filters = 24, kernel_size = (12, 12), 27 | strides = (1, 1), name = 'conv1', 28 | border_mode = 'same')) 29 | model.add(BatchNormalization(axis = 1)) 30 | model.add(Activation('relu')) 31 | model.add(MaxPooling2D(pool_size = (2, 2), strides = (2,2), padding = 'valid', 32 | data_format = 'channels_last')) 33 | 34 | # Conv Block 2 35 | model.add(Conv2D(filters = 48, kernel_size = (8, 8), 36 | name = 'conv2', border_mode = 'same')) 37 | model.add(BatchNormalization(axis = 1)) 38 | model.add(Activation('relu')) 39 | model.add(MaxPooling2D(pool_size = (2, 2), strides = (2, 2), padding = 'valid', 40 | data_format = 'channels_last')) 41 | 42 | # Conv Block 3 43 | model.add(Conv2D(filters = 96, kernel_size = (4, 4), 44 | name = 'conv3', border_mode = 'same')) 45 | model.add(BatchNormalization(axis = 1)) 46 | model.add(Activation('relu')) 47 | model.add(MaxPooling2D(pool_size = (2, 2), strides = (2,2), 48 | padding = 'valid', 49 | data_format = 'channels_last')) 50 | model.add(Dropout(dropout)) 51 | 52 | # classificator 53 | model.add(Flatten()) 54 | model.add(Dense(2)) # two classes only 55 | model.add(Activation('softmax')) 56 | 57 | if print_summary: 58 | print(model.summary()) 59 | 60 | # compile the model 61 | model.compile(loss = 'categorical_crossentropy', 62 | optimizer = 'adam', 63 | metrics = ['accuracy']) 64 | 65 | # assign model and return 66 | self.model = model 67 | return model 68 | -------------------------------------------------------------------------------- /models/deep_cnn.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import SimpleRNN, Dense, Conv2D, Dropout, BatchNormalization, \ 5 | Reshape, Activation, Flatten, AveragePooling2D, Conv3D 6 | 7 | 8 | class Deep_CNN(KerasModel): 9 | 10 | def create_model(self, augmented_data=False, print_summary=False, downsampled=False): 11 | CLASS_COUNT = 2 12 | 13 | model = Sequential() 14 | if augmented_data and downsampled: 15 | input_shape = (3, 1280, 1) 16 | # Conv Pool Block 1 17 | model.add(Conv2D(input_shape=input_shape, filters=25, kernel_size=(1, 10), strides=(1, 1), 18 | padding='valid', activation='linear')) 19 | model.add(Reshape(target_shape=(3, 1271, 25, 1))) 20 | model.add(Conv3D(filters=25, kernel_size=(3, 1, 25), 21 | data_format='channels_last')) 22 | model.add(BatchNormalization()) 23 | model.add(Activation(activation='elu')) 24 | model.add(Flatten()) 25 | model.add(Reshape(target_shape=(1271, 25, 1))) 26 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1), data_format='channels_last')) 27 | model.add(Dropout(0.5)) 28 | 29 | # Conv Pool Block 2 30 | model.add(Conv2D(filters=50, kernel_size=(10, 25))) 31 | model.add(BatchNormalization()) 32 | model.add(Activation(activation='elu')) 33 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 34 | model.add(Flatten()) 35 | model.add(Reshape(target_shape=(138, 50, 1))) 36 | model.add(Dropout(0.5)) 37 | 38 | # Conv Pool Block 3 39 | model.add(Conv2D(filters=100, kernel_size=(10, 50))) 40 | model.add(BatchNormalization()) 41 | model.add(Activation(activation='elu')) 42 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 43 | model.add(Flatten()) 44 | model.add(Reshape(target_shape=(43, 100, 1))) 45 | model.add(Dropout(0.5)) 46 | 47 | if augmented_data and not downsampled: 48 | input_shape = (3, 1024, 1) 49 | # Conv Pool Block 1 50 | model.add(Conv2D(input_shape=input_shape, filters=25, kernel_size=(1, 10), strides=(1, 1), 51 | padding='valid', activation='linear')) 52 | 53 | model.add(Reshape(target_shape=(3, 1015, 25, 1))) 54 | model.add(Conv3D(filters=25, kernel_size=(3, 1, 25), 55 | data_format='channels_last')) 56 | model.add(BatchNormalization()) 57 | model.add(Activation(activation='elu')) 58 | model.add(Flatten()) 59 | model.add(Reshape(target_shape=(1015, 25, 1))) 60 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1), data_format='channels_last')) 61 | model.add(Dropout(0.5)) 62 | 63 | # Conv Pool Block 2 64 | model.add(Conv2D(filters=50, kernel_size=(10, 25))) 65 | model.add(BatchNormalization()) 66 | model.add(Activation(activation='elu')) 67 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 68 | model.add(Flatten()) 69 | 70 | model.add(Reshape(target_shape=(109, 50, 1))) 71 | model.add(Dropout(0.5)) 72 | 73 | # Conv Pool Block 3 74 | model.add(Conv2D(filters=100, kernel_size=(10, 50))) 75 | model.add(BatchNormalization()) 76 | model.add(Activation(activation='elu')) 77 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 78 | model.add(Flatten()) 79 | model.add(Reshape(target_shape=(33, 100, 1))) 80 | model.add(Dropout(0.5)) 81 | 82 | if not augmented_data and not downsampled: 83 | input_shape = (3, 2560, 1) 84 | 85 | # Conv Pool Block 1 86 | model.add(Conv2D(input_shape=input_shape, filters=25, kernel_size=(1, 10), strides=(1, 1), 87 | padding='valid', activation='linear')) 88 | 89 | model.add(Reshape(target_shape=(3, 2551, 25, 1))) 90 | model.add(Conv3D(filters=25, kernel_size=(3, 1, 25), 91 | data_format='channels_last')) 92 | model.add(BatchNormalization()) 93 | model.add(Activation(activation='elu')) 94 | model.add(Flatten()) 95 | model.add(Reshape(target_shape=(2551, 25, 1))) 96 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1), data_format='channels_last')) 97 | model.add(Dropout(0.5)) 98 | 99 | # Conv Pool Block 2 100 | model.add(Conv2D(filters=50, kernel_size=(10, 25))) 101 | model.add(BatchNormalization()) 102 | model.add(Activation(activation='elu')) 103 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 104 | model.add(Flatten()) 105 | 106 | model.add(Reshape(target_shape=(280, 50, 1))) 107 | model.add(Dropout(0.5)) 108 | 109 | # Conv Pool Block 3 110 | model.add(Conv2D(filters=100, kernel_size=(10, 50))) 111 | model.add(BatchNormalization()) 112 | model.add(Activation(activation='elu')) 113 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 114 | model.add(Flatten()) 115 | 116 | model.add(Reshape(target_shape=(90, 100, 1))) 117 | model.add(Dropout(0.5)) 118 | 119 | if not augmented_data and downsampled: 120 | # Conv Pool Block 1 121 | model.add(Conv2D(input_shape=input_shape, filters=25, kernel_size=(1, 10), strides=(1, 1), 122 | padding='valid', activation='linear')) 123 | model.add(Reshape(target_shape=(3, 1271, 25, 1))) 124 | model.add(Conv3D(filters=25, kernel_size=(3, 1, 25), 125 | data_format='channels_last')) 126 | model.add(BatchNormalization()) 127 | model.add(Activation(activation='elu')) 128 | model.add(Flatten()) 129 | model.add(Reshape(target_shape=(1271, 25, 1))) 130 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1), data_format='channels_last')) 131 | model.add(Dropout(0.5)) 132 | 133 | # Conv Pool Block 2 134 | model.add(Conv2D(filters=50, kernel_size=(10, 25))) 135 | model.add(BatchNormalization()) 136 | model.add(Activation(activation='elu')) 137 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 138 | model.add(Flatten()) 139 | model.add(Reshape(target_shape=(138, 50, 1))) 140 | model.add(Dropout(0.5)) 141 | 142 | # Conv Pool Block 3 143 | model.add(Conv2D(filters=100, kernel_size=(10, 50))) 144 | model.add(BatchNormalization()) 145 | model.add(Activation(activation='elu')) 146 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 147 | model.add(Flatten()) 148 | model.add(Reshape(target_shape=(43, 100, 1))) 149 | model.add(Dropout(0.5)) 150 | 151 | # Conv Pool Block 4 152 | model.add(Conv2D(filters=200, kernel_size=(10, 100))) 153 | 154 | model.add(BatchNormalization()) 155 | model.add(Activation(activation='elu')) 156 | model.add(MaxPool2D(pool_size=(3, 1), strides=(3, 1))) 157 | 158 | # Softmax for classification 159 | model.add(Flatten()) 160 | model.add(Dense(CLASS_COUNT)) 161 | model.add(Activation('softmax')) 162 | 163 | if print_summary: 164 | print(model.summary()) 165 | 166 | # compile the model 167 | model.compile(loss='categorical_crossentropy', 168 | optimizer='adam', 169 | metrics=['accuracy']) 170 | 171 | # assign and return 172 | self.model = model 173 | return model 174 | -------------------------------------------------------------------------------- /models/lstm.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import SimpleRNN, Dense, LSTM as _LSTM 5 | 6 | class LSTM(KerasModel): 7 | 8 | def create_model(self, input_shape, num_hidden_neurons=128, 9 | num_layers=1, dropout=0.2, recurrent_dropout=0.2, 10 | print_summary=False): 11 | 12 | model = Sequential() 13 | if num_layers > 1: 14 | for i in range(1, num_layers, 1): 15 | model.add(_LSTM(num_hidden_neurons, input_shape=input_shape, 16 | return_sequences=True, dropout=dropout, recurrent_dropout=recurrent_dropout)) 17 | model.add(_LSTM(num_hidden_neurons)) 18 | else: 19 | model.add(_LSTM(num_hidden_neurons, input_shape=input_shape, dropout=dropout, 20 | recurrent_dropout=recurrent_dropout)) 21 | model.add(Dense(2, activation='softmax')) 22 | 23 | if print_summary: 24 | print(model.summary()) 25 | 26 | # compile the model 27 | model.compile(loss='categorical_crossentropy', 28 | optimizer='adam', 29 | metrics=['accuracy']) 30 | 31 | # assign and return 32 | self.model = model 33 | return model 34 | 35 | 36 | -------------------------------------------------------------------------------- /models/lstm_stft.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import SimpleRNN, Dense, LSTM, Dropout 5 | 6 | class LSTM_STFT(KerasModel): 7 | def create_model(self, input_shape, num_hidden_neurons=128, 8 | num_layers=1, n_dft=128, n_hop=16, dropout=0.0, recurrent_dropout=0.0, 9 | print_summary=False): 10 | model = Sequential() 11 | # STFT layer 12 | model.add(Spectrogram(n_dft=n_dft, n_hop=n_hop, input_shape=input_shape, 13 | return_decibel_spectrogram=False, power_spectrogram=2.0, 14 | trainable_kernel=False, name='static_stft')) 15 | 16 | model.add(Permute((1, 3, 2))) # needs to be (3,1,2) 17 | model.add(Reshape((64, 65 * 3))) 18 | 19 | if num_layers > 1: 20 | for i in range(1, num_layers, 1): 21 | model.add(LSTM(num_hidden_neurons, return_sequences=True, 22 | dropout=dropout, 23 | recurrent_dropout=recurrent_dropout)) 24 | model.add(LSTM(num_hidden_neurons)) 25 | else: 26 | model.add(LSTM(num_hidden_neurons)) 27 | 28 | model.add(Dropout(dropout)) 29 | model.add(Dense(2, activation='softmax')) 30 | 31 | if print_summary: 32 | print(model.summary()) 33 | 34 | # compile the model 35 | model.compile(loss='categorical_crossentropy', 36 | optimizer='adam', 37 | metrics=['accuracy']) 38 | 39 | # assign and return 40 | self.model = model 41 | return model 42 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | import keras 4 | from keras.models import model_from_json 5 | from keras.callbacks import CSVLogger, ModelCheckpoint 6 | 7 | 8 | class Model(ABC): 9 | """An abstract deep learning model. 10 | 11 | The abstract class functions as a facade for the backend. Although 12 | gumpy-deeplearning currently uses keras, it is possible that future releases 13 | may use different front- or backends. The Model ABC should represent the 14 | baseline for any such model. 15 | 16 | For more information about the reason behind ``Model``, see https://xkcd.com/927/ 17 | 18 | """ 19 | 20 | def __init__(self, name): 21 | self.name = name 22 | self.model = None 23 | 24 | @abstractmethod 25 | def create_model(self): 26 | pass 27 | 28 | @abstractmethod 29 | def fit(self): 30 | pass 31 | 32 | @abstractmethod 33 | def evaluate(self): 34 | pass 35 | 36 | @abstractmethod 37 | def from_json(self): 38 | pass 39 | 40 | 41 | 42 | 43 | class KerasModel(Model): 44 | """ABC for Models that rely on keras. 45 | 46 | The ABC provides an implementation to generate callbacks to monitor the 47 | model and write the data to HDF5 files. The function ``fit`` simply forwards 48 | to the keras' ``fit``, but will enable monitoring if wanted. 49 | 50 | """ 51 | 52 | def __init__(self, name): 53 | super(KerasModel, self).__init__(name) 54 | self.callbacks = None 55 | 56 | 57 | def get_callbacks(self): 58 | """Returns callbacks to monitor the model. 59 | 60 | """ 61 | 62 | # save weights in an HDF5 file 63 | model_file = self.name + '_monitoring' + '.h5' 64 | checkpoint = ModelCheckpoint(model_file, monitor = 'val_loss', 65 | verbose = 0, save_best_only = True, mode = 'min') 66 | log_file = self.name + '.csv' 67 | csv_logger = CSVLogger(log_file, append = True, separator = ';') 68 | callbacks_list = [csv_logger] # callback list 69 | 70 | self.callbacks = callbacks_list 71 | return callbacks_list 72 | 73 | 74 | def fit(self, x, y, monitor=True, **kwargs): 75 | # TODO: allow user to specify filename 76 | if monitor and (self.callbacks is None): 77 | self.get_callbacks() 78 | 79 | if self.callbacks is not None: 80 | self.model.fit(x, y, **kwargs, callbacks=self.callbacks) 81 | else: 82 | self.model.fit(x, y, **kwargs) 83 | 84 | 85 | def evaluate(self, x, y, **kwargs): 86 | return self.model.evaluate(x, y, **kwargs) 87 | 88 | 89 | def from_json(self, model_file_name=None): 90 | try: 91 | # set the model_file_name if it is not passed to the function 92 | if model_file_name is None: 93 | model_file_name = self.name 94 | 95 | # load trained model 96 | model_path = model_file_name + ".json" 97 | if not os.path.isfile(model_path): 98 | raise IOError('file "%s" does not exist' %(model_path)) 99 | model = model_from_json(open(model_path).read()) 100 | 101 | # load weights of trained model 102 | model_weight_path = model_file + ".hdf5" 103 | if not os.path.isfile(model_weight_path): 104 | raise OSError('file "%s" does not exist' %(model_path)) 105 | model.load_weights(model_weight_path) 106 | 107 | return model 108 | except IOError: 109 | print(IOError) 110 | return None 111 | -------------------------------------------------------------------------------- /models/rcnn.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import Dense, Activation, Flatten 5 | from keras.layers import BatchNormalization, Dropout, Conv2D, MaxPooling2D 6 | from keras.layers import Input, Dense, Dropout, Flatten 7 | from keras.layers import merge, Conv2D, MaxPooling2D, Input 8 | from keras.layers.advanced_activations import PReLU 9 | from keras.layers import Activation 10 | from keras.models import Model 11 | 12 | 13 | class RCNN(KerasModel): 14 | 15 | # TODO: why is this called RCL? 16 | def RCL(self,l, a): 17 | # TODO: documentation 18 | 19 | # first convolutional layer 20 | conv1 = Conv2D(filters=128, kernel_size=(1, 9), strides=(1, 1), padding='same', data_format='channels_last', 21 | init='he_normal')(l) 22 | bn1 = BatchNormalization(epsilon=0.000001)(conv1) 23 | relu1 = PReLU()(bn1) 24 | pool1 = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding='valid', data_format='channels_last')(relu1) 25 | drop1 = Dropout(0)(pool1) 26 | 27 | # start first RCL layer 28 | # the second time convolution and stored for recurrent 29 | conv2 = Conv2D(filters=128, kernel_size=(1, 1), padding='same', init='he_normal')(drop1) 30 | bn2 = BatchNormalization(axis=1, epsilon=0.000001)(conv2) 31 | relu2 = PReLU()(bn2) 32 | 33 | # first recurrent for the first convolution 34 | conv2a = Conv2D(filters=128, kernel_size=(1, 9), padding='same', init='he_normal') 35 | conv2aa = conv2a(relu2) 36 | merged2a = merge([conv2, conv2aa], mode='sum') 37 | 38 | # second recurrent for the first convolution 39 | bn2a = BatchNormalization(axis=1, epsilon=0.000001)(merged2a) 40 | relu2a = PReLU()(bn2a) 41 | conv2b = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv2a.get_weights())(relu2a) 42 | merged2b = merge([conv2, conv2b], mode='sum') 43 | 44 | # third recurrent for the first convolution 45 | bn2b = BatchNormalization(axis=1, epsilon=0.000001)(merged2b) 46 | relu2b = PReLU()(bn2b) 47 | conv2c = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv2a.get_weights())(relu2b) 48 | merged2c = merge([conv2, conv2c], mode='sum') 49 | 50 | bn2c = BatchNormalization(axis=1, epsilon=0.000001)(merged2c) 51 | relu2c = PReLU()(bn2c) 52 | pool2 = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding='valid', data_format='channels_last')(relu2c) 53 | drop2 = Dropout(0.2)(pool2) 54 | 55 | conv3 = Conv2D(filters=128, kernel_size=(1, 1), padding='same')(drop2) 56 | bn3 = BatchNormalization(axis=1, epsilon=0.000001)(conv3) 57 | relu3 = PReLU()(bn3) 58 | conv3a = Conv2D(filters=128, kernel_size=(1, 9), padding='same', init='he_normal') 59 | conv3aa = conv3a(relu3) 60 | merged3a = merge([conv3, conv3aa], mode='sum') 61 | 62 | bn3a = BatchNormalization(axis=1, epsilon=0.000001)(merged3a) 63 | relu3a = PReLU()(bn3a) 64 | conv3b = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv3a.get_weights())(relu3a) 65 | merged3b = merge([conv3, conv3b], mode='sum') 66 | 67 | bn3b = BatchNormalization(axis=1, epsilon=0.000001)(merged3b) 68 | relu3b = PReLU()(bn3b) 69 | conv3c = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv3a.get_weights())(relu3b) 70 | merged3c = merge([conv3, conv3c], mode='sum') 71 | 72 | bn3c = BatchNormalization(axis=1, epsilon=0.000001)(merged3c) 73 | relu3c = PReLU()(bn3c) 74 | pool3 = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding='valid', data_format='channels_last')(relu3c) 75 | drop3 = Dropout(0.2)(pool3) 76 | 77 | conv4 = Conv2D(filters=128, kernel_size=(1, 1), padding='same', init='he_normal')(drop3) 78 | bn4 = BatchNormalization(axis=1, epsilon=0.000001)(conv4) 79 | relu4 = PReLU()(bn4) 80 | conv4a = Conv2D(filters=128, kernel_size=(1, 9), padding='same') 81 | conv4aa = conv4a(relu4) 82 | merged4a = merge([conv4, conv4aa], mode='sum') 83 | 84 | bn4a = BatchNormalization(axis=1, epsilon=0.000001)(merged4a) 85 | relu4a = PReLU()(bn4a) 86 | conv4b = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv4a.get_weights())(relu4a) 87 | merged4b = merge([conv4, conv4b], mode='sum') 88 | 89 | bn4b = BatchNormalization(axis=1, epsilon=0.000001)(merged4b) 90 | relu4b = PReLU()(bn4b) 91 | conv4c = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv4a.get_weights())(relu4b) 92 | merged4c = merge([conv4, conv4c], mode='sum') 93 | 94 | bn4c = BatchNormalization(axis=1, epsilon=0.000001)(merged4c) 95 | relu4c = PReLU()(bn4c) 96 | pool4 = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding='valid', data_format='channels_last')(relu4c) 97 | drop4 = Dropout(0.2)(pool4) 98 | 99 | conv5 = Conv2D(filters=128, kernel_size=(1, 1), padding='same')(drop4) 100 | bn5 = BatchNormalization(axis=1, epsilon=0.000001)(conv5) 101 | relu5 = PReLU()(bn5) 102 | conv5a = Conv2D(filters=128, kernel_size=(1, 9), padding='same') 103 | conv5aa = conv5a(relu5) 104 | merged5a = merge([conv5, conv5aa], mode='sum') 105 | 106 | bn5a = BatchNormalization(axis=1, epsilon=0.000001)(merged5a) 107 | relu5a = PReLU()(bn5a) 108 | conv5b = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv5a.get_weights())(relu5a) 109 | merged5b = merge([conv5, conv5b], mode='sum') 110 | 111 | bn5b = BatchNormalization(axis=1, epsilon=0.000001)(merged5b) 112 | relu5b = PReLU()(bn5b) 113 | conv5c = Conv2D(filters=128, kernel_size=(1, 9), padding='same', weights=conv5a.get_weights())(relu5b) 114 | merged5c = merge([conv5, conv5c], mode='sum') 115 | 116 | bn5c = BatchNormalization(axis=1, epsilon=0.000001)(merged5c) 117 | relu5c = PReLU()(bn5c) 118 | # pool5 = MaxPooling2D(pool_size=(1, 4), strides=(1, 4), padding='valid', data_format='channels_last')(relu5c) 119 | drop5 = Dropout(0.2)(relu5c) 120 | 121 | conv_relu = Activation('sigmoid')(drop5) 122 | 123 | # TODO: what is going on with this variable name? 124 | l1111 = Flatten()(conv_relu) 125 | out = Dense(a, activation='softmax')(l1111) 126 | 127 | return out 128 | 129 | # TODO: documentation 130 | def create_model(self, input_shape, print_summary=False, class_count = 2): 131 | """Create a new RCNN model instance""" 132 | 133 | changed_shape = (1,input_shape[1],input_shape[0]) 134 | input_1 = Input(changed_shape) 135 | output = self.RCL(input_1,a) 136 | model = Model(inputs=input_1, outputs=output) 137 | model.compile(loss='categorical_crossentropy', 138 | optimizer='RMSprop', 139 | metrics=['accuracy']) 140 | self.model = model 141 | return model 142 | -------------------------------------------------------------------------------- /models/shallow_cnn.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import SimpleRNN, Dense, Conv2D, Dropout, BatchNormalization, \ 5 | Reshape, Activation, Flatten, AveragePooling2D, Conv3D 6 | 7 | 8 | class Shallow_CNN(KerasModel): 9 | 10 | def create_model(self, augmented_data=True, print_summary=False, downsampled=False): 11 | 12 | CLASS_COUNT = 2 13 | model = Sequential() 14 | # augmented_data = False 15 | # print_summary=False 16 | 17 | if augmented_data and downsampled: 18 | # Conv Block 1 19 | model.add(Conv2D(input_shape=(3, 512, 1), filters=40, kernel_size=(1, 25), strides=(1, 1), 20 | padding='valid', activation=None)) 21 | model.add(Reshape(target_shape=(3, 488, 40, 1))) 22 | model.add(Dropout(0.5)) 23 | 24 | # Conv Block 2 25 | model.add(Conv3D(filters=40, kernel_size=(3, 1, 40), padding='valid', 26 | data_format='channels_last')) 27 | model.add(BatchNormalization()) 28 | model.add(Activation(keras.backend.square)) # custom squaring activation function 29 | model.add(Flatten()) 30 | model.add(Reshape(target_shape=(488, 40, 1))) 31 | model.add(Dropout(0.5)) 32 | # Pooling 33 | model.add(AveragePooling2D(pool_size=(75, 1), strides=(15, 1), data_format='channels_last')) 34 | model.add(Activation(keras.backend.log)) # custom log function 35 | if augmented_data and not downsampled: 36 | # Conv Block 1 37 | model.add(Conv2D(input_shape=(3, 1024, 1), filters=40, kernel_size=(1, 25), strides=(1, 1), 38 | padding='valid', activation=None)) 39 | model.add(Reshape(target_shape=(3, 1000, 40, 1))) 40 | model.add(Dropout(0.5)) 41 | 42 | # Conv Block 2 43 | model.add(Conv3D(filters=40, kernel_size=(3, 1, 40), padding='valid', 44 | data_format='channels_last')) 45 | model.add(BatchNormalization()) 46 | model.add(Activation(keras.backend.square)) # custom squaring activation function 47 | model.add(Flatten()) 48 | model.add(Reshape(target_shape=(1000, 40, 1))) 49 | model.add(Dropout(0.5)) 50 | # Pooling 51 | model.add(AveragePooling2D(pool_size=(75, 1), strides=(15, 1), data_format='channels_last')) 52 | model.add(Activation(keras.backend.log)) # custom log function 53 | 54 | 55 | else: 56 | # Conv Block 1 57 | model.add(Conv2D(input_shape=(3, 1280, 1), filters=40, kernel_size=(1, 25), strides=(1, 1), 58 | padding='valid', activation=None)) 59 | model.add(Reshape(target_shape=(3, 1256, 40, 1))) 60 | model.add(Dropout(0.5)) 61 | 62 | # Conv Block 2 63 | model.add(Conv3D(filters=40, kernel_size=(3, 1, 40), padding='valid', 64 | data_format='channels_last')) 65 | model.add(BatchNormalization()) 66 | model.add(Activation(keras.backend.square)) # custom squaring activation function 67 | model.add(Flatten()) 68 | model.add(Reshape(target_shape=(1256, 40, 1))) 69 | model.add(Dropout(0.5)) 70 | 71 | # Pooling 72 | model.add(AveragePooling2D(pool_size=(75, 1), strides=(15, 1), data_format='channels_last')) 73 | model.add(Activation(keras.backend.log)) # custom log function 74 | 75 | # Classification 76 | model.add(Flatten()) 77 | model.add(Dense(CLASS_COUNT)) 78 | model.add(Activation('softmax')) 79 | 80 | if print_summary: 81 | print(model.summary()) 82 | 83 | # compile the model 84 | model.compile(loss='categorical_crossentropy', 85 | optimizer='adam', 86 | metrics=['accuracy']) 87 | 88 | # assign and return 89 | self.model = model 90 | return model 91 | 92 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import gumpy 2 | import numpy as np 3 | from datetime import datetime 4 | import kapre 5 | import keras 6 | import keras.utils as ku 7 | 8 | 9 | def load_preprocess_data(data, debug, lowcut, highcut, w0, Q, anti_drift, class_count, cutoff, axis, fs): 10 | """Load and preprocess data. 11 | The routine loads data with the use of gumpy's Dataset objects, and 12 | subsequently applies some post-processing filters to improve the data. 13 | """ 14 | # TODO: improve documentation 15 | 16 | data_loaded = data.load() 17 | 18 | if debug: 19 | print('Band-pass filtering the data in frequency range from %.1f Hz to %.1f Hz... ' 20 | %(lowcut, highcut)) 21 | 22 | data_notch_filtered = gumpy.signal.notch(data_loaded.raw_data, cutoff, axis) 23 | data_hp_filtered = gumpy.signal.butter_highpass(data_notch_filtered, anti_drift, axis) 24 | data_bp_filtered = gumpy.signal.butter_bandpass(data_hp_filtered, lowcut, highcut, axis) 25 | 26 | # Split data into classes. 27 | # TODO: as soon as gumpy.utils.extract_trails2 is merged with the 28 | # regular extract_trails, change here accordingly! 29 | class1_mat, class2_mat = gumpy.utils.extract_trials2(data_bp_filtered, data_loaded.trials, 30 | data_loaded.labels, data_loaded.trial_total, 31 | fs, nbClasses = 2) 32 | 33 | # concatenate data for training and create labels 34 | x_train = np.concatenate((class1_mat, class2_mat)) 35 | labels_c1 = np.zeros((class1_mat.shape[0], )) 36 | labels_c2 = np.ones((class2_mat.shape[0], )) 37 | y_train = np.concatenate((labels_c1, labels_c2)) 38 | 39 | # for categorical crossentropy 40 | y_train = ku.to_categorical(y_train) 41 | 42 | print("Data loaded and processed successfully!") 43 | return x_train, y_train 44 | 45 | 46 | def print_version_info(): 47 | now = datetime.now() 48 | print('%s/%s/%s' % (now.year, now.month, now.day)) 49 | print('Keras version: {}'.format(keras.__version__)) 50 | if keras.backend._BACKEND == 'tensorflow': 51 | import tensorflow 52 | print('Keras backend: {}: {}'.format(keras.backend._backend, tensorflow.__version__)) 53 | else: 54 | import theano 55 | print('Keras backend: {}: {}'.format(keras.backend._backend, theano.__version__)) 56 | print('Keras image dim ordering: {}'.format(keras.backend.image_dim_ordering())) 57 | print('Kapre version: {}'.format(kapre.__version__)) 58 | -------------------------------------------------------------------------------- /models/vanilla_rnn.py: -------------------------------------------------------------------------------- 1 | from .model import KerasModel 2 | import keras 3 | from keras.models import Sequential 4 | from keras.layers import SimpleRNN, Dense 5 | 6 | 7 | class Vanilla_RNN(KerasModel): 8 | 9 | def create_model(self, input_shape, num_hidden_neurons=128, num_layers=1, print_summary=False): 10 | self.model = Sequential() 11 | if num_layers > 1: 12 | for i in range(1, num_layers, 1): 13 | self.model.add(SimpleRNN(num_hidden_neurons, input_shape=input_shape, return_sequences=True)) 14 | self.model.add(SimpleRNN(num_hidden_neurons)) 15 | else: 16 | self.model.add(SimpleRNN(num_hidden_neurons, input_shape=input_shape)) 17 | 18 | self.model.add(Dense(2, activation='softmax')) 19 | 20 | if print_summary: 21 | print(self.model.summary()) 22 | 23 | # compile the model 24 | self.model.compile(loss='categorical_crossentropy', 25 | optimizer='adam', 26 | metrics=['accuracy']) 27 | 28 | 29 | return self.model 30 | 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gumpy-bci/gumpy-deeplearning/53e9f59c3a1035aa58d7e6d269adac6e794bed84/requirements.txt --------------------------------------------------------------------------------