├── .gitignore ├── LICENSE ├── README.rst ├── doc ├── Makefile └── source │ ├── ChangeLog.rst │ ├── apidoc.rst │ ├── conf.py │ ├── index.rst │ ├── installation.rst │ └── module_overview.rst ├── examples └── notebooks │ ├── EEG-motor-imagery-CSP.ipynb │ ├── EEG-motor-imagery.ipynb │ └── EMG-script.ipynb ├── gumpy ├── __init__.py ├── classification │ ├── __init__.py │ ├── classifier.py │ └── common.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── graz.py │ ├── khushaba.py │ ├── nst.py │ └── nst_emg.py ├── features.py ├── plot.py ├── signal.py ├── split.py ├── utils.py └── version.py ├── setup.cfg ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .ipynb_checkpoints 4 | gumpy.egg-info 5 | build/ 6 | dist/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 The gumpy developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | gumpy 2 | ===== 3 | 4 | ``gumpy`` is a Python 3 toolbox to develop Brain-Computer Interfaces (BCI). 5 | 6 | ``gumpy`` contains implementations of several functions that are commonly used 7 | during EEG and EMG decoding. For this purpose it heavily relies on other 8 | numerical and scientific libraries, for instance ``numpy``, ``scipy``, or 9 | ``scikit-learn``, to name just a few. In fact, ``gumpy`` mostly wraps existing 10 | functions in such a way that researchers working in the field can quickly 11 | perform data analysis and implement novel classifiers. Moreover, one of 12 | ``gumpy``'s design principles was to make it easily extendable. 13 | 14 | :license: MIT License 15 | :contributions: Please use github (www.github.com/gumpy-bci/gumpy) and see below 16 | :issues: Please use the issue tracker on github (www.github.com/gumpy-bci/gumpy/issues) 17 | 18 | 19 | Documentation 20 | ============= 21 | 22 | You can find documentation for gumpy either on www.gumpy.org or in subfolder 23 | ``doc``. For examples, see the folder ``examples``. 24 | 25 | 26 | Contributing 27 | ============ 28 | 29 | If you wish to contribute to gumpy's development clone the main repository from 30 | github and start coding, test if everything works as expected, and finally 31 | submit patches or open merge requests. Preferrably in this order. 32 | 33 | Please make sure that you follow PEP8, or have a look at the formatting of 34 | gumpy's code, and include proper documentation both in your commit messages as 35 | well as the source code. We use Google docstrings for formatting, and 36 | auto-generate parts of the documentation with sphinx. 37 | 38 | 39 | gumpy core developers and contributors 40 | ====================================== 41 | * Zied Tayeb 42 | * Nicolai Waniek, www.github.com/rochus 43 | * Juri Fedjaev 44 | * Nejla Ghaboosi 45 | * Leonard Rychly 46 | 47 | 48 | How to cite gumpy 49 | ================= 50 | 51 | Zied Tayeb, Nicolai Waniek, Juri Fedjaev, Nejla Ghaboosi, Leonard Rychly, 52 | Christian Widderich, Christoph Richter, Jonas Braun, Matteo Saveriano, Gordon 53 | Cheng, and Jörg Conradt. "gumpy: A Python Toolbox Suitable for Hybrid 54 | Brain-Computer Interfaces" 55 | 56 | 57 | .. code:: latex 58 | 59 | @Article{gumpy2018, 60 | Title = {gumpy: A Python Toolbox Suitable for Hybrid Brain-Computer Interfaces}, 61 | Author = {Tayeb, Zied and Waniek, Nicolai and Fedjaev, Juri and Ghaboosi, Nejla and Rychly, Leonard and Widderich, Christian and Richter, Christoph and Braun, Jonas and Saveriano, Matteo and Cheng, Gordon and Conradt, Jorg}, 62 | Year = {2018}, 63 | Journal = {} 64 | } 65 | 66 | 67 | Additional References 68 | ===================== 69 | 70 | * www.gumpy.org: gumpy's main website. You can find links to datasets here 71 | * www.github.com/gumpy-bci/gumpy: gumpy's main github repository 72 | * www.github.com/gumpy-bci/gumpy-deeplearning: gumpy's deep learning models for BCI 73 | * https://github.com/gumpy-bci/gumpy-realtime : gumpy's real-time BCI module with several online demos 74 | * https://www.youtube.com/channel/UCdarvfot4Ustk2UCmCp62sw : gumpy's Youtube channel 75 | * https://www.youtube.com/watch?v=M68GeL8PafE 76 | 77 | 78 | License 79 | ======= 80 | 81 | * All code in this repository is published under the MIT License. 82 | For more details see the LICENSE file. 83 | 84 | 85 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = gumpy 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/source/ChangeLog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 5 | -------------------------------------------------------------------------------- /doc/source/apidoc.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ============================ 3 | 4 | 5 | Classification 6 | -------------- 7 | 8 | .. automodule:: gumpy.classification.classifier 9 | :members: 10 | 11 | .. automodule:: gumpy.classification.common 12 | :members: 13 | 14 | 15 | Feature Extraction 16 | ------------------ 17 | .. automodule:: gumpy.features 18 | :members: 19 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # gumpy documentation build configuration file, created by 5 | # sphinx-quickstart on Mon Jan 29 16:56:59 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | # import os 21 | # import sys 22 | # sys.path.insert(0, os.path.abspath('.')) 23 | 24 | import os, sys 25 | sys.path.append('..') 26 | 27 | import sphinx_rtd_theme 28 | 29 | 30 | 31 | # -- General configuration ------------------------------------------------ 32 | 33 | # If your documentation needs a minimal Sphinx version, state it here. 34 | # 35 | # needs_sphinx = '1.0' 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = ['sphinx.ext.autodoc', 41 | 'sphinx.ext.doctest', 42 | 'sphinx.ext.todo', 43 | 'sphinx.ext.viewcode', 44 | 'sphinx.ext.githubpages', 45 | 'sphinx.ext.napoleon'] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # The suffix(es) of source filenames. 51 | # You can specify multiple suffix as a list of string: 52 | # 53 | # source_suffix = ['.rst', '.md'] 54 | source_suffix = '.rst' 55 | 56 | # The master toctree document. 57 | master_doc = 'index' 58 | 59 | # General information about the project. 60 | project = 'gumpy' 61 | copyright = '2018, the gumpy developers' 62 | author = 'the gumpy developers' 63 | 64 | # The version info for the project you're documenting, acts as replacement for 65 | # |version| and |release|, also used in various other places throughout the 66 | # built documents. 67 | # 68 | # The short X.Y version. 69 | version = '1.0' 70 | # The full version, including alpha/beta/rc tags. 71 | release = '1.0' 72 | 73 | # The language for content autogenerated by Sphinx. Refer to documentation 74 | # for a list of supported languages. 75 | # 76 | # This is also used if you do content translation via gettext catalogs. 77 | # Usually you set "language" from the command line for these cases. 78 | language = None 79 | 80 | # List of patterns, relative to source directory, that match files and 81 | # directories to ignore when looking for source files. 82 | # This patterns also effect to html_static_path and html_extra_path 83 | exclude_patterns = [] 84 | 85 | # The name of the Pygments (syntax highlighting) style to use. 86 | pygments_style = 'sphinx' 87 | 88 | # If true, `todo` and `todoList` produce output, else they produce nothing. 89 | todo_include_todos = True 90 | 91 | 92 | # -- Options for HTML output ---------------------------------------------- 93 | 94 | # The theme to use for HTML and HTML Help pages. See the documentation for 95 | # a list of builtin themes. 96 | # 97 | html_theme = 'sphinx_rtd_theme' 98 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 99 | 100 | # Theme options are theme-specific and customize the look and feel of a theme 101 | # further. For a list of options available for each theme, see the 102 | # documentation. 103 | # 104 | # html_theme_options = {} 105 | 106 | # Add any paths that contain custom static files (such as style sheets) here, 107 | # relative to this directory. They are copied after the builtin static files, 108 | # so a file named "default.css" will overwrite the builtin "default.css". 109 | html_static_path = ['_static'] 110 | 111 | # Custom sidebar templates, must be a dictionary that maps document names 112 | # to template names. 113 | # 114 | # This is required for the alabaster theme 115 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 116 | html_sidebars = { 117 | '**': [ 118 | 'about.html', 119 | 'navigation.html', 120 | 'relations.html', # needs 'show_related': True theme option to display 121 | 'searchbox.html', 122 | 'donate.html', 123 | ] 124 | } 125 | 126 | 127 | # -- Options for HTMLHelp output ------------------------------------------ 128 | 129 | # Output file base name for HTML help builder. 130 | htmlhelp_basename = 'gumpydoc' 131 | 132 | 133 | # -- Options for LaTeX output --------------------------------------------- 134 | 135 | latex_elements = { 136 | # The paper size ('letterpaper' or 'a4paper'). 137 | # 138 | # 'papersize': 'letterpaper', 139 | 140 | # The font size ('10pt', '11pt' or '12pt'). 141 | # 142 | # 'pointsize': '10pt', 143 | 144 | # Additional stuff for the LaTeX preamble. 145 | # 146 | # 'preamble': '', 147 | 148 | # Latex figure (float) alignment 149 | # 150 | # 'figure_align': 'htbp', 151 | } 152 | 153 | # Grouping the document tree into LaTeX files. List of tuples 154 | # (source start file, target name, title, 155 | # author, documentclass [howto, manual, or own class]). 156 | latex_documents = [ 157 | (master_doc, 'gumpy.tex', 'gumpy Documentation', 158 | 'the gympy developers', 'manual'), 159 | ] 160 | 161 | 162 | # -- Options for manual page output --------------------------------------- 163 | 164 | # One entry per manual page. List of tuples 165 | # (source start file, name, description, authors, manual section). 166 | man_pages = [ 167 | (master_doc, 'gumpy', 'gumpy Documentation', 168 | [author], 1) 169 | ] 170 | 171 | 172 | # -- Options for Texinfo output ------------------------------------------- 173 | 174 | # Grouping the document tree into Texinfo files. List of tuples 175 | # (source start file, target name, title, author, 176 | # dir menu entry, description, category) 177 | texinfo_documents = [ 178 | (master_doc, 'gumpy', 'gumpy Documentation', 179 | author, 'gumpy', 'One line description of project.', 180 | 'Miscellaneous'), 181 | ] 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to gumpy's documentation! 2 | ================================= 3 | 4 | `gumpy` is a Python 3 toolbox to develop Brain-Computer Interfaces (BCI). 5 | 6 | `gumpy` contains implementations of several functions that are commonly used 7 | during EEG and EMG decoding. For this purpose it heavily relies on other 8 | numerical and scientific libraries, for instance `numpy`, `scipy`, or 9 | `scikit-learn`, to name just a few. In fact, `gumpy` mostly wraps existing 10 | functions in such a way that researchers working in the field can quickly 11 | perform data analysis and implement novel classifiers. Moreover, one of 12 | `gumpy`'s design principles was to make it easily extendable. 13 | 14 | :license: MIT License 15 | :contributions: Please use github (www.github.com/gumpy-bci/gumpy) and see below 16 | :issues: Please use the issue tracker on github (www.github.com/gumpy-bci/gumpy/issues) 17 | 18 | 19 | How to cite gumpy 20 | ================= 21 | 22 | "gumpy: A Python Toolbox Suitable for Hybrid Brain-Computer Interfaces". 23 | Zied Tayeb, Nicolai Waniek, Juri Fedjaev, Nejla Ghaboosi, Leonard Rychly, 24 | Christian Widderich, Christoph Richter, Jörg Conradt. 2018 25 | 26 | .. code:: latex 27 | 28 | @Article{gumpy2018, 29 | Title = {gumpy: A Python Toolbox Suitable for Hybrid Brain-Computer Interfaces}, 30 | Author = {Tayeb, Zied and Waniek, Nicolai and Fedjaev, Juri and Ghaboosi, Nejla and Rychly, Leonard and Widderich, Christian and Richter, Christoph and Conradt, Jorg}, 31 | Year = {2018}, 32 | Journal = {} 33 | } 34 | 35 | 36 | User and developer guides 37 | ========================= 38 | 39 | .. toctree:: 40 | :maxdepth: 2 41 | :caption: Contents: 42 | :numbered: 43 | 44 | installation 45 | module_overview 46 | apidoc 47 | ChangeLog 48 | 49 | 50 | Indices and tables 51 | ================== 52 | 53 | * :ref:`genindex` 54 | * :ref:`modindex` 55 | * :ref:`search` 56 | 57 | -------------------------------------------------------------------------------- /doc/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | -------------------------------------------------------------------------------- /doc/source/module_overview.rst: -------------------------------------------------------------------------------- 1 | Module Overview 2 | =============== 3 | -------------------------------------------------------------------------------- /examples/notebooks/EEG-motor-imagery-CSP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Preparation \n", 8 | "## Append to path and import packages\n", 9 | "In case gumpy is not installed as package, you may have to specify the path to the gumpy directory" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%reset\n", 19 | "%matplotlib inline\n", 20 | "import matplotlib.pyplot as plt \n", 21 | "import sys, os, os.path\n", 22 | "sys.path.append('/.../gumpy')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Import gumpy\n", 30 | "This may take a while, as gumpy as several dependencies that will be loaded automatically" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "import gumpy" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Select workflow\n", 48 | "Select the actions you want to perform" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# Import data\n", 56 | "To import data, you have to specify the directory in which your data is stored in. For the example given here, the data is in the subfolder ``../EEG-Data/Graz_data/data``. \n", 57 | "Then, one of the classes that subclass from ``dataset`` can be used to load the data. In the example, we will use the GrazB dataset, for which ``gumpy`` already includes a corresponding class. If you have different data, simply subclass from ``gumpy.dataset.Dataset``." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# First specify the location of the data and some \n", 67 | "# identifier that is exposed by the dataset (e.g. subject)\n", 68 | "\n", 69 | "data_base_dir = '/.../.../Data'\n", 70 | "\n", 71 | "grazb_base_dir = os.path.join(data_base_dir, 'Graz')\n", 72 | "subject = 'B01'\n", 73 | "\n", 74 | "# The next line first initializes the data structure. \n", 75 | "# Note that this does not yet load the data! In custom implementations\n", 76 | "# of a dataset, this should be used to prepare file transfers, \n", 77 | "# for instance check if all files are available, etc.\n", 78 | "grazb_data = gumpy.data.GrazB(grazb_base_dir, subject)\n", 79 | "\n", 80 | "# Finally, load the dataset\n", 81 | "grazb_data.load()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "The abstract class allows to print some information about the contained data. This is a commodity function that allows quick inspection of the data as long as all necessary fields are provided in the subclassed variant." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "grazb_data.print_stats()\n", 98 | "# labels = grazb_data.labels" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "# Postprocess data\n", 106 | "Usually it is necessary to postprocess the raw data before you can properly use it. ``gumpy`` provides several methods to easily do so, or provides implementations that can be adapted to your needs.\n", 107 | "\n", 108 | "Most methods internally use other Python toolkits, for instance ``sklearn``, which is heavily used throughout ``gumpy``. Thereby, it is easy to extend ``gumpy`` with custom filters. In addition, we expect users to have to manipulate the raw data directly as shown in the following example." 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "## Common average re-referencing the data to Cz\n", 116 | "Some data is required to be re-referenced to a certain electrode. Because this may depend on your dataset, there is no common function provided by ``gumpy`` to do so. However and if sub-classed according to the documentation, you can access the raw-data directly as in the following example." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "if False:\n", 126 | " grazb_data.raw_data[:, 0] -= 2 * grazb_data.raw_data[:, 1]\n", 127 | " grazb_data.raw_data[:, 2] -= 2 * grazb_data.raw_data[:, 2]" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Notch and Band-Pass Filters\n", 135 | "``gumpy`` ships with several filters already implemented. They accept either raw data to be filtered, or a subclass of ``Dataset``. In the latter case, ``gumpy`` will automatically convert all channels using parameters extracted from the dataset." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# this returns a butter-bandpass filtered version of the entire dataset\n", 145 | "btr_data = gumpy.signal.butter_bandpass(grazb_data, lo=1, hi=35)\n", 146 | "\n", 147 | "# it is also possible to use filters on individual electrodes using \n", 148 | "# the .raw_data field of a dataset. The example here will remove a certain\n", 149 | "# from a single electrode using a Notch filter. This example also demonstrates\n", 150 | "# that parameters will be forwarded to the internal call to the filter, in this\n", 151 | "# case the scipy implementation iirnotch (Note that iirnotch is only available\n", 152 | "# in recent versions of scipy, and thus disabled in this example by default)\n", 153 | "\n", 154 | "# frequency to be removed from the signal\n", 155 | "if False:\n", 156 | " f0 = 50 \n", 157 | " # quality factor\n", 158 | " Q = 50 \n", 159 | " # get the cutoff frequency\n", 160 | " w0 = f0/(grazb_data.sampling_freq/2) \n", 161 | " # apply the notch filter\n", 162 | " btr_data = gumpy.signal.notch(btr_data[:, 0], w0, Q)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Normalization\n", 170 | "Many datasets require normalization. ``gumpy`` provides functions to normalize either using a mean computation or via min/max computation. As with the filters, this function accepts either an instance of ``Dataset``, or raw_data. In fact, it can be used for postprocessing any row-wise data in a numpy matrix." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "if False:\n", 180 | " # normalize the data first\n", 181 | " norm_data = gumpy.signal.normalize(grazb_data, 'mean_std')\n", 182 | " # let's see some statistics\n", 183 | " print(\"\"\"Normalized Data:\n", 184 | " Mean = {:.3f}\n", 185 | " Min = {:.3f}\n", 186 | " Max = {:.3f}\n", 187 | " Std.Dev = {:.3f}\"\"\".format(\n", 188 | " np.nanmean(norm_data),np.nanmin(norm_data),np.nanmax(norm_data),np.nanstd(norm_data)\n", 189 | " ))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "# Plotting and Feature Extraction\n", 197 | "\n", 198 | "Certainly you wish to plot results. ``gumpy`` provides several functions that show how to implement visualizations. For this purpose it heavily relies on ``matplotlib``, ``pandas``, and ``seaborn``. The following examples will show several of the implemented signal processing methods as well as their corresponding plotting functions. Moreover, the examples will show you how to extract features\n", 199 | "\n", 200 | "That said, let's start with a simple visualization where we access the filtered data from above to show you how to access the data and plot it." 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "\n", 210 | "Plot after filtering with a butter bandpass (ignore normalization)\n", 211 | "plt.figure()\n", 212 | "plt.clf()\n", 213 | "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 0], label='C3')\n", 214 | "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 1], alpha=0.7, label='C4')\n", 215 | "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 2], alpha=0.7, label='Cz')\n", 216 | "plt.legend()\n", 217 | "plt.title(\" Filtered Data\")" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "## EEG band visualization\n", 225 | "Using ``gumpy``'s filters and the provided method, it is easy to filter and subsequently plot the EEG bands of a trial." 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "# determine the trial that we wish to plot\n", 235 | "n_trial = 120\n", 236 | "# now specify the alpha and beta cutoff frequencies\n", 237 | "lo_a, lo_b = 7, 16\n", 238 | "hi_a, hi_b = 13, 24\n", 239 | "\n", 240 | "# first step is to filter the data\n", 241 | "flt_a = gumpy.signal.butter_bandpass(grazb_data, lo=lo_a, hi=hi_a)\n", 242 | "flt_b = gumpy.signal.butter_bandpass(grazb_data, lo=lo_b, hi=hi_b)\n", 243 | "\n", 244 | "# finally we can visualize the data\n", 245 | "gumpy.plot.EEG_bandwave_visualizer(grazb_data, flt_a, n_trial, lo_a, hi_a)\n", 246 | "gumpy.plot.EEG_bandwave_visualizer(grazb_data, flt_b, n_trial, lo_a, hi_a)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "## Extract trials\n", 254 | "Now we wish to extract the trials from the data. This operation may heavily depend on your dataset, and thus we cannot guarantee that the function works for your specific dataset. However, the used function ``gumpy.utils.extract_trials`` can be used as a guideline how to extract the trials you wish to examine." 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "# retrieve the trials from the filtered data. This requires that the function\n", 264 | "# knows the number of trials, labels, etc. when only passed a (filtered) data matrix\n", 265 | "trial_marks = grazb_data.trials\n", 266 | "labels = grazb_data.labels\n", 267 | "sampling_freq = grazb_data.sampling_freq\n", 268 | "epochs = gumpy.utils.extract_trials(grazb_data, trials=trial_marks, labels=labels, sampling_freq=sampling_freq)\n", 269 | "# it is also possible to pass an instance of Dataset and filtered data.\n", 270 | "# gumpy will then infer all necessary details from the dataset\n", 271 | "data_class_b = gumpy.utils.extract_trials(grazb_data, flt_b)\n", 272 | "\n", 273 | "# similar to other functions, this one allows to pass an entire instance of Dataset\n", 274 | "# to operate on the raw data\n", 275 | "data_class1 = gumpy.utils.extract_trials(grazb_data)\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "## Visualize the classes\n", 283 | "Given the extracted trials from above, we can proceed to visualize the average power of a class. Again, this depends on the specific data and thus you may have to adapt the function accordingly." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "# specify some cutoff values for the visualization\n", 293 | "lowcut_a, highcut_a = 14, 30\n", 294 | "# and also an interval to display\n", 295 | "interval_a = [0, 8]\n", 296 | "# visualize logarithmic power?\n", 297 | "logarithmic_power = False\n", 298 | "\n", 299 | "# visualize the extracted trial from above\n", 300 | "gumpy.plot.average_power(epochs, lowcut_a, highcut_a, interval_a, grazb_data.sampling_freq, logarithmic_power)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": { 306 | "collapsed": true 307 | }, 308 | "source": [ 309 | "## Wavelet transform\n", 310 | "``gumpy`` relies on ``pywt`` to compute wavelet transforms. Furthermore, it contains convenience functions to visualize the results of the discrete wavelet transform as shown in the example below for the Graz dataset and the classes extracted above." 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "# As with most functions, you can pass arguments to a \n", 320 | "# gumpy function that will be forwarded to the backend.\n", 321 | "# In this example the decomposition levels are mandatory, and the \n", 322 | "# mother wavelet that should be passed is optional\n", 323 | "level = 6\n", 324 | "wavelet = 'db4'\n", 325 | "\n", 326 | "# now we can retrieve the dwt for the different channels\n", 327 | "mean_coeff_ch0_c1 = gumpy.signal.dwt(data_class1[0], level=level, wavelet=wavelet)\n", 328 | "mean_coeff_ch1_c1 = gumpy.signal.dwt(data_class1[1], level=level, wavelet=wavelet)\n", 329 | "mean_coeff_ch0_c2 = gumpy.signal.dwt(data_class1[3], level=level, wavelet=wavelet)\n", 330 | "mean_coeff_ch1_c2 = gumpy.signal.dwt(data_class1[4], level=level, wavelet=wavelet)\n", 331 | "\n", 332 | "# gumpy's signal.dwt function returns the approximation of the \n", 333 | "# coefficients as first result, and all the coefficient details as list\n", 334 | "# as second return value (this is contrast to the backend, which returns\n", 335 | "# the entire set of coefficients as a single list)\n", 336 | "approximation_C3 = mean_coeff_ch0_c2[0]\n", 337 | "approximation_C4 = mean_coeff_ch1_c2[0]\n", 338 | "\n", 339 | "# as mentioned in the comment above, the list of details are in the second\n", 340 | "# return value of gumpy.signal.dwt. Here we save them to additional variables\n", 341 | "# to improve clarity\n", 342 | "details_c3_c1 = mean_coeff_ch0_c1[1]\n", 343 | "details_c4_c1 = mean_coeff_ch1_c1[1]\n", 344 | "details_c3_c2 = mean_coeff_ch0_c2[1]\n", 345 | "details_c4_c2 = mean_coeff_ch1_c2[1]\n", 346 | "\n", 347 | "# gumpy exhibits a function to plot the dwt results. You must pass three lists,\n", 348 | "# i.e. the labels of the data, the approximations, as well as the detailed coeffs,\n", 349 | "# so that gumpy can automatically generate appropriate titles and labels.\n", 350 | "# you can pass an additional class string that will be incorporated into the title.\n", 351 | "# the function returns a matplotlib axis object in case you want to further\n", 352 | "# customize the plot.\n", 353 | "gumpy.plot.dwt(\n", 354 | " [approximation_C3, approximation_C4],\n", 355 | " [details_c3_c1, details_c4_c1],\n", 356 | " ['C3, c1', 'C4, c1'], level, grazb_data.sampling_freq, 'Class: Left')\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "## DWT reconstruction and visualization\n", 364 | "Often a user wantes to reconstruct the power spectrum of a dwt and visualize the results. The functions will return a list of all the reconstructed signals as well as a handle to the figure." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "gumpy.plot.reconstruct_without_approx(\n", 374 | " [details_c3_c2[4], details_c4_c2[4]], \n", 375 | " ['C3-c1', 'C4-c1'], level=6)\n", 376 | "\n", 377 | "gumpy.plot.reconstruct_without_approx(\n", 378 | " [details_c3_c1[5], details_c4_c1[5]], \n", 379 | " ['C3-c1', 'C4-c1'], level=6)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "gumpy.plot.reconstruct_with_approx(\n", 389 | " [details_c3_c1[5], details_c4_c1[5]],\n", 390 | " ['C3', 'C4'], wavelet=wavelet)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": {}, 396 | "source": [ 397 | "## Welch's Power Spectral Density estimate\n", 398 | "Estimating the power spectral density according to Welch's method is =imilar to the power reconstruction shown above" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "# the function gumpy.plot.welch_psd returns the power densities as \n", 408 | "# well as a handle to the figure. You can also pass a figure in if you \n", 409 | "# wish to modify the plot\n", 410 | "fig = plt.figure()\n", 411 | "plt.title('Customized plot')\n", 412 | "ps, fig = gumpy.plot.welch_psd(\n", 413 | " [details_c3_c1[4], details_c3_c2[4]],\n", 414 | " ['C3 - c1', 'C4 - c1'],\n", 415 | " grazb_data.sampling_freq, fig=fig)\n", 416 | "\n", 417 | "ps, fig = gumpy.plot.welch_psd(\n", 418 | " [details_c4_c1[4], details_c4_c2[4]],\n", 419 | " ['C3 - c1', 'C4 - c1'],\n", 420 | " grazb_data.sampling_freq)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "## Alpha and Beta sub-bands\n", 428 | "Using gumpys functions you can quickly define feature extractors. The following examples will demonstrate how you can use the predefined filters" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "def alpha_subBP_features(data):\n", 438 | " # filter data in sub-bands by specification of low- and high-cut frequencies\n", 439 | " alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6)\n", 440 | " alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6)\n", 441 | " alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6)\n", 442 | " alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6)\n", 443 | "\n", 444 | " # return a list of sub-bands\n", 445 | " return [alpha1, alpha2, alpha3, alpha4]\n", 446 | "\n", 447 | "alpha_bands = np.array(alpha_subBP_features(btr_data))" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "def beta_subBP_features(data):\n", 457 | " beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6)\n", 458 | " beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6)\n", 459 | " beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6)\n", 460 | " beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6)\n", 461 | " return [beta1, beta2, beta3, beta4]\n", 462 | "\n", 463 | "beta_bands = np.array(beta_subBP_features(btr_data))" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "## Extract features without considering class information\n", 471 | "The following examples show how the sub-bands can be used to extract features. This also shows how the fields of the dataset can be accessed, and how to write methods specific to your data using a mix of gumpy's and numpy's functions." 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "### Method 1: logarithmic sub-band power" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "def powermean(data, trial, fs, w):\n", 488 | " return np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],0],2).mean(), \\\n", 489 | " np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],1],2).mean(), \\\n", 490 | " np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],2],2).mean()\n", 491 | "\n", 492 | "def log_subBP_feature_extraction(alpha, beta, trials, fs, w):\n", 493 | " # number of features combined for all trials\n", 494 | " n_features = 15\n", 495 | " # initialize the feature matrix\n", 496 | " X = np.zeros((len(trials), n_features))\n", 497 | " \n", 498 | " # Extract features\n", 499 | " for t, trial in enumerate(trials):\n", 500 | " power_c31, power_c41, power_cz1 = powermean(alpha[0], trial, fs, w)\n", 501 | " power_c32, power_c42, power_cz2 = powermean(alpha[1], trial, fs, w)\n", 502 | " power_c33, power_c43, power_cz3 = powermean(alpha[2], trial, fs, w)\n", 503 | " power_c34, power_c44, power_cz4 = powermean(alpha[3], trial, fs, w)\n", 504 | " power_c31_b, power_c41_b, power_cz1_b = powermean(beta[0], trial, fs, w)\n", 505 | " \n", 506 | " X[t, :] = np.array(\n", 507 | " [np.log(power_c31), np.log(power_c41), np.log(power_cz1),\n", 508 | " np.log(power_c32), np.log(power_c42), np.log(power_cz2),\n", 509 | " np.log(power_c33), np.log(power_c43), np.log(power_cz3), \n", 510 | " np.log(power_c34), np.log(power_c44), np.log(power_cz4),\n", 511 | " np.log(power_c31_b), np.log(power_c41_b), np.log(power_cz1_b)])\n", 512 | "\n", 513 | " return X" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "if False:\n", 523 | " w1 = [0,125]\n", 524 | " w2 = [125,256]\n", 525 | " w3 = [256,512]\n", 526 | " w4 = [512,512+256]\n", 527 | " \n", 528 | " # extract the features\n", 529 | " features1 = log_subBP_feature_extraction(\n", 530 | " alpha_bands, beta_bands, \n", 531 | " grazb_data.trials, grazb_data.sampling_freq,\n", 532 | " w1)\n", 533 | " features2 = log_subBP_feature_extraction(\n", 534 | " alpha_bands, beta_bands, \n", 535 | " grazb_data.trials, grazb_data.sampling_freq,\n", 536 | " w2) \n", 537 | " features3 = log_subBP_feature_extraction(\n", 538 | " alpha_bands, beta_bands, \n", 539 | " grazb_data.trials, grazb_data.sampling_freq,\n", 540 | " w3) \n", 541 | " features4 = log_subBP_feature_extraction(\n", 542 | " alpha_bands, beta_bands, \n", 543 | " grazb_data.trials, grazb_data.sampling_freq,\n", 544 | " w4) \n", 545 | " print(features4.shape)\n", 546 | "\n", 547 | " # concatenate and normalize the features\n", 548 | " features = np.concatenate((features1, features2, features3, features4), axis=1)\n", 549 | " features -= np.mean(features)\n", 550 | " features = gumpy.signal.normalize(features, 'mean_std')\n", 551 | " features = gumpy.signal.normalize(features, 'min_max')\n", 552 | "\n", 553 | " # print shape to quickly check if everything is as expected\n", 554 | " print(features.shape)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": {}, 560 | "source": [ 561 | "### Method 2: Discrete Wavelet Transform (DWT)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "def dwt_features(data, trials, level, sampling_freq, w, n, wavelet): \n", 571 | " import pywt\n", 572 | " \n", 573 | " # number of features per trial\n", 574 | " n_features = 9 \n", 575 | " # allocate memory to store the features\n", 576 | " X = np.zeros((len(trials), n_features))\n", 577 | "\n", 578 | " # Extract Features\n", 579 | " for t, trial in enumerate(trials):\n", 580 | " signals = data[trial + fs*4 + (w[0]) : trial + fs*4 + (w[1])]\n", 581 | " coeffs_c3 = pywt.wavedec(data = signals[:,0], wavelet=wavelet, level=level)\n", 582 | " coeffs_c4 = pywt.wavedec(data = signals[:,1], wavelet=wavelet, level=level)\n", 583 | " coeffs_cz = pywt.wavedec(data = signals[:,2], wavelet=wavelet, level=level)\n", 584 | "\n", 585 | " X[t, :] = np.array([\n", 586 | " np.std(coeffs_c3[n]), np.mean(coeffs_c3[n]**2), \n", 587 | " np.std(coeffs_c4[n]), np.mean(coeffs_c4[n]**2),\n", 588 | " np.std(coeffs_cz[n]), np.mean(coeffs_cz[n]**2), \n", 589 | " np.mean(coeffs_c3[n]),\n", 590 | " np.mean(coeffs_c4[n]), \n", 591 | " np.mean(coeffs_cz[n])])\n", 592 | " \n", 593 | " return X\n" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "# We'll work with the data that was postprocessed using a butter bandpass\n", 603 | "# filter further above\n", 604 | "\n", 605 | "# to see it work, enable here. We'll use the log-power features further\n", 606 | "# below, though\n", 607 | "if False:\n", 608 | " w = [0, 256]\n", 609 | " \n", 610 | " # extract the features\n", 611 | " trials = grazb_data.trials\n", 612 | " fs = grazb_data.sampling_freq\n", 613 | " features1= np.array(dwt_features(btr_data, trials, 5, fs, w, 3, \"db4\"))\n", 614 | " features2= np.array(dwt_features(btr_data, trials, 5, fs, w, 4, \"db4\"))\n", 615 | "\n", 616 | " # concatenate and normalize the features \n", 617 | " features = np.concatenate((features1, features2), axis=1)\n", 618 | " features -= np.mean(features)\n", 619 | " features = gumpy.signal.normalize(features, 'min_max')\n", 620 | " \n", 621 | " print(features.shape)" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": {}, 627 | "source": [ 628 | "### Split the data\n", 629 | "Now that we extracted features (and reduced the dimensionality), we can split the data for \n", 630 | "test and training purposes." 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "# gumpy exposes several methods to split a dataset, as shown in the examples:\n", 640 | "if 1: \n", 641 | " split_features = np.array(gumpy.split.normal(features, labels,test_size=0.2))\n", 642 | " X_train = split_features[0]\n", 643 | " X_test = split_features[1]\n", 644 | " Y_train = split_features[2]\n", 645 | " Y_test = split_features[3]\n", 646 | " X_train.shape\n", 647 | "if 0: \n", 648 | " n_splits=5\n", 649 | " split_features = np.array(gumpy.split.time_series_split(features, labels, n_splits)) \n", 650 | "if 0: \n", 651 | " split_features = np.array(gumpy.split.normal(PCA, labels, test_size=0.2))\n", 652 | " \n", 653 | "#ShuffleSplit: Random permutation cross-validator \n", 654 | "if 0: \n", 655 | " split_features = gumpy.split.shuffle_Split(features, labels, n_splits=10,test_size=0.2,random_state=0)\n", 656 | " \n", 657 | "# #Stratified K-Folds cross-validator\n", 658 | "# #Stratification is the process of rearranging the data as to ensure each fold is a good representative of the whole \n", 659 | "if 0: \n", 660 | " split_features = gumpy.split.stratified_KFold(features, labels, n_splits=3)\n", 661 | " \n", 662 | "#Stratified ShuffleSplit cross-validator \n", 663 | "#Repeated random sub-sampling validation\n", 664 | "if 0: \n", 665 | " split_features = gumpy.split.stratified_shuffle_Split(features, labels, n_splits=10,test_size=0.3,random_state=0)\n", 666 | "\n", 667 | "\n", 668 | "# # the functions return a list with the data according to the following example\n", 669 | "# X_train = split_features[0]\n", 670 | "# X_test = split_features[1]\n", 671 | "# Y_train = split_features[2]\n", 672 | "# Y_test = split_features[3]\n", 673 | "# X_train.shape" 674 | ] 675 | }, 676 | { 677 | "cell_type": "markdown", 678 | "metadata": {}, 679 | "source": [ 680 | "## Extract features considering class information" 681 | ] 682 | }, 683 | { 684 | "cell_type": "markdown", 685 | "metadata": {}, 686 | "source": [ 687 | "### Method 3: Common Spatial Patterns (CSP)\n", 688 | "#### I - Common prerequesites" 689 | ] 690 | }, 691 | { 692 | "cell_type": "code", 693 | "execution_count": null, 694 | "metadata": {}, 695 | "outputs": [], 696 | "source": [ 697 | "def extract_features(epoch, spatial_filter):\n", 698 | " feature_matrix = np.dot(spatial_filter, epoch)\n", 699 | " variance = np.var(feature_matrix, axis=1)\n", 700 | " if np.all(variance == 0):\n", 701 | " return np.zeros(spatial_filter.shape[0])\n", 702 | " features = np.log(variance/np.sum(variance))\n", 703 | " return features\n", 704 | "\n", 705 | "def select_filters(filters, num_components = None, verbose = False):\n", 706 | " if verbose:\n", 707 | " print(\"Incomming filters:\", filters.shape, \"\\nNumber of components:\", num_components if num_components else \" all\")\n", 708 | " if num_components == None:\n", 709 | " return filters\n", 710 | " assert num_components <= filters.shape[0]/2, \"The requested number of components is too high\"\n", 711 | " selection = list(range(0, num_components)) + list(range(filters.shape[0] - num_components, filters.shape[0]))\n", 712 | " reduced_filters = filters[selection,:]\n", 713 | " return reduced_filters\n", 714 | "\n", 715 | "# Select the number of used spatial components\n", 716 | "n_components = None; # assign None for all components\n", 717 | "\n", 718 | "# Rearrange epochs to (trials x channels x timesteps)\n", 719 | "epochs1 = np.swapaxes(np.array([epochs[0],epochs[1],epochs[2]]),1,0)\n", 720 | "epochs2 = np.swapaxes(np.array([epochs[3],epochs[4],epochs[5]]),1,0)\n", 721 | "print(\"Number of trials per class:\",epochs1.shape[0],\"|\",epochs2.shape[0])\n", 722 | "#print(epochs1.shape)\n", 723 | "#print(epochs2.shape)\n", 724 | "\n", 725 | "# Remove invalid epochs\n", 726 | "invalid_entries_1 = np.where([np.all(epoch == 0) for epoch in epochs1])\n", 727 | "invalid_entries_2 = np.where([np.all(epoch == 0) for epoch in epochs2])\n", 728 | "print(\"Number of invalid trials per class:\",len(invalid_entries_1[0]),\"|\",len(invalid_entries_2[0]))\n", 729 | "#print(invalid_entries_1)\n", 730 | "#print(invalid_entries_2)\n", 731 | "epochs1 = np.delete(epochs1, invalid_entries_1, 0)\n", 732 | "epochs2 = np.delete(epochs2, invalid_entries_2, 0)\n", 733 | "print(\"Number of trials per class after cleanup:\",epochs1.shape[0],\"|\",epochs2.shape[0])\n", 734 | "\n", 735 | "# Concatenate the trials\n", 736 | "epochs_re = np.concatenate((epochs1, epochs2), axis=0)\n", 737 | "print(\"Dataset:\", epochs_re.shape)\n", 738 | "\n", 739 | "# Update the label vector\n", 740 | "labels = np.ones(epochs_re.shape[0])\n", 741 | "labels[:epochs1.shape[0]] = 0\n", 742 | "\n", 743 | "# Split data\n", 744 | "epochs_train, epochs_test, y_train, y_test = gumpy.split.stratified_shuffle_Split(epochs_re, labels, n_splits=10,test_size=0.2,random_state=0)\n", 745 | "print(\"Training data:\", epochs_train.shape, \"Testing data:\", epochs_test.shape)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": {}, 751 | "source": [ 752 | "#### II A - Standard implementation" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": {}, 759 | "outputs": [], 760 | "source": [ 761 | "if False: \n", 762 | " # Generate the spacial filterns for the training data\n", 763 | " temp_filters = np.asarray(gumpy.features.CSP(epochs_train[np.where(y_train==0)], epochs_train[np.where(y_train==1)]))\n", 764 | " #print(temp_filters.shape)\n", 765 | " spatial_filter = select_filters(temp_filters[0], n_components)\n", 766 | " #print(spatial_filter.shape)\n", 767 | " \n", 768 | " # Extract the CSPs\n", 769 | " features_train = np.array([extract_features(epoch, spatial_filter) for epoch in epochs_train])\n", 770 | " #print(features_train.shape) \n", 771 | " \n", 772 | " features_test = np.array([extract_features(epoch, spatial_filter) for epoch in epochs_test])" 773 | ] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": {}, 778 | "source": [ 779 | "#### II B - Filter Bank CSP" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": null, 785 | "metadata": {}, 786 | "outputs": [], 787 | "source": [ 788 | "if True:\n", 789 | " # Apply the filter bank\n", 790 | " freqs = [[1,4],[4,8],[8,13],[13,22],[22,30]] # delta, theta, alpha, low beta, high beta \n", 791 | " #freqs = [[1,4],[4,8],[8,13],[13,22],[22,30],[1,30]]\n", 792 | " #freqs = [[1,4],[4,8],[8,12],[12,16],[16,20],[20,24],[24,28],[28,32]]\n", 793 | " #freqs = [[1,3],[3,5],[5,7],[7,9],[9,11],[11,13],[13,15],[15,17],[17,19],[19,21],[21,23],[23,25],[25,27],[27,29],[29,31],[31,33],[33,35]]\n", 794 | " x_train_fb = [gumpy.signal.butter_bandpass(epochs_train, lo=f[0], hi=f[1]) for f in freqs]\n", 795 | " x_test_fb = [gumpy.signal.butter_bandpass(epochs_test, lo=f[0], hi=f[1]) for f in freqs]\n", 796 | "\n", 797 | " # Generate the spatial filters for the training data\n", 798 | " temp_filters = [np.asarray(gumpy.features.CSP(x_train_fb[f][np.where(y_train==0)], x_train_fb[f][np.where(y_train==1)])) for f in range(len(freqs))]\n", 799 | " if n_components is not None:\n", 800 | " spatial_filters = [select_filters(temp_filters[f][0], n_components) for f in range(len(freqs))]\n", 801 | " else:\n", 802 | " spatial_filters = [temp_filters[f][0] for f in range(len(freqs))]\n", 803 | "\n", 804 | " # Extract the CSPs\n", 805 | " features_train_tmp = [np.array([extract_features(epoch, spatial_filters[f]) for epoch in x_train_fb[f]]) for f in range(len(freqs))]\n", 806 | " features_train = np.concatenate(features_train_tmp, axis=1)\n", 807 | " features_test_tmp = [np.array([extract_features(epoch, spatial_filters[f]) for epoch in x_test_fb[f]]) for f in range(len(freqs))]\n", 808 | " features_test = np.concatenate(features_test_tmp, axis=1)\n", 809 | "\n", 810 | " print(features_train.shape)\n", 811 | " print(features_test.shape)" 812 | ] 813 | }, 814 | { 815 | "cell_type": "markdown", 816 | "metadata": {}, 817 | "source": [ 818 | "#### III - Common postprocessing" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": null, 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [ 827 | "if True:\n", 828 | " # Feature normalization\n", 829 | " features_train = gumpy.signal.normalize(features_train, 'mean_std')\n", 830 | " features_test = gumpy.signal.normalize(features_test, 'mean_std')\n", 831 | " X_train = gumpy.signal.normalize(features_train, 'min_max')\n", 832 | " X_test = gumpy.signal.normalize(features_test, 'min_max')\n", 833 | " Y_train = y_train\n", 834 | " Y_test = y_test\n", 835 | " \n", 836 | "# Debugging output\n", 837 | "#usable_epochs = np.concatenate((epochs1, epochs2), axis=0)\n", 838 | "#print(usable_epochs.shape)\n", 839 | "#print(features.shape)" 840 | ] 841 | }, 842 | { 843 | "cell_type": "markdown", 844 | "metadata": {}, 845 | "source": [ 846 | "# Select features " 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": {}, 852 | "source": [ 853 | "## Sequential Feature Selection Algorithm\n", 854 | "``gumpy`` provides a generic function with which you can select features. For a list of the implemented selectors please have a look at the function documentation.\n" 855 | ] 856 | }, 857 | { 858 | "cell_type": "code", 859 | "execution_count": null, 860 | "metadata": {}, 861 | "outputs": [], 862 | "source": [ 863 | "if False:\n", 864 | " Results = []\n", 865 | " classifiers = []\n", 866 | " Accuracy=[]\n", 867 | " Final_results = {}\n", 868 | " for model in gumpy.classification.available_classifiers:\n", 869 | " print (model)\n", 870 | " feature_idx, cv_scores, algorithm,sfs, clf = gumpy.features.sequential_feature_selector(X_train, Y_train, model,(6, 10), 5, 'SFFS')\n", 871 | " classifiers.append(model)\n", 872 | " Accuracy.append (cv_scores*100)\n", 873 | " Final_results[model]= cv_scores*100\n", 874 | " print (Final_results)" 875 | ] 876 | }, 877 | { 878 | "cell_type": "markdown", 879 | "metadata": {}, 880 | "source": [ 881 | "## PCA \n", 882 | "``gumpy`` provides a wrapper around sklearn to reduce the dimensionality via PCA in a straightfoward manner." 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": null, 888 | "metadata": {}, 889 | "outputs": [], 890 | "source": [ 891 | "PCA = gumpy.features.PCA_dim_red(features, 0.95)" 892 | ] 893 | }, 894 | { 895 | "cell_type": "markdown", 896 | "metadata": {}, 897 | "source": [ 898 | "## Plotting features\n", 899 | "``gumpy`` wraps 3D plotting of features into a single line" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": {}, 906 | "outputs": [], 907 | "source": [ 908 | " gumpy.plot.PCA(\"3D\", features, split_features[0], split_features[2])" 909 | ] 910 | }, 911 | { 912 | "cell_type": "markdown", 913 | "metadata": {}, 914 | "source": [ 915 | "# Validation and test accuracies " 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": null, 921 | "metadata": {}, 922 | "outputs": [], 923 | "source": [ 924 | "#SVM, RF, KNN, NB, LR, QLDA, LDA\n", 925 | "from sklearn.cross_validation import cross_val_score\n", 926 | "feature_idx, cv_scores, algorithm,sfs, clf = gumpy.features.sequential_feature_selector(X_train, Y_train, 'RandomForest',(6, 10), 10, 'SFFS')\n", 927 | "\n", 928 | "feature=X_train[:,feature_idx]\n", 929 | "scores = cross_val_score(clf, feature, Y_train, cv=10)\n", 930 | "\n", 931 | "\n", 932 | "print(\"Validation Accuracy: %0.2f (+/- %0.2f)\" % (scores.mean(), scores.std()))\n", 933 | "clf.fit(feature, Y_train)\n", 934 | "feature1=X_test[:,feature_idx]\n", 935 | "clf.predict(feature1)\n", 936 | "f=clf.score(feature1, Y_test)\n", 937 | "print(\"Test Accuracy:\",f )" 938 | ] 939 | }, 940 | { 941 | "cell_type": "markdown", 942 | "metadata": {}, 943 | "source": [ 944 | "# Voting classifiers " 945 | ] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "metadata": {}, 950 | "source": [ 951 | "Because `gumpy.classification.vote` uses `sklearn.ensemble.VotingClassifier` as backend, it is possible to specify different methods for the voting such as 'soft'. In addition, the method can be told to first extract features via `mlxtend.feature_selection.SequentialFeatureSelector` before classification." 952 | ] 953 | }, 954 | { 955 | "cell_type": "code", 956 | "execution_count": null, 957 | "metadata": {}, 958 | "outputs": [], 959 | "source": [ 960 | "if True:\n", 961 | " result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'soft', False, (6,12))\n", 962 | " print(\"Classification result for hard voting classifier\")\n", 963 | " print(result)\n", 964 | " print(\"Accuracy: \", result.accuracy)" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": null, 970 | "metadata": {}, 971 | "outputs": [], 972 | "source": [ 973 | "if True:\n", 974 | " result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'hard', False, (6,12))\n", 975 | " print(\"Classification result for hard voting classifier\")\n", 976 | " print(result)\n", 977 | " print(\"Accuracy: \", result.accuracy)" 978 | ] 979 | }, 980 | { 981 | "cell_type": "markdown", 982 | "metadata": { 983 | "collapsed": true 984 | }, 985 | "source": [ 986 | "## Voting Classifier with feature selection\n", 987 | "`gumpy` allows to automatically use all classifiers that are known in `gumpy.classification.available_classifiers` in a voting classifier (for more details see `sklearn.ensemble.VotingClassifier`). In case you developed a custom classifier and registered it using the `@register_classifier` decorator, it will be automatically used as well." 988 | ] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": null, 993 | "metadata": {}, 994 | "outputs": [], 995 | "source": [ 996 | "result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'soft', True, (6,12))\n", 997 | "print(\"Classification result for soft voting classifier\")\n", 998 | "print(result)\n", 999 | "print(\"Accuracy: \", result.accuracy)" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "markdown", 1004 | "metadata": {}, 1005 | "source": [ 1006 | "## Classification without feature selection " 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "code", 1011 | "execution_count": null, 1012 | "metadata": {}, 1013 | "outputs": [], 1014 | "source": [ 1015 | "if False:\n", 1016 | " Results = []\n", 1017 | " classifiers = []\n", 1018 | " Accuracy=[]\n", 1019 | " Final_results = {}\n", 1020 | " for model in gumpy.classification.available_classifiers:\n", 1021 | " results, clf = gumpy.classify(model, X_train, Y_train, X_test, Y_test)\n", 1022 | " print (model)\n", 1023 | " print (results)\n", 1024 | " classifiers.append(model)\n", 1025 | " Accuracy.append (results.accuracy) \n", 1026 | " Final_results[model]= results.accuracy\n", 1027 | " print (Final_results)" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "markdown", 1032 | "metadata": {}, 1033 | "source": [ 1034 | "## Confusion Matrix\n", 1035 | "One of the ideas behind ``gumpy`` is to provide users the means to quickly examine their data. Therefore, gumpy provides mostly wraps existing libraries. This allows to show data with ease, and still be able to modify the plots in any way the underlying libraries allow:" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "code", 1040 | "execution_count": null, 1041 | "metadata": {}, 1042 | "outputs": [], 1043 | "source": [ 1044 | "#Method 1\n", 1045 | "gumpy.plot.confusion_matrix(Y_test, results.pred)\n", 1046 | "\n", 1047 | "#Method 2\n", 1048 | "gumpy.plot. plot_confusion_matrix(path='...', cm= , normalize = False, target_names = ['...' ], title = \"...\")" 1049 | ] 1050 | } 1051 | ], 1052 | "metadata": { 1053 | "kernelspec": { 1054 | "display_name": "Python 3", 1055 | "language": "python", 1056 | "name": "python3" 1057 | }, 1058 | "language_info": { 1059 | "codemirror_mode": { 1060 | "name": "ipython", 1061 | "version": 3 1062 | }, 1063 | "file_extension": ".py", 1064 | "mimetype": "text/x-python", 1065 | "name": "python", 1066 | "nbconvert_exporter": "python", 1067 | "pygments_lexer": "ipython3", 1068 | "version": "3.6.2" 1069 | } 1070 | }, 1071 | "nbformat": 4, 1072 | "nbformat_minor": 2 1073 | } 1074 | -------------------------------------------------------------------------------- /examples/notebooks/EMG-script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Preparation \n", 8 | "## Append to path and import\n", 9 | "In case gumpy is not installed as package, you may have to specify the path to the gumpy directory" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 5, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%matplotlib inline\n", 19 | "import matplotlib.pyplot as plt \n", 20 | "\n", 21 | "import sys, os, os.path\n", 22 | "sys.path.append('../..')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## import gumpy\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "This may take a while, as gumpy as several dependencies that will be loaded automatically" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import numpy as np\n", 46 | "import gumpy\n", 47 | "import warnings\n", 48 | "warnings.filterwarnings(\"ignore\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# Import data\n", 56 | "To import data, you have to specify the directory in which your data is stored in. For the example given here, the data is in the subfolder ``../EEG-Data/Graz_data/data``. \n", 57 | "Then, one of the classes that subclass from ``dataset`` can be used to load the data. In the example, we will use the GrazB dataset, for which ``gumpy`` already includes a corresponding class. If you have different data, simply subclass from ``gumpy.dataset.Dataset``." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 6, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "" 69 | ] 70 | }, 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "# First specify the location of the data and some \n", 78 | "# identifier that is exposed by the dataset (e.g. subject)\n", 79 | "base_dir = '../..'\n", 80 | "subject = 'S1'\n", 81 | "\n", 82 | "# The next line first initializes the data structure. \n", 83 | "# Note that this does not yet load the data! In custom implementations\n", 84 | "# of a dataset, this should be used to prepare file transfers, \n", 85 | "# for instance check if all files are available, etc.\n", 86 | "data_low = gumpy.data.NST_EMG(base_dir, subject, 'low')\n", 87 | "data_high = gumpy.data.NST_EMG(base_dir, subject, 'high') \n", 88 | "\n", 89 | "# Finally, load the dataset\n", 90 | "data_low.load()\n", 91 | "data_high.load()\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "The abstract class allows to print some information about the contained data. This is a commodity function that allows quick inspection of the data as long as all necessary fields are provided in the subclassed variant." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 7, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "Data identification: NST_EMG-S1\n", 111 | "EMG-data shape: (217613, 8)\n", 112 | "Trials data shape: (36,)\n", 113 | "Labels shape: (48,)\n", 114 | "Total length of single trial: 10\n", 115 | "Sampling frequency of EMG data: 512\n", 116 | "Interval for motor imagery in trial: [5, 10]\n", 117 | "Classes possible: [0. 1. 2. 3.]\n", 118 | "----------\n", 119 | "Data identification: NST_EMG-S1\n", 120 | "EMG-data shape: (217618, 8)\n", 121 | "Trials data shape: (36,)\n", 122 | "Labels shape: (48,)\n", 123 | "Total length of single trial: 10\n", 124 | "Sampling frequency of EMG data: 512\n", 125 | "Interval for motor imagery in trial: [5, 10]\n", 126 | "Classes possible: [0. 1. 2. 3.]\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "data_low.print_stats()\n", 132 | "print('----------')\n", 133 | "data_high.print_stats()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "# Signal Filtering" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "#bandpass\n", 150 | "lowcut=20\n", 151 | "highcut=255\n", 152 | "#notch\n", 153 | "f0=50\n", 154 | "Q=50\n", 155 | "\n", 156 | "flt_low = gumpy.signal.butter_bandpass(data_low, lowcut, highcut)\n", 157 | "flt_low = gumpy.signal.notch(flt_low, cutoff=f0, Q=Q)\n", 158 | "\n", 159 | "trialsLow = gumpy.utils.getTrials(data_low, flt_low)\n", 160 | "trialsLowBg = gumpy.utils.getTrials(data_low, flt_low, True)\n", 161 | "\n", 162 | "flt_high = gumpy.signal.butter_bandpass(data_high, lowcut, highcut)\n", 163 | "flt_high = gumpy.signal.notch(flt_high, cutoff=f0, Q=Q)\n", 164 | "\n", 165 | "trialsHigh = gumpy.utils.getTrials(data_high, flt_high)\n", 166 | "trialsHighBg = gumpy.utils.getTrials(data_high, flt_high, True)\n", 167 | "\n" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "# Data Visualization " 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 9, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD8CAYAAACVZ8iyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3Xd8FGX+B/DPNwkhdELvhI4gPdJFelc4K1hQT8U70Tv11MOOYkFOPQ8LHth/iuLZQAHpRaUZeofQCSV0kBaSPL8/dnYzuzuzvc/n/XrxYnd2ZvaZze7znaeLUgpERGRtSdFOABERRR+DARERMRgQERGDARERgcGAiIjAYEBERGAwICIiMBgQEREYDIiICEBKtBPgq0qVKqmMjIxoJ4OIKG6sWrXqmFKqsi/7xk0wyMjIQFZWVrSTQUQUN0Rkr6/7spqIiIgYDIiIiMGAiIjAYEBERGAwICIiMBgQEREYDIiICAwGlnLqfB5mrD8U7WQQUQxiMLCQUVNWY9SU1cg5dSHaSSGiGMNgYCE5J21BIC+/MMopIaJYw2BAREQMBkRExGBgSUqpaCeBiGIMg4GFiEi0k0BEMYrBwEJYIiAiMwwGFsQSAhG5YjCwIJYQiMgVg4GFsERARGYYDIiIiMGAiIhCFAxE5CMRyRWRjbptY0QkR0TWav8G6l57UkSyRWSbiPQLRRrIO7YVEJGZUJUMPgHQ32D7v5VSrbV/MwFARJoBGAaguXbMeyKSHKJ0UIzJLyjEJ7/txuUCzodEFMtCEgyUUksAnPBx9yEAvlJKXVJK7QaQDaB9KNJBnkWjAfmLFfsw5sfN+PDX3RF/byLyXbjbDB4UkfVaNVK6tq0mgP26fQ5o29yIyEgRyRKRrKNHj4Y5qeSrLYfOoMkzs3D49EWv+569eNnpfyKKTeEMBhMBNADQGsAhAG/4ewKl1CSlVKZSKrNy5cqhTh8F6LNle3EpvxDztx6JdlKIKETCFgyUUkeUUgVKqUIAk1FUFZQDoLZu11raNooQNiMTkauwBQMRqa57+icA9p5G0wEME5HiIlIPQCMAK8OVDiIi8i4lFCcRkS8BdAdQSUQOAHgeQHcRaQ3bjegeAPcDgFJqk4h8DWAzgHwAo5RSBaFIB/kmGuOQ2auVKLaFJBgopYYbbP7Qw/4vA3g5FO9N0eNLBs8pMIjiA0cgW1CwN+n+5O8c6EYUHxgMLCSa9+gsIBDFNgYD8htv9okSD4NBgth88AwmL9nlcR/m4URkhsEgBq3edxKdXp2PM36M2h044Re8PHNLGFNVJJAqH5YmiGIbg0EMenPOdhw6fRHr9p8K6XmjUW3P3kRE8YHBgIiIGAwocL7U/LBrKVF8YDAgvwVS8cPaIqLYxmBgQcHerPNenyjxMBgksBnrD2HZzuNhO78/N/usLSKKbSGZm4jCI9gMdNSU1QCAPeMGOW0PVZUN83eixMGSQRzILyjElBX7UFAYuez3Ql4BPvp1NwoN3tOfWPL6nO2hSxQRhQ1LBnHgo99245WZW1GoFG7vWDfo8/lS4nh9zjZ8+OtuVClbHINb1gj6PYkotrFkEAdOnreNRD59Ich1hP24pbe/1/k8LjVBZAUMBmTIp/YKtgoTJQwGgxgUtj75AeTdHB5AZA0MBhQ4jiQjShgMBlbCvJuITDAYJJhQzQWkDOqU9p84j/yCwpCc3274pOV4Y862kJ6TiPzHYBAHortcpe3dj5y5iKvHL8S4WVuLXgxB4Fm26zjeXpAd9HmIKDgMBjHINY8NfZ8d/894/I88AMCv2cfYVECUgEISDETkIxHJFZGNum0VRGSuiOzQ/k/XtouITBCRbBFZLyJtQ5EGsglXb0+jaiMiShyhKhl8AqC/y7bRAOYrpRoBmK89B4ABABpp/0YCmBiiNCQM1zvvQG7EL3us2/fhjCZ5P1cuI0pMIQkGSqklAE64bB4C4FPt8acAhuq2f6ZslgMoLyLVQ5GORFdYqLD/xHmP+9jz8Nf0dftBiFbWf8eHK5AxekaU3p3IesLZZlBVKXVIe3wYQFXtcU0A+3X7HdC2kRfvL96Jq8cvRHbuWa/7bjvifZ9gPTttk3ODcgj9suNYWM5LRMYi0oCsbP0d/a50FpGRIpIlIllHjx4NQ8riyzltnqCcUxeDPJP/9f9mbRHvL94ZZFqIKBaEMxgcsVf/aP/nattzANTW7VdL2+ZGKTVJKZWplMqsXLlyGJOaODyNM4hmbT+bn4liWziDwXQAd2qP7wQwTbd9hNarqCOA07rqJNIZ8dFKrNl30m17oAPL/DmKmTeRtYSqa+mXAJYBaCIiB0TkHgDjAPQRkR0AemvPAWAmgF0AsgFMBvBAKNKQqMb+tNmv/X3LxIvKCHn5hZi2NsfnACMAJIAyBvsgEcW2kCxuo5QabvJSL4N9FYBRoXhfK1i97xQ61q8Y4rMWZfxvzduO9xbtRMnUFPRpVtXDMUSUyDgCOU7dOHEpXp25xa9jjO7OD5+xNUabLZxjNKwgkAFof1zK9/sYIoocBoM4YJQhZ+09if8u2eW2PdARyGN/2oy8/KKBav60S3y5ch8yRs/A+TzzDP+zZXsDSxgRRQSDQRwIJIP3tZ++vf7/9IXL+GGtYacu5/3Fvc1g4iJb99KjZy/5mUrnc/zzm/UBH09EwWEwiEPbwzSgrKDQPOp4CkihmLfotZ+3YmrWfu87ElFYMBjEINeM17Wa6JWZ5qN+g8mYCw1yfH+mIgqklxERxQYGA4vTZ/b6WODXmIQQDkr4z7wdOHvRuDE7GL3eWIQWz88O+XmJEgWDQQwK5u7e14y5oFDhs2V7/G40FjEvLVzML8Dz0zZ67TmUnXsWd3y4AhcvF7i99u952z2WfAK18+g5nGWPJiJTIRlnQPFlQ85pPDhlDbYedm578DUEmcWM/1u2F/+3fC9KFU/BE/2bmh4/Zvpm/Jp9DL/vcZ3o1uaCh15JRBQeDAYWYl+L4JGp6wxfLzRoQPanHSBfO77ASwnDW8knmlNhHDlzEVXLpkUxBUTRwWqiGBS21cq8ZtI2Ww6dwdzNR1xe872nUbANyeG6fm/mbj6CDq/Mx6Jtud53JkowDAZxwJ/M1ai+fvKSXWj09EzsPHrO47H2gsGA//yC83nO9fn2DNpTWuzjDIJdDM3fWJBfUIilO4Nf/2DtftukgBsOnA76XETxhsEghiilkJdf6HZn7E+DstHEdi/P3ILLBd7P4fNkdSaZ/bwtR4xfMDuPD0HuX7OLGpMPnDyPV2ZucavOmjB/B26dvAIrdxu3QfibHs7YSlbEYBBDPlu2F42fmYUjZwNfvMZsjqFICuVog3cXFi2e8/ev1mLSkl1Yn+N8574j9w8AxiOgj5y5iM0Hz/j0XvYgF61qKqJoYgNyDJm+7iAAuK1zHM3BXA9PXYuhbYpWJQ22CkjPbPoLsxKKfYS00eA4M13GLXA0bHvDIXNkZSwZxAFvvXOc9i1UeMnPNRC8UcpzRdWxP/ybk8h+Od+sOmD8uslxSVpubdTryYyvgcD5/Vk0IOthySCG2O9MXfN++0Rwvvhlx7GAF5Nfve8kHvpyjdv271bn4Kf1B02Pu6QbuAbYSg/T1uagTFqAXy+TvDhJK5aY5e9Bl1q0E7CaiKyIwSAG+ZsXnQnR9A0zNxw23L7l0Bks3HbU5/MIBH//aq3b9nmbj6B3EAvoJDky6/Dk1o5gHJazE8U2VhPFucJChZZj5oT1PfIKiu78fbn5NrtDv/ezLGTnng34ztt+Xk/VZvM2H8GP68xLMb6cn8iKWDKIQf7c+frTnhAofR557I+8oBaqOXvR+1QTMzYcMtyeZFCNszHnNGZtLCrR3PtZFgDg2lY1Ak4j64nIilgyiCGOro3+HBOWlDj7v+VFmX/OqQtBnSuA9lzsPGrrOmrU9XPC/B1BpUeP4wzIyhgMYogjM/IjN9JX4YSLvxm4pwB1w8SluJjvPlupJ33/vQSAvgG5KEGekubaRRdwDmyukjjOgCyMwSAGLNh6BBmjZzgWp/fH2J+2hCFFwXnXS++nTTm+DQKzs48v8FZycg1CV49f6LbPsz9sNF0pruj8jAZkPQwGMWDKCttyj/sM7mS92XPM83xD0eBp+cxgiEFvokDu4vPyPZemWDIgKwp7MBCRPSKyQUTWikiWtq2CiMwVkR3a/+nhTkeiikQ1UaywV+P8+ZPfdVuNc+6nvt/g9/kdwcbvI4niX6RKBj2UUq2VUpna89EA5iulGgGYrz2nAKzaezLaSfBboAHMXg3kS8Fjyop95ucRYNysrbhP63lERNHrWjoEQHft8acAFgH4Z5TSQnHg21UHDAe+BVql8/5i30d1E1lBJEoGCsAcEVklIiO1bVWVUvbO5IcBBD4slSzhH/8zXp0tVFU6X6zYi5km4xuIrCASJYOuSqkcEakCYK6IOK12rpRSImL4m9aCx0gAqFOnTvhT6gOlFA6fuYjq5UpEOykUQk9/v9HxmA3IZEVhLxkopXK0/3MBfA+gPYAjIlIdALT/DdcZVEpNUkplKqUyK1euHO6k+uTzFfvQ6dUF2JjD1bDilbcpwRUUzl3Kx+HTga8rQRRvwhoMRKSUiJSxPwbQF8BGANMB3KntdieAaeFMRygt33kcALA7Brt0WsnqfSfdAnIo5xa6YeJSdHx1fuhOSBTjwl1NVBXA91qXvRQAU5RSP4vI7wC+FpF7AOwFcHOY0xEy9gFJnNQsuq5/bykAoEcT/0uMvvztth42HphGlKjCGgyUUrsAtDLYfhxAr3C+d7j4sjA8RY5+8Zq3F2SH/PxKKcf4A6JExhHIAWL+EBt+yy5ayGeTj2sd++OFH22rxi3clotsba1lokTEYOCnopIBxYJwzHxx5kLRNNufLN0DALj749/R+83FoX8zohjBYOAnthkkvi9Xmo9etrqvVu7DsEnLop0MCgMGAz+Fug96zqkLlppfKBbM3Xwk4GMLCxXO53lfoCdRjf5uA5bvOhHtZARkxa7jWLTNsBc7Ji7aiaw98XldocJgELDgiwar9p5Al3ELsGS77+sLU/DenLvd8fiSn2srvDJzC5o9NxsXL/t3XDxbuDUX+QY3LAu2HjHcHmuOnr2Ehdtyccuk5bjr498N93nt56248X1rl3gYDPxkLxiEoprob1+6LxpPkeXvwLL/rToAALiQZ41gsGhbLu7+5He3nlpLth/Fnz/JCulKc/5atfcENhzwPvjzlknLcLdJEPDVgq1HcPX4BX7fPMQTBgM/sQE5sazZd8qv/e1rIVilzejo2UsAgP0nndfaOPaHbfu0dQeRMXqG6YJB4XTDxGW49p1fve6366j7ANE9x85h00HfZxF4btom7D9xAblnLvmVxmU7j+PEuTyv+63aexI/69byjgYGA7/ZosFqPzMRik0PT/WvdHYhzqqH5m85gk9+2x3w8Y6SsMntz97jtiARy1WdroG7sFCh++uLMGiCeyBZsv0oCv3oonbyXB6em7bRcMEkpRSGT16OWycv93qeGyYuxV8+X+V4fuyPSxEvhVguGFy8XIBnftiAU+e9R2sj9pIBp0BOHCt2Hff4+mfL9rhti6VBh4WFymn1N717Ps3CGG2sRDBcM1T354Jzl/IxYf6OmG9HWLzDPHCN+GglPjIInmKyPvars7bgs2V78eO6g8gvKMRy3XfJHlMCGc2e+dI8/PXz1X4fFwzLBYMf1uTg8+X7MH72toCO54SWieeWSZ7v3J6btsn0tUXbcrFUN/AtEAWFCi/8uAkHT10I6Pj6T83EDROXOm37ZtUBnPShesIuL7/QeLlSky+8azAUAM2fn403527H9HUHHdvPXryM7FzfM8O1+08FNO/XgZPn3W7wMkbPwCsz3dcIv+ilvWe/wfKz9ut1XR/bPgK+UCm8NW8Hhk1ajlV7Tzi26dOy2M/S04KtuX79DYNluWBg//MUFPiXrb+/eCc+/HW32x3YmYuXA08L50qOW/ZM4a6Pf8etH6xwbD9z8TIyRs/Ah7/6XjWTtecEPv5tD/7xddGaDbM3HUbG6BmOOntv9NWWu4+dw2P/W4eHvlxjuv8jU9ciY/QMx/PGz8zCXR+vdNvPMa7Gp1TY2KvS8vIL0WLMHPR+cwl2HnUfvb0x57Qj47Qb+u5v6PH6Ij/ezabrawvRbfxCKKWQe+ai47c1ackuv8/ljyTHutxwjFC3/+1dg+udH63EuUv5fvVEazN2rml32FCzXDCwf6kVbEVrXzPkcbO2YuxPm53uC2ZtOISWY+Zg7X62H1iN2dfm+B+2OznXqqUdR8466qI3HTztlCHY8wz9neTny/cCAEb+X5ZTpm33wBer8MEvxhmdvf4696x5T6nv1+S4bftlh3kJx1uDuf51+3Xq67yPnHFPy+C3f8UNE5fhhzU52OZHVcqFvAK0HTvXbfuZi/l4f/EutH9lPn5Y6359dt5+8Z8u24snvrEF5s+X7zUs+WWMnoGM0TPwjda7TEEhSctNZ244jF5vLDL8jjR/fjaaPvszxkzf5CgB7T3uuSQUqfZJSwWD7Nw/8IbWx1wp4KqX56HekzNx3Tu/+lzPqf8D/7bT9iXZcMC/P9bGnNPIGD0DBzlfftzKManSKZZsyxUv6xoUtxw6gz7/XoJ3Fmbj1Pk8DJrwK/7x9TqcPm9cqvxx3UFHxmzW22nmhsN4aYZ7FQhgXr8N2ALEaz8XrS/19e/78baue6hrFYlZ0PMUHOx3xE6HesiBH566Fv3eWuK0bZa26tzJc3l4c+52tHlxDv49dzt2HzuH3cfOmfbQWbzddhf9yNSiUpZrUn25//s66wC2Hj6DZ37YiFs/WIF92ufi6dgk3Yey8+g5p+Du6pOle3DXxyuxcvcJXPOvRU6vRWsMi2WCwcac0xjx4QpHsVsBOKbdxa0/cBqnL7j/MPPyC/HBL7twWRco9H9eez3i/pMXMGb6JmSMnoE+Psxf89N6Lq8Y7wa/7dwT5VJ+AX7eeAjFkm0/qcu6KgL7WIbV+07i4mXbd2nGhkNo9eIcpwbHFbtPYPPBM4bVO2Y9XB4zWA7UniXtMJhYb/S3GzBxUVHnhye+Xe+4QQKAq8cvNMyM7CUes+f6DPd/qw4gY/QMnL3oPlJbKYU35243DKar9p4sSud3GwAA93++ChPm78DJ85fxn/k7AqpCcv3o5m72rQtn/7d+cdvW/fVF+FYrDehNWrLL7XftKRgAtvzl5v+6D3S7zofusuFgiWCwau8JDH77V6c78W9c/qCztD6+p89fxrlLti/xR7/txksztqDR07Mc++mrldZpJYJJS3Y5JjTbkfsHCgoVFmw94lYFdeyPS2jx/GxsPhT62TUp8vSZ5r9+3oa/fL4aK3bb6sALChUWbsvF1sPOf+tzLlNZDJu0HPtOFFUTDJzgngEBwJGzF/Hp0j1u53P9Hntz2YcSsL6u2/5o/tZc7NLV+7tOyTH5l6I2Evvssfr2jsXbj+Km95di6+GzmDB/Bx74wr2njL4R3H5ztu+4e2OuayOuni89d35YW9TAvcugLcMbo/W4dxqMZTAKhnqHDGoGbGM2nNMUqX5rkVgDOeoOnPTeS+OZHzaiT7Oq6PDKfJQvWQxrn+uL85fc/5j6etX1JqMfJ/+yC+NmbcXrN7VC7fQSKJmagha1ymHpzuM4eyk/pvtkk+/0d+X2QVmntR4tJ87lOUa9vndbWwC2KoZeb7iXHP/57Qav79Xp1QWOx0tH9zTc592F2SgsVPhihflEe0l+jpbT38/01KXdtX3B6E5/i+6m579aQ+5CrTF0n5d6csAWcA4btDW4Xp++VH/KpOrNTM83FmPPuEF+HeOrbuMXhuW84WKJYDBrg2/Fwv9l7Qdg+0I1enomHuzRKKD3GzfLVierzyxWPtUrhnqmUyjoqwVmb7JNfmcvGejZ74KXeRnP4KvO4xYYbv+Xh+7S936ahY71KyDJhy/hubx8NH9+Nt6/vR1OmozHMbpOV09+5x7kxv9sS+NJHzJto5HDADDFJRi0emGO13N58vXv+4M63kx+iOZXvxihwWcSL90bMzMzVVZWVkDHGvXGiIaO9SvE7YyPlBh6Nq2CBVs9d1Ucf2NLPPHN+gilyNzw9rXx5crwZNTxJtDSi4isUkpl+rJvwrcZrPezp084MRBQtHkLBABiIhAAYCCIsIQPBo/4OfcMEZEVJXwwMGrlJyIiZwkfDIiIyLuoBQMR6S8i20QkW0RGh+M94qVxnIgo2qISDEQkGcC7AAYAaAZguIg0C8P7hPqUREQJKVolg/YAspVSu5RSeQC+AjAkSmkhIrK8aAWDmgD0/cYOaNuIiCgKYroBWURGikiWiGQdPcopHIiIwiVawSAHQG3d81raNidKqUlKqUylVGblypUjljgiIquJVjD4HUAjEaknIqkAhgGYHqW0EBFZXlQmqlNK5YvIgwBmA0gG8JFSynyhWSIiCquozVqqlJoJYGa03p+IiIrEdAMyERFFRsIHgz+1YY9VIiJvEj4Y3NiuVrSTQEQU8xI+GPiqUulU09dSfFkeiogoDMoUj0zTbsIHA1+z8QWPdcenf26PVrXKub028+9XhzZRREGoVLp4tJNAEZScHJmb0YQPBnblSxYzfW1wy+oom1YM1zSujC/u6+j2erkS5scSRVrp4sl4ZtAVAIAqZdwDQ9NqZYJ+jx5N/BvkmZpclJV093DssKtqm77WqX5Fp+dNqgZ2HX2aVQ3oOABIK+acJT5/bTO3dHlSs3wJv97vsb6Nve7z0V1X+XXOQCV+MNCCatNqZTCoRXXH5v/e0c7x+J1b2zoeF09x/0iKJQf3Mf1nWGvsemUgWhqUOih+bR3b3+n52CHNPe7/916NQvK+b97SGte2qgEAcJ2k/ZO7r8LPD3czPK59vQpu2+7qnIFH+zhnSBOGt8EHd16FPeMGoUSxZMNzJbtUnb6ou/YPRmQiXXfzVaNcmuNxWZMbq+ta1cBVGelO22Y/0g1l09yrSL57oLPT8/duK/r93tSuFiaPyMSTA5oavg8A/KOPeQbcprZzGu7uUg8lU40/AyOPeDi3kQd7NnL7LCfe1hbPDi6axLlhldJ+nTNQCR8MWtUqj8ZVS2P0gCvw3LVFH3C3Rra7lxGd6jrtb1QgSwmgmHa9rhfTkNY1kZQkKNTWV5gwvA2+vr+T3+ek2KL/EQ9vXwc3ZdZ2yhRd3dqhDiYMbxPUe755cyu0rZPu+J4qBTzer4njdXtwGHBlNbdj00sWw5LHe7htd13247pWNRzX9lCvhgCAWzJrO5UWXH8RSbrPIiU5CUm66eNfv6kVPrn7KvwwqgtuzjQuGZitPJJicCPWtk469owbhL/1bIgrqpdFrfSiu/HH+9s+i/uvaWB61z2ghfNnoy/V6H/rD/W0Xbv9d1u/cim3c93icj2BdFhZ8kTR32Teo9dgQIvquLV9Hce2pAhNxZ/wwaBU8RTMeeQatK5dHvrPNCkJ2P7SAIy51vnHa7QGgv3L0rVhJZ/f95XrW7hts//o6lcqZXiXRvFF/0159foWSCuW7LGNSing2pbVMeXeDtj0Qj/H9oEt3DNu0/e0v4HjjRRG9WiIsUOvBADUq2jLsN67rS2mjuyI3a8OxOs3tdIOEdSpWBJbx/bH89qNkVLKkdkZeaB7Q+wZNwiv3dgSH9/d3j0dAPaMG+SWm9t/R0Nb10CH+hXRvUkVtK5d3uNdrlEqOjewVdHYq8X0Hu3bBLP+frVTMBPdX+DBno1QoZR7xxD9/lfWLIt/3dTS8fzOThmOx//o28QpXU8PvAI7Xh6ApwYWlTpGD2iK+pXcg4Q3b93SGnMesZXgjDqvlEhNdnRciVT/lYQPBnr6CJskgtSUJKc7Gtt29+OSkwTzHu2GSSPaub9oIs2geF2ofavsyfjPsNY+n49ij4jgxSHNnTMD3Xds8ePdMe/Raxz10AoKIoLODSuhlK6HSMVSvjcIlyluq2YpmWo7/prGVQAAt3eogw1j+iJDS4uIoEP9ihARRzWHPWlpxZId38WkJHFkdmXSUjBtVBeP7//VyI6YeFtbp0zXfm0AcENb253x5BHtcF2rGnjz5tZu1SBGlFIY0tp9TJD9WHujeSmDKhtP6xm6vnOHehVQt2LR3+unh67GkNY18fpNrbDy6V7obdDeYA8eSSIolpyEwS1rOF4rVAqDW9VwO8aboW1qorHWJqIvmRgFBpYMwkD/kSabfMD2O5reVxR9KZJE0LBKGZRMTUHpILp52RvWKrM3SFzq1bSK0/MkAUZ0ysCCx7ob7l+3Yik0rFIaLWuVBwCkJDn/3Gb8rSt+/WcPR0bqyR0d62L8jS3R6wpbGkoXT8EvT/TAq1oJVERQJs24Pr59vQpITUnCvVfXd2zLLyjU0iSO3O7ervXRqnZ5j+noWL8iBrSoDrP8yZ6vtamTjgnD27jdbOmtebaP0/OGVUrbShk6j/Vtgs4NKqJ3s6pY/WwfLHuql9t5PC1vq0/no30aY+r9nZBq0C54Y7taqFImTbsG10DnfC595qwAPGzSFrT2uT6G293TKNj96kBsebE/ypdM1W13v4ZwslQw0PP0AS97sifeubWNo1eG/ruxXPdlvK1DHddD8Xi/JhjVo4HheR/r2wTLn+yFKmXTDF+n2FS+ZDHsGTfIrZeMp2VVb9V9Nybd0Q7/vaMdKrv0/GleoxxqpZd0q7PXm3JvBwBAryuq4ObM2k7vWbtCScOMzVWl0sWx/aUBaFe3qHH02lY1UKl0cdzaoW5RKcGPTMf10gNZbjy9VKqjDcXs8NoVSmLKfR1RungKKpRKRVmTgGfnfmctBo9smlUva3iO35/ujWVP9nQ8twcb+2ev/5yUgmnAK18yFVPu7YCpI209FB/v18TtO+BImwhKmDRUu5bCwiVqE9VFg/2PmVYsyeMPuXo5W4PUlPs6Ytvhs0776ksGL1zXHF+s2AcAeGnolZi+7iBG9Whoet7kJEG1cgwE8cb+1x/evg4KChXG/LjZ6776zLF8yVT0a+69XeCF65rjrXnbcfL8Zce2zg0rYevY/obVjsGoUb4Esp7pDcAWGN5ZmI0But52vhy/6+g5x3PH3bMPGde9XeshTyv1UT7sAAATEUlEQVSZ2LttVy0T+O+iRc1yuL5NTTzQo6Hb79rsZ775xX5uJTU713aG/ALb1dnr8CuXKY7Muuk4deEyKhq0Seh11toZ7SWem9rVwq5j5zwd4jCkdU18s+pAxNoMrBUMtP99/WFVKJWKTg2M+xi7znl0e8e6uL2jc8+k2Q93wx+X8s3TE6nyHwXlDq1RMSU5CXd1qYcjZy9h4qKdhvva7/xqlvc9c7NnpElJYtiNOdSBwFWTamXcqme8+fK+jujwynxH6blOhZIAgCuqex8b8Mxgfa++Snjrltbob9D7yVcpyUl48xbj9jf9L2yo7jdrb3Pxhb19p6jtRfDNXzt7OsRUlbJpPtcMjLu+BZ4eeIVhj6pwsFQwsI/kq6t9cQO1dWx/n8YeNPFx8E/fZlUxZ/ORoNJE4ZFZNx2P9HauE/5n/6b4Z3/jfux9m1XFpDvaoadL+4Iv4unWoGrZNOx8ZaAjzV0aVsJPD3VF8xrGVS9mRMQpk7ZrFKK+9X/t3gAv/Lg5qNLV+Btb4vs1OWjtpT1FLxRT2KQkJyHdS8kjlCwVDMqmFcPkEZloW8f3P6oR+5eqsDCAilId+9cl2EFtFLj6lUs5VXe4EvGvBCci6OtDlZBeIPXtscC1ofXKmqEZVLnm2T4hKw3d3aUe7u5SL6hzVCiVinu6+n6Otc/18akHVayxXC7Up1lVVAxRbx57HuE68MTf4yl2RSKjvrtLBiqUSkXfIKZRSCTppVJNG1PjQfmSqaY9u2KZpUoGoSYi2PxiPxRPCe6L60vXQgqPJBH8rWdDTFiQbfh6JP4yjauWwWqtmyVvEChaGAyC5E9DlKtIdRkjc5NHZKJepVLmwSBKdThv3NQKJ8/nReW9yZoYDKLIXiKI1zrjRFDPy1QC0frTdGlYid2QKaIs12YQS+z9l9mAHLv0c9VEgr2XUnqp+KtzpvjGkkEU2Sft6tG0MqavOxjl1JCr2Q9387l7cKhc37YWrm/LpVop8hgMoujKmuWw5tk+SC+Vikemrot2csgFG3PJVyuf6hVfA0UMhK1+QkTGiEiOiKzV/g3UvfakiGSLyDYR6efpPInOPqjkm79wfYNQ8GchEm/YlkO+qlI2zTHRXbwKd2X1v5VSrbV/MwFARJoBGAagOYD+AN4TkfjtVBwimRkVDFdZI2N/7W48GaA/3ry5VQhSQpQYopH7DAHwlVLqklJqN4BsAO29HGMJsXwnWiOGeraM6FQXV3tZaMiXYOFtzhdWE5GVhDsYPCgi60XkIxGxz59bE8B+3T4HtG2WF6uDz1KSBLXSg5vPKZQEQF2tS6jRUoQAcH+3+obbichYUMFAROaJyEaDf0MATATQAEBrAIcAvBHA+UeKSJaIZB09ejSYpMaFWC0ZiMRWoKqVXhI1y5fAphf6RbzrJ1GiCqo3kVKqty/7ichkAD9pT3MA6CfzqaVtMzr/JACTACAzMzN2cqMw8bQWbTTF2jiIP2uThpXysOqcL6O79aOLq5dLw6HTF4NPHFGcCmdvIv1KGX8CsFF7PB3AMBEpLiL1ADQCsDJc6YgnQU6CGjbFkpNiqtSinxHSdboIVbTKiuEa03UqlDRclH32I91QvqRtoFcfbcI4bwuXECWScI4zGC8irWEb0b8HwP0AoJTaJCJfA9gMIB/AKKVUQRjTQUFK0S2aHmvM0iViWykqJSkJ+YWF+PtXawEAS57ogYe+XIPs3D+c9i+bVgzVyqbh1PnLeLh3I/zrxpZO69ESJbqwlQyUUncopVoopVoqpa5TSh3SvfayUqqBUqqJUmpWuNIQb6bc1yHaSTAUyyuyuZZYXJM6qGV1DGlt659QtaxvU5cniTAQkOXEVmWwxXVu4Lm7ZLREOxb88kQPn/e1BwfXJH/718746aGrfTqWyIoYDMgrXxZterRP47C9f20Py5SaVxM5J7pd3XTH+sRE5I7BgLwSiNd5/QXAVRnpHvcBgGo+LAbuT0nELF2eTmH2WrRLQETRxGAQBxqYDKwys+PlASF9/yQxvwMf3r6O4/EX93bE5/d4bvfo48PSjrtfHYTUAKbm6Nm0CjN0ogAxGMSoAVcWLaperoR/c9uHelyAiKB74yqGr1XQzbufmpKE0mkh6qDmEn0yKnoeAT20dQ1MvL1tUPX+bDMgK2MwiFETb2/neBwLvXke6tnQ4+u+JjHQS1n0uHEjsj0Dr1i6uNNa1L68DzN/oiIMBgnGbK6eYCX50oocBe20doqrGzn3xPKU0V9RvSwA26hjvRiIuURRw2AQB/zJoxb8o3tA7+Fa1+9v1VSo+ToXUts66dg6tj+6NzGuxjJyf7f6+GFUF3SoXzHQ5BElHAYDC3jhuuZe9ynhsiiM/i75gR626aCfHdwMY65t5rSfv1UtvgY2o/PqG6v10oq5L4fh6S4/KUnQunZ5n96TyCoYDCygYwB3wPqM8bYOdQEA93St56hiceyn/e/artGyVjlse6m/43koViB7eeiVIe8pRUQ2DAZxIJbqsn29eRbAqUH38X5Ngn7vpCQJ6wyqsfQ5E0Uag0Ec8DQC11djh17pZQ/nbD7YjNE1aDimifDxxIHW2Iy/sSXqViyJtBT/SyKsJiIrYzCIAy95zcjdNdNV5ygo3NGxrl/HK+U8VbR+u15Jrb7eqN4+Gq5tVQOLH+8RVO8nlhDIisI5hTWFSMlU//9M3z3QGU2f/Tmo953zSDes23/KaZtrL5/7utWHCPwONt54m/4inFhCICtiMEgAJYol48Jl5yUh/L9Td74dFgEaVC6NBpXdF4JxfZ8HezYyOUsRff5aJi0FZy/mAwCeGtgUVcqk4eGpa/1Mb+ixREBWxmoi0pisGOZ5N39fhgiw8LHujufdm1TB0DY1vaYuElgiICtjMIgxX97XEd8/0DnayTAVivyyUuniaOSy9OTDvRuZ7B15LCGQFTEYxJhODSqiTR3vU0GHw1cjO+L+bvUBBJ4hmlYTudx2X9uqBgBbYACAh3s7r4fAm3SiyGIwIIeO9SviL9c08LhPmzrl0apWOdPXzTLxato8QLXTbd1kH+rZEJte6IcKMbjoPKuLyIoYDBJAKKo1fM0AS6amYNqDXf0+/6AW1fHx3Vfhrs4ZAGzjDUoVZ/8FoljBX2MMe+fWNmharUzIz3tbhzr4YsW+oM4x429dkXvmktt281XEBD38mEwumthmQFbEYBDDBresEZLziEsW7WkVMfuEdYNbVvd4zuY1yqF5aJJHRDEgqGoiEblJRDaJSKGIZLq89qSIZIvINhHpp9veX9uWLSKjg3l/8k3jqp7HCuilFUvG2uf64IXr/B/1HErRrLdnmwFZUbBtBhsBXA9giX6jiDQDMAxAcwD9AbwnIskikgzgXQADADQDMFzblwIwslt9TLqjndf9XOcD8pbZlS+ZajgVBRElrqCqiZRSWwDDyceGAPhKKXUJwG4RyQbQXnstWym1SzvuK23fzcGkw6qeGnhFtJPgpoy2BnJDLyOXYxnbDMiKwtVmUBPAct3zA9o2ANjvst15iS2Ka/Url8bn93RAu7rRGStBRIHxGgxEZB6AagYvPa2Umhb6JDm990gAIwGgTh3jVa4oNEJZTd7VZT3ieMM2A7Iir8FAKdU7gPPmAKite15L2wYP243eexKASQCQmZnJn6iJRKnVeLxfE/x38U6nbUse7xGl1BBZS7gGnU0HMExEiotIPQCNAKwE8DuARiJST0RSYWtknh6mNFjGn9rGxkRvwRrVoyHWj+nntK1queIRTwfbDMiKgu1a+icROQCgE4AZIjIbAJRSmwB8DVvD8M8ARimlCpRS+QAeBDAbwBYAX2v7ko/KFE9Boyql8fHdVzm2RbsbKBHFv2B7E30P4HuT114G8LLB9pkAZgbzvlb27OBmuPmq2k7bQtENNFbryV0HzEVCrH4WROHEEcgJYuFj3XEpv8D7jojuKmJEFJsYDOKNyY1yvUqlgjst68kd+FmQFXHW0ngTgpv6Hk3jY8I4gBkzUaQwGFhQ9yZV8MOoLk7bWHNUhJ8FWRGriRLY4se7Iy+/0PC15Di55Y6PVBLFPwaDeONH7li3YnDtCFZzudAWOIslMwSR9bCaKN6EqQojVnsYGUyCGDYFhbbPICWJPwuyHn7riTTVytrWaU4rlhzllBBFHquJ4g1rMMJm4u3tsGzncVQrlxbtpBBFHIMBhdXbw9vgQp5vg+GMRDL2VSiVikFelvskSlQMBhRW17biQslE8YBtBnHib70aAQC6Naoc5ZREVpz0gCWKeywZxIlH+zTGo30aRzsZRJSgGAws5L3b2mLGhkPRToZfItm1lMjKGAwsZGCL6hjYwriBNDZHGRBRpLDNgIiIGAysqlmNsri9Yx1ULRv5ZSWJKPYwGFhUcpLgpaEtOH8REQFgMLC8G9vWAhD84jhEFN/YgGxxN19V221N5ViQmpyEvALj6beJKPQYDCgm/fhQV/yy42i0k0FkGQwGFJOaVCuDJtXKRDsZRJbBNgMiIgouGIjITSKySUQKRSRTtz1DRC6IyFrt3/u619qJyAYRyRaRCcIhpkREURdsyWAjgOsBLDF4badSqrX27y+67RMB3Aegkfavf5BpICKiIAUVDJRSW5RS23zdX0SqAyirlFqubOssfgZgaDBpICKi4IWzzaCeiKwRkcUicrW2rSaAA7p9DmjbiIgoirz2JhKReQCqGbz0tFJqmslhhwDUUUodF5F2AH4Qkeb+Jk5ERgIYCQB16tTx93AiIvKR12CglOrt70mVUpcAXNIerxKRnQAaA8gBUEu3ay1tm9l5JgGYBACZmZmcWJOIKEzCUk0kIpVFJFl7XB+2huJdSqlDAM6ISEetF9EIAGalCyIiihCxteMGeLDInwC8DaAygFMA1iql+onIDQBeBHAZQCGA55VSP2rHZAL4BEAJALMAPKR8SISIHAWwN8CkVgJwLMBj442VrhWw1vVa6VoBa11vuK61rlLKp7VygwoG8UJEspRSmd73jH9WulbAWtdrpWsFrHW9sXCtHIFMREQMBkREZJ1gMCnaCYggK10rYK3rtdK1Ata63qhfqyXaDIiIyDOrlAyIiMiDhA4GItJfRLZpM6SOjnZ6AiUiH4lIrohs1G2rICJzRWSH9n+6tl202WCzRWS9iLTVHXOntv8OEbkzGtfijYjUFpGFIrJZmxH379r2hLteEUkTkZUisk671he07fVEZIV2TVNFJFXbXlx7nq29nqE715Pa9m0i0i86V+QbEUnWpqr5SXuekNcrInu0GZrXikiWti12v8dKqYT8ByAZwE4A9QGkAlgHoFm00xXgtXQD0BbARt228QBGa49HA3hNezwQtvEbAqAjgBXa9goAdmn/p2uP06N9bQbXWh1AW+1xGQDbATRLxOvV0lxae1wMwArtGr4GMEzb/j6Av2qPHwDwvvZ4GICp2uNm2ve7OIB62vc+OdrX5+G6HwUwBcBP2vOEvF4AewBUctkWs9/jRC4ZtAeQrZTapZTKA/AVgCFRTlNAlFJLAJxw2TwEwKfa409RNPvrEACfKZvlAMqLbbbYfgDmKqVOKKVOApiLGJw+XCl1SCm1Wnt8FsAW2CYzTLjr1dL8h/a0mPZPAegJ4Bttu+u12j+DbwD00kbyDwHwlVLqklJqN4Bs2L7/MUdEagEYBOAD7bkgga/XQMx+jxM5GNQEsF/3PNFmSK2qbNN7AMBhAFW1x2bXHXefh1Yt0Aa2O+aEvF6tymQtgFzYfug7AZxSSuVru+jT7bgm7fXTACoiTq5V8xaAJ2CbmQCwpT9Rr1cBmCMiq8Q26SYQw99jroGcAJRSSkQSqluYiJQG8C2Ah5VSZ0S3IF4iXa9SqgBAaxEpD+B7AE2jnKSwEZHBAHKVbfLK7tFOTwR0VUrliEgVAHNFZKv+xVj7HidyySAHQG3dc48zpMahI1ox0r5oUK623ey64+bzEJFisAWCL5RS32mbE/Z6AUApdQrAQgCdYKsisN+o6dPtuCbt9XIAjiN+rrULgOtEZA9s1bY9AfwHCXq9Sqkc7f9c2AJ9e8Tw9ziRg8HvABppPRVSYWuAmh7lNIXSdAD2ngV3omj21+kARmi9EzoCOK0VS2cD6Csi6VoPhr7atpii1Ql/CGCLUupN3UsJd71im923vPa4BIA+sLWRLARwo7ab67XaP4MbASxQtlbG6QCGab1v6sE2S/DKyFyF75RSTyqlaimlMmD7PS5QSt2GBLxeESklImXsj2H7/m1ELH+Po93iHs5/sLXQb4etHvbpaKcniOv4ErYFgy7DVmd4D2x1p/MB7AAwD0AFbV8B8K52zRsAZOrO82fYGtuyAdwd7esyudausNW1rgewVvs3MBGvF0BLAGu0a90I4Dlte33YMrdsAP8DUFzbnqY9z9Zer68719PaZ7ANwIBoX5sP194dRb2JEu56tWtap/3bZM9/Yvl7zBHIRESU0NVERETkIwYDIiJiMCAiIgYDIiICgwEREYHBgIiIwGBARERgMCAiIgD/DxWDVobfWHrAAAAAAElFTkSuQmCC\n", 185 | "text/plain": [ 186 | "
" 187 | ] 188 | }, 189 | "metadata": {}, 190 | "output_type": "display_data" 191 | }, 192 | { 193 | "data": { 194 | "text/plain": [ 195 | "512" 196 | ] 197 | }, 198 | "execution_count": 9, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "i=15\n", 205 | "plt.figure()\n", 206 | "plt.plot(flt_low[data_low.trials[i]: data_low.trials[i+1],4])\n", 207 | "#plt.plot(data_low.raw_data[data_low.trials[i]: data_low.trials[i+1],2])\n", 208 | "plt.show()\n", 209 | "data_high.sampling_freq" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "# Features Extraction" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 9, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "array([[0.00317433, 0.00347712, 0.00351526, ..., 0.0009686 , 0.00128107,\n", 228 | " 0.00107977],\n", 229 | " [0.00362845, 0.00595219, 0.00753056, ..., 0.00087038, 0.00078948,\n", 230 | " 0.00062584],\n", 231 | " [0.00309323, 0.00401737, 0.00336045, ..., 0.0010687 , 0.00079988,\n", 232 | " 0.00072598],\n", 233 | " ...,\n", 234 | " [0.00212149, 0.00193713, 0.0020841 , ..., 0.00049861, 0.0005653 ,\n", 235 | " 0.00119055],\n", 236 | " [0.00195615, 0.00182604, 0.00189547, ..., 0.00057526, 0.00064521,\n", 237 | " 0.00068972],\n", 238 | " [0.00162141, 0.00206098, 0.00338333, ..., 0.00065747, 0.00061476,\n", 239 | " 0.00061072]])" 240 | ] 241 | }, 242 | "execution_count": 9, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "\n", 249 | "\n", 250 | "window_size = 0.2\n", 251 | "window_shift = 0.05\n", 252 | "\n", 253 | "highRMSfeatures = gumpy.features.RMS_features_extraction(data_high, trialsHigh, window_size, window_shift)\n", 254 | "highRMSfeaturesBg = gumpy.features.RMS_features_extraction(data_high, trialsHighBg, window_size, window_shift)\n", 255 | "lowRMSfeatures = gumpy.features.RMS_features_extraction(data_high, trialsLow, window_size, window_shift)\n", 256 | "lowRMSfeaturesBg = gumpy.features.RMS_features_extraction(data_high, trialsLowBg, window_size, window_shift)\n", 257 | "\n", 258 | "\n", 259 | "\n", 260 | "# Constructing Classification arrays\n", 261 | "X_tot = np.vstack((highRMSfeatures, lowRMSfeatures))\n", 262 | "y_tot = np.hstack((np.ones((highRMSfeatures.shape[0])),\n", 263 | " np.zeros((lowRMSfeatures.shape[0]))))\n", 264 | " \n", 265 | "X_totSig = np.vstack((highRMSfeatures, highRMSfeaturesBg, lowRMSfeatures, lowRMSfeaturesBg))\n", 266 | "# Normalizing the features\n", 267 | "X_totSig = X_totSig/np.linalg.norm(X_totSig)\n", 268 | "\n", 269 | "y_totSig = np.hstack((data_high.labels, \n", 270 | " data_low.labels))\n", 271 | "\n", 272 | "X_totSig" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 12, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "data": { 282 | "text/plain": [ 283 | "array([2., 1., 2., 2., 1., 0., 1., 2., 0., 1., 0., 0., 2., 1., 2., 1., 0.,\n", 284 | " 1., 1., 0., 0., 2., 0., 2., 1., 2., 0., 1., 2., 0., 0., 1., 1., 0.,\n", 285 | " 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 0., 1., 1.,\n", 286 | " 0., 2., 2., 2., 2., 0., 0., 1., 1., 2., 1., 2., 0., 0., 0., 0., 2.,\n", 287 | " 1., 2., 1., 1., 1., 1., 2., 2., 1., 1., 0., 0., 2., 2., 0., 0., 3.,\n", 288 | " 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])" 289 | ] 290 | }, 291 | "execution_count": 12, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "labels=y_totSig\n", 298 | "labels" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "# Normalization of features " 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 13, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "name": "stdout", 315 | "output_type": "stream", 316 | "text": [ 317 | "Normalized Data:\n", 318 | " Mean = 0.000\n", 319 | " Min = -2.636\n", 320 | " Max = 5.390\n", 321 | " Std.Dev = 1.000\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "features =X_totSig\n", 327 | "labels =y_totSig\n", 328 | "\n", 329 | "# normalize the data first\n", 330 | "features = gumpy.signal.normalize(features, 'mean_std')\n", 331 | "# let's see some statistics\n", 332 | "print(\"\"\"Normalized Data:\n", 333 | " Mean = {:.3f}\n", 334 | " Min = {:.3f}\n", 335 | " Max = {:.3f}\n", 336 | " Std.Dev = {:.3f}\"\"\".format(\n", 337 | " np.nanmean(features),np.nanmin(features),np.nanmax(features),np.nanstd(features)\n", 338 | "))" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "## Splitting data for training \n", 346 | "Now that we extracted features (and reduced the dimensionality), we can split the data for test and training purposes." 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 10, 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "data": { 356 | "text/plain": [ 357 | "array([[0.00241969, 0.0022337 , 0.0023087 , ..., 0.00084206, 0.00076152,\n", 358 | " 0.00103709],\n", 359 | " [0.00316464, 0.00331519, 0.00239209, ..., 0.00194006, 0.0019764 ,\n", 360 | " 0.00227144],\n", 361 | " [0.00433969, 0.00631524, 0.00652404, ..., 0.00064123, 0.00075638,\n", 362 | " 0.00059114],\n", 363 | " ...,\n", 364 | " [0.00233438, 0.0031688 , 0.00322296, ..., 0.00297931, 0.00107865,\n", 365 | " 0.00169824],\n", 366 | " [0.00252272, 0.00243595, 0.00187793, ..., 0.00090753, 0.00114431,\n", 367 | " 0.00146587],\n", 368 | " [0.00252272, 0.00243595, 0.00187793, ..., 0.00090753, 0.00114431,\n", 369 | " 0.00146587]])" 370 | ] 371 | }, 372 | "execution_count": 10, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "\n", 379 | "# gumpy exposes several methods to split a dataset, as shown in the examples:\n", 380 | "if 1: \n", 381 | " split_features = np.array(gumpy.split.normal(features, labels,test_size=0.2))\n", 382 | "if 0: \n", 383 | " n_splits=5\n", 384 | " split_features = np.array(gumpy.split.time_series_split(features, labels, n_splits)) \n", 385 | "if 0: \n", 386 | " split_features = np.array(gumpy.split.normal(PCA, labels, test_size=0.2))\n", 387 | " \n", 388 | "#ShuffleSplit: Random permutation cross-validator \n", 389 | "if 0: \n", 390 | " split_features = gumpy.split.shuffle_Split(features, labels, n_splits=10,test_size=0.2,random_state=0)\n", 391 | " \n", 392 | "# #Stratified K-Folds cross-validator\n", 393 | "# #Stratification is the process of rearranging the data as to ensure each fold is a good representative of the whole \n", 394 | "if 0: \n", 395 | " split_features = gumpy.split.stratified_KFold(features, labels, n_splits=3)\n", 396 | " \n", 397 | "#Stratified ShuffleSplit cross-validator \n", 398 | "if 0: \n", 399 | " split_features = gumpy.split.stratified_shuffle_Split(features, labels, n_splits=10,test_size=0.3,random_state=0)\n", 400 | "\n", 401 | "\n", 402 | "# the functions return a list with the data according to the following example\n", 403 | "X_train = split_features[0]\n", 404 | "X_test = split_features[1]\n", 405 | "Y_train = split_features[2]\n", 406 | "Y_test = split_features[3]\n" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "#SVM, RF, KNN, NB, LR, QLDA, LDA\n", 416 | "from sklearn.cross_validation import cross_val_score\n", 417 | "feature_idx, cv_scores, algorithm,sfs, clf = gumpy.features.sequential_feature_selector(X_train, Y_train, 'SVM',(6, 30), 3, 'SFFS')\n", 418 | "\n", 419 | "feature=X_train[:,feature_idx]\n", 420 | "# features=features[:,feature_idx]\n", 421 | "scores = cross_val_score(clf, feature, Y_train, cv=3)\n", 422 | "# scores = cross_val_score(clf, features, labels, cv=5)\n", 423 | "\n", 424 | "\n", 425 | "print(\"Validation Accuracy: %0.2f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))\n", 426 | "clf.fit(feature, Y_train)\n", 427 | "feature1=X_test[:,feature_idx]\n", 428 | "feature1.shape\n", 429 | "clf.predict(feature1)\n", 430 | "f=clf.score(feature1, Y_test)\n", 431 | "print(\"Test Accuracy:\",f )" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "# Gesture Classification" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "\n", 448 | "\n", 449 | "if __name__ == '__main__':\n", 450 | " # Posture Classification\n", 451 | " classifiers = []\n", 452 | " Accuracy=[]\n", 453 | " Final_results = {}\n", 454 | " for model in gumpy.classification.available_classifiers:\n", 455 | " print (model)\n", 456 | " feature_idx, cv_scores, algorithm, sfs, clf = gumpy.features.sequential_feature_selector(features, labels,model,(6, 25), 3, 'SFFS')\n", 457 | " classifiers.append(model)\n", 458 | " Accuracy.append (cv_scores*100) \n", 459 | " Final_results[model]= cv_scores*100\n", 460 | " print (Final_results)" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "## Force classification" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "if __name__ == '__main__':\n", 477 | " # Posture Classification\n", 478 | " classifiers = []\n", 479 | " Accuracy=[]\n", 480 | " Final_results = {}\n", 481 | " for model in gumpy.classification.available_classifiers:\n", 482 | " print (model)\n", 483 | " feature_idx, cv_scores, algorithm = gumpy.features.sequential_feature_selector(X_tot, y_tot, model,(6, 25), 10, 'SFFS')\n", 484 | " classifiers.append(model)\n", 485 | " Accuracy.append (cv_scores*100) \n", 486 | " Final_results[model]= cv_scores*100\n", 487 | "print (Final_results)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "## Classification without the feature selection method" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "classifiers = []\n", 504 | "Accuracy=[]\n", 505 | "Final_results = {}\n", 506 | "for model in gumpy.classification.available_classifiers:\n", 507 | " results, clf = gumpy.classify(model, X_train, Y_train, X_test, Y_test)\n", 508 | " print (model)\n", 509 | " classifiers.append(model)\n", 510 | " Accuracy.append (results.accuracy) \n", 511 | " Final_results[model]= results.accuracy\n", 512 | "print (Final_results)" 513 | ] 514 | } 515 | ], 516 | "metadata": { 517 | "kernelspec": { 518 | "display_name": "Python 3", 519 | "language": "python", 520 | "name": "python3" 521 | }, 522 | "language_info": { 523 | "codemirror_mode": { 524 | "name": "ipython", 525 | "version": 3 526 | }, 527 | "file_extension": ".py", 528 | "mimetype": "text/x-python", 529 | "name": "python", 530 | "nbconvert_exporter": "python", 531 | "pygments_lexer": "ipython3", 532 | "version": "3.5.2" 533 | } 534 | }, 535 | "nbformat": 4, 536 | "nbformat_minor": 2 537 | } 538 | -------------------------------------------------------------------------------- /gumpy/__init__.py: -------------------------------------------------------------------------------- 1 | # import all gumpy submodules that should be loaded automatically 2 | import gumpy.classification 3 | import gumpy.data 4 | import gumpy.plot 5 | import gumpy.signal 6 | import gumpy.utils 7 | import gumpy.features 8 | import gumpy.split 9 | 10 | # fetch into gumpy-scope so that users don't have to specify the entire 11 | # namespace 12 | from gumpy.classification import classify 13 | 14 | # retrieve the gumpy version (required for package manager) 15 | from gumpy.version import __version__ 16 | -------------------------------------------------------------------------------- /gumpy/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # import all functions and ABCs for convenience 2 | from .classifier import classify, vote, Classifier, available_classifiers, register_classifier 3 | 4 | # import all the classifiers for convenience 5 | from .common import SVM, LDA, RandomForest, NaiveBayes, KNN, LogisticRegression, MLP 6 | -------------------------------------------------------------------------------- /gumpy/classification/classifier.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from sklearn.metrics import classification_report 3 | from sklearn.ensemble import VotingClassifier 4 | from mlxtend.feature_selection import SequentialFeatureSelector as SFS 5 | import numpy as np 6 | 7 | class ClassifierError(Exception): 8 | pass 9 | 10 | 11 | class Classifier(ABC): 12 | """ 13 | Abstract base class representing a classifier. 14 | 15 | All classifiers should subclass from this baseclass. All subclasses need to 16 | implement `run()`, which will be called for the classification. Additional 17 | arguments to the initialization should be captured via `**kwargs`. For an 18 | example, see the SVM classifier. 19 | 20 | In case a classifier auto-tunes its hyperparameters (for instance with the 21 | help of a grid search) but should avoid this behavior during voting 22 | classification or feature selection, a set of static options should be 23 | obtainable in form of a key-value dictionary using the ``static_opts`` 24 | member function which will subsequently be passed to ``__init__``. Note that 25 | this function has to be defined with the staticmethod decorator. The 26 | Classifier provides an empty static_opts implementation. For an example of a 27 | customization, see the SVM classifier which should not perform grid search 28 | during voting classification or feature selection. 29 | 30 | """ 31 | 32 | def __init__(self): 33 | pass 34 | 35 | 36 | @staticmethod 37 | def static_opts(ftype, **kwargs): 38 | """Return a kwargs dict for voting classification or feature computation. 39 | 40 | For more information see the documentation of the Classifier class. For additional 41 | information about the passed keyword arguments see the corresponding documentation 42 | in 43 | - ``gumpy.classification.classifier.vote`` 44 | - ``gumpy.features.sequential_feature_selector`` 45 | 46 | Args: 47 | ftype (str): function type for which the options are requested. 48 | One of the following: 'vote', 'sequential_feature_selector' 49 | **kwargs (dict): Additional arguments, depends on the function type 50 | 51 | Returns: 52 | A kwargs dictionary that can be passed to ``__init__`` 53 | """ 54 | return {} 55 | 56 | 57 | @abstractmethod 58 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 59 | """ 60 | Run a classification. 61 | 62 | Args: 63 | self: reference to object 64 | X_train: training data (values) 65 | Y_train: training data (labels) 66 | X_test: evaluation data (values) 67 | Y_test: evaluation data (labels) 68 | **kwargs: Any additional arguments that may be passed to a classifier 69 | 70 | Returns: 71 | 2-element tuple containing 72 | 73 | - **ClassificationResult**: Object with all the classification results 74 | - **Classifier**: Reference to the classifier 75 | 76 | """ 77 | return None, self 78 | 79 | 80 | def __call__(self, X_train, Y_train, X_test, Y_test, **kwargs): 81 | return self.run(X_train, Y_train, X_test, Y_test) 82 | 83 | 84 | class ClassificationResult: 85 | """ 86 | The result of a classification run. 87 | 88 | The result includes the accuracy of the classification, a reference to the y 89 | data, as well as the prediction. 90 | 91 | """ 92 | 93 | def __init__(self, test, pred): 94 | self.test = test 95 | self.pred = pred 96 | self.n_correct = len(np.where(test - pred == 0)[0]) 97 | self.accuracy = (self.n_correct / len(pred)) * 100.0 98 | self.report = classification_report(self.test, self.pred) 99 | 100 | def __str__(self): 101 | return self.report 102 | 103 | 104 | 105 | # list of known classifiers. 106 | available_classifiers = {} 107 | 108 | 109 | def register_classifier(cls): 110 | """Automatically register a class in the classifiers dictionary. 111 | 112 | This function should be used as decorator. 113 | 114 | Args: 115 | cls: subclass of `gumpy.classification.Classifier` that should be 116 | registered to gumpy. 117 | 118 | Returns: 119 | The class that was passed as argument 120 | 121 | Raises: 122 | ClassifierError: This error will be raised when a classifier is 123 | registered with a name that is already used. 124 | 125 | """ 126 | if cls.__name__ in available_classifiers: 127 | raise ClassifierError("Classifier {name} already exists in available_classifiers".format(name=cls.__name__)) 128 | 129 | available_classifiers[cls.__name__] = cls 130 | return cls 131 | 132 | 133 | 134 | def classify(c, *args, **kwargs): 135 | """Classify EEG data given a certain classifier. 136 | 137 | The classifier can be specified by a string or be passed as an object. The 138 | latter option is useful if a classifier has to be called repeatedly, but the 139 | instantiation is computationally expensive. 140 | 141 | Additional arguments for the classifier instantiation can be passed in 142 | kwargs as a dictionary with name `opts`. They will be forwarded to the 143 | classifier on construction. If the classifier was passed as object, this 144 | will be ignored. 145 | 146 | Args: 147 | c (str or object): The classifier. Either specified by the classifier 148 | name, or passed as object 149 | X_train: training data (values) 150 | Y_train: training data (labels) 151 | X_test: evaluation data (values) 152 | Y_test: evaluation data (labels) 153 | **kwargs: additional arguments that may be passed on to the classifier. If the 154 | classifier is selected via string/name, you can pass options to the 155 | classifier by a dict with the name `opts`, i.e. `classify('SVM', 156 | opts={'a': 1})`. 157 | 158 | Returns: 159 | 2-element tuple containing 160 | 161 | - **ClassificationResult**: The result of the classification. 162 | - **Classifier**: The classifier that was used during the classification. 163 | 164 | Raises: 165 | ClassifierError: If the classifier is unknown or classification fails. 166 | 167 | Examples: 168 | >>> import gumpy 169 | >>> result, clf = gumpy.classify("SVM", X_train, Y_train, X_test, Y_test) 170 | 171 | """ 172 | 173 | if isinstance(c, str): 174 | if not (c in available_classifiers): 175 | raise ClassifierError("Unknown classifier {c}".format(c=c.__repr__())) 176 | 177 | # instantiate the classifier 178 | opts = kwargs.pop('opts', None) 179 | if opts is not None: 180 | clf = available_classifiers[c](**opts) 181 | else: 182 | clf = available_classifiers[c]() 183 | return clf.run(*args, **kwargs) 184 | 185 | elif isinstance(c, Classifier): 186 | return c.run(*args, **kwargs) 187 | 188 | # invalid argument passed to the function 189 | raise ClassifierError("Unknown classifier {c}".format(c=c.__repr__())) 190 | 191 | 192 | 193 | def vote(X_train, Y_train, X_test, Y_test, voting_type, feature_selection, k_features): 194 | """Invokation of a soft voting/majority rule classification. 195 | 196 | This is a wrapper around `sklearn.ensemble.VotingClassifier` which 197 | automatically uses all classifiers that are known to `gumpy` in 198 | `gumpy.classification.available_classifiers`. 199 | 200 | Args: 201 | X_train: training data (values) 202 | Y_train: training data (labels) 203 | X_test: evaluation data (values) 204 | Y_test: evaluation data (labels) 205 | voting_type (str): either of 'soft' or 'hard'. See the 206 | sklearn.ensemble.VotingClassifier documentation for more details 207 | 208 | Returns: 209 | 2-element tuple containing 210 | 211 | - **ClassificationResult**: The result of the classification. 212 | - **Classifier**: The instance of `sklearn.ensemble.VotingClassifier` 213 | that was used during the classification. 214 | 215 | """ 216 | 217 | k_cross_val = 10 218 | N_JOBS=-1 219 | 220 | clfs = [] 221 | for classifier in available_classifiers: 222 | # determine kwargs such that the classifiers get initialized with 223 | # proper default settings. This avoids cross-validation, for instance 224 | opts = available_classifiers[classifier].static_opts('vote', X_train=X_train) 225 | 226 | # retrieve instance 227 | cobj = available_classifiers[classifier](**opts) 228 | clfs.append((classifier, cobj.clf)) 229 | 230 | # instantiate the VotingClassifier 231 | soft_vote_clf = VotingClassifier(estimators=clfs, voting=voting_type) 232 | 233 | if feature_selection: 234 | sfs = SFS(soft_vote_clf, 235 | k_features, 236 | forward=True, 237 | floating=True, 238 | verbose=2, 239 | scoring='accuracy', 240 | cv=k_cross_val, 241 | n_jobs=N_JOBS) 242 | sfs = sfs.fit(X_train, Y_train) 243 | X_train = sfs.transform(X_train) 244 | X_test = sfs.transform(X_test) 245 | 246 | soft_vote_clf.fit(X_train, Y_train) 247 | Y_pred = soft_vote_clf.predict(X_test) 248 | return ClassificationResult(Y_test, Y_pred), soft_vote_clf 249 | 250 | 251 | 252 | # TODO: what to do with this old code? adopt it similar to `vote` above? 253 | # def cross_validation_classification (classifier,X, y, k): 254 | # 255 | # 256 | # k_cross_val = 10 257 | # N_JOBS=4 258 | # if classifier == "SVM": 259 | # parameters_svm = [{'kernel': ['rbf', 'sigmoid', 'poly'], 260 | # 'C': [1e1, 1e2, 1e3, 1e4], 261 | # 'gamma': [1e4, 1e3, 1e2, 1, 1e-1, 1e-2], 262 | # 'degree': [2,3,4]}] 263 | # clf = GridSearchCV(svm.SVC(max_iter=1e6), 264 | # parameters_svm, cv=k_cross_val) 265 | # 266 | # elif classifier == "LDA": 267 | # 268 | # from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 269 | # clf = LinearDiscriminantAnalysis() 270 | # 271 | # elif classifier == "Random Forest": 272 | # 273 | # parameters_rf = [{'n_estimators': [10, 100, 1000], 274 | # 'criterion': ['gini', 'entropy']}] 275 | # clf = GridSearchCV(RandomForestClassifier(n_jobs=N_JOBS), 276 | # parameters_rf, cv=k_cross_val) 277 | # 278 | # elif classifier == "Naive Bayes": # Without Feed Back 279 | # clf = GaussianNB() 280 | # 281 | # elif classifier == "KNN": # Without Feed Back 282 | # from sklearn import neighbors 283 | # clf = neighbors.KNeighborsClassifier(n_neighbors=5) 284 | # 285 | # elif classifier == "Logistic regression": # Without Feed Back 286 | # from sklearn.linear_model import LogisticRegression 287 | # clf = LogisticRegression(C=100) 288 | # 289 | # elif classifier == "MLP": 290 | # from sklearn.neural_network import MLPClassifier 291 | # 292 | # clf = MLPClassifier(solver='lbfgs', alpha=1e-5,hidden_layer_sizes=(X.shape[1],X.shape[1]), random_state=1) 293 | # 294 | # elif classifier == "LDA_with_shrinkage": 295 | # from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 296 | # clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto') 297 | # 298 | # 299 | # 300 | # kfold = KFold(n_splits = k, random_state = 777) 301 | # 302 | # results = cross_val_score(clf, X, y, cv = kfold) 303 | # 304 | # # VISULALISATION 305 | # 306 | # 307 | # print('Accuracy Score') 308 | # print('Avearge %: ', results.mean()*100) 309 | # print('Standard deviation: ', results.std()) 310 | -------------------------------------------------------------------------------- /gumpy/classification/common.py: -------------------------------------------------------------------------------- 1 | """Implementations of common classifiers. 2 | 3 | The implementations rely mostly on scikit-learn. They use default parameters 4 | that were found to work on most datasets. 5 | """ 6 | 7 | # TODO: check consistency in variable naming 8 | # TODO: implement unit tests 9 | 10 | from .classifier import Classifier, ClassificationResult, register_classifier 11 | 12 | # selectively import relevant sklearn classes. Prepend them with _ to avoid 13 | # confusion with classes specified in this module 14 | from sklearn.svm import SVC as _SVC 15 | from sklearn.model_selection import GridSearchCV 16 | from sklearn import neighbors 17 | from sklearn.neural_network import MLPClassifier as _MLPClassifier 18 | from sklearn.linear_model import LogisticRegression as _LogisticRegression 19 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as _LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis as _QuadraticDiscriminantAnalysis 20 | from sklearn.naive_bayes import GaussianNB as _GaussianNB 21 | from sklearn.ensemble import RandomForestClassifier as _RandomForestClassifier 22 | from sklearn.tree import DecisionTreeClassifier as _DecisionTreeClassifier 23 | 24 | 25 | # some 'good' default values across different classifiers 26 | 27 | 28 | @register_classifier 29 | class SVM(Classifier): 30 | """Support Vector Machine classifier for EEG data. 31 | 32 | """ 33 | 34 | def __init__(self, **kwargs): 35 | """Initialize the SVM classifier. 36 | 37 | All keyword arguments that are not listed will be forwarded to the 38 | underlying classifier. In this case, it is sklearn.SVC. For instance, 39 | if you pass an argument ``probability=True``, this will be forwarded 40 | to the initialization of SVC. 41 | 42 | Keyword arguments 43 | ----------------- 44 | max_iter: int, default = 1e6 45 | number of iterations during hyper-parameter tuning 46 | k_cross_val: int, default = 5 47 | number cross-validations (k-fold) 48 | cross_validation: Boolean, default = True 49 | Enable k-fold cross validation for hyper-parameter tuning. If False, 50 | the the SVM will use `probability=True` if not specified otherwise 51 | in kwargs. 52 | """ 53 | super(SVM, self).__init__() 54 | 55 | # initialize some default values for the SVM backend 56 | self.max_iter = kwargs.pop('max_iter', 1e6) 57 | 58 | # parameters for k cross validation / hyper-parameter tuning 59 | self.params = [{ 60 | 'kernel': ['rbf', 'sigmoid', 'poly'], 61 | 'C': [1e1, 1e2, 1e3, 1e4], 62 | 'gamma': [1e4, 1e3, 1e2, 1, 1e-1, 1e-2], 63 | 'degree': [2,3,4]}] 64 | self.k_cross_val = kwargs.pop('k_cross_val', 5) 65 | 66 | # initialize the classifier using grid search to find optimal parameters 67 | # via cross validation 68 | if kwargs.pop('cross_validation', True): 69 | self.clf = GridSearchCV(_SVC(max_iter=self.max_iter, **kwargs), 70 | self.params, 71 | cv=self.k_cross_val) 72 | else: 73 | probability = kwargs.pop('probability', True) 74 | self.clf = _SVC(max_iter=self.max_iter, probability=probability, **kwargs) 75 | 76 | 77 | @staticmethod 78 | def static_opts(ftype, **kwargs): 79 | """Returns default options for voting classification. 80 | 81 | This will avoid grid search during initialization. 82 | """ 83 | return {'cross_validation': False} 84 | 85 | 86 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 87 | self.clf.fit(X_train, Y_train.astype(int)) 88 | Y_pred = self.clf.predict(X_test) 89 | result = ClassificationResult(Y_test, Y_pred) 90 | return result, self 91 | 92 | 93 | 94 | @register_classifier 95 | class KNN(Classifier): 96 | """ 97 | """ 98 | 99 | def __init__(self, **kwargs): 100 | """Initialize a K Nearest Neighbors (KNN) classifier. 101 | 102 | All additional keyword arguments will be forwarded to the underlying 103 | classifier, which is here ``sklearn.neighbors.KNeighborsClassifier``. 104 | 105 | Keyword Arguments 106 | ----------------- 107 | n_neighbors: int, default 5 108 | Number of neighbors 109 | """ 110 | 111 | super(KNN, self).__init__() 112 | self.nneighbors = kwargs.pop('n_neighbors', 5) 113 | self.clf = neighbors.KNeighborsClassifier(n_neighbors=self.nneighbors, **kwargs) 114 | 115 | 116 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 117 | self.clf.fit(X_train, Y_train.astype(int)) 118 | Y_pred = self.clf.predict(X_test) 119 | return ClassificationResult(Y_test, Y_pred), self 120 | 121 | 122 | 123 | @register_classifier 124 | class LDA(Classifier): 125 | """Linear Discriminant Analysis classifier. 126 | 127 | """ 128 | 129 | def __init__(self, **kwargs): 130 | super(LDA, self).__init__() 131 | self.clf = _LinearDiscriminantAnalysis(**kwargs) 132 | 133 | 134 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 135 | self.clf.fit(X_train, Y_train.astype(int)) 136 | Y_pred = self.clf.predict(X_test) 137 | return ClassificationResult(Y_test, Y_pred), self 138 | 139 | 140 | @register_classifier 141 | class Tree(Classifier): 142 | """Decision Tree 143 | 144 | """ 145 | 146 | def __init__(self, **kwargs): 147 | super(Tree, self).__init__() 148 | self.clf = _DecisionTreeClassifier(**kwargs) 149 | 150 | 151 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 152 | self.clf.fit(X_train, Y_train.astype(int)) 153 | Y_pred = self.clf.predict(X_test) 154 | return ClassificationResult(Y_test, Y_pred), self 155 | 156 | 157 | @register_classifier 158 | class LogisticRegression(Classifier): 159 | """ 160 | """ 161 | 162 | def __init__(self, **kwargs): 163 | """Initialize a Logistic Regression Classifier. 164 | 165 | Additional keyword arguments will be passed to the classifier 166 | initialization which is ``sklearn.linear_model.LogisticRegression`` 167 | here. 168 | 169 | Keyword Arguments 170 | ----------------- 171 | C: int, default = 100 172 | """ 173 | super(LogisticRegression, self).__init__() 174 | self.C = kwargs.pop("C", 100) 175 | self.clf = _LogisticRegression(C=self.C, **kwargs) 176 | 177 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 178 | self.clf.fit(X_train, Y_train.astype(int)) 179 | Y_pred = self.clf.predict(X_test) 180 | return ClassificationResult(Y_test, Y_pred), self 181 | 182 | 183 | 184 | @register_classifier 185 | class MLP(Classifier): 186 | """ 187 | """ 188 | 189 | def __init__(self, **kwargs): 190 | """This 'initializes' an MLP Classifier. 191 | 192 | If no further keyword arguments are passed, the initializer is not fully 193 | created and the MLP will only be constructed during `run`. If, however, 194 | the hidden layer size is specified, the MLP will be constructed fully. 195 | 196 | Keyword Arguments 197 | ----------------- 198 | solver: default = ``lbfgs`` 199 | The internal solver for weight optimization. 200 | alpha: default = ``1e-5`` 201 | Regularization parameter. 202 | random_state: int or None 203 | Seed used to initialize the random number generator. default = 1, 204 | can be None. 205 | hidden_layer_sizes: tuple 206 | The sizes of the hidden layers. 207 | """ 208 | 209 | super(MLP, self).__init__() 210 | 211 | # TODO: why lbfgs and not adam? 212 | self.solver = kwargs.pop('solver', 'lbfgs') 213 | self.alpha = kwargs.pop('alpha', 1e-5) 214 | self.random_state = kwargs.pop('random_state', 1) 215 | 216 | # determine if the MLP can be initialized or not 217 | self.clf = None 218 | self.hidden_layer_sizes = kwargs.pop('hidden_layer_sizes', -1) 219 | if not (self.hidden_layer_sizes == -1): 220 | self.initMLPClassifier(**kwargs) 221 | 222 | 223 | @staticmethod 224 | def static_opts(ftype, **kwargs): 225 | """Sets options that are required during voting and feature selection runs. 226 | 227 | """ 228 | 229 | opts = dict() 230 | 231 | if ftype == 'sequential_feature_selector': 232 | # check if we got the features 233 | features = kwargs.pop('features', None) 234 | if features is not None: 235 | opts['hidden_layer_sizes'] = (features.shape[0], features.shape[1]) 236 | 237 | if ftype == 'vote': 238 | # check if we got the training data 239 | X_train = kwargs.pop('X_train', None) 240 | if X_train is not None: 241 | # TODO: check dimensions! 242 | opts['hidden_layer_sizes'] = (X_train.shape[1], X_train.shape[1]) 243 | 244 | return opts 245 | 246 | 247 | def initMLPClassifier(self, **kwargs): 248 | self.hidden_layer_sizes = kwargs.pop('hidden_layer_sizes', self.hidden_layer_sizes) 249 | self.clf = _MLPClassifier(solver=self.solver, 250 | alpha=self.alpha, 251 | hidden_layer_sizes=self.hidden_layer_sizes, 252 | random_state=self.random_state, 253 | **kwargs) 254 | 255 | 256 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 257 | """Run the MLP classifier. 258 | 259 | In case the user did not specify layer sizes during 260 | initialization, the run method will automatically deduce 261 | the size from the input arguments. 262 | """ 263 | if self.clf is None: 264 | self.hidden_layer_sizes = (X_train.shape[1], X_train.shape[1]) 265 | self.initMLPClassifier(**kwargs) 266 | 267 | self.clf.fit(X_train, Y_train.astype(int)) 268 | Y_pred = self.clf.predict(X_test) 269 | return ClassificationResult(Y_test, Y_pred), self 270 | 271 | 272 | 273 | @register_classifier 274 | class NaiveBayes(Classifier): 275 | """ 276 | """ 277 | 278 | def __init__(self, **kwargs): 279 | super(NaiveBayes, self).__init__() 280 | self.clf = _GaussianNB(**kwargs) 281 | 282 | 283 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 284 | self.clf.fit(X_train, Y_train.astype(int)) 285 | Y_pred = self.clf.predict(X_test) 286 | return ClassificationResult(Y_test, Y_pred), self 287 | 288 | 289 | 290 | @register_classifier 291 | class RandomForest(Classifier): 292 | """ 293 | """ 294 | 295 | def __init__(self, **kwargs): 296 | """Initialize a RandomForest classifier. 297 | 298 | All keyword arguments that are not listed will be forwarded to the 299 | underlying classifier. In this case, it is ``sklearn.esemble.RandomForestClassifier``. 300 | 301 | Keyword Arguments 302 | ----------------- 303 | n_jobs: int, default = 4 304 | Number of jobs for the RandomForestClassifier 305 | k_cross_val: int, default = 5 306 | Number of cross-validations in hyper-parameter tuning. 307 | cross_validation: Boolean, default True 308 | Enable k-fold cross validation for hyper-parameter tuning. If set to 309 | false, the criterion will be `gini` and 10 estimators will be used 310 | if not specified otherwise in kwargs. 311 | """ 312 | # TODO: document that all additional kwargs will be passed to the 313 | # RandomForestClassifier! 314 | 315 | super(RandomForest, self).__init__() 316 | 317 | self.n_jobs = kwargs.pop("n_jobs", 4) 318 | self.params = [{ 319 | 'n_estimators': [10, 100, 1000], 320 | 'criterion': ['gini', 'entropy']}] 321 | self.k_cross_val = kwargs.pop('k_cross_val', 5) 322 | 323 | # initialize the classifier, which will be optimized using k cross 324 | # validation during fitting 325 | if kwargs.pop('cross_validation', True): 326 | self.clf = GridSearchCV(_RandomForestClassifier(n_jobs=self.n_jobs, **kwargs), 327 | self.params, 328 | cv=self.k_cross_val) 329 | else: 330 | # default arguments to use if not specified otherwise 331 | # TODO: move to static_opts? 332 | criterion = kwargs.pop('criterion', 'gini') 333 | n_estimators = kwargs.pop('n_estimators', 10) 334 | self.clf = _RandomForestClassifier(criterion=criterion, n_estimators=n_estimators, n_jobs=self.n_jobs, **kwargs) 335 | 336 | 337 | @staticmethod 338 | def static_opts(ftype, **kwargs): 339 | """Returns default options for voting classification. 340 | 341 | This will avoid grid search during initialization. 342 | """ 343 | return {'cross_validation': False} 344 | 345 | 346 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 347 | self.clf.fit(X_train, Y_train.astype(int)) 348 | Y_pred = self.clf.predict(X_test) 349 | result = ClassificationResult(Y_test, Y_pred) 350 | return result, self 351 | 352 | 353 | 354 | @register_classifier 355 | class QuadraticLDA(Classifier): 356 | def __init__(self, **kwargs): 357 | super(QuadraticLDA, self).__init__() 358 | self.clf = _QuadraticDiscriminantAnalysis(**kwargs) 359 | 360 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 361 | self.clf.fit(X_train, Y_train.astype(int)) 362 | Y_pred = self.clf.predict(X_test) 363 | return ClassificationResult(Y_test, Y_pred), self 364 | 365 | 366 | 367 | @register_classifier 368 | class ShrinkingLDA(Classifier): 369 | def __init__(self, **kwargs): 370 | """Initializes a ShrinkingLDA classifier. 371 | 372 | Additional arguments will be forwarded to the underlying classifier 373 | instantiation, which is 374 | ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis`` here. 375 | 376 | Keyword Arguments 377 | ----------------- 378 | solver: string, default = lsqr 379 | Solver used in LDA 380 | shrinkage: string, default = 'auto' 381 | 382 | """ 383 | super(ShrinkingLDA, self).__init__() 384 | self.solver = kwargs.pop('solver', 'lsqr') 385 | self.shrinkage = kwargs.pop('shrinkage', 'auto') 386 | self.clf = _LinearDiscriminantAnalysis(solver=self.solver, shrinkage=self.shrinkage, **kwargs) 387 | 388 | def run(self, X_train, Y_train, X_test, Y_test, **kwargs): 389 | self.clf.fit(X_train, Y_train.astype(int)) 390 | Y_pred = self.clf.predict(X_test) 391 | return ClassificationResult(Y_test, Y_pred), self 392 | 393 | -------------------------------------------------------------------------------- /gumpy/data/__init__.py: -------------------------------------------------------------------------------- 1 | # import for convenience. Individual dataset implementatios are kept separate as 2 | # they may require several subroutines that otherwise clutter the namespace 3 | from .dataset import Dataset 4 | from .nst import NST 5 | from .graz import GrazB 6 | from .khushaba import Khushaba 7 | from .nst_emg import NST_EMG 8 | -------------------------------------------------------------------------------- /gumpy/data/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | 4 | class DatasetError(Exception): 5 | pass 6 | 7 | 8 | class Dataset(ABC): 9 | """ 10 | Abstract base class representing a dataset. 11 | 12 | All datasets should subclass from this baseclass and need to implement the 13 | `load` function. Initializing of the dataset and actually loading the data is 14 | separated as the latter may require significant time, depending on where the 15 | data is coming from. It also allows to implement different handlers for the 16 | remote end where the data originates, e.g. download from server, etc. 17 | 18 | When subclassing form Dataset it is helpful to set fields `data_type`, 19 | `data_name`, and `data_id`. For more information on this field, see for 20 | instance the implementation in :func:`gumpy.data.graz.GrazB.__init__`. 21 | 22 | """ 23 | 24 | 25 | def __init__(self, **kwargs): 26 | """Initialize a dataset.""" 27 | pass 28 | 29 | 30 | @abstractmethod 31 | def load(self, **kwargs): 32 | """Load the data and prepare it for usage. 33 | 34 | gumpy expects the EEG/EMG trial data to be in the following format: 35 | 36 | ===========================================> time 37 | | | 38 | trial_start trial_end 39 | |<------------trial_len------------>| 40 | |<---MotorImager--->| 41 | 42 | 43 | Consequentially the class members need to adhere the following structure 44 | 45 | .raw_data (n_samples, n_channels) return all channels 46 | .trials (,n_trials) 47 | .labels (,n_labels) 48 | .trial_len scalar 49 | .sampling_freq scalar 50 | .mi_interval [mi_start, mi_end] within a trial in seconds 51 | 52 | Arrays, such as `.raw_data` have to be accessible using bracket 53 | notation `[]`. You can provide a custom implementation, however the 54 | easiest way is to use numpy ndarrays to store the data. 55 | 56 | For an example implementation, have a look at `gumpy.data.nst.NST`. 57 | """ 58 | return self 59 | 60 | 61 | def print_stats(self): 62 | """Commodity function to print information about the dataset. 63 | 64 | This method uses the fields that need to be implemented when 65 | subclassing. For more information about the fields that need to be 66 | implemented see :func:`gumpy.data.dataset.Dataset.load` and 67 | :func:`gumpy.data.dataset.Dataset.__init__`. 68 | """ 69 | 70 | print("Data identification: {name}-{id}".format(name=self.data_name, id=self.data_id)) 71 | print("{type}-data shape: {shape}".format(type=self.data_type, shape=self.raw_data.shape)) 72 | print("Trials data shape: ", self.trials.shape) 73 | print("Labels shape: ", self.labels.shape) 74 | print("Total length of single trial: ", self.trial_total) 75 | print("Sampling frequency of {type} data: {freq}".format(type=self.data_type, freq=self.sampling_freq)) 76 | print("Interval for motor imagery in trial: ", self.mi_interval) 77 | print('Classes possible: ', np.unique(self.labels)) 78 | 79 | 80 | -------------------------------------------------------------------------------- /gumpy/data/graz.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DatasetError 2 | import os 3 | import numpy as np 4 | import scipy.io 5 | 6 | 7 | class GrazB(Dataset): 8 | """An NST dataset. 9 | 10 | An NST dataset usually consists of three files that are within a specific 11 | subdirectory. The implementation follows this structuring, i.e. the user 12 | needs to pass a base-directory as well as the identifier upon instantiation. 13 | 14 | """ 15 | 16 | def __init__(self, base_dir, identifier, **kwargs): 17 | """Initialize a GrazB dataset without loading it. 18 | 19 | Args: 20 | base_dir (str): The path to the base directory in which the GrazB dataset resides. 21 | identifier (str): String identifier for the dataset, e.g. `B01` 22 | **kwargs: Arbitrary keyword arguments (unused). 23 | 24 | """ 25 | 26 | super(GrazB, self).__init__(**kwargs) 27 | 28 | self.base_dir = base_dir 29 | self.data_id = identifier 30 | self.data_dir = base_dir 31 | self.data_type = 'EEG' 32 | self.data_name = 'GrazB' 33 | 34 | # parameters of the GrazB dataset 35 | # length of a trial (in seconds) 36 | self.trial_len = 8 37 | # motor imagery appears in interval (in seconds) 38 | self.mi_interval = [4, 7] 39 | # idle perior prior to start of signal (in seconds) 40 | self.trial_offset = 0 41 | # total length of a trial (in seconds) 42 | self.trial_total = self.trial_len 43 | # sampling frequency (in Hz) 44 | self.expected_freq_s = 250 45 | 46 | # the graz dataset is split into T and E files 47 | self.fT = os.path.join(self.data_dir, "{id}T.mat".format(id=self.data_id)) 48 | self.fE = os.path.join(self.data_dir, "{id}E.mat".format(id=self.data_id)) 49 | 50 | for f in [self.fT, self.fE]: 51 | if not os.path.isfile(f): 52 | raise DatasetError("GrazB Dataset ({id}) file '{f}' unavailable".format(id=self.data_id, f=f)) 53 | 54 | # variables to store data 55 | self.raw_data = None 56 | self.labels = None 57 | self.trials = None 58 | self.sampling_freq = None 59 | 60 | 61 | def load(self, **kwargs): 62 | """Load a dataset. 63 | 64 | Args: 65 | **kwargs: Arbitrary keyword arguments (unused). 66 | 67 | Returns: 68 | Instance to the dataset (i.e. `self`). 69 | 70 | """ 71 | 72 | 73 | mat1 = scipy.io.loadmat(self.fT)['data'] 74 | #mat2 = scipy.io.loadmat(folder_dir + file_dir2)['data'] 75 | # dict_keys(['__header__', '__globals__', '__version__', 'data']) 76 | 77 | # Load Test Data 78 | data_bt = [] 79 | labels_bt = [] 80 | trials_bt = [] 81 | n_experiments = 3 82 | for i in range(n_experiments): 83 | data = mat1[0,i][0][0][0] 84 | trials = mat1[0,i][0][0][1] 85 | labels = mat1[0,i][0][0][2] - 1 86 | # TODO: fs shadows self.fs? do we need to store this somewhere? 87 | fs = mat1[0,i][0][0][3].flatten()[0] 88 | if fs != self.expected_freq_s: 89 | raise DatasetError("GrazB Dataset ({id}) Sampling Frequencies don't match (expected {f1}, got {f2})".format(id=self.data_id, f1=self.expected_freq_s, f2=fs)) 90 | artifacts = mat1[0,i][0][0][5] 91 | # remove artivacts 92 | artifact_idxs = np.where(artifacts == 1)[0] 93 | trials = np.delete(trials, artifact_idxs) 94 | labels = np.delete(labels, artifact_idxs) 95 | # add data to files 96 | data_bt.append(data) 97 | labels_bt.append(labels) 98 | trials_bt.append(trials) 99 | 100 | # add length of previous data set to adjust trial start points 101 | trials_bt[1] += data_bt[0].shape[0] 102 | trials_bt[2] += data_bt[0].shape[0] + data_bt[1].shape[0] 103 | 104 | # concatenate all data mat, trials, and labels 105 | data_bt = np.concatenate((data_bt[0], data_bt[1], data_bt[2])) 106 | trials_bt = np.concatenate((trials_bt[0], trials_bt[1], trials_bt[2])) 107 | labels_bt = np.concatenate((labels_bt[0], labels_bt[1], labels_bt[2])) 108 | 109 | self.raw_data = data_bt[:,:3] 110 | self.trials = trials_bt 111 | self.labels = labels_bt 112 | self.sampling_freq = self.expected_freq_s 113 | 114 | return self 115 | 116 | -------------------------------------------------------------------------------- /gumpy/data/khushaba.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DatasetError 2 | import os 3 | import numpy as np 4 | import scipy.io 5 | 6 | 7 | # TODO: BROKEN! 8 | class Khushaba(Dataset): 9 | """A Khushaba dataset. 10 | 11 | An Khushaba dataset usually consists of three files that are within a specific 12 | subdirectory. The implementation follows this structuring, i.e. the user 13 | needs to pass a base-directory as well as the identifier upon instantiation. 14 | 15 | """ 16 | 17 | def __init__(self, base_dir, identifier, class_labels=[], **kwargs): 18 | """Initialize a Khushaba dataset without loading it. 19 | 20 | Args: 21 | base_dir (str): The path to the base directory in which the Khushaba dataset resides. 22 | identifier (str): String identifier for the dataset, e.g. `S1` 23 | class_labels (list): A list of class labels 24 | **kwargs: Additional keyword arguments (unused) 25 | 26 | """ 27 | 28 | super(Khushaba, self).__init__(**kwargs) 29 | 30 | self.base_dir = base_dir 31 | self.data_id = identifier 32 | self.data_dir = os.path.join(self.base_dir, self.data_id) 33 | self.data_type = 'EMG' 34 | self.data_name = 'Khushaba' 35 | 36 | self._class_labels = ['Ball', 'ThInd', 'ThIndMid', 'Ind', 'LRMI', 'Th'] 37 | self._force_levels = ['high', 'low', 'med'] 38 | # number of classes in the dataset 39 | if not isinstance(class_labels, list): 40 | raise ValueError('Required list of class labels (`class_labels`)') 41 | 42 | self.class_labels = class_labels 43 | 44 | # all Khushaba datasets have the same configuration and parameters 45 | 46 | # length of a trial after trial_sample (in seconds) 47 | self.trial_len = None 48 | # idle period prior to trial start (in seconds) 49 | self.trial_offset = None 50 | # total time of the trial 51 | self.trial_total = None #self.trial_offset + self.trial_len 52 | # interval of motor imagery within trial_t (in seconds) 53 | self.mi_interval = [self.trial_offset, self.trial_offset + self.trial_len] 54 | 55 | # additional variables to store data as expected by the ABC 56 | self.raw_data = None 57 | self.trials = None 58 | self.labels = None 59 | self.sampling_freq = 2000 60 | 61 | 62 | def load(self, **kwargs): 63 | """Loads a Khushaba dataset. 64 | 65 | For more information about the returned values, see 66 | :meth:`gumpy.data.Dataset.load` 67 | """ 68 | 69 | self.trials = () 70 | self.labels = () 71 | 72 | for class_name in self.class_labels: 73 | classTrials, label_list = self.getClassTrials(class_name) 74 | self.trials = self.trials + (classTrials,) 75 | 76 | for trial in self.trials: 77 | if self.raw_data is None: 78 | self.raw_data = trial 79 | else: 80 | self.raw_data = np.concatenate((self.raw_data, trial)) 81 | 82 | self.labels = self.labels + (label_list,) 83 | 84 | return self 85 | 86 | 87 | def getClassTrials(self, class_name): 88 | """Return all class trials and labels. 89 | 90 | Args: 91 | class_name (str): The class name for which the trials should be returned 92 | 93 | Returns: 94 | A 2-tuple containing 95 | 96 | - **trials**: A list of all trials of `class_name` 97 | - **labels**: A list of corresponding labels for the trials 98 | 99 | """ 100 | Results = [] 101 | label_list = [] 102 | 103 | for force_level in self._force_levels: 104 | path = base_dir+'{}_Force Exp/{}_{}/'.format(self.data_id, class_name, force_level) 105 | 106 | for i in range(1,6): 107 | file = path+'{}_{}_{}_t{}.mat'.format(self.data_id, class_name, force_level, str(i)) 108 | 109 | trial = scipy.io.loadmat(file)['t{}'.format(i)] 110 | 111 | Results.append(trial) 112 | label_list.append(self._class_labels.index(class_name)) 113 | 114 | return Results, label_list 115 | -------------------------------------------------------------------------------- /gumpy/data/nst.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DatasetError 2 | import os 3 | import numpy as np 4 | import scipy.io 5 | 6 | 7 | class NST(Dataset): 8 | """An NST dataset. 9 | 10 | An NST dataset usually consists of three files that are within a specific 11 | subdirectory. The implementation follows this structuring, i.e. the user 12 | needs to pass a base-directory as well as the identifier upon instantiation. 13 | 14 | """ 15 | 16 | def __init__(self, base_dir, identifier, **kwargs): 17 | """Initialize an NST dataset without loading it. 18 | 19 | Args: 20 | base_dir (str): The path to the base directory in which the NST dataset resides. 21 | identifier (str): String identifier for the dataset, e.g. ``S1`` 22 | **kwargs: Additional keyword arguments: n_classes (int, default=3): number of classes to fetch. 23 | 24 | """ 25 | 26 | super(NST, self).__init__(**kwargs) 27 | 28 | self.base_dir = base_dir 29 | self.data_id = identifier 30 | self.data_dir = os.path.join(self.base_dir, self.data_id) 31 | self.data_type = 'EEG' 32 | self.data_name = 'NST' 33 | 34 | # number of classes in the dataset 35 | self.n_classes = kwargs.pop('n_classes', 3) 36 | 37 | # all NST datasets have the same configuration and parameters 38 | # length of a trial after trial_sample (in seconds) 39 | self.trial_len = 4 40 | # idle period prior to trial start (in seconds) 41 | self.trial_offset = 4 42 | # total time of the trial 43 | self.trial_total = self.trial_offset + self.trial_len+2 44 | # interval of motor imagery within trial_t (in seconds) 45 | self.mi_interval = [self.trial_offset, self.trial_offset + self.trial_len] 46 | 47 | # additional variables to store data as expected by the ABC 48 | self.raw_data = None 49 | self.trials = None 50 | self.labels = None 51 | self.sampling_freq = None 52 | 53 | # TODO: change the files on disk, don't check in here... 54 | # the first few sessions had a different file type 55 | self.f0 = os.path.join(self.data_dir, 'Run1.mat') 56 | self.f1 = os.path.join(self.data_dir, 'Run2.mat') 57 | self.f2 = os.path.join(self.data_dir, 'Run3.mat') 58 | 59 | # check if files are available 60 | for f in [self.f0, self.f1, self.f2]: 61 | if not os.path.isfile(f): 62 | raise DatasetError("NST Dataset ({id}) file '{f}' unavailable".format(id=self.data_id, f=f)) 63 | 64 | 65 | def load(self, **kwargs): 66 | """Loads an NST dataset. 67 | 68 | For more information about the returned values, see 69 | :meth:`gumpy.data.Dataset.load` 70 | """ 71 | mat1 = scipy.io.loadmat(self.f0) 72 | mat2 = scipy.io.loadmat(self.f1) 73 | mat3 = scipy.io.loadmat(self.f2) 74 | 75 | fs = mat1['Fs'].flatten()[0] 76 | # read matlab data 77 | raw_data1 = mat1['X'][:,0:3] 78 | raw_data2 = mat2['X'][:,0:3] 79 | raw_data3 = mat3['X'][:,0:3] 80 | trials1 = mat1['trial'][0] 81 | trials2 = mat2['trial'][0] 82 | trials3 = mat3['trial'][0] 83 | raw_data3 = mat3['X'][:,0:3] 84 | 85 | # extract labels 86 | labels1 = mat1['Y'].flatten() - 1 87 | labels2 = mat2['Y'].flatten() - 1 88 | labels3 = mat3['Y'].flatten() - 1 89 | 90 | # prepare trial data 91 | trials1 = mat1['trial'].flatten() - fs*self.trial_offset 92 | trials2 = mat2['trial'].flatten() - fs*self.trial_offset 93 | trials3 = mat3['trial'].flatten() - fs*self.trial_offset 94 | trials2 += raw_data1.T.shape[1] 95 | trials3 += raw_data1.T.shape[1] + raw_data2.T.shape[1] 96 | 97 | # concatenate matrices 98 | self.raw_data = np.concatenate((raw_data1, raw_data2, raw_data3)) 99 | self.labels = np.concatenate((labels1, labels2, labels3)) 100 | self.trials = np.concatenate((trials1, trials2, trials3)) 101 | self.sampling_freq = fs 102 | if self.n_classes == 2: # Remove class 3 if desired 103 | c3_idxs = np.where(self.labels==2)[0] 104 | self.labels = np.delete(self.labels, c3_idxs) 105 | self.trials = np.delete(self.trials, c3_idxs) 106 | 107 | return self 108 | -------------------------------------------------------------------------------- /gumpy/data/nst_emg.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset, DatasetError 2 | import os 3 | import numpy as np 4 | import scipy.io 5 | 6 | 7 | class NST_EMG(Dataset): 8 | """An NST_EMG dataset. 9 | 10 | An NST_EMG dataset usually consists of three files that are within a specific 11 | subdirectory. The implementation follows this structuring, i.e. the user 12 | needs to pass a base-directory as well as the identifier upon instantiation. 13 | 14 | If you require a copy of the data, please contact one of the gumpy authors. 15 | 16 | """ 17 | 18 | def __init__(self, base_dir, identifier, force_level, **kwargs): 19 | """Initialize an NST_EMG dataset without loading it. 20 | 21 | Args: 22 | base_dir (str): The path to the base directory in which the NST_EMG dataset resides. 23 | identifier (str): String identifier for the dataset, e.g. ``S1`` 24 | **kwargs: Additional keyword arguments: n_classes (int, default=3): number of classes to fetch. 25 | 26 | """ 27 | 28 | super(NST_EMG, self).__init__(**kwargs) 29 | 30 | self.base_dir = base_dir 31 | self.data_id = identifier 32 | self.force_level = force_level 33 | self.data_dir = os.path.join(self.base_dir, self.data_id) 34 | self.data_type = 'EMG' 35 | self.data_name = 'NST_EMG' 36 | 37 | self.electrodePairList = [(0, 2), (1, 3), (4, 6), (5,7)] 38 | self.channel = [] 39 | self.trialSignalOffset = (0.5,5.5) 40 | self.trialBgOffset = (5.5,10.5) 41 | self.trialForceOffset = (5,10) 42 | self.duration = 5 43 | 44 | # number of classes in the dataset 45 | self.n_classes = kwargs.pop('n_classes', 3) 46 | 47 | # all NST_EMG datasets have the same configuration and parameters 48 | # length of a trial after trial_sample (in seconds) 49 | self.trial_len = 5 50 | # idle period prior to trial start (in seconds) 51 | self.trial_offset = 5 52 | # total time of the trial 53 | self.trial_total = self.trial_offset + self.trial_len 54 | # interval of motor imagery within trial_t (in seconds) 55 | self.mi_interval = [self.trial_offset, self.trial_offset + self.trial_len] 56 | 57 | # additional variables to store data as expected by the ABC 58 | self.raw_data = None 59 | self.trials = None 60 | self.labels = None 61 | self.sampling_freq = None 62 | 63 | file_list_highForce = [] 64 | file_list_lowForce = [] 65 | 66 | # S1 67 | if self.data_id == 'S1': 68 | file_list_highForce = ['session_14_26_15_01_2018.mat', 'session_14_35_15_01_2018.mat', 'session_14_43_15_01_2018.mat'] 69 | file_list_lowForce = ['session_15_00_15_01_2018.mat', 'session_15_08_15_01_2018.mat', 'session_15_16_15_01_2018.mat'] 70 | 71 | # S2 72 | elif self.data_id == 'S2': 73 | file_list_highForce = ['session_14_51_10_01_2018.mat', 'session_15_10_10_01_2018.mat', 'session_15_10_10_01_2018.mat'] 74 | file_list_lowForce = ['session_15_25_10_01_2018.mat', 'session_15_32_10_01_2018.mat', 'session_15_45_10_01_2018.mat'] 75 | 76 | # S3 77 | elif self.data_id == 'S3': 78 | file_list_highForce = ['session_13_04_16_01_2018.mat', 'session_13_10_16_01_2018.mat', 'session_13_18_16_01_2018.mat'] 79 | file_list_lowForce = ['session_13_26_16_01_2018.mat', 'session_13_31_16_01_2018.mat', 'session_13_35_16_01_2018.mat'] 80 | 81 | # S4 82 | elif self.data_id == 'S4': 83 | file_list_highForce = ['session_13_36_09_03_2018', 'session_13_39_09_03_2018'] 84 | file_list_lowForce = ['session_13_42_09_03_2018', 'session_13_44_09_03_2018'] 85 | 86 | # S4 87 | if self.force_level == 'high': 88 | self.fileList = file_list_highForce 89 | elif self.force_level == 'low': 90 | self.fileList = file_list_lowForce 91 | 92 | 93 | def load(self, **kwargs): 94 | """Load an NST_EMG dataset. 95 | 96 | For more information about the returned values, see 97 | :meth:`gumpy.data.Dataset.load` 98 | """ 99 | 100 | trial_len = 5 # sec (length of a trial after trial_sample) 101 | trial_offset = 5 # idle period prior to trial start [sec] 102 | self.trial_total = trial_offset + trial_len # total length of trial 103 | self.mi__interval = [trial_offset, trial_offset+trial_len] # interval of motor imagery within trial_t [sec] 104 | 105 | matrices = [] 106 | raw_data_ = [] 107 | labels_ = [] 108 | trials_ = [] 109 | forces_ = [] 110 | 111 | for file in self.fileList: 112 | try: 113 | fname = os.path.join(self.data_dir, file) 114 | if fname.exists(): 115 | matrices.append(scipy.io.loadmat(fname)) 116 | except Exception as e: 117 | print('An exception occured while reading file {}: {}'.format(file, e)) 118 | 119 | 120 | # read matlab data 121 | for matrix in matrices: 122 | raw_data_.append(matrix['X'][:,:]) 123 | labels_.append(matrix['Y'][:]) 124 | trials_.append(matrix['trial'][:]) 125 | 126 | #forces_.append(matrix['force'][:].T) 127 | 128 | size_X = len(matrix['X'][:,0]) 129 | size_force = np.shape(matrix['force'][:])[1] 130 | 131 | #print(size_X) 132 | #print(size_force) 133 | 134 | Zero = size_X-size_force 135 | f = np.zeros((1, size_X)) 136 | f[0, Zero:] = matrix['force'][:] 137 | forces_.append(f.T) 138 | 139 | 140 | #forces_.append(matrix['force'][:]) 141 | 142 | # to get the correct values of the trials 143 | for i in range(1, len(trials_)): 144 | trials_[i] += raw_data_[i-1].T.shape[1] 145 | 146 | #combine matrices together 147 | self.raw_data = np.concatenate(tuple(raw_data_)) 148 | self.labels = np.concatenate(tuple(labels_)) 149 | self.trials = np.concatenate(tuple(trials_)) 150 | self.forces = np.concatenate(tuple(forces_)) 151 | 152 | # Resetting points higher than max intensity of force to 0 153 | self.forces[self.forces > 20] = 0 154 | # Resetting points lower than 0 to 0 155 | self.forces[self.forces < 0] = 0 156 | 157 | # Remove class 3 158 | c3_idxs = np.where(self.labels==3)[0] 159 | self.labels = np.delete(self.labels, c3_idxs) 160 | self.trials = np.delete(self.trials, c3_idxs) 161 | 162 | self.labels = np.hstack((self.labels, 3*np.ones(int(self.trials.shape[0]/3)))) 163 | 164 | 165 | self.sampling_freq = matrices[0]['Fs'].flatten()[0] 166 | 167 | return self 168 | -------------------------------------------------------------------------------- /gumpy/features.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .classification import available_classifiers 4 | import matplotlib.pyplot as plt 5 | import sklearn.decomposition 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.preprocessing import StandardScaler 8 | from mlxtend.feature_selection import SequentialFeatureSelector as SFS 9 | from mlxtend.plotting import plot_sequential_feature_selection as plot_sfs 10 | import numpy as np 11 | import scipy.linalg as la 12 | import pywt 13 | 14 | 15 | def sequential_feature_selector(features, labels, classifier, k_features, kfold, selection_type, plot=True, **kwargs): 16 | """Sequential feature selection to reduce the number of features. 17 | 18 | The function reduces a d-dimensional feature space to a k-dimensional 19 | feature space by sequential feature selection. The features are selected 20 | using ``mlxtend.feature_selection.SequentialFeatureSelection`` which 21 | essentially selects or removes a feature from the d-dimensional input space 22 | until the preferred size is reached. 23 | 24 | The function will pass ``ftype='feature'`` and forward ``features`` on to a 25 | classifier's ``static_opts`` method. 26 | 27 | Args: 28 | features: The original d-dimensional feature space 29 | labels: corresponding labels 30 | classifier (str or object): The classifier which should be used for 31 | feature selection. This can be either a string (name of a classifier 32 | known to gumpy) or an instance of a classifier which adheres 33 | to the sklearn classifier interface. 34 | k_features (int): Number of features to select 35 | kfold (int): k-fold cross validation 36 | selection_type (str): One of ``SFS`` (Sequential Forward Selection), 37 | ``SBS`` (Sequential Backward Selection), ``SFFS`` (Sequential Forward 38 | Floating Selection), ``SBFS`` (Sequential Backward Floating Selection) 39 | plot (bool): Plot the results of the dimensinality reduction 40 | **kwargs: Additional keyword arguments that will be passed to the 41 | Classifier instantiation 42 | 43 | Returns: 44 | A 3-element tuple containing 45 | 46 | - **feature index**: Index of features in the remaining set 47 | - **cv_scores**: cross validation scores during classification 48 | - **algorithm**: Algorithm that was used for search 49 | 50 | """ 51 | 52 | # retrieve the appropriate classifier 53 | if isinstance(classifier, str): 54 | if not (classifier in available_classifiers): 55 | raise ClassifierError("Unknown classifier {c}".format(c=classifier.__repr__())) 56 | 57 | kwopts = kwargs.pop('opts', dict()) 58 | # opts = dict() 59 | 60 | # retrieve the options that we need to forward to the classifier 61 | # TODO: should we forward all arguments to sequential_feature_selector ? 62 | opts = available_classifiers[classifier].static_opts('sequential_feature_selector', features=features) 63 | opts.update(kwopts) 64 | 65 | # XXX: now merged into the static_opts invocation. TODO: test 66 | # if classifier == 'SVM': 67 | # opts['cross_validation'] = kwopts.pop('cross_validation', False) 68 | # elif classifier == 'RandomForest': 69 | # opts['cross_validation'] = kwopts.pop('cross_validation', False) 70 | # elif classifier == 'MLP': 71 | # # TODO: check if the dimensions are correct here 72 | # opts['hidden_layer_sizes'] = (features.shape[1], features.shape[2]) 73 | # get all additional entries for the options 74 | # opts.update(kwopts) 75 | 76 | # retrieve a classifier object 77 | classifier_obj = available_classifiers[classifier](**opts) 78 | 79 | # extract the backend classifier 80 | clf = classifier_obj.clf 81 | else: 82 | # if we received a classifier object we'll just use this one 83 | clf = classifier.clf 84 | 85 | 86 | if selection_type == 'SFS': 87 | algorithm = "Sequential Forward Selection (SFS)" 88 | sfs = SFS(clf, k_features, forward=True, floating=False, 89 | verbose=2, scoring='accuracy', cv=kfold, n_jobs=-1) 90 | 91 | elif selection_type == 'SBS': 92 | algorithm = "Sequential Backward Selection (SBS)" 93 | sfs = SFS(clf, k_features, forward=False, floating=False, 94 | verbose=2, scoring='accuracy', cv=kfold, n_jobs=-1) 95 | 96 | elif selection_type == 'SFFS': 97 | algorithm = "Sequential Forward Floating Selection (SFFS)" 98 | sfs = SFS(clf, k_features, forward=True, floating=True, 99 | verbose=2, scoring='accuracy', cv=kfold, n_jobs=-1) 100 | 101 | elif selection_type == 'SBFS': 102 | algorithm = "Sequential Backward Floating Selection (SFFS)" 103 | sfs = SFS(clf, k_features, forward=True, floating=True, 104 | verbose=2, scoring='accuracy', cv=kfold, n_jobs=-1) 105 | 106 | else: 107 | raise Exception("Unknown selection type '{}'".format(selection_type)) 108 | 109 | 110 | pipe = make_pipeline(StandardScaler(), sfs) 111 | pipe.fit(features, labels) 112 | subsets = sfs.subsets_ 113 | feature_idx = sfs.k_feature_idx_ 114 | cv_scores = sfs.k_score_ 115 | 116 | if plot: 117 | fig1 = plot_sfs(sfs.get_metric_dict(), kind='std_dev') 118 | plt.ylim([0.5, 1]) 119 | plt.title(algorithm) 120 | plt.grid() 121 | plt.show() 122 | 123 | return feature_idx, cv_scores, algorithm, sfs, clf 124 | 125 | 126 | # TODO: improve description of argument. I have no clue what exactly I should 127 | # pass to the function! 128 | def CSP(tasks): 129 | """This function extracts Common Spatial Pattern (CSP) features. 130 | 131 | Args: 132 | For N tasks, N arrays are passed to CSP each with dimensionality (# of 133 | trials of task N) x (feature vector) 134 | 135 | Returns: 136 | A 2D CSP features matrix. 137 | 138 | """ 139 | if len(tasks) < 2: 140 | print("Must have at least 2 tasks for filtering.") 141 | return (None,) * len(tasks) 142 | else: 143 | filters = () 144 | # CSP algorithm 145 | # For each task x, find the mean variance matrices Rx and not_Rx, which will be used to compute spatial filter SFx 146 | iterator = range(0,len(tasks)) 147 | for x in iterator: 148 | # Find Rx 149 | Rx = covarianceMatrix(tasks[x][0]) 150 | for t in range(1,len(tasks[x])): 151 | Rx += covarianceMatrix(tasks[x][t]) 152 | Rx = Rx / len(tasks[x]) 153 | 154 | # Find not_Rx 155 | count = 0 156 | not_Rx = Rx * 0 157 | for not_x in [element for element in iterator if element != x]: 158 | for t in range(0,len(tasks[not_x])): 159 | not_Rx += covarianceMatrix(tasks[not_x][t]) 160 | count += 1 161 | not_Rx = not_Rx / count 162 | 163 | # Find the spatial filter SFx 164 | SFx = spatialFilter(Rx,not_Rx) 165 | filters += (SFx,) 166 | 167 | # Special case: only two tasks, no need to compute any more mean variances 168 | if len(tasks) == 2: 169 | filters += (spatialFilter(not_Rx,Rx),) 170 | break 171 | return filters 172 | 173 | 174 | # covarianceMatrix takes a matrix A and returns the covariance matrix, scaled by the variance 175 | def covarianceMatrix(A): 176 | """This function computes the covariance Matrix 177 | 178 | Args: 179 | A: 2D matrix 180 | 181 | Returns: 182 | A 2D covariance matrix scaled by the variance 183 | """ 184 | #Ca = np.dot(A,np.transpose(A))/np.trace(np.dot(A,np.transpose(A))) 185 | Ca = np.cov(A) 186 | return Ca 187 | 188 | 189 | def spatialFilter(Ra,Rb): 190 | """This function extracts spatial filters 191 | 192 | Args: 193 | Ra, Rb: Covariance matrices Ra and Rb 194 | 195 | Returns: 196 | A 2D spatial filter matrix 197 | """ 198 | 199 | R = Ra + Rb 200 | E,U = la.eig(R) 201 | 202 | # CSP requires the eigenvalues E and eigenvector U be sorted in descending order 203 | ord = np.argsort(E) 204 | ord = ord[::-1] # argsort gives ascending order, flip to get descending 205 | E = E[ord] 206 | U = U[:,ord] 207 | 208 | # Find the whitening transformation matrix 209 | P = np.dot(np.sqrt(la.inv(np.diag(E))),np.transpose(U)) 210 | 211 | # The mean covariance matrices may now be transformed 212 | Sa = np.dot(P,np.dot(Ra,np.transpose(P))) 213 | Sb = np.dot(P,np.dot(Rb,np.transpose(P))) 214 | 215 | # Find and sort the generalized eigenvalues and eigenvector 216 | E1,U1 = la.eig(Sa,Sb) 217 | ord1 = np.argsort(E1) 218 | ord1 = ord1[::-1] 219 | E1 = E1[ord1] 220 | U1 = U1[:,ord1] 221 | 222 | # The projection matrix (the spatial filter) may now be obtained 223 | SFa = np.dot(np.transpose(U1),P) 224 | #return SFa.astype(np.float32) 225 | return SFa 226 | 227 | 228 | def PCA_dim_red(features, var_desired): 229 | """Dimensionality reduction of features using PCA. 230 | 231 | Args: 232 | features (matrix (2d np.array)): The feature matrix 233 | var_desired (float): desired preserved variance 234 | 235 | Returns: 236 | features with reduced dimensions 237 | 238 | """ 239 | # PCA 240 | pca = sklearn.decomposition.PCA(n_components=features.shape[1]-1) 241 | pca.fit(features) 242 | # print('pca.explained_variance_ratio_:\n',pca.explained_variance_ratio_) 243 | var_sum = pca.explained_variance_ratio_.sum() 244 | var = 0 245 | for n, v in enumerate(pca.explained_variance_ratio_): 246 | var += v 247 | if var / var_sum >= var_desired: 248 | features_reduced = sklearn.decomposition.PCA(n_components=n+1).fit_transform(features) 249 | return features_reduced 250 | 251 | 252 | def RMS_features_extraction(data, trial_list, window_size, window_shift): 253 | """Extract RMS features from data 254 | 255 | Args: 256 | data: 2D (time points, Channels) 257 | trial_list: list of the trials 258 | window_size: Size of the window for extracting features 259 | window_shift: size of the overalp 260 | 261 | Returns: 262 | The features matrix (trials, features) 263 | """ 264 | if window_shift > window_size: 265 | raise ValueError("window_shift > window_size") 266 | 267 | fs = data.sampling_freq 268 | 269 | n_features = int(data.duration/(window_size-window_shift)) 270 | 271 | X = np.zeros((len(trial_list), n_features*4)) 272 | 273 | t = 0 274 | for trial in trial_list: 275 | # x3 is the worst of all with 43.3% average performance 276 | x1=gumpy.signal.rms(trial[0], fs, window_size, window_shift) 277 | x2=gumpy.signal.rms(trial[1], fs, window_size, window_shift) 278 | x3=gumpy.signal.rms(trial[2], fs, window_size, window_shift) 279 | x4=gumpy.signal.rms(trial[3], fs, window_size, window_shift) 280 | x=np.concatenate((x1, x2, x3, x4)) 281 | X[t, :] = np.array([x]) 282 | t += 1 283 | return X 284 | 285 | 286 | def dwt_features(data, trials, level, sampling_freq, w, n, wavelet): 287 | """Extract discrete wavelet features 288 | 289 | Args: 290 | data: 2D (time points, Channels) 291 | trials: Trials vector 292 | lLevel: level of DWT decomposition 293 | sampling_freq: Sampling frequency 294 | 295 | Returns: 296 | The features matrix (Nbre trials, Nbre features) 297 | """ 298 | 299 | # number of features per trial 300 | n_features = 9 301 | # allocate memory to store the features 302 | X = np.zeros((len(trials), n_features)) 303 | 304 | # Extract Features 305 | for t, trial in enumerate(trials): 306 | signals = data[trial + fs*4 + (w[0]) : trial + fs*4 + (w[1])] 307 | coeffs_c3 = pywt.wavedec(data = signals[:,0], wavelet=wavelet, level=level) 308 | coeffs_c4 = pywt.wavedec(data = signals[:,1], wavelet=wavelet, level=level) 309 | coeffs_cz = pywt.wavedec(data = signals[:,2], wavelet=wavelet, level=level) 310 | 311 | X[t, :] = np.array([ 312 | np.std(coeffs_c3[n]), np.mean(coeffs_c3[n]**2), 313 | np.std(coeffs_c4[n]), np.mean(coeffs_c4[n]**2), 314 | np.std(coeffs_cz[n]), np.mean(coeffs_cz[n]**2), 315 | np.mean(coeffs_c3[n]), 316 | np.mean(coeffs_c4[n]), 317 | np.mean(coeffs_cz[n])]) 318 | 319 | return X 320 | 321 | 322 | def alpha_subBP_features(data): 323 | """Extract alpha bands 324 | 325 | Args: 326 | data: 2D (time points, Channels) 327 | 328 | Returns: 329 | The alpha sub-bands 330 | """ 331 | # filter data in sub-bands by specification of low- and high-cut frequencies 332 | alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6) 333 | alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6) 334 | alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6) 335 | alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6) 336 | 337 | # return a list of sub-bands 338 | return [alpha1, alpha2, alpha3, alpha4] 339 | 340 | 341 | def beta_subBP_features(data): 342 | """Extract beta bands 343 | 344 | Args: 345 | data: 2D (time points, Channels) 346 | 347 | Returns: 348 | The beta sub-bands 349 | """ 350 | beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6) 351 | beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6) 352 | beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6) 353 | beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6) 354 | return [beta1, beta2, beta3, beta4] 355 | 356 | 357 | def powermean(data, trial, fs, w): 358 | """Compute the mean power of the data 359 | 360 | Args: 361 | data: 2D (time points, Channels) 362 | trial: trial vector 363 | fs: sampling frequency 364 | w: window 365 | 366 | Returns: 367 | The mean power 368 | """ 369 | return np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],0],2).mean(), \ 370 | np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],1],2).mean(), \ 371 | np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],2],2).mean() 372 | 373 | 374 | def log_subBP_feature_extraction(alpha, beta, trials, fs, w): 375 | """Extract the log power of alpha and beta bands 376 | 377 | Args: 378 | alpha: filtered data in the alpha band 379 | beta: filtered data in the beta band 380 | trials: trial vector 381 | fs: sampling frequency 382 | w: window 383 | 384 | Returns: 385 | The features matrix 386 | """ 387 | # number of features combined for all trials 388 | n_features = 15 389 | # initialize the feature matrix 390 | X = np.zeros((len(trials), n_features)) 391 | 392 | # Extract features 393 | for t, trial in enumerate(trials): 394 | power_c31, power_c41, power_cz1 = powermean(alpha[0], trial, fs, w) 395 | power_c32, power_c42, power_cz2 = powermean(alpha[1], trial, fs, w) 396 | power_c33, power_c43, power_cz3 = powermean(alpha[2], trial, fs, w) 397 | power_c34, power_c44, power_cz4 = powermean(alpha[3], trial, fs, w) 398 | power_c31_b, power_c41_b, power_cz1_b = powermean(beta[0], trial, fs, w) 399 | 400 | X[t, :] = np.array( 401 | [np.log(power_c31), np.log(power_c41), np.log(power_cz1), 402 | np.log(power_c32), np.log(power_c42), np.log(power_cz2), 403 | np.log(power_c33), np.log(power_c43), np.log(power_cz3), 404 | np.log(power_c34), np.log(power_c44), np.log(power_cz4), 405 | np.log(power_c31_b), np.log(power_c41_b), np.log(power_cz1_b)]) 406 | 407 | return X 408 | 409 | 410 | -------------------------------------------------------------------------------- /gumpy/plot.py: -------------------------------------------------------------------------------- 1 | """Functions for plotting EEG processing results. 2 | """ 3 | 4 | import numpy as np 5 | import matplotlib as mpl 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | import sklearn.metrics as skm 9 | import seaborn as sns 10 | import pandas as pd 11 | import pywt 12 | import scipy.signal 13 | import sklearn.decomposition 14 | from matplotlib.gridspec import GridSpec 15 | from pylab import rcParams 16 | import itertools 17 | 18 | 19 | def plot_confusion_matrix(path, cm, target_names, title='Confusion matrix ', cmap=None, normalize=True): 20 | """Produces a plot for a confusion matrix and saves it to file. 21 | 22 | Args: 23 | path (str): Filename of produced plot 24 | cm (ndarray): confusion matrix from sklearn.metrics.confusion_matrix 25 | target_names ([str]): given classification classes such as [0, 1, 2] the 26 | class names, for example: ['high', 'medium', 'low'] 27 | title (str): the text to display at the top of the matrix 28 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm see 29 | http://matplotlib.org/examples/color/colormaps_reference.html 30 | plt.get_cmap('jet') or plt.cm.Blues 31 | normalize (bool): if False, plot the raw numbers. If True, plot the 32 | proportions 33 | 34 | Example: 35 | plot_confusion_matrix(cm = cm, # confusion matrix created by 36 | # sklearn.metrics.confusion_matrix 37 | normalize = True, # show proportions 38 | target_names = y_labels_vals, # list of names of the classes 39 | title = best_estimator_name) # title of graph 40 | 41 | References: 42 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 43 | 44 | """ 45 | 46 | accuracy = np.trace(cm) / float(np.sum(cm)) 47 | misclass = 1 - accuracy 48 | 49 | if cmap is None: 50 | cmap = plt.get_cmap('Blues') 51 | 52 | fig = plt.figure(figsize=(8, 6)) 53 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 54 | plt.title(title) 55 | plt.colorbar() 56 | 57 | if target_names is not None: 58 | tick_marks = np.arange(len(target_names)) 59 | plt.xticks(tick_marks, target_names, rotation=45) 60 | plt.yticks(tick_marks, target_names) 61 | 62 | if normalize: 63 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 64 | 65 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2 66 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 67 | if normalize: 68 | plt.text(j, i, "{:0.4f}".format(cm[i, j]), 69 | horizontalalignment="center", 70 | color="white" if cm[i, j] > thresh else "black") 71 | else: 72 | plt.text(j, i, "{:,}".format(cm[i, j]), 73 | horizontalalignment="center", 74 | color="white" if cm[i, j] > thresh else "black") 75 | 76 | 77 | plt.tight_layout() 78 | plt.ylabel('True label') 79 | #plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) 80 | plt.show() 81 | fig.savefig(path) 82 | 83 | 84 | # TODO: check formatting (whitespaces, etc) 85 | # TODO: check all variable names and improve them 86 | def ROC_curve(Y_pred, Y_test, fig=None): 87 | Y_score = np.array(Y_pred) 88 | # The following were moved inside the function call (roc_curve) to avoid 89 | # potential side effects of this functin 90 | # Y_score -=1 91 | # Y_test -=1 92 | 93 | # print (roc_auc_score(y_test, y_score)) 94 | 95 | fpr, tpr, _ = sklearn.metrics.roc_curve(Y_test - 1, Y_score - 1) 96 | 97 | # plotting 98 | if fig is None: 99 | fig = plt.figure() 100 | plt.plot(fpr, tpr, color= 'red', lw = 2) 101 | plt.plot([0, 1], [0, 1], color='navy', lw=2) 102 | plt.xlim([0.0, 1.0]) 103 | plt.ylim([0.0, 1.05]) 104 | plt.xlabel('False Positive Rate') 105 | plt.ylabel('True Positive Rate') 106 | plt.title('Roc curve') 107 | plt.legend(loc="lower right") 108 | plt.show() 109 | 110 | 111 | def confusion_matrix(true_labels, predicted_labels, cmap=plt.cm.Blues): 112 | cm = skm.confusion_matrix(true_labels, predicted_labels) 113 | # TODO: 114 | # print(cm) 115 | # Show confusion matrix in a separate window ? 116 | plt.matshow(cm,cmap=cmap) 117 | plt.title('Confusion matrix') 118 | plt.colorbar() 119 | plt.ylabel('True label') 120 | plt.xlabel('Predicted label') 121 | plt.show() 122 | 123 | 124 | # TODO: permit the user to specify the figure where this plot shall appear 125 | def accuracy_results_plot(data_path): 126 | data = pd.read_csv(data_path,index_col=0) 127 | sns.boxplot(data=data) 128 | sns.set(rc={"figure.figsize": (9, 6)}) 129 | ax = sns.boxplot( data=data) 130 | ax.set_xlabel(x_label,fontsize=15) 131 | ax.set_ylabel(y_label,fontsize=15) 132 | plt.show() 133 | 134 | 135 | def reconstruct_without_approx(xs, labels, level, fig=None): 136 | # reconstruct 137 | rs = [pywt.upcoef('d', x, 'db4', level=level) for x in xs] 138 | 139 | # generate plot 140 | if fig is None: 141 | fig = plt.figure() 142 | for i, x in enumerate(xs): 143 | plt.plot((np.abs(x))**2, label="Power of reconstructed signal ({})".format(labels[i])) 144 | 145 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 146 | return rs, fig 147 | 148 | 149 | def reconstruct_with_approx(cDs, labels, wavelet, fig=None): 150 | rs = [pywt.idwt(cA=None, cD=cD, wavelet=wavelet) for cD in cDs] 151 | 152 | if fig is None: 153 | fig = plt.figure() 154 | 155 | for i, r in enumerate(rs): 156 | plt.plot((np.abs(r))**2, label="Power of reconstructed signal ({})".format(labels[i])) 157 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 158 | 159 | return rs, fig 160 | 161 | 162 | def fft(x, fs, fig_fft=None, fig_psd=None): 163 | t = np.arange(fs) 164 | signal_fft = np.fft.fft(x) 165 | signal_psd = np.abs(signal_fft)**2 166 | freq = np.linspace(0, fs, len(signal_fft)) 167 | freq1 = np.linspace(0, fs, len(signal_psd)) 168 | 169 | if fig_fft is None: 170 | fig_fft = plt.figure() 171 | plt.plot(freq, signal_fft, label="fft") 172 | 173 | if fig_psd is None: 174 | fig_psd = plt.figure() 175 | plt.plot(freq, signal_psd, label="PSD") 176 | 177 | return signal_fft, signal_psd, fig_fft, fig_psd 178 | 179 | 180 | def dwt(approx, details, labels, level, sampling_freq, class_str=None): 181 | """ 182 | Plot the results of a DWT transform. 183 | """ 184 | 185 | fig, axis = plt.subplots(level+1, 1, figsize=(8, 8)) 186 | fig.tight_layout() 187 | 188 | # plot the approximation 189 | for i, l in enumerate(labels): 190 | axis[0].plot(approx[i], label=l) 191 | axis[0].legend() 192 | if class_str is None: 193 | axis[0].set_title('DWT approximations (level={}, sampling-freq={}Hz)'.format(level, sampling_freq)) 194 | else: 195 | axis[0].set_title('DWT approximations, {} (level={}, sampling-freq={}Hz)'.format(class_str, level, sampling_freq)) 196 | axis[0].set_ylabel('(A={})'.format(level)) 197 | 198 | # build the rows of detail coefficients 199 | for j in range (1,level+1): 200 | for i, l in enumerate(labels): 201 | axis[j].plot(details[i][j-1], label=l) 202 | if class_str is None: 203 | axis[j].set_title('DWT Coeffs (level{}, sampling-freq={}Hz)'.format(level, sampling_freq)) 204 | else: 205 | axis[j].set_title('DWT Coeffs, {} (level={}, sampling-freq={}Hz)'.format(class_str, level, sampling_freq)) 206 | axis[j].legend() 207 | axis[j].set_ylabel('(D={})'.format(j)) 208 | 209 | return axis 210 | 211 | 212 | def welch_psd(xs, labels, sampling_freq, fig=None): 213 | """Compute and plot the power spectrum density (PSD) using Welch's method. 214 | """ 215 | 216 | fs = [] 217 | ps = [] 218 | for i, x in enumerate(xs): 219 | f, p = scipy.signal.welch(x, sampling_freq, 'flattop', scaling='spectrum') 220 | fs.append(f) 221 | ps.append(p) 222 | 223 | if fig is None: 224 | fig = plt.figure() 225 | 226 | plt.subplots_adjust(hspace=0.4) 227 | for i, p in enumerate(ps): 228 | plt.semilogy(f/8, p.T, label=labels[i]) 229 | 230 | plt.xlabel('frequency [Hz]') 231 | plt.ylabel('PSD') 232 | 233 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 234 | plt.grid() 235 | plt.show() 236 | 237 | return ps, fig 238 | 239 | 240 | 241 | def artifact_removal(X, S, S_reconst, fig=None): 242 | """Plot the results of an artifact removal. 243 | 244 | This function displays the results after artifact removal, for instance 245 | performed via :func:`gumpy.signal.artifact_removal`. 246 | 247 | Parameters 248 | ---------- 249 | X: 250 | Observations 251 | S: 252 | True sources 253 | S_reconst: 254 | The reconstructed signal 255 | """ 256 | 257 | if fig is None: 258 | fig = plt.figure() 259 | 260 | models = [X, S, S_reconst] 261 | names = ['Observations (mixed signal)', 262 | 'True Sources', 263 | 'ICA recovered signals'] 264 | for ii, (model, name) in enumerate(zip(models, names), 1): 265 | plt.subplot(3, 1, ii) 266 | plt.title(name) 267 | plt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.46) 268 | plt.show() 269 | 270 | 271 | def PCA_2D(X, X_train, Y_train, colors=None): 272 | # computation 273 | pca_2comp = PCA(n_components=2) 274 | X_2comp = pca_2comp.fit_transform(X) 275 | 276 | # color and figure initialization 277 | if colors is None: 278 | colors = ['red','cyan'] 279 | if fig is None: 280 | fig = plt.figure() 281 | 282 | # plotting 283 | fig.suptitle('2D - Data') 284 | ax = fig.add_subplot(1,1,1) 285 | ax.scatter(X_train.T[0], X_train.T[1], alpha=0.5, 286 | c=Y_train, cmap=mpl.colors.ListedColormap(colors)) 287 | ax.set_xlabel('x1') 288 | ax.set_ylabel('x2') 289 | 290 | 291 | def PCA_3D(X, X_train, Y_train, fig=None, colors=None): 292 | # computation 293 | pca_3comp = sklearn.decomposition.PCA(n_components=3) 294 | X_3comp = pca_3comp.fit_transform(X) 295 | 296 | # color and figure initialization 297 | if colors is None: 298 | colors = ['red','cyan'] 299 | if fig is None: 300 | fig = plt.figure() 301 | 302 | # plotting 303 | fig.suptitle('3D - Data') 304 | ax = fig.add_subplot(1,1,1, projection='3d') 305 | ax.scatter(X_train.T[0], X_train.T[1], X_train.T[2], alpha=0.5, 306 | c=Y_train, cmap=mpl.colors.ListedColormap(colors)) 307 | ax.set_xlabel('x1') 308 | ax.set_ylabel('x2') 309 | ax.set_zlabel('x3') 310 | 311 | 312 | # TODO: allow user to pass formatting control, e.g. colors, cmap, etc 313 | def PCA(ttype, X, X_train, Y_train, fig=None, colors=None): 314 | plot_fns = {'2D': PCA_2D, '3D': PCA_3D} 315 | if not ttype in plot_fns: 316 | raise Exception("Transformation type '{ttype}' unknown".format(ttype=ttype)) 317 | plot_fns[ttype](X, X_train, Y_train, fig, colors) 318 | 319 | 320 | 321 | def EEG_bandwave_visualizer(data, band_wave, n_trial, lo, hi, fig=None): 322 | if not fig: 323 | fig = plt.figure() 324 | 325 | plt.clf() 326 | plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 0], 327 | alpha=0.7, label='C3') 328 | plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 1], 329 | alpha=0.7, label='C4') 330 | plt.plot(band_wave[data.trials[n_trial]-data.mi_interval[0]*data.sampling_freq : data.trials[n_trial]+data.mi_interval[0]*data.sampling_freq, 2], 331 | alpha=0.7, label='Cz') 332 | 333 | plt.legend() 334 | plt.title("Filtered data (Band wave {}-{})".format(lo, hi)) 335 | 336 | 337 | # TODO: check if this is too specific 338 | # TODO: documentation 339 | # TODO: units missing 340 | def average_power(data_class1, lowcut, highcut, interval, sampling_freq, logarithmic_power): 341 | fs = sampling_freq 342 | if logarithmic_power: 343 | power_c3_c1_a = np.log(np.power(data_class1[0], 2).mean(axis=0)) 344 | power_c4_c1_a = np.log(np.power(data_class1[1], 2).mean(axis=0)) 345 | power_cz_c1_a = np.log(np.power(data_class1[2], 2).mean(axis=0)) 346 | power_c3_c2_a = np.log(np.power(data_class1[3], 2).mean(axis=0)) 347 | power_c4_c2_a = np.log(np.power(data_class1[4], 2).mean(axis=0)) 348 | power_cz_c2_a = np.log(np.power(data_class1[5], 2).mean(axis=0)) 349 | else: 350 | power_c3_c1_a = np.power(data_class1[0], 2).mean(axis=0) 351 | power_c4_c1_a = np.power(data_class1[1], 2).mean(axis=0) 352 | power_cz_c1_a = np.power(data_class1[2], 2).mean(axis=0) 353 | power_c3_c2_a = np.power(data_class1[3], 2).mean(axis=0) 354 | power_c4_c2_a = np.power(data_class1[4], 2).mean(axis=0) 355 | power_cz_c2_a = np.power(data_class1[5], 2).mean(axis=0) 356 | 357 | # time indices 358 | t = np.linspace(interval[0],interval[1],len(power_c3_c1_a[fs*interval[0]:fs*interval[1]])) 359 | 360 | # first figure, left motor imagery 361 | plt.figure() 362 | plt.plot(t, power_c3_c1_a[fs*interval[0]:fs*interval[1]], c='blue', 363 | label='C3', alpha=0.7) 364 | plt.plot(t,power_c4_c1_a [fs*interval[0]:fs*interval[1]],c='red', 365 | label='C4', alpha=0.7) 366 | plt.legend() 367 | plt.xlabel('Time') 368 | if logarithmic_power: 369 | plt.ylabel('Logarithmic Power') 370 | else: 371 | plt.ylabel('Power') 372 | plt.title("Left motor imagery movements ".format(lowcut, highcut)) 373 | plt.show() 374 | 375 | # second figure, right motor imagery 376 | plt.figure() 377 | plt.clf() 378 | plt.plot(t, power_c3_c2_a[fs*interval[0] : fs*interval[1]], c='blue', label='C3', alpha=0.7) 379 | plt.plot(t, power_c4_c2_a[fs*interval[0] : fs*interval[1]], c='red', label='C4', alpha=0.7) 380 | plt.legend() 381 | plt.xlabel('Time') 382 | if logarithmic_power: 383 | plt.ylabel('Logarithmic Power') 384 | else: 385 | plt.ylabel('Power') 386 | plt.title("Right motor imagery movements".format(lowcut, highcut)) 387 | -------------------------------------------------------------------------------- /gumpy/signal.py: -------------------------------------------------------------------------------- 1 | """Signal processing utilities, filters, and data post-processing routines. 2 | 3 | 4 | Every filter comes in form of a pair: 5 | 1) filter class 6 | 2) filter commodity function 7 | 8 | The commodity functions internally create a filter class and invoke the 9 | corresponding ``process`` method. Often, however, usage requires to apply a 10 | filter multiple times. In this case, the filter classes should be used directly 11 | as this avoids redundant initialization of the filters. 12 | 13 | If possible, the filters are initialized with arguments that were found to be 14 | suitable for most EEG/EMG post-processing needs. Other arguments need to be 15 | passed when creating a filter class. The commodity functions forward all 16 | (unknown) arguments to the filter initialization. 17 | 18 | """ 19 | 20 | # TODO: description above. check if we really have a filter class for every 21 | # filter, or if we specify them 22 | 23 | from .data.dataset import Dataset 24 | 25 | import numpy as np 26 | import pandas as pd 27 | import scipy.signal 28 | from scipy.signal import butter, lfilter, freqz, iirnotch, filtfilt 29 | import scipy.stats 30 | import sklearn.decomposition 31 | import pywt 32 | 33 | 34 | class ButterBandpass: 35 | """Filter class for a Butterworth bandpass filter. 36 | 37 | """ 38 | 39 | def __init__(self, lowcut, highcut, order=4, fs=256): 40 | """Initialize the Butterworth bandpass filter. 41 | 42 | Args: 43 | lowcut (float): low cut-off frequency 44 | highcut (float): high cut-off frequency 45 | order (int): order of the Butterworth bandpass filter 46 | fs (int): sampling frequency 47 | 48 | """ 49 | self.lowcut = lowcut 50 | self.highcut = highcut 51 | self.order = order 52 | 53 | nyq = 0.5 * fs 54 | low = lowcut / nyq 55 | high = highcut / nyq 56 | self.b, self.a = scipy.signal.butter(order, [low, high], btype='bandpass') 57 | 58 | 59 | def process(self, data, axis=0): 60 | """Apply the filter to data along a given axis. 61 | 62 | Args: 63 | data (array_like): data to filter 64 | axis (int): along which data to filter 65 | 66 | Returns: 67 | ndarray: Result of the same shape as data 68 | 69 | """ 70 | return scipy.signal.filtfilt(self.b, self.a, data, axis) 71 | 72 | 73 | 74 | def butter_bandpass(data, lo, hi, axis=0, **kwargs): 75 | """Apply a Butterworth bandpass filter to some data. 76 | 77 | The function either takes an ``array_like`` object (e.g. numpy's ndarray) or 78 | an instance of a gumpy.data.Dataset subclass as first argument. 79 | 80 | Args: 81 | data (array_like or Dataset instance): input data. If this is an 82 | instance of a Dataset subclass, the sampling frequency will be extracted 83 | automatically. 84 | lo (float): low cutoff frequency. 85 | hi (float): high cutoff frequency. 86 | axis (int): along which axis of data the filter should be applied. Default = 0. 87 | **kwargs: Additional keyword arguments that will be passed to ``gumpy.signal.ButterBandstop``. 88 | 89 | Returns: 90 | array_like: data filtered long the specified axis. 91 | 92 | """ 93 | if isinstance(data, Dataset): 94 | flt = ButterBandpass(lo, hi, fs=data.sampling_freq, **kwargs) 95 | filtered = [flt.process(data.raw_data[:, i], axis) for i in range(data.raw_data.shape[1])] 96 | reshaped = [f.reshape(-1, 1) for f in filtered] 97 | return np.hstack(reshaped) 98 | else: 99 | flt = ButterBandpass(lo, hi, **kwargs) 100 | return flt.process(data, axis) 101 | 102 | 103 | 104 | class ButterHighpass: 105 | """Filter class for a Butterworth bandpass filter. 106 | 107 | """ 108 | 109 | def __init__(self, cutoff, order=4, fs=256): 110 | """Initialize the Butterworth highpass filter. 111 | 112 | Args: 113 | cutoff (float): cut-off frequency 114 | order (int): order of the Butterworth bandpass filter 115 | fs (int): sampling frequency 116 | 117 | """ 118 | self.cutoff = cutoff 119 | self.order = order 120 | 121 | nyq = 0.5 * fs 122 | high = cutoff / nyq 123 | self.b, self.a = scipy.signal.butter(order, high, btype='highpass') 124 | 125 | 126 | def process(self, data, axis=0): 127 | """Apply the filter to data along a given axis. 128 | 129 | Args: 130 | data (array_like): data to filter 131 | axis (int): along which data to filter 132 | 133 | Returns: 134 | ndarray: Result of the same shape as data 135 | 136 | """ 137 | return scipy.signal.filtfilt(self.b, self.a, data, axis) 138 | 139 | 140 | 141 | def butter_highpass(data, cutoff, axis=0, **kwargs): 142 | """Apply a Butterworth highpass filter to some data. 143 | 144 | The function either takes an ``array_like`` object (e.g. numpy's ndarray) or 145 | an instance of a gumpy.data.Dataset subclass as first argument. 146 | 147 | Args: 148 | data (array_like or Dataset instance): input data. If this is an 149 | instance of a Dataset subclass, the sampling frequency will be extracted 150 | automatically. 151 | cutoff (float): cutoff frequency. 152 | axis (int): along which axis of data the filter should be applied. Default = 0. 153 | **kwargs: Additional keyword arguments that will be passed to ``gumpy.signal.ButterBandstop``. 154 | 155 | Returns: 156 | array_like: data filtered long the specified axis. 157 | 158 | """ 159 | 160 | if isinstance(data, Dataset): 161 | flt = ButterHighpass(cutoff, fs=data.sampling_freq, **kwargs) 162 | filtered = [flt.process(data.raw_data[:, i], axis) for i in range(data.raw_data.shape[1])] 163 | reshaped = [f.reshape(-1, 1) for f in filtered] 164 | return np.hstack(reshaped) 165 | else: 166 | flt = ButterHighpass(cutoff, **kwargs) 167 | return flt.process(data, axis) 168 | 169 | 170 | 171 | class ButterLowpass: 172 | """Filter class for a Butterworth lowpass filter. 173 | 174 | """ 175 | 176 | def __init__(self, cutoff, order=4, fs=256): 177 | """Initialize the Butterworth lowpass filter. 178 | 179 | Args: 180 | cutoff (float): cut-off frequency 181 | order (int): order of the Butterworth bandpass filter 182 | fs (int): sampling frequency 183 | 184 | """ 185 | 186 | self.cutoff = cutoff 187 | self.order = order 188 | 189 | nyq = 0.5 * fs 190 | low = cutoff / nyq 191 | self.b, self.a = scipy.signal.butter(order, low, btype='lowpass') 192 | 193 | def process(self, data, axis=0): 194 | """Apply the filter to data along a given axis. 195 | 196 | Args: 197 | data (array_like): data to filter 198 | axis (int): along which data to filter 199 | 200 | Returns: 201 | ndarray: Result of the same shape as data 202 | 203 | """ 204 | return scipy.signal.filtfilt(self.b, self.a, data, axis) 205 | 206 | 207 | 208 | def butter_lowpass(data, cutoff, axis=0, **kwargs): 209 | """Apply a Butterworth lowpass filter to some data. 210 | 211 | The function either takes an ``array_like`` object (e.g. numpy's ndarray) or 212 | an instance of a gumpy.data.Dataset subclass as first argument. 213 | 214 | Args: 215 | data (array_like or Dataset instance): input data. If this is an 216 | instance of a Dataset subclass, the sampling frequency will be extracted 217 | automatically. 218 | cutoff (float): cutoff frequency. 219 | axis (int): along which axis of data the filter should be applied. Default = 0. 220 | **kwargs: Additional keyword arguments that will be passed to ``gumpy.signal.ButterBandstop``. 221 | 222 | Returns: 223 | array_like: data filtered long the specified axis. 224 | 225 | """ 226 | if isinstance(data, Dataset): 227 | flt = ButterLowpass(cutoff, fs=data.sampling_freq, **kwargs) 228 | filtered = [flt.process(data.raw_data[:, i], axis) for i in range(data.raw_data.shape[1])] 229 | reshaped = [f.reshape(-1, 1) for f in filtered] 230 | return np.hstack(reshaped) 231 | else: 232 | flt = ButterLowpass(cutoff, **kwargs) 233 | return flt.process(data, axis) 234 | 235 | 236 | 237 | class ButterBandstop: 238 | """Filter class for a Butterworth bandstop filter. 239 | 240 | """ 241 | 242 | def __init__(self, lowpass=49, highpass=51, order=4, fs=256): 243 | """Initialize the Butterworth bandstop filter. 244 | 245 | Args: 246 | lowpass (float): low cut-off frequency. Default = 49 247 | highapss (float): high cut-off frequency. Default = 51 248 | order (int): order of the Butterworth bandpass filter. 249 | fs (int): sampling frequency 250 | """ 251 | self.lowpass = lowpass 252 | self.highpass = highpass 253 | self.order = order 254 | 255 | nyq = 0.5 * fs 256 | low = lowpass / nyq 257 | high = highpass / nyq 258 | self.b, self.a = scipy.signal.butter(order, [low, high], btype='bandstop') 259 | 260 | 261 | def process(self, data, axis=0): 262 | """Apply the filter to data along a given axis. 263 | 264 | Args: 265 | data (array_like): data to filter 266 | axis (int): along which data to filter 267 | 268 | Returns: 269 | ndarray: Result of the same shape as data 270 | 271 | """ 272 | return scipy.signal.filtfilt(self.b, self.a, data, axis) 273 | 274 | 275 | 276 | def butter_bandstop(data, axis=0, **kwargs): 277 | """Apply a Butterworth bandstop filter to some data. 278 | 279 | The function either takes an ``array_like`` object (e.g. numpy's ndarray) or 280 | an instance of a gumpy.data.Dataset subclass as first argument. 281 | 282 | Args: 283 | data (array_like or Dataset instance): input data. If this is an 284 | instance of a Dataset subclass, the sampling frequency will be extracted 285 | automatically. 286 | axis (int): along which axis of data the filter should be applied. Default = 0. 287 | **kwargs: Additional keyword arguments that will be passed to ``gumpy.signal.ButterBandstop``. 288 | 289 | Returns: 290 | array_like: data filtered long the specified axis. 291 | 292 | """ 293 | if isinstance(data, Dataset): 294 | flt = ButterBandstop(lo, hi, fs=data.sampling_freq, **kwargs) 295 | filtered = [flt.process(data.raw_data[:, i], axis) for i in range(data.raw_data.shape[1])] 296 | reshaped = [f.reshape(-1, 1) for f in filtered] 297 | return np.hstack(reshaped) 298 | else: 299 | flt = ButterBandstop(lo, hi, **kwargs) 300 | return flt.process(data, axis) 301 | 302 | 303 | 304 | class Notch: 305 | """Filter class for a notch filter. 306 | 307 | """ 308 | 309 | def __init__(self, cutoff=50, Q=30, fs=256): 310 | """Initialize the notch filter. 311 | 312 | Args: 313 | cutoff (float): cut-off frequency. Default = 50. 314 | Q (float): Quality factor. Default = 30. 315 | fs (int): sampling frequency. Default = 256 316 | """ 317 | self.cutoff = cutoff 318 | self.Q = Q 319 | 320 | nyq = 0.5 * fs 321 | w0 = cutoff / nyq 322 | self.b, self.a = scipy.signal.iirnotch(w0, Q) 323 | 324 | 325 | def process(self, data, axis=0): 326 | """Apply the filter to data along a given axis. 327 | 328 | Args: 329 | data (array_like): data to filter 330 | axis (int): along which data to filter 331 | 332 | Returns: 333 | ndarray: Result of the same shape as data 334 | 335 | """ 336 | return scipy.signal.filtfilt(self.b, self.a, data, axis) 337 | 338 | 339 | 340 | def notch(data, cutoff, axis=0, **kwargs): 341 | """Apply a notch filter to data. 342 | 343 | The function either takes an ``array_like`` object (e.g. numpy's ndarray) or 344 | an instance of a gumpy.data.Dataset subclass as first argument. 345 | 346 | Args: 347 | data (array_like or Dataset instance): input data. 348 | cutoff (float): cutoff frequency. Default = 50. 349 | axis (int): along which axis of data the filter should be applied. Default = 0. 350 | Q (float): quality factor. Default = 30. 351 | fs (int): sampling frequenct. Default = 256. 352 | 353 | Returns: 354 | array_like: data filtered long the specified axis. 355 | 356 | """ 357 | if isinstance(data, Dataset): 358 | flt = Notch(cutoff, fs=data.sampling_freq, **kwargs) 359 | filtered = [flt.process(data.raw_data[:, i], axis) for i in range(data.raw_data.shape[1])] 360 | reshaped = [f.reshape(-1, 1) for f in filtered] 361 | return np.hstack(reshaped) 362 | else: 363 | flt = Notch(cutoff, **kwargs) 364 | return flt.process(data, axis) 365 | 366 | 367 | 368 | def _norm_min_max(data): 369 | return (data - np.min(data))/(np.max(data)-np.min(data)) 370 | 371 | 372 | 373 | def _norm_mean_std(data): 374 | mean = np.mean(data, axis=0) 375 | std_dev = np.std(data, axis=0) 376 | return (data - mean) / std_dev 377 | 378 | 379 | 380 | def normalize(data, normalization_type): 381 | """Normalize data. 382 | 383 | Normalize data either by shifting and rescaling the data to [0,1] 384 | (``min_max``) or by rescaling via mean and standard deviation 385 | (``mean_std``). 386 | 387 | Args: 388 | data (array_like): Input data 389 | normalization_type (str): One of ``mean_std``, ``mean_std`` 390 | 391 | Returns: 392 | ndarray: normalized data with same shape as ``data`` 393 | 394 | Raises: 395 | Exception: if the normalization type is unknown. 396 | 397 | """ 398 | norm_fns = {'mean_std': _norm_mean_std, 399 | 'min_max' : _norm_min_max 400 | } 401 | if not normalization_type in norm_fns: 402 | raise Exception("Normalization method '{m}' is not supported".format(m=normalization_type)) 403 | if isinstance(data, Dataset): 404 | return norm_fns[normalization_type](data.raw_data) 405 | else: 406 | return norm_fns[normalization_type](data) 407 | 408 | 409 | 410 | def EEG_mean_power(data): 411 | """Compute the power of data. 412 | 413 | """ 414 | return np.power(data, 2).mean(axis=0) 415 | 416 | 417 | 418 | #def bootstrap_resample(X, n=None): 419 | # """Resample data. 420 | # 421 | # Args: 422 | # X (array_like): Input data from which to resample. 423 | # n (int): Number of elements to sample. 424 | # 425 | # Returns: 426 | # ndarray: n elements sampled from X. 427 | # 428 | # """ 429 | # if isinstance(X, pd.Series): 430 | # X = X.copy() 431 | # X.index = range(len(X.index)) 432 | # 433 | # if n is None: 434 | # n = len(X) 435 | # 436 | # resample_i = np.floor(np.random.rand(n)*len(X)).astype(int) 437 | # return np.array(X[resample_i]) 438 | 439 | 440 | 441 | def dwt(raw_eeg_data, level, **kwargs): 442 | """Multilevel Discrete Wavelet Transform (DWT). 443 | 444 | Compute the DWT for a raw eeg signal on multiple levels. 445 | 446 | Args: 447 | raw_eeg_data (array_like): input data 448 | level (int >= 0): decomposition levels 449 | **kwargs: Additional arguments that will be forwarded to ``pywt.wavedec`` 450 | 451 | Returns: 452 | A 2-element tuple containing 453 | 454 | - **float**: mean value of the first decomposition coefficients 455 | - **list**: list of mean values for the individual (detail) decomposition coefficients 456 | 457 | """ 458 | wt_coeffs = pywt.wavedec(data = raw_eeg_data, level=level, **kwargs) 459 | 460 | # A7: 0 Hz - 1 Hz 461 | cAL_mean = np.nanmean(wt_coeffs[0], axis=0) 462 | details = [] 463 | 464 | # For Fs = 128 H 465 | for i in range(1, level+1): 466 | # D7: 1 Hz - 2 Hz 467 | cDL_mean = np.nanmean(wt_coeffs[i], axis=0) 468 | details.append(cDL_mean) 469 | 470 | return cAL_mean, details 471 | 472 | 473 | 474 | def rms(signal, fs, window_size, window_shift): 475 | """Root Mean Square. 476 | 477 | Args: 478 | signal (array_like): TODO 479 | fs (int): Sampling frequency 480 | window_size: TODO 481 | window_shift: TODO 482 | 483 | Returns: 484 | TODO: 485 | """ 486 | duration = len(signal)/fs 487 | n_features = int(duration/(window_size-window_shift)) 488 | 489 | features = np.zeros(n_features) 490 | 491 | for i in range(n_features): 492 | idx1 = int((i*(window_size-window_shift))*fs) 493 | idx2 = int(((i+1)*window_size-i*window_shift)*fs) 494 | rms = np.sqrt(np.mean(np.square(signal[idx1:idx2]))) 495 | features[i] = rms 496 | 497 | return features 498 | 499 | 500 | 501 | def correlation(x, y): 502 | """Compute the correlation between x and y using Pearson's r. 503 | 504 | """ 505 | return scipy.stats.pearsonr(x,y) 506 | 507 | 508 | 509 | def artifact_removal(X, n_components=None, check_result=True): 510 | """Remove artifacts from data. 511 | 512 | The artifacts are detected via Independent Component Analysis (ICA) and 513 | subsequently removed. To plot the results, use 514 | :func:`gumpy.plot.artifact_removal` 515 | 516 | Args: 517 | X (array_like): Data to remove artifacts from 518 | n_components (int): Number of components for ICA. If None is passed, all will be used 519 | check_result (bool): Examine/test the ICA model by reverting the mixing. 520 | 521 | 522 | Returns: 523 | A 2-tuple containing 524 | 525 | - **ndarray**: The reconstructed signal without artifacts. 526 | - **ndarray**: The mixing matrix that wqas used by ICA. 527 | 528 | """ 529 | 530 | ica = sklearn.decomposition.FastICA(n_components) 531 | S_reconst = ica.fit_transform(X) 532 | A_mixing = ica.mixing_ 533 | if check_result: 534 | assert np.allclose(X, np.dot(S_reconst, A_mixing.T) + ica.mean_) 535 | 536 | return S_reconst, A_mixing 537 | 538 | 539 | def sliding_window(data, labels, window_sz, n_hop, n_start=0, show_status=False): 540 | """ 541 | input: (array) data : matrix to be processed 542 | (int) window_sz : nb of samples to be used in the window 543 | (int) n_hop : size of jump between windows 544 | output:(array) new_data : output matrix of size (None, window_sz, feature_dim) 545 | 546 | """ 547 | flag = 0 548 | for sample in range(data.shape[0]): 549 | tmp = np.array( 550 | [data[sample, i:i + window_sz, :] for i in np.arange(n_start, data.shape[1] - window_sz + n_hop, n_hop)]) 551 | 552 | tmp_lab = np.array([labels[sample] for i in np.arange(n_start, data.shape[1] - window_sz + n_hop, n_hop)]) 553 | 554 | if sample % 100 == 0 and show_status == True: 555 | print("Sample " + str(sample) + "processed!\n") 556 | 557 | if flag == 0: 558 | new_data = tmp 559 | new_lab = tmp_lab 560 | flag = 1 561 | else: 562 | new_data = np.concatenate((new_data, tmp)) 563 | new_lab = np.concatenate((new_lab, tmp_lab)) 564 | return new_data, new_lab 565 | -------------------------------------------------------------------------------- /gumpy/split.py: -------------------------------------------------------------------------------- 1 | import sklearn.model_selection 2 | import numpy as np 3 | from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit, cross_val_score, StratifiedKFold 4 | 5 | 6 | def normal(X, labels, test_size): 7 | """Split a dataset into training and test parts. 8 | Args: 9 | X (numpy.ndarray): 2D features matrix 10 | labels: labels vector 11 | test_size: size of the split 12 | 13 | Returns: 14 | A 2D CSP features matrix 15 | """ 16 | Y = labels 17 | X_train, X_test, Y_train, Y_test = \ 18 | sklearn.model_selection.train_test_split(X, Y, 19 | test_size=test_size, 20 | random_state=0) 21 | return X_train, X_test, Y_train, Y_test 22 | 23 | 24 | def time_series_split(features, labels, n_splits): 25 | """Split a dataset into n splits. 26 | 27 | """ 28 | xx = sklearn.model_selection.TimeSeriesSplit(n_splits) 29 | for train_index, test_index in xx.split(features): 30 | X_train, X_test = features[train_index], features[test_index] 31 | y_train, y_test = labels[train_index], labels[test_index] 32 | 33 | return X_train, X_test, y_train, y_test 34 | 35 | 36 | def stratified_KFold(features, labels, n_splits): 37 | 38 | """Stratified K-Folds cross-validator 39 | Stratification is the process of rearranging the data as to ensure each fold is a good representative of the whole 40 | and by also keeping the balance of classes 41 | """ 42 | skf = StratifiedKFold(n_splits) 43 | skf.get_n_splits(features, labels) 44 | for train_index, test_index in skf.split(features, labels): 45 | X_train, X_test = features[train_index], features[test_index] 46 | Y_train, Y_test = labels[train_index], labels[test_index] 47 | return X_train, X_test, Y_train, Y_test 48 | 49 | #Stratified ShuffleSplit cross-validator 50 | def stratified_shuffle_Split(features, labels, n_splits,test_size,random_state): 51 | 52 | """Stratified ShuffleSplit cross-validator 53 | """ 54 | cv = StratifiedShuffleSplit(n_splits, test_size, random_state=random_state) 55 | for train_index, test_index in cv.split(features,labels): 56 | X_train = features[train_index] 57 | X_test = features[test_index] 58 | Y_train = labels[train_index] 59 | Y_test = labels[test_index] 60 | return X_train, X_test, Y_train, Y_test 61 | 62 | 63 | #Random permutation cross-validator 64 | def shuffle_Split(features, labels, n_splits,test_size,random_state): 65 | 66 | """ShuffleSplit: Random permutation cross-validator 67 | """ 68 | cv = ShuffleSplit(n_splits, test_size, random_state=random_state) 69 | for train_index, test_index in cv.split(features): 70 | X_train = features[train_index] 71 | X_test = features[test_index] 72 | Y_train = labels[train_index] 73 | Y_test = labels[test_index] 74 | return X_train, X_test, Y_train, Y_test 75 | -------------------------------------------------------------------------------- /gumpy/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions that may be used during data processing. 2 | 3 | Because many datasets differ slightly, not all utility functions may work with 4 | each dataset. However, the modifications are typically only minor, and thus the 5 | functions provided within this module can be adapted easily. 6 | """ 7 | 8 | from .data.dataset import Dataset 9 | import signal 10 | import numpy as np 11 | import warnings 12 | 13 | # TODO: documentation 14 | 15 | 16 | def extract_trials(data, filtered=None, trials=None, labels=None, sampling_freq=0): 17 | 18 | if isinstance(data, Dataset) or (filtered is not None): 19 | # extract all necessary information from the dataset 20 | fs = data.sampling_freq 21 | labels = data.labels 22 | trial_len = data.trial_len 23 | trial_offset = data.trial_offset 24 | trials = data.trials 25 | 26 | # determine if to work on raw_data or if filtered information was passed 27 | # along 28 | if filtered is None: 29 | _data = data.raw_data 30 | else: 31 | _data = filtered 32 | else: 33 | _data = data 34 | fs = sampling_freq 35 | trial_len=8 36 | trial_offset=0 37 | 38 | 39 | # Indices of class 1 and 2 40 | c1_idxs = np.where(labels == 0)[0] # 1 means left 41 | c2_idxs = np.where(labels == 1)[0] # 2 means right 42 | 43 | c1_trials = trials[c1_idxs] 44 | c2_trials = trials[c2_idxs] 45 | 46 | # Init arrays (#trials, length_trial) 47 | raw_c3_c1_a = np.zeros((len(c1_idxs), fs*(trial_len+trial_offset))) 48 | raw_c4_c1_a = np.zeros((len(c1_idxs), fs*(trial_len+trial_offset))) 49 | raw_cz_c1_a = np.zeros((len(c1_idxs), fs*(trial_len+trial_offset))) 50 | 51 | raw_c3_c2_a = np.zeros((len(c2_idxs), fs*(trial_len+trial_offset))) 52 | raw_c4_c2_a = np.zeros((len(c2_idxs), fs*(trial_len+trial_offset))) 53 | raw_cz_c2_a = np.zeros((len(c2_idxs), fs*(trial_len+trial_offset))) 54 | 55 | # Add eeg trial data to array 56 | for i,(idx_c1, idx_c2) in enumerate(zip(c1_trials, c2_trials)): 57 | raw_c3_c1_a[i,:] = _data[idx_c1-(trial_offset*fs) : idx_c1+(trial_len*fs), 0] 58 | raw_c4_c1_a[i,:] = _data[idx_c1-(trial_offset*fs) : idx_c1+(trial_len*fs), 2] 59 | raw_cz_c1_a[i,:] = _data[idx_c1-(trial_offset*fs) : idx_c1+(trial_len*fs), 1] 60 | 61 | raw_c3_c2_a[i,:] = _data[idx_c2-(trial_offset*fs) : idx_c2+(trial_len*fs), 0] 62 | raw_c4_c2_a[i,:] = _data[idx_c2-(trial_offset*fs) : idx_c2+(trial_len*fs), 2] 63 | raw_cz_c2_a[i,:] = _data[idx_c2-(trial_offset*fs) : idx_c2+(trial_len*fs), 1] 64 | 65 | return np.array((raw_c3_c1_a, raw_c4_c1_a, raw_cz_c1_a, raw_c3_c2_a, raw_c4_c2_a, raw_cz_c2_a)) 66 | 67 | 68 | # TODO: merge extract_trials and extract_trials2 69 | def extract_trials2(raw_data, trials, labels, trial_total, fs, nbClasses): 70 | """ 71 | raw_data: Raw EEG data (n_samples,n_channels) 72 | trials: Starting sample of a trial (n_trials,) 73 | labels: Corresponding label (n_labels,) 74 | trial_total: Total length of trial [sec] scalar 75 | fs: Sampling frequency in [Hz] scalar 76 | """ 77 | warnings.warn("Function extract_trials2 will be removed in the future.", PendingDeprecationWarning) 78 | 79 | # get class indecis 80 | class1_idxs = np.where(labels == 0)[0] 81 | class2_idxs = np.where(labels == 1)[0] 82 | class3_idxs = np.where(labels == 2)[0] 83 | 84 | # init data lists for each class 85 | # (n_trials, n_samples, n_channels ) 86 | class1_data = np.zeros((len(class1_idxs), trial_total*fs, raw_data.shape[1])) 87 | class2_data = np.zeros((len(class2_idxs), trial_total*fs, raw_data.shape[1])) 88 | class3_data = np.zeros((len(class3_idxs), trial_total*fs, raw_data.shape[1])) 89 | 90 | # split data class 1 91 | for i, c1_idx in enumerate(class1_idxs): # iterate over trials 92 | trial = raw_data[trials[c1_idx] : trials[c1_idx]+trial_total*fs] # (n_samples, n_channels) 93 | class1_data[i,:,:] = trial 94 | # split data class 2 95 | for i, c2_idx in enumerate(class2_idxs): # iterate over trials 96 | trial = raw_data[trials[c2_idx] : trials[c2_idx]+trial_total*fs] # (n_samples, n_channels) 97 | class2_data[i,:,:] = trial 98 | # split data class 3 99 | if nbClasses == 3: 100 | for i, c3_idx in enumerate(class3_idxs): # iterate over trials 101 | trial = raw_data[trials[c3_idx] : trials[c3_idx]+trial_total*fs] # (n_samples, n_channels) 102 | class3_data[i,:,:] = trial 103 | 104 | 105 | if nbClasses == 2: 106 | return class1_data, class2_data 107 | else: 108 | return class1_data, class2_data, class3_data 109 | 110 | 111 | 112 | 113 | def _retrieveTrialSlice(data, trialIndex, type='signal'): 114 | if type=='signal': 115 | return slice(int(data.trials[trialIndex] + 116 | data.trialSignalOffset[0]*data.sampling_freq), int(data.trials[trialIndex] + 117 | data.trialSignalOffset[1]*data.sampling_freq)) 118 | 119 | elif type=='force': 120 | return slice(int(data.trials[trialIndex] + 121 | data.trialForceOffset[0]*data.sampling_freq), int(data.trials[trialIndex] + 122 | data.trialForceOffset[1]*data.sampling_freq)) 123 | 124 | elif type=='background': 125 | return slice(int(data.trials[trialIndex] + 126 | data.trialBgOffset[0]*data.sampling_freq), int(data.trials[trialIndex] + 127 | data.trialBgOffset[1]*data.sampling_freq)) 128 | 129 | else: 130 | raise AttributeError('type should be "signal" or "force".') 131 | 132 | 133 | 134 | def _processData(data, type='signal'): 135 | if type=='signal': 136 | return data 137 | elif type=='force': 138 | try: 139 | return data/max(data) 140 | except ValueError: 141 | return data 142 | 143 | 144 | 145 | def getTrials(data, filtered=None, background=False): 146 | data.channel = [] 147 | 148 | raw_data = data.raw_data 149 | if filtered is not None: 150 | raw_data = filtered 151 | 152 | for pair in data.electrodePairList: 153 | data.channel.append(_processData(raw_data[:, pair[0]]- 154 | raw_data[:, pair[1]])) 155 | 156 | processedForces = _processData(data.forces, 'force') 157 | 158 | if background: 159 | return [(data.channel[0][_retrieveTrialSlice(data, i%3, 'background')], 160 | data.channel[1][_retrieveTrialSlice(data, i%3, 'background')], 161 | data.channel[2][_retrieveTrialSlice(data, i%3, 'background')], 162 | data.channel[3][_retrieveTrialSlice(data, i%3, 'background')], 163 | processedForces[_retrieveTrialSlice(data, i%3, 'force')]) 164 | for i in range(int(data.trials.shape[0]/3))] 165 | else: 166 | return [(data.channel[0][_retrieveTrialSlice(data, i, 'signal')], 167 | data.channel[1][_retrieveTrialSlice(data, i, 'signal')], 168 | data.channel[2][_retrieveTrialSlice(data, i, 'signal')], 169 | data.channel[3][_retrieveTrialSlice(data, i, 'signal')], 170 | processedForces[_retrieveTrialSlice(data, i, 'force')]) 171 | for i in range(data.trials.shape[0])] 172 | 173 | -------------------------------------------------------------------------------- /gumpy/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.5.0' 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_sphinx] 2 | source_dir = doc/source 3 | build_dir = doc/build 4 | all_files = 1 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017-2018 The gumpy developers: 4 | # 5 | # 2017-2018 Zied Tayeb 6 | # 2017-2018 Nicolai Waniek 7 | # 2017-2018 Juri Fedjaev 8 | # 2017-2018 Leonard Rychly 9 | # 10 | 11 | """EEG signal processing and classification toolbox. 12 | 13 | This toolbox provides signal processing functions and classes to work with BCI 14 | datasets. Many of the functions internally call existing libraries for signal 15 | processing or numerical computation such as ``numpy`` or ``scipy``. In these 16 | cases the functions are called with parameters that were found to be suitable 17 | for BCI computing and brain machine interfaces. 18 | 19 | The name of the toolbox is a reference to the Gumby Brain Specialist sketch by 20 | Monty Python. 21 | """ 22 | 23 | DISTNAME = 'gumpy' 24 | DESCRIPTION = 'EEG signal processing and classification toolbox' 25 | LONG_DESCRIPTION = __doc__ 26 | MAINTAINER = 'The gumpy developers' 27 | MAINTAINER_MAIL = 'zied.tayeb@tum.de' 28 | LICENSE = 'MIT' 29 | URL = 'www.gumpy.org' 30 | 31 | # extract version from source file 32 | VERSION_DATA = {} 33 | with open('gumpy/version.py') as fp: 34 | exec(fp.read(), VERSION_DATA) 35 | VERSION = VERSION_DATA['__version__'] 36 | 37 | 38 | from setuptools import setup, find_packages 39 | 40 | if __name__ == "__main__": 41 | setup(classifiers=[ 42 | 'Development Status :: 4 - Beta', 43 | 'Intended Audience :: Science/Research', 44 | 'License :: OSI Approved :: MIT License', 45 | 'Programming Language :: Python :: 3', 46 | 'Topic :: Scientific/Engineering :: Human Machine Interfaces', 47 | ], 48 | install_requires = [ 49 | 'numpy', 50 | 'scipy', 51 | 'scikit-learn', 52 | 'seaborn', 53 | 'pandas', 54 | 'PyWavelets', 55 | 'mlxtend', 56 | ], 57 | name=DISTNAME, 58 | version=VERSION, 59 | description=DESCRIPTION, 60 | long_description=__doc__, 61 | url=URL, 62 | maintainer=MAINTAINER, 63 | maintainer_email=MAINTAINER_MAIL, 64 | license=LICENSE, 65 | packages=find_packages(exclude=['tests*']), 66 | python_requires='>=3', 67 | zip_safe=False) 68 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys, os, os.path 2 | sys.path.append('../../gumpy') 3 | 4 | 5 | import numpy as np 6 | import gumpy 7 | 8 | 9 | # First specify the location of the data and some 10 | # identifier that is exposed by the dataset (e.g. subject) 11 | base_dir = '../Data/NST-EMG' 12 | subject = 'S1' 13 | 14 | # The next line first initializes the data structure. 15 | # Note that this does not yet load the data! In custom implementations 16 | # of a dataset, this should be used to prepare file transfers, 17 | # for instance check if all files are available, etc. 18 | data_low = gumpy.data.NST_EMG(base_dir, subject, 'low') 19 | data_high = gumpy.data.NST_EMG(base_dir, subject, 'high') 20 | 21 | # Finally, load the dataset 22 | data_low.load() 23 | data_high.load() 24 | 25 | # Printing Informations About the dataset 26 | data_low.print_stats() 27 | data_high.print_stats() 28 | 29 | 30 | # Filtering the Signals 31 | #bandpass 32 | lowcut=20 33 | highcut=255 34 | #notch 35 | f0=50 36 | Q=50 37 | 38 | flt_low = gumpy.signal.butter_bandpass(data_low, lowcut, highcut) 39 | flt_low = gumpy.signal.notch(flt_low, cutoff=f0, Q=Q) 40 | 41 | trialsLow = gumpy.utils.getTrials(data_low, flt_low) 42 | trialsLowBg = gumpy.utils.getTrials(data_low, flt_low, True) 43 | 44 | flt_high = gumpy.signal.butter_bandpass(data_high, lowcut, highcut) 45 | flt_high = gumpy.signal.notch(flt_high, cutoff=f0, Q=Q) 46 | 47 | trialsHigh = gumpy.utils.getTrials(data_high, flt_high) 48 | trialsHighBg = gumpy.utils.getTrials(data_high, flt_high, True) 49 | 50 | 51 | # Creating an RMS feature extraction function 52 | def RMS_features_extraction(data, trialList, window_size, window_shift): 53 | if window_shift > window_size: 54 | raise ValueError("window_shift > window_size") 55 | 56 | fs = data.sampling_freq 57 | 58 | n_features = int(data.duration/(window_size-window_shift)) 59 | 60 | X = np.zeros((len(trialList), n_features*4)) 61 | 62 | t = 0 63 | for trial in trialList: 64 | # x3 is the worst of all with 43.3% average performance 65 | x1=gumpy.signal.rms(trial[0], fs, window_size, window_shift) 66 | x2=gumpy.signal.rms(trial[1], fs, window_size, window_shift) 67 | x3=gumpy.signal.rms(trial[2], fs, window_size, window_shift) 68 | x4=gumpy.signal.rms(trial[3], fs, window_size, window_shift) 69 | x=np.concatenate((x1, x2, x3, x4)) 70 | X[t, :] = np.array([x]) 71 | t += 1 72 | return X 73 | 74 | 75 | # Retrieving the features 76 | window_size = 0.2 77 | window_shift = 0.05 78 | 79 | highRMSfeatures = RMS_features_extraction(data_high, trialsHigh, window_size, window_shift) 80 | highRMSfeaturesBg = RMS_features_extraction(data_high, trialsHighBg, window_size, window_shift) 81 | lowRMSfeatures = RMS_features_extraction(data_high, trialsLow, window_size, window_shift) 82 | lowRMSfeaturesBg = RMS_features_extraction(data_high, trialsLowBg, window_size, window_shift) 83 | 84 | 85 | 86 | # Constructing Classification arrays 87 | X_tot = np.vstack((highRMSfeatures, lowRMSfeatures)) 88 | y_tot = np.hstack((np.ones((highRMSfeatures.shape[0])), 89 | np.zeros((lowRMSfeatures.shape[0])))) 90 | 91 | X_totSig = np.vstack((highRMSfeatures, highRMSfeaturesBg, lowRMSfeatures, lowRMSfeaturesBg)) 92 | X_totSig = X_totSig/np.linalg.norm(X_totSig) 93 | 94 | #pHigh.labels = np.hstack((self.labels, 3*np.ones(self.trials.shape[0]/3))) 95 | y_totSig = np.hstack((data_high.labels, 96 | data_low.labels)) 97 | 98 | 99 | 100 | # Posture Classification 101 | (clf, sfs) = gumpy.features.sequential_feature_selector(X_totSig, y_totSig, 'SVM', (10,25), 3, 'SFFS') 102 | 103 | # Force Level Classification 104 | (clfF, sfsF) = Sequential_Feature_Selector(X_tot, y_tot, 'SVM', (10,25), 3, 'SFFS') --------------------------------------------------------------------------------