├── tests
├── __init__.py
├── samples
│ └── axonio.abf
├── test_stats.py
├── test_beans.py
├── utils.py
├── test_io.py
├── test_features.py
└── test_core.py
├── src
├── spike_analysis
│ ├── __init__.py
│ ├── xcorr.py
│ ├── io_tools.py
│ ├── dashboard.py
│ └── basic.py
├── spike_beans
│ ├── __init__.py
│ └── base.py
└── spike_sort
│ ├── io
│ ├── __init__.py
│ ├── export.py
│ └── neo_filters.py
│ ├── ui
│ ├── __init__.py
│ ├── _mpl_helpers.py
│ ├── zoomer.py
│ ├── manual_sort.py
│ └── plotting.py
│ ├── stats
│ ├── __init__.py
│ ├── diptst
│ │ ├── diptst.pyf
│ │ └── diptst.f
│ └── tests.py
│ ├── __init__.py
│ └── core
│ ├── __init__.py
│ ├── filters.py
│ ├── evaluate.py
│ ├── cluster.py
│ └── extract.py
├── setup.cfg
├── docs
├── source
│ ├── _static
│ │ └── logo.png
│ ├── tutorials
│ │ ├── images_beans
│ │ │ ├── waves.png
│ │ │ ├── features.png
│ │ │ ├── browser_zoom.png
│ │ │ ├── browser_nozoom.png
│ │ │ ├── waves_2_deleted.png
│ │ │ └── waves_one_left.png
│ │ ├── images_manual
│ │ │ ├── tutorial_cells.png
│ │ │ ├── tutorial_spikes.png
│ │ │ ├── tutorial_clusters.png
│ │ │ └── tutorial_features.png
│ │ ├── index.rst
│ │ ├── tutorial_io.rst
│ │ └── tutorial_manual.rst
│ ├── modules
│ │ ├── components.rst
│ │ ├── evaluate.rst
│ │ ├── ui.rst
│ │ ├── extract.rst
│ │ ├── cluster.rst
│ │ ├── features.rst
│ │ └── io.rst
│ ├── _themes
│ │ ├── flask
│ │ │ ├── theme.conf
│ │ │ ├── layout.html
│ │ │ ├── relations.html
│ │ │ └── static
│ │ │ │ ├── small_flask.css
│ │ │ │ └── flasky.css_t
│ │ ├── README
│ │ ├── LICENSE
│ │ └── flask_theme_support.py
│ ├── pyplots
│ │ ├── tutorial_spikes.py
│ │ ├── tutorial_features.py
│ │ ├── tutorial_cells.py
│ │ └── tutorial_clusters.py
│ ├── index.rst
│ ├── datafiles.rst
│ ├── conf.py
│ ├── datastructures.rst
│ └── intro.rst
├── README
└── Makefile
├── requirements.txt
├── .gitignore
├── readthedocs.yml
├── data
├── gollum.inf
├── gollum_export.inf
└── gen_tutorial_data.py
├── README
├── examples
├── analysis
│ ├── cell_dashboard.py
│ └── cell_xcorr.py
└── sorting
│ ├── browse_data.py
│ ├── read_axon.py
│ ├── cluster_auto.py
│ ├── cluster_manual.py
│ └── cluster_beans.py
├── .travis.yml
├── conda-environment.yml
├── LICENSE
└── setup.py
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spike_analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spike_beans/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/spike_sort/io/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/spike_sort/ui/__init__.py:
--------------------------------------------------------------------------------
1 | from plotting import *
--------------------------------------------------------------------------------
/src/spike_sort/stats/__init__.py:
--------------------------------------------------------------------------------
1 | from tests import *
2 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [nosetests]
2 | verbosity=2
3 | with-doctest=1
4 | tests=tests
5 |
6 |
--------------------------------------------------------------------------------
/tests/samples/axonio.abf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/tests/samples/axonio.abf
--------------------------------------------------------------------------------
/docs/source/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/_static/logo.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/waves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/features.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy >= 1.4
2 | matplotlib >= 1.1.0
3 | tables >= 2.3.0
4 | scipy >= 0.9.0
5 | neo>=0.2.0
6 | scikits.learn
7 | pywavelets
8 |
--------------------------------------------------------------------------------
/src/spike_sort/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | from core import *
5 |
6 | from ui import plotting
7 | import io
8 |
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/browser_zoom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/browser_zoom.png
--------------------------------------------------------------------------------
/src/spike_sort/core/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | __all__ = ["extract", "features", "filters", "cluster", "evaluate"]
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.swo
3 | *.pyc
4 | *.so
5 | docs/build/
6 | build/
7 | /dist
8 | /src/SpikeSort.egg-info/**/*
9 | /src/SpikeSort.egg-info
10 |
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/browser_nozoom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/browser_nozoom.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/waves_2_deleted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves_2_deleted.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_beans/waves_one_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves_one_left.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_manual/tutorial_cells.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_cells.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_manual/tutorial_spikes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_spikes.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_manual/tutorial_clusters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_clusters.png
--------------------------------------------------------------------------------
/docs/source/tutorials/images_manual/tutorial_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_features.png
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | formats:
2 | - none
3 |
4 | conda:
5 | file: conda-environment.yml
6 |
7 | python:
8 | version: 2
9 | setup_py_install: true
10 |
--------------------------------------------------------------------------------
/docs/source/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | .. _tutorials-section:
2 |
3 | Tutorials
4 | =========
5 |
6 | .. toctree::
7 |
8 | tutorial_beans
9 | tutorial_manual
10 | tutorial_io
11 |
--------------------------------------------------------------------------------
/docs/source/modules/components.rst:
--------------------------------------------------------------------------------
1 | Components (:mod:`spike_beans.components`)
2 | ==========================================
3 |
4 | .. automodule:: spike_beans.components
5 | :show-inheritance:
6 |
7 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask/theme.conf:
--------------------------------------------------------------------------------
1 | [theme]
2 | inherit = basic
3 | stylesheet = flasky.css
4 | pygments_style = flask_theme_support.FlaskyStyle
5 |
6 | [options]
7 | index_logo = ''
8 | index_logo_height = 120px
9 | touch_icon =
10 |
--------------------------------------------------------------------------------
/data/gollum.inf:
--------------------------------------------------------------------------------
1 | {
2 | "fspike":"{subject}/{ses_id}/{ses_id}-{el_id}{contact_id}.sp",
3 | "cell": "{subject}/spt/{ses_id}-{el_id}-{cell_id}.spt",
4 | "stim": "{subject}/{ses_id}/{ses_id}-10.spt",
5 | "n_contacts":4,
6 | "dirname":"{DATAPATH}",
7 | "FS":25000
8 | }
9 |
--------------------------------------------------------------------------------
/data/gollum_export.inf:
--------------------------------------------------------------------------------
1 | {
2 | "fspike":"{subject}/{ses_id}/{ses_id}-{el_id}{contact_id}.sp",
3 | "cell": "spike_sort/{ses_id}-{el_id}-{cell_id}.spt",
4 | "stim": "{subject}/{ses_id}/{ses_id}-10.spt",
5 | "n_contacts":4,
6 | "dirname":"{DATAPATH}",
7 | "FS":25000
8 | }
9 |
--------------------------------------------------------------------------------
/data/gen_tutorial_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | from spike_sort.io.filters import PyTablesFilter, BakerlabFilter
5 |
6 | in_dataset = "/Gollum/s5gollum01/el3"
7 | out_dataset = "/SubjectA/session01/el1/raw"
8 |
9 | in_filter = BakerlabFilter("gollum.inf")
10 | out_filter = PyTablesFilter("tutorial.h5")
11 |
12 | sp = in_filter.read_sp(in_dataset)
13 | out_filter.write_sp(sp, out_dataset)
14 |
15 | in_filter.close()
16 | out_filter.close()
17 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask/layout.html:
--------------------------------------------------------------------------------
1 | {%- extends "basic/layout.html" %}
2 |
3 | {% block sidebarlogo %}
{% endblock %}
5 | {% block header %}
6 | {{ super() }}
7 | {% endblock %}
8 | {% block relbar2 %} {% endblock %}
9 |
10 | {%- block footer %}
11 |
15 |
16 | {%- endblock %}
17 |
--------------------------------------------------------------------------------
/README:
--------------------------------------------------------------------------------
1 | Spike sorting library implemented in Python/NumPy/PyTables
2 | ----------------------------------------------------------
3 |
4 | Project website: spikesort.org
5 |
6 | Requirements:
7 |
8 | * Python >= 2.6
9 | * PyTables
10 | * NumPy
11 | * matplotlib
12 |
13 | Optional:
14 |
15 | * scikits.learn -- clustering algorithms
16 | * neurotools -- spike train analysis
17 |
18 | Test dependencies:
19 |
20 | * all the above
21 | * hdf5-tools
22 |
23 | To see the library in actions see examples folder.
24 |
--------------------------------------------------------------------------------
/docs/source/modules/evaluate.rst:
--------------------------------------------------------------------------------
1 | Evaluate (:mod:`spike_sort.evaluate`)
2 | =====================================
3 |
4 | .. currentmodule:: spike_sort.core.evaluate
5 |
6 |
7 | Utility functions to evaluate the quality of spike sorting
8 |
9 | .. autosummary::
10 |
11 | snr_spike
12 | snr_clust
13 | detect_noise
14 | calc_noise_threshold
15 | isolation_score
16 | calc_isolation_score
17 |
18 |
19 | Reference
20 | ---------
21 |
22 | .. automodule:: spike_sort.core.evaluate
23 | :members:
24 |
--------------------------------------------------------------------------------
/examples/analysis/cell_dashboard.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import os
5 | import matplotlib.pyplot as plt
6 |
7 | from spike_analysis import dashboard
8 | from spike_sort.io.filters import BakerlabFilter
9 |
10 | if __name__ == "__main__":
11 | cell = "/Gollum/s39gollum03/el1/cell1"
12 | path = os.path.join(os.pardir, os.pardir, 'data', 'gollum_export.inf')
13 | filt = BakerlabFilter(path)
14 |
15 | dashboard.show_cell(filt, cell)
16 |
17 | plt.show()
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/docs/source/modules/ui.rst:
--------------------------------------------------------------------------------
1 | User Interface (:mod:`spike_sort.ui`)
2 | =====================================
3 |
4 | .. currentmodule:: spike_sort.ui
5 |
6 | Plotting (:mod:`spike_sort.ui.plotting`)
7 | ----------------------------------------
8 |
9 | This module provides basic plotting capabilities using matplotlib
10 | library.
11 |
12 | .. autosummary::
13 |
14 | ~plotting.plot_spikes
15 | ~plotting.plot_features
16 |
17 |
18 | Reference
19 | ---------
20 |
21 | .. automodule:: spike_sort.ui.plotting
22 | :members:
23 | :undoc-members:
24 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | virtualenv:
5 | system_site_packages: true
6 | before_install:
7 | - sudo apt-get update -qq
8 | - sudo apt-get install libhdf5-serial-dev python-numpy python-scipy libatlas-dev liblapack-dev gfortran python-tk python-matplotlib hdf5-tools python-tables cython tk-dev
9 | # command to install dependencies
10 | install:
11 | - "pip install numexpr cython --use-mirrors"
12 | - "pip install -r requirements.txt --use-mirrors"
13 | - "pip install ."
14 | # command to run tests
15 | script: nosetests -P
16 |
--------------------------------------------------------------------------------
/docs/README:
--------------------------------------------------------------------------------
1 | The documentation was created using Sphinx. You may create a HTML
2 | version of the documentation using:
3 |
4 | make html
5 |
6 | Alternatively, you can call your sphinx-build tool manually running the
7 | following command from docs directory:
8 |
9 | sphinx-build-2.6 -b html -d build/doctrees source build/html
10 |
11 | Note: In order to compile figures in tutorials you need to download
12 | tutorial.h5 file and copy it to spike_sort/data directory. In
13 | addition, matplotlib has to be installed at the time of compilation
14 | (it is also a dependency of spike_sort).
15 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask/relations.html:
--------------------------------------------------------------------------------
1 | Related Topics
2 |
20 |
--------------------------------------------------------------------------------
/docs/source/modules/extract.rst:
--------------------------------------------------------------------------------
1 | Extract (:mod:`spike_sort.core.extract`)
2 | ========================================
3 |
4 | This module is a collection of functions for pre-processing of raw
5 | recordings (filtering and upsampling), detecting spikes and extracting
6 | spikes based on detected spike times.
7 |
8 | .. currentmodule:: spike_sort.core.extract
9 |
10 | .. autosummary::
11 |
12 | align_spikes
13 | detect_spikes
14 | extract_spikes
15 | filter_proxy
16 | merge_spikes
17 | merge_spiketimes
18 | remove_spikes
19 | resample_spikes
20 | split_cells
21 |
22 | Reference
23 | ---------
24 |
25 | .. automodule:: spike_sort.core.extract
26 | :members:
27 | :undoc-members:
28 |
--------------------------------------------------------------------------------
/examples/analysis/cell_xcorr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import os
5 | import itertools
6 |
7 | import matplotlib.pyplot as plt
8 |
9 | from spike_sort.io.filters import BakerlabFilter
10 | from spike_analysis import io_tools, xcorr
11 |
12 | cell_pattern = "/Gollum/s4gollum*/el*/cell*"
13 | path = os.path.join(os.pardir, os.pardir, 'data', 'gollum.inf')
14 | filt = BakerlabFilter(path)
15 |
16 | if __name__ == "__main__":
17 | cell_nodes = io_tools.list_cells(filt, cell_pattern)
18 | cells = [io_tools.read_dataset(filt, node) for node in cell_nodes]
19 | for data, name in itertools.izip(cells, cell_nodes):
20 | data['dataset'] = name
21 | xcorr.show_xcorr(cells)
22 | plt.show()
23 |
24 |
25 |
--------------------------------------------------------------------------------
/examples/sorting/browse_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | """
5 | Simple raw data browser.
6 |
7 | Keyboard shortcuts:
8 |
9 | +/- - zoom in/out
10 | """
11 |
12 | import spike_sort as sort
13 | from spike_sort.io.filters import PyTablesFilter
14 | from spike_sort.ui import spike_browser
15 | import os
16 |
17 | DATAPATH = os.environ['DATAPATH']
18 |
19 | if __name__ == "__main__":
20 | dataset = "/SubjectA/session01/el1"
21 | data_fname = os.path.join(DATAPATH, "tutorial.h5")
22 |
23 | io_filter = PyTablesFilter(data_fname)
24 | sp = io_filter.read_sp(dataset)
25 | spt = sort.extract.detect_spikes(sp, contact=3, thresh='auto')
26 |
27 | spike_browser.browse_data_tk(sp, spt, win=50)
28 |
--------------------------------------------------------------------------------
/docs/source/pyplots/tutorial_spikes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from spike_sort.io.filters import PyTablesFilter
5 | from spike_sort import extract
6 | from spike_sort import features
7 | from spike_sort import cluster
8 | from spike_sort.ui import plotting
9 | import os
10 |
11 | dataset = '/SubjectA/session01/el1'
12 | datapath = '../../../data/tutorial.h5'
13 |
14 | io_filter = PyTablesFilter(datapath)
15 | raw = io_filter.read_sp(dataset)
16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto')
17 |
18 | sp_win = [-0.2, 0.8]
19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10)
20 | sp_waves = extract.extract_spikes(raw, spt, sp_win)
21 | plotting.plot_spikes(sp_waves, n_spikes=200)
22 | plotting.show()
23 | io_filter.close()
24 |
--------------------------------------------------------------------------------
/docs/source/tutorials/tutorial_io.rst:
--------------------------------------------------------------------------------
1 | .. _io_tutorial:
2 |
3 | Reading custom data formats
4 | ===========================
5 |
6 | .. testsetup::
7 |
8 | .. testcleanup::
9 |
10 | SpikeSort provides a flexible interface with various data sources via
11 | so call input-output `Filters`. The code available for download already
12 | contains a few filters that allow for reading (and in some cases writing)
13 | of several data formats, including:
14 |
15 | * raw binary data,
16 | * HDF5 files,
17 | * proprietary data formats (such as Axon `.abf`) supported by
18 | `Neo `_ library.
19 |
20 | In this tutorial we show you how to use available filters to read and plot data
21 | from `.abf` files and next we will convience you how easy it is to add support
22 | for custom formats by defining your own `Filter`.
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/spike_sort/io/export.py:
--------------------------------------------------------------------------------
1 | def export_cells(io_filter, node_templ, spike_times, overwrite=False):
2 | """Export discriminated spike times of all cells to a file.
3 |
4 | Parameters
5 | ----------
6 | io_filter : object,
7 | read/write filter object (see :py:mod:`spike_sort.io.filters`)
8 | node_templ : string
9 | string identifing the dataset name. It will be passed to
10 | IOFilters.write_spt method. It can contain the
11 | `{cell_id}` placeholder that will be substituted by cell
12 | identifier.
13 | spt_dict : dict
14 | dictionary in which keys are the cell IDs and values are spike
15 | times structures
16 | """
17 | for cell_id, spt_cell in spike_times.items():
18 | dataset = node_templ.format(cell_id=cell_id)
19 | io_filter.write_spt(spt_cell, dataset, overwrite=overwrite)
20 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. SortSpikes documentation master file, created by
2 | sphinx-quickstart on Thu Jan 27 14:09:31 2011.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to SpikesSort's documentation!
7 | ======================================
8 |
9 | Contents:
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 |
14 | intro
15 | tutorials/index
16 | datastructures
17 | datafiles
18 |
19 | Modules:
20 |
21 | .. toctree::
22 | :maxdepth: 1
23 |
24 | modules/extract
25 | modules/features
26 | modules/cluster
27 | modules/evaluate
28 | modules/io
29 | modules/ui
30 |
31 | modules/components
32 |
33 |
34 |
35 | Indices and tables
36 | ==================
37 |
38 | * :ref:`genindex`
39 | * :ref:`modindex`
40 | * :ref:`search`
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/docs/source/pyplots/tutorial_features.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from spike_sort.io.filters import PyTablesFilter
5 | from spike_sort import extract
6 | from spike_sort import features
7 | from spike_sort import cluster
8 | from spike_sort.ui import plotting
9 | import os
10 |
11 | dataset = '/SubjectA/session01/el1'
12 | datapath = '../../../data/tutorial.h5'
13 |
14 | io_filter = PyTablesFilter(datapath)
15 | raw = io_filter.read_sp(dataset)
16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto')
17 |
18 | sp_win = [-0.2, 0.8]
19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10)
20 | sp_waves = extract.extract_spikes(raw, spt, sp_win)
21 | sp_feats = features.combine(
22 | (
23 | features.fetP2P(sp_waves),
24 | features.fetPCA(sp_waves)
25 | )
26 | )
27 |
28 | plotting.plot_features(sp_feats)
29 | plotting.show()
30 |
--------------------------------------------------------------------------------
/docs/source/datafiles.rst:
--------------------------------------------------------------------------------
1 | Sample data files
2 | =================
3 |
4 |
5 | These files are provided only for demonstration purposes. Authors do
6 | not hold any responsibility for using these file for other purposes
7 | than testing and learning the functions of SpikeSort. Unless stated
8 | otherwise, use of these datasets in publications is allowed only with
9 | written consent of the authors.
10 |
11 |
12 | .. _tutorial_data:
13 |
14 | Tutorial data
15 | -------------
16 |
17 | `tutorial.h5 `_
18 |
19 | This is a single dataset of extracellular spikes recorded with a
20 | tetrode in stimulus-driven paradigm. For more information see
21 | [Telenczuk2011]_.
22 |
23 | .. [Telenczuk2011] Telenczuk, Bartosz, Stuart N Baker, Andreas V M Herz, and Gabriel Curio. *“High-frequency EEG Covaries with Spike Burst Patterns Detected in Cortical Neurons.”* Journal of Neurophysiology **105**, no. 6 (2011): 2951–2959.
24 |
25 |
--------------------------------------------------------------------------------
/docs/source/pyplots/tutorial_cells.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from spike_sort.io.filters import PyTablesFilter
5 | from spike_sort import extract
6 | from spike_sort import features
7 | from spike_sort import cluster
8 | from spike_sort.ui import plotting
9 | import os
10 |
11 | dataset = '/SubjectA/session01/el1'
12 | datapath = '../../../data/tutorial.h5'
13 |
14 | io_filter = PyTablesFilter(datapath)
15 | raw = io_filter.read_sp(dataset)
16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto')
17 |
18 | sp_win = [-0.2, 0.8]
19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10)
20 | sp_waves = extract.extract_spikes(raw, spt, sp_win)
21 | sp_feats = features.combine(
22 | (
23 | features.fetP2P(sp_waves),
24 | features.fetPCA(sp_waves)
25 | )
26 | )
27 | clust_idx = cluster.cluster("gmm",sp_feats,4)
28 | plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200)
29 | plotting.show()
30 | io_filter.close()
31 |
--------------------------------------------------------------------------------
/docs/source/pyplots/tutorial_clusters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from spike_sort.io.filters import PyTablesFilter
5 | from spike_sort import extract
6 | from spike_sort import features
7 | from spike_sort import cluster
8 | from spike_sort.ui import plotting
9 | import os
10 |
11 | dataset = '/SubjectA/session01/el1'
12 | datapath = '../../../data/tutorial.h5'
13 |
14 | io_filter = PyTablesFilter(datapath)
15 | raw = io_filter.read_sp(dataset)
16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto')
17 |
18 | sp_win = [-0.2, 0.8]
19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10)
20 | sp_waves = extract.extract_spikes(raw, spt, sp_win)
21 | sp_feats = features.combine(
22 | (
23 | features.fetP2P(sp_waves),
24 | features.fetPCA(sp_waves)
25 | )
26 | )
27 |
28 | clust_idx = cluster.cluster("gmm",sp_feats,4)
29 | plotting.plot_features(sp_feats, clust_idx)
30 | plotting.show()
31 | io_filter.close()
32 |
--------------------------------------------------------------------------------
/docs/source/modules/cluster.rst:
--------------------------------------------------------------------------------
1 | Cluster (:mod:`spike_sort.cluster`)
2 | ===================================
3 |
4 | .. currentmodule:: spike_sort.core.cluster
5 |
6 | Module with clustering algorithms.
7 |
8 | Utility functions
9 | -----------------
10 |
11 | Spike sorting is usually done with the
12 | :py:func:`~spike_sort.core.cluster.cluster` function which takes as an
13 | argument one of the clustering methods (as a string).
14 |
15 | Others functions help to manipulate the results:
16 |
17 | .. autosummary::
18 |
19 | cluster
20 | split_cells
21 |
22 |
23 |
24 | Clustering methods
25 | ------------------
26 |
27 | Several different clustering methods are defined in the module. Each
28 | method should take at least one argument -- the features structure.
29 |
30 | .. autosummary::
31 |
32 | k_means_plus
33 | gmm
34 | manual
35 | none
36 | k_means
37 |
38 |
39 | Reference
40 | ---------
41 |
42 | .. automodule:: spike_sort.core.cluster
43 | :members:
44 |
--------------------------------------------------------------------------------
/examples/sorting/read_axon.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | matplotlib.use("TkAgg")
5 | matplotlib.interactive(True)
6 |
7 | from spike_beans import components, base
8 | from spike_sort.io import neo_filters
9 |
10 | ####################################
11 | # Adjust these fields for your needs
12 |
13 | sp_win = [-0.6, 0.8]
14 |
15 | url = 'https://portal.g-node.org/neo/axon/File_axon_1.abf'
16 | path = 'file_axon.abf'
17 |
18 | import urllib
19 | urllib.urlretrieve(url, path)
20 |
21 | io = neo_filters.NeoSource(path)
22 |
23 | base.register("SignalSource", io)
24 | base.register("SpikeMarkerSource",
25 | components.SpikeDetector(contact=0,
26 | thresh='auto',
27 | type='max',
28 | sp_win=sp_win,
29 | resample=1,
30 | align=True))
31 | base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
32 |
33 | browser = components.SpikeBrowser()
34 |
35 | browser.show()
36 |
--------------------------------------------------------------------------------
/tests/test_stats.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from spike_sort import stats
3 |
4 | from nose.tools import ok_
5 |
6 | class TestStats(object):
7 | def setup(self):
8 | n_samples = 1000
9 |
10 | # unimodal
11 | normal = np.random.randn(n_samples)
12 | self.normal = (normal - normal.mean()) / normal.std()
13 | uniform = np.random.uniform(size = n_samples)
14 | self.uniform = (uniform - uniform.mean()) / uniform.std()
15 | laplace = np.random.laplace(size = n_samples)
16 | self.laplace = (laplace - laplace.mean()) / laplace.std()
17 |
18 | # multimodal
19 | multimodal = np.hstack((self.normal, self.uniform + 4, self.laplace + 8))
20 | np.random.shuffle(multimodal)
21 | multimodal = multimodal[:n_samples]
22 | self.multimodal = (multimodal - multimodal.mean()) / multimodal.std()
23 |
24 | def test_multimodality_detection(self):
25 | data = [self.normal, self.uniform, self.laplace]
26 | tests = [stats.dip, stats.ks]
27 |
28 | detected = [test(self.multimodal) > test(dist) for dist in data for test in tests]
29 |
30 | ok_(np.array(detected).all())
31 |
--------------------------------------------------------------------------------
/src/spike_sort/stats/diptst/diptst.pyf:
--------------------------------------------------------------------------------
1 | ! -*- f90 -*-
2 | ! Note: the context of this file is case sensitive.
3 |
4 | python module _diptst ! in
5 | interface ! in :_diptst
6 | subroutine diptst1(x,n,dip,xl,xu,ifault,gcm,lcm,mn,mj,ddx,ddxsgn) ! in :_diptst:diptst.f
7 | real dimension(n),intent(in) :: x
8 | integer optional,check(len(x)>=n),intent(hide),depend(x) :: n=len(x)
9 | real intent(out) :: dip
10 | real intent(out) :: xl
11 | real intent(out) :: xu
12 | integer intent(out) :: ifault
13 | integer intent(hide,cache),dimension(n),depend(n) :: gcm
14 | integer intent(hide,cache),dimension(n),depend(n) :: lcm
15 | integer intent(hide,cache),dimension(n),depend(n) :: mn
16 | integer intent(hide,cache),dimension(n),depend(n) :: mj
17 | real intent(hide,cache),dimension(n),depend(n) :: ddx
18 | integer intent(hide,cache),dimension(n),depend(n) :: ddxsgn
19 | end subroutine diptst1
20 | end interface
21 | end python module _diptst
22 |
23 | ! This file was auto-generated with f2py (version:2).
24 | ! See http://cens.ioc.ee/projects/f2py2e/
25 |
--------------------------------------------------------------------------------
/conda-environment.yml:
--------------------------------------------------------------------------------
1 | name: spikesort-rtd
2 | dependencies:
3 | - alabaster=0.7.7=py27_0
4 | - babel=2.2.0=py27_0
5 | - cairo=1.12.18=6
6 | - cycler=0.10.0=py27_0
7 | - docutils=0.12=py27_0
8 | - fontconfig=2.11.1=5
9 | - freetype=2.5.5=0
10 | - hdf5=1.8.15.1=2
11 | - jinja2=2.8=py27_0
12 | - libgfortran=1.0=0
13 | - libpng=1.6.17=0
14 | - libxml2=2.9.2=0
15 | - markupsafe=0.23=py27_0
16 | - matplotlib=1.5.1=np110py27_0
17 | - mkl=11.3.1=0
18 | - numexpr=2.4.6=np110py27_1
19 | - numpy=1.10.4=py27_0
20 | - numpydoc=0.5=py27_1
21 | - openssl=1.0.2g=0
22 | - pip=8.0.3=py27_0
23 | - pixman=0.32.6=0
24 | - pycairo=1.10.0=py27_0
25 | - pygments=2.1.1=py27_0
26 | - pyparsing=2.0.3=py27_0
27 | - pyqt=4.11.4=py27_1
28 | - pytables=3.2.2=np110py27_0
29 | - python=2.7.11=0
30 | - python-dateutil=2.4.2=py27_0
31 | - pytz=2015.7=py27_0
32 | - qt=4.8.7=1
33 | - readline=6.2=2
34 | - scikit-learn=0.17.1=np110py27_0
35 | - scipy=0.17.0=np110py27_1
36 | - setuptools=20.1.1=py27_0
37 | - sip=4.16.9=py27_0
38 | - six=1.10.0=py27_0
39 | - snowballstemmer=1.2.1=py27_0
40 | - sphinx=1.3.5=py27_0
41 | - sphinx_rtd_theme=0.1.9=py27_0
42 | - sqlite=3.9.2=0
43 | - tk=8.5.18=0
44 | - wheel=0.29.0=py27_0
45 | - zlib=1.2.8=0
46 | - pip:
47 | - pywavelets==0.4.0
48 |
--------------------------------------------------------------------------------
/docs/source/_themes/README:
--------------------------------------------------------------------------------
1 | Flask Sphinx Styles
2 | ===================
3 |
4 | This repository contains sphinx styles for Flask and Flask related
5 | projects. To use this style in your Sphinx documentation, follow
6 | this guide:
7 |
8 | 1. put this folder as _themes into your docs folder. Alternatively
9 | you can also use git submodules to check out the contents there.
10 | 2. add this to your conf.py:
11 |
12 | sys.path.append(os.path.abspath('_themes'))
13 | html_theme_path = ['_themes']
14 | html_theme = 'flask'
15 |
16 | The following themes exist:
17 |
18 | - 'flask' - the standard flask documentation theme for large
19 | projects
20 | - 'flask_small' - small one-page theme. Intended to be used by
21 | very small addon libraries for flask.
22 |
23 | The following options exist for the flask_small theme:
24 |
25 | [options]
26 | index_logo = '' filename of a picture in _static
27 | to be used as replacement for the
28 | h1 in the index.rst file.
29 | index_logo_height = 120px height of the index logo
30 | github_fork = '' repository name on github for the
31 | "fork me" badge
32 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask/static/small_flask.css:
--------------------------------------------------------------------------------
1 | /*
2 | * small_flask.css_t
3 | * ~~~~~~~~~~~~~~~~~
4 | *
5 | * :copyright: Copyright 2010 by Armin Ronacher.
6 | * :license: Flask Design License, see LICENSE for details.
7 | */
8 |
9 | body {
10 | margin: 0;
11 | padding: 20px 30px;
12 | }
13 |
14 | div.documentwrapper {
15 | float: none;
16 | background: white;
17 | }
18 |
19 | div.sphinxsidebar {
20 | display: block;
21 | float: none;
22 | width: 102.5%;
23 | margin: 50px -30px -20px -30px;
24 | padding: 10px 20px;
25 | background: #333;
26 | color: white;
27 | }
28 |
29 | div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p,
30 | div.sphinxsidebar h3 a {
31 | color: white;
32 | }
33 |
34 | div.sphinxsidebar a {
35 | color: #aaa;
36 | }
37 |
38 | div.sphinxsidebar p.logo {
39 | display: none;
40 | }
41 |
42 | div.document {
43 | width: 100%;
44 | margin: 0;
45 | }
46 |
47 | div.related {
48 | display: block;
49 | margin: 0;
50 | padding: 10px 0 20px 0;
51 | }
52 |
53 | div.related ul,
54 | div.related ul li {
55 | margin: 0;
56 | padding: 0;
57 | }
58 |
59 | div.footer {
60 | display: none;
61 | }
62 |
63 | div.bodywrapper {
64 | margin: 0;
65 | }
66 |
67 | div.body {
68 | min-height: 0;
69 | padding: 0;
70 | }
71 |
--------------------------------------------------------------------------------
/docs/source/modules/features.rst:
--------------------------------------------------------------------------------
1 | Calculate features (:mod:`spike_sort.features`)
2 | ===============================================
3 |
4 | .. currentmodule:: spike_sort.core.features
5 |
6 | Provides functions to calculate spike waveforms features.
7 |
8 | .. _features_doc:
9 |
10 | Features
11 | --------
12 |
13 | Functions starting with `fet` implement various features calculated
14 | from the spike waveshapes. They have usually one required argument
15 | :ref:`spike_wave` structure and some can have optional arguments (see
16 | below)
17 |
18 | Each of the function returns a mapping structure (dictionary) with the following keys:
19 |
20 | * `data` -- an array of shape (n_spikes x n_features)
21 | * `names` -- a list of length n_features with feature labels
22 |
23 | The following features are implemented:
24 |
25 | .. autosummary::
26 |
27 | fetPCA
28 | fetP2P
29 | fetSpIdx
30 | fetSpTime
31 | fetSpProjection
32 |
33 | Tools
34 | -----
35 |
36 | This module provides a few tools to facilitate working with features
37 | data structure:
38 |
39 | .. autosummary::
40 |
41 | split_cells
42 | select
43 | combine
44 | normalize
45 |
46 | Auxiliary
47 | ---------
48 | .. autosummary::
49 |
50 | PCA
51 | add_mask
52 |
53 |
54 | Reference
55 | ---------
56 |
57 | .. automodule:: spike_sort.core.features
58 | :members:
59 |
--------------------------------------------------------------------------------
/examples/sorting/cluster_auto.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | """
5 | Based on raw recordings detect spikes, calculate features and do automatic
6 | clustering with gaussian mixture models.
7 | """
8 |
9 | import os
10 |
11 | import spike_sort as sort
12 | from spike_sort.io.filters import PyTablesFilter
13 | import spike_sort.ui.manual_sort
14 |
15 | DATAPATH = os.environ['DATAPATH']
16 |
17 | if __name__ == "__main__":
18 | h5_fname = os.path.join(DATAPATH, "tutorial.h5")
19 | h5filter = PyTablesFilter(h5_fname, 'a')
20 |
21 | dataset = "/SubjectA/session01/el1"
22 | sp_win = [-0.2, 0.8]
23 |
24 | sp = h5filter.read_sp(dataset)
25 | spt = sort.extract.detect_spikes(sp, contact=3, thresh='auto')
26 |
27 | spt = sort.extract.align_spikes(sp, spt, sp_win, type="max", resample=10)
28 | sp_waves = sort.extract.extract_spikes(sp, spt, sp_win)
29 | features = sort.features.combine(
30 | (sort.features.fetP2P(sp_waves),
31 | sort.features.fetPCA(sp_waves)),
32 | norm=True
33 | )
34 |
35 |
36 | clust_idx = sort.cluster.cluster("gmm",features,4)
37 |
38 | spike_sort.ui.plotting.plot_features(features, clust_idx)
39 | spike_sort.ui.plotting.figure()
40 | spike_sort.ui.plotting.plot_spikes(sp_waves, clust_idx,n_spikes=200)
41 |
42 |
43 | spike_sort.ui.plotting.show()
44 | h5filter.close()
45 |
--------------------------------------------------------------------------------
/src/spike_sort/ui/_mpl_helpers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | from matplotlib.axes import Axes
5 | from matplotlib.ticker import NullLocator
6 | from matplotlib.projections import register_projection
7 | import matplotlib.axis as maxis
8 |
9 | class NoTicksXAxis(maxis.XAxis):
10 | def reset_ticks(self):
11 | self._lastNumMajorTicks = 1
12 | self._lastNumMinorTicks = 1
13 | def set_clip_path(self, clippath, transform=None):
14 | pass
15 |
16 | class NoTicksYAxis(maxis.YAxis):
17 | def reset_ticks(self):
18 | self._lastNumMajorTicks = 1
19 | self._lastNumMinorTicks = 1
20 | def set_clip_path(self, clippath, transform=None):
21 | pass
22 |
23 | class ThinAxes(Axes):
24 | """Thin axes without spines and ticks to accelerate axes creation"""
25 |
26 | name = 'thin'
27 |
28 | def _init_axis(self):
29 | self.xaxis = NoTicksXAxis(self)
30 | self.yaxis = NoTicksYAxis(self)
31 |
32 | def cla(self):
33 | """
34 | Override to set up some reasonable defaults.
35 | """
36 | Axes.cla(self)
37 | self.xaxis.set_minor_locator(NullLocator())
38 | self.yaxis.set_minor_locator(NullLocator())
39 | self.xaxis.set_major_locator(NullLocator())
40 | self.yaxis.set_major_locator(NullLocator())
41 |
42 | def _gen_axes_spines(self):
43 | return {}
44 |
45 |
46 | # Now register the projection with matplotlib so the user can select
47 | # it.
48 | register_projection(ThinAxes)
49 |
--------------------------------------------------------------------------------
/docs/source/modules/io.rst:
--------------------------------------------------------------------------------
1 | File I/O (:mod:`spike_sort.io`)
2 | ===============================
3 |
4 | .. currentmodule:: spike_sort.io
5 |
6 | Functions for reading and writing datafiles.
7 |
8 | .. _io_filters:
9 |
10 | Read/Write Filters (:mod:`spike_sort.io.filters`)
11 | -------------------------------------------------
12 |
13 | Filters are basic backends for read/write operations. They offer following
14 | methods:
15 |
16 | * :py:func:`read_spt` -- read event times (such as spike times)
17 | * :py:func:`read_sp` -- read raw spike waveforms
18 | * :py:func:`write_spt` -- write spike times
19 | * :py:func:`write_sp` -- write raw spike waveforms
20 |
21 | The `read_*` methods take usually one argument (`datapath`), but it is
22 | not required. The `write_*` methods take `datapath` and the data to be
23 | written.
24 |
25 | If you want to read/write you custom data format, it is enough that
26 | you implement a class with these functions.
27 |
28 | The following filters are implemented:
29 |
30 | .. autosummary::
31 |
32 | filters.BakerlabFilter
33 | filters.PyTablesFilter
34 |
35 |
36 | Export tools (:mod:`spike_sort.io.export`)
37 | --------------------------------------------
38 |
39 | These tolls take one of the `io.filters` as an argument and export
40 | data to the file using `write_spt` or `write_sp` methods.
41 |
42 | .. autosummary::
43 |
44 | export.export_cells
45 |
46 | Reference
47 | ---------
48 |
49 | .. automodule:: spike_sort.io.filters
50 | :members:
51 |
52 | .. automodule:: spike_sort.io.export
53 | :members:
54 |
55 |
--------------------------------------------------------------------------------
/src/spike_analysis/xcorr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | def raise_exception(*args, **kwargs):
8 | raise NotImplementedError("This function requires NeuroTools")
9 |
10 | try:
11 | from NeuroTools.analysis import crosscorrelate
12 | except ImportError:
13 | crosscorrelate = raise_exception
14 |
15 | def show_xcorr(cells):
16 | n = len(cells)
17 | maxlag = 3
18 | bins = np.arange(-maxlag, maxlag, 0.01)
19 | ax = None
20 | for i in range(n):
21 | for j in range(i, n):
22 | if i != j:
23 | ax = plt.subplot(n, n, i + j * n + 1, sharey=ax)
24 |
25 | crosscorrelate(cells[i]['spt'], cells[j]['spt'], maxlag,
26 | display=ax,
27 | kwargs={"bins": bins})
28 | plt.axvline(0, color='k')
29 |
30 | if ax:
31 | ax.set_ylabel("")
32 | ax.set_yticks([])
33 | ax.set_xticks([])
34 | ax.set_xlabel("")
35 | else:
36 | ax_label = plt.subplot(n, n, i + j * n + 1, frameon=False)
37 | ax_label.set_xticks([])
38 | ax_label.set_yticks([])
39 | ax_label.text(0,0, cells[j]['dataset'],
40 | transform=ax_label.transAxes,
41 | rotation=45, ha='left', va='bottom')
42 |
43 | ax.set_xticks((-maxlag, maxlag))
44 | ax.set_yticks(ax.get_ylim())
45 |
--------------------------------------------------------------------------------
/examples/sorting/cluster_manual.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | """
5 | Based on raw recordings detect spikes, calculate features and do
6 | clustering by means of manual cluster-cutting.
7 | """
8 |
9 | import os
10 |
11 | import matplotlib
12 | matplotlib.use("TkAgg")
13 | matplotlib.interactive(True)
14 |
15 | import spike_sort as sort
16 | from spike_sort.io.filters import PyTablesFilter
17 |
18 | DATAPATH = os.environ['DATAPATH']
19 |
20 | if __name__ == "__main__":
21 | h5_fname = os.path.join(DATAPATH, "tutorial.h5")
22 | h5filter = PyTablesFilter(h5_fname, 'r')
23 |
24 | dataset = "/SubjectA/session01/el1"
25 | sp_win = [-0.2, 0.8]
26 |
27 | sp = h5filter.read_sp(dataset)
28 | spt = sort.extract.detect_spikes(sp, contact=3, thresh=300)
29 |
30 | spt = sort.extract.align_spikes(sp, spt, sp_win, type="max", resample=10)
31 | sp_waves = sort.extract.extract_spikes(sp, spt, sp_win)
32 | features = sort.features.combine(
33 | (
34 | sort.features.fetSpIdx(sp_waves),
35 | sort.features.fetP2P(sp_waves),
36 | sort.features.fetPCA(sp_waves)),
37 | norm=True
38 | )
39 |
40 | clust_idx = sort.ui.manual_sort.manual_sort(features,
41 | ['Ch0:P2P', 'Ch3:P2P'])
42 |
43 | clust, rest = sort.cluster.split_cells(spt, clust_idx, [1, 0])
44 |
45 | sort.ui.plotting.figure()
46 | sort.ui.plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200)
47 |
48 | raw_input('Press any key to exit...')
49 |
50 | h5filter.close()
51 |
--------------------------------------------------------------------------------
/src/spike_analysis/io_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 | import glob, re
4 | import os
5 |
6 | def read_dataset(filter, dataset):
7 | spt = filter.read_spt(dataset)
8 | stim_node = "/".join(dataset.split('/')[:-1]+['stim'])
9 | stim = filter.read_spt(stim_node)
10 | return {'spt': spt['data'], 'stim': stim['data'], 'ev': []}
11 |
12 | def list_cells(filter, dataset):
13 | """List all cells which fit the pattern given in dataset. Dataset can contain
14 | wildcards.
15 |
16 | Example:
17 |
18 | dataset = "/Subject/sSession01/el*/cell*"
19 | """
20 | regexp = "^/(?P[a-zA-z\*]+)/s(?P.+)/el(?P[0-9\*]+)/?(?P[a-zA-Z]+)?(?P[0-9\*]+)?$"
21 |
22 | conf = filter.conf_dict
23 | fpath = (conf['dirname'].format(DATAPATH=os.environ['DATAPATH'])+conf['cell'])
24 | rec_wildcard = re.match(regexp, dataset).groupdict()
25 | fname = fpath.format(**rec_wildcard)
26 |
27 | files = glob.glob(fname)
28 |
29 | rec_regexp = {"subject": "(?P[a-zA-z\*]+)",
30 | "ses_id": "(?P.+)",
31 | "el_id": "(?P[0-9\*]+)",
32 | "cell_id": "(?P[0-9\*]+)"}
33 | node_fmt = "/{subject}/s{ses_id}/el{el_id}/cell{cell_id}"
34 |
35 | f_regexp = fpath.format(**rec_regexp)
36 | pattern = re.compile(f_regexp)
37 | nodes = []
38 | for f in files:
39 | dataset_match = pattern.match(f)
40 | rec = rec_wildcard.copy()
41 | rec.update(dataset_match.groupdict())
42 | nodes.append(node_fmt.format(**rec))
43 |
44 | return nodes
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2012, Bartosz Telenczuk, Dmytro Bielievtsov
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation
11 | and/or other materials provided with the distribution.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 |
24 | The views and conclusions contained in the software and documentation are those
25 | of the authors and should not be interpreted as representing official policies,
26 | either expressed or implied, of the FreeBSD Project.
27 |
28 | Third-party code
29 | ----------------
30 |
31 | Initial version of the third-party code in `spike_beans.base` library was
32 | created by Zoran Isailovski and Ed Swierk and publised under permissive PSF
33 | license. For details see notes in the module.
34 |
--------------------------------------------------------------------------------
/docs/source/_themes/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2010 by Armin Ronacher.
2 |
3 | Some rights reserved.
4 |
5 | Redistribution and use in source and binary forms of the theme, with or
6 | without modification, are permitted provided that the following conditions
7 | are met:
8 |
9 | * Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above
13 | copyright notice, this list of conditions and the following
14 | disclaimer in the documentation and/or other materials provided
15 | with the distribution.
16 |
17 | * The names of the contributors may not be used to endorse or
18 | promote products derived from this software without specific
19 | prior written permission.
20 |
21 | We kindly ask you to only use these themes in an unmodified manner just
22 | for Flask and Flask-related products, not for unrelated projects. If you
23 | like the visual style and want to use it for your own projects, please
24 | consider making some larger changes to the themes (such as changing
25 | font faces, sizes, colors or margins).
26 |
27 | THIS THEME IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
30 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
31 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
32 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
33 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
34 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
35 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
36 | ARISING IN ANY WAY OUT OF THE USE OF THIS THEME, EVEN IF ADVISED OF THE
37 | POSSIBILITY OF SUCH DAMAGE.
38 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import setuptools
5 |
6 | from numpy.distutils.core import setup, Extension
7 |
8 | ext_modules = []
9 |
10 | import os
11 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
12 |
13 | if not on_rtd:
14 | diptst_ext = Extension(name = 'spike_sort.stats._diptst',
15 | sources = ['src/spike_sort/stats/diptst/diptst.f',
16 | 'src/spike_sort/stats/diptst/diptst.pyf'])
17 | ext_modules.append(diptst_ext)
18 |
19 |
20 | setup(name='SpikeSort',
21 | version='0.13',
22 | description='Python Spike Sorting Package',
23 | long_description="""SpikeSort is a flexible spike sorting framework
24 | implemented entirely in Python based on widely-used packages such as numpy,
25 | PyTables and matplotlib. It features manual and automatic clustering, many
26 | data formats and it is memory-efficient.""",
27 | author='Bartosz Telenczuk and Dmytro Bielievtsov',
28 | author_email='bartosz.telenczuk@gmail.com',
29 | url='http://spikesort.org',
30 | ext_modules = ext_modules,
31 | classifiers = [
32 | "Development Status :: 4 - Beta",
33 | "Environment :: Console",
34 | "Environment :: X11 Applications",
35 | "Intended Audience :: Science/Research",
36 | "License :: OSI Approved :: BSD License",
37 | "Operating System :: OS Independent",
38 | "Programming Language :: Python :: 2.7"
39 | ],
40 | packages=['spike_sort',
41 | 'spike_sort.core',
42 | 'spike_sort.stats',
43 | 'spike_sort.ui',
44 | 'spike_sort.io',
45 | 'spike_beans',
46 | 'spike_analysis',
47 | ],
48 | package_dir = {"": "src"},
49 | install_requires=[
50 | 'matplotlib',
51 | 'tables',
52 | 'numpy >= 1.4.1',
53 | 'scipy',
54 | 'PyWavelets'
55 | ]
56 | )
57 |
58 |
59 |
--------------------------------------------------------------------------------
/examples/sorting/cluster_beans.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | matplotlib.use("TkAgg")
5 | matplotlib.interactive(True)
6 |
7 | from spike_beans import components, base
8 |
9 | ####################################
10 | # Adjust these fields for your needs
11 |
12 | # data source
13 | hdf5filename = 'tutorial.h5'
14 | dataset = "/SubjectA/session01/el1"
15 |
16 | # spike detection/extraction properties
17 | contact = 3
18 | detection_type = "max"
19 | thresh = "auto"
20 | filter_freq = (800.0, 100.0)
21 |
22 | sp_win = [-0.6, 0.8]
23 |
24 | path = filter(None, os.environ['DATAPATH'].split(os.sep)) + [hdf5filename]
25 | hdf5file = os.path.join(os.sep, *path)
26 |
27 | io = components.PyTablesSource(hdf5file, dataset)
28 | io_filter = components.FilterStack()
29 |
30 | base.register("RawSource", io)
31 | base.register("EventsOutput", io)
32 | base.register("SignalSource", io_filter)
33 | base.register("SpikeMarkerSource",
34 | components.SpikeDetector(contact=contact,
35 | thresh=thresh,
36 | type=detection_type,
37 | sp_win=sp_win,
38 | resample=1,
39 | align=True))
40 | base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win))
41 | base.register("FeatureSource", components.FeatureExtractor())
42 | base.register("LabelSource", components.ClusterAnalyzer("gmm", 4))
43 |
44 | browser = components.SpikeBrowser()
45 | feature_plot = components.PlotFeaturesTimeline()
46 | wave_plot = components.PlotSpikes()
47 | legend = components.Legend()
48 | export = components.ExportCells()
49 |
50 | #############################################################
51 | # Add filters here:
52 | base.features["SignalSource"].add_filter("LinearIIR", *filter_freq)
53 |
54 | # Add the features here:
55 | base.features["FeatureSource"].add_feature("P2P")
56 | base.features["FeatureSource"].add_feature("PCA", ncomps=2)
57 |
58 | #############################################################
59 | # Run the analysis (this can take a while)
60 | browser.update()
61 |
--------------------------------------------------------------------------------
/src/spike_sort/io/neo_filters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import spike_sort as sort
5 | from spike_beans import components
6 | from spike_beans.components import GenericSource
7 | try:
8 | import neo
9 | except ImportError:
10 | raise ImportError, "To use the extra data filters you have to install Neo package"
11 | import numpy as np
12 | import os
13 |
14 | class AxonFilter(object):
15 | """Read Axon .abf files
16 |
17 | Parameters
18 | ----------
19 | fname : str
20 | path to file
21 |
22 | electrodes : (optional) list of ints
23 | list of electrode indices to use
24 |
25 | """
26 |
27 | def __init__(self, fname, electrodes=None):
28 | self.reader = neo.io.AxonIO(fname)
29 | self.block = self.reader.read_block()
30 | self.electrodes = electrodes
31 |
32 | def read_sp(self, dataset=None):
33 | electrodes = self.electrodes
34 | analogsignals = self.block.segments[0].analogsignals
35 | if electrodes is not None:
36 | analogsignals = [analogsignals[i] for i in electrodes]
37 | sp_raw = np.array(analogsignals)
38 | FS = float(analogsignals[0].sampling_rate.magnitude)
39 | n_contacts, _ = sp_raw.shape
40 | return {"data": sp_raw, "FS": FS, "n_contacts": n_contacts}
41 |
42 | def write_sp(self):
43 | raise NotImplementedError, "Writing to Axon files not yet implemented"
44 | def write_spt(self, spt_dict, dataset, overwrite=False):
45 | raise NotImplementedError, "Writing spikes in Axon format not yet implemented"
46 |
47 | class NeoSource(components.GenericSource):
48 | def __init__(self, fname, electrodes=None, overwrite=False):
49 | GenericSource.__init__(self, '', overwrite)
50 |
51 | root, ext = os.path.splitext(fname)
52 | ext = ext.lower()
53 | if ext == '.abf':
54 | self.io_filter = AxonFilter(fname, electrodes)
55 | else:
56 | raise IOError, "Format {0} not recognised".format(ext)
57 | self.read_sp = self.io_filter.read_sp
58 | self.write_sp = self.io_filter.write_sp
59 | self.write_spt = self.io_filter.write_spt
60 |
--------------------------------------------------------------------------------
/src/spike_sort/ui/zoomer.py:
--------------------------------------------------------------------------------
1 | class Zoomer(object):
2 | '''Allows to zoom subplots'''
3 | def __init__(self, plt, fig):
4 | self.axlist = []
5 |
6 | self.zoomed_state = {'geometry': (1, 1, 1),
7 | 'xlabel_visible': True,
8 | 'ylabel_visible': True}
9 |
10 | self.old_state = {'geometry': [1, 1, 1],
11 | 'xlabel_visible': True,
12 | 'ylabel_visible': True}
13 |
14 | plt.connect('key_press_event', self.zoom)
15 | self.fig = fig
16 |
17 | def zoom(self, event):
18 | axis = event.inaxes
19 |
20 | if axis is None or event.key != 'z':
21 | return
22 |
23 | if axis.get_geometry() == self.zoomed_state['geometry']:
24 | zoomed = True
25 | else:
26 | zoomed = False
27 |
28 | if not zoomed:
29 | # saving previous state
30 | self.old_state['geometry'] = axis.get_geometry()
31 | self.old_state['xlabel_visible'] = axis.xaxis.label.get_visible()
32 | self.old_state['ylabel_visible'] = axis.yaxis.label.get_visible()
33 |
34 | # removing old axes
35 | self.axlist = list(self.fig.get_axes())
36 | for ax in self.axlist:
37 | self.fig.delaxes(ax)
38 |
39 | # modifying state (zooming)
40 | axis.change_geometry(*self.zoomed_state['geometry'])
41 | axis.xaxis.label.set_visible(self.zoomed_state['xlabel_visible'])
42 | axis.yaxis.label.set_visible(self.zoomed_state['ylabel_visible'])
43 | self.fig.add_axes(axis)
44 | self.fig.show()
45 |
46 | else:
47 | # removing current axes
48 | self.fig.delaxes(axis)
49 |
50 | # bringing the old state back
51 | axis.change_geometry(*self.old_state['geometry'])
52 | axis.xaxis.label.set_visible(self.old_state['xlabel_visible'])
53 | axis.yaxis.label.set_visible(self.old_state['ylabel_visible'])
54 |
55 | # adding old axes
56 | for ax in self.axlist:
57 | self.fig.add_axes(ax)
58 |
59 | self.fig.show()
60 |
--------------------------------------------------------------------------------
/src/spike_analysis/dashboard.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import numpy as np
5 | from scipy import stats
6 | import basic
7 |
8 | import matplotlib.pyplot as plt
9 | from io_tools import read_dataset
10 |
11 | def plot_psth(ax, dataset, **kwargs):
12 | spt = dataset['spt']
13 | stim = dataset['stim']
14 | ev = dataset['ev']
15 |
16 | basic.plotPSTH(spt, stim,ax=ax, **kwargs)
17 | ymin, ymax = plt.ylim()
18 | if len(ev)>0:
19 | plt.vlines(ev, ymin, ymax)
20 |
21 | ax.text(0.95, 0.9,"total n/o spikes: %d" % (len(spt),),
22 | transform=ax.transAxes,
23 | ha='right')
24 |
25 | def plot_isi(ax, dataset, win=[0,5], bin=0.1, color='k'):
26 | spt = dataset['spt']
27 | stim = dataset['stim']
28 |
29 | isi = np.diff(spt)
30 | intvs = np.arange(win[0], win[1], bin)
31 | counts, bins = np.histogram(isi, intvs)
32 | mode,n = stats.mode(isi)
33 |
34 | ax.set_xlim(win)
35 | ax.set_ylim(0, counts.max())
36 | ax.plot(intvs[:-1], counts, color=color, drawstyle='steps-post')
37 | ax.axvline(mode, color='k')
38 |
39 | ax.text(0.95, 0.9,"mode: %.2f ms" % (mode,),
40 | transform=ax.transAxes,
41 | ha='right')
42 | ax.set_xlabel("interval (ms)")
43 | ax.set_ylabel("count")
44 |
45 | def plot_trains(ax, dataset, **kwargs):
46 | spt = dataset['spt']
47 | stim = dataset['stim']
48 | ev = dataset['ev']
49 |
50 | basic.plotraster(spt, stim,ax=ax, **kwargs)
51 | ymin, ymax = plt.ylim()
52 | if len(ev)>0:
53 | plt.vlines(ev, ymin, ymax)
54 |
55 | def plot_nspikes(ax, dataset, win=[0,30], color="k"):
56 | spt = dataset['spt']
57 | stim = dataset['stim']
58 |
59 | trains = basic.SortSpikes(spt, stim, win)
60 | n_spks = np.array([len(t) for t in trains])
61 | count, bins = np.histogram(n_spks, np.arange(10))
62 | ax.bar(bins[:-1]-0.5, count, color=color)
63 | ax.set_xlim((-1,10))
64 |
65 | burst_frac = np.mean(n_spks>1)
66 | ax.text(0.95, 0.9,"%d %% bursts" % (burst_frac*100,),
67 | transform=ax.transAxes,
68 | ha='right')
69 | ax.set_xlabel("no. spikes")
70 | ax.set_ylabel("count")
71 |
72 |
73 | def plot_dataset(dataset, fig=None, **kwargs):
74 |
75 | if not fig:
76 | fig = plt.gcf()
77 | plt.subplots_adjust(hspace=0.3)
78 | ax1=fig.add_subplot(2,2,1)
79 | plot_psth(ax1, dataset, **kwargs)
80 | plt.title("PSTH")
81 | ax2= fig.add_subplot(2,2,2)
82 | plot_isi(ax2, dataset, **kwargs)
83 | plt.title("ISIH")
84 | ax3=fig.add_subplot(2,2,3)
85 | plot_trains(ax3,dataset)
86 | plt.title("raster")
87 | ax4= fig.add_subplot(2,2,4)
88 | plot_nspikes(ax4, dataset, **kwargs)
89 | plt.title("burst order")
90 |
91 | def show_cell(filter, cell):
92 |
93 | dataset = read_dataset(filter, cell)
94 | plot_dataset(dataset)
95 |
--------------------------------------------------------------------------------
/src/spike_sort/stats/tests.py:
--------------------------------------------------------------------------------
1 | try:
2 | import _diptst
3 | except ImportError:
4 | _diptst = None
5 |
6 |
7 | import numpy as np
8 | from scipy import stats as st
9 |
10 |
11 | def multidimensional(func1d):
12 | """apply 1d function along specified axis
13 | """
14 | def _decorated(data, axis=0):
15 | if data.ndim <= 1:
16 | return func1d(data)
17 | else:
18 | return np.apply_along_axis(func1d, axis, data)
19 | _decorated.__doc__ = func1d.__doc__
20 |
21 | return _decorated
22 |
23 |
24 | def unsqueeze(data, axis):
25 | """inserts new axis to data at position `axis`.
26 |
27 | This is very useful when one wants to do operations which support
28 | broadcasting without using np.newaxis every time.
29 |
30 | Parameters
31 | ----------
32 | data : array
33 | input array
34 | axis : int
35 | axis to be inserted
36 |
37 | Returns
38 | -------
39 | out : array
40 | array which has data.ndim+1 dimensions. Additional dimension
41 | has length 1
42 |
43 | Example
44 | -------
45 | >>> data = [[1,2,3], [4,5,6]]
46 | >>> m = np.mean(data, 1)
47 | >>> data -= unsqueeze(m, 1)
48 | >>> data.mean(1)
49 | array([ 0., 0.])
50 |
51 | """
52 | shape = data.shape
53 | shape = np.insert(shape, axis, 1)
54 | return np.reshape(data, shape)
55 |
56 |
57 | def std_r(data, axis=0):
58 | """Computes robust estimate of standard deviation
59 | (Quiroga et al, 2004)
60 |
61 | Parameters
62 | ----------
63 | data : array
64 | input data array
65 | axis : int
66 |
67 | Returns
68 | -------
69 | data : array
70 |
71 | """
72 | median = unsqueeze(np.median(data, axis), axis)
73 | std_r = np.median(np.abs(data - median), axis)/0.6745
74 | return std_r
75 |
76 |
77 | @multidimensional
78 | def dip(data):
79 | """Computes DIP statistic (Hartigan & Hartigan 1985)
80 |
81 | Parameters
82 | ----------
83 | data : array
84 | input data array
85 | axis : int
86 | axis along which to compute dip
87 |
88 | Returns
89 | -------
90 | data : float or array
91 | DIP statistic. If the input data is flat, returns float
92 | """
93 | if not _diptst:
94 | raise NotImplemented, "module unavailable"
95 | sdata = np.sort(data)
96 | return _diptst.diptst1(sdata)[0]
97 |
98 |
99 | @multidimensional
100 | def ks(data):
101 | """Computes Kolmogorov-Smirnov statistic (Lilliefors modification)
102 |
103 | Parameters
104 | ----------
105 | data : array
106 | (n_vecs) input data array
107 | axis : int
108 | axis along which to compute ks
109 |
110 | Returns
111 | -------
112 | data : array or float
113 | KS statistic. If the input data is flat, returns float
114 | """
115 | mr = np.median(data)
116 | stdr = std_r(data)
117 |
118 | # avoid zero-variance
119 | if stdr == 0:
120 | return 0.
121 |
122 | return st.kstest(data, st.norm(loc=mr, scale=stdr).cdf)[0]
123 |
124 | def std(*args, **kwargs):
125 | return np.std(*args, **kwargs)
126 |
--------------------------------------------------------------------------------
/src/spike_sort/ui/manual_sort.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from matplotlib.widgets import Lasso
4 | from matplotlib.path import Path
5 | from matplotlib.colors import colorConverter
6 | from matplotlib.collections import RegularPolyCollection # , LineCollection
7 |
8 | from matplotlib.pyplot import figure
9 | from numpy import nonzero
10 |
11 | import numpy as np
12 |
13 |
14 | class LassoManager(object):
15 | def __init__(self, ax, data, labels=None, color_on='r', color_off='k', markersize=1):
16 | self.axes = ax
17 | self.canvas = ax.figure.canvas
18 | self.data = data
19 | self.call_list = []
20 |
21 | self.Nxy = data.shape[0]
22 | self.color_on = colorConverter.to_rgba(color_on)
23 | self.color_off = colorConverter.to_rgba(color_off)
24 |
25 | facecolors = [self.color_on for _ in range(self.Nxy)]
26 | fig = ax.figure
27 | self.collection = RegularPolyCollection(
28 | fig.dpi, 6, sizes=(markersize,),
29 | facecolors=facecolors,
30 | edgecolors=facecolors,
31 | offsets=data,
32 | transOffset=ax.transData)
33 |
34 | ax.add_collection(self.collection, autolim=True)
35 | ax.autoscale_view()
36 |
37 | if labels is not None:
38 | ax.set_xlabel(labels[0])
39 | ax.set_ylabel(labels[1])
40 | self.cid = self.canvas.mpl_connect('button_press_event', self.onpress)
41 | self.ind = None
42 | self.canvas.draw()
43 |
44 | def register(self, callback_func):
45 | self.call_list.append(callback_func)
46 |
47 | def callback(self, verts):
48 | facecolors = self.collection.get_facecolors()
49 | edgecolors = self.collection.get_edgecolors()
50 | ind = nonzero(Path(verts).contains_points(self.data))[0]
51 | for i in range(self.Nxy):
52 | if i in ind:
53 | facecolors[i] = self.color_on
54 | edgecolors[i] = self.color_on
55 | else:
56 | facecolors[i] = self.color_off
57 | edgecolors[i] = self.color_off
58 |
59 | self.canvas.draw_idle()
60 | self.canvas.widgetlock.release(self.lasso)
61 | del self.lasso
62 | self.ind = ind
63 |
64 | for func in self.call_list:
65 | func(ind)
66 |
67 | def onpress(self, event):
68 | if self.canvas.widgetlock.locked():
69 | return
70 | if event.inaxes is None:
71 | return
72 | self.lasso = Lasso(event.inaxes, (event.xdata, event.ydata),
73 | self.callback)
74 | # acquire a lock on the widget drawing
75 | self.canvas.widgetlock(self.lasso)
76 |
77 |
78 | def manual_sort(features_dict, feat_idx):
79 |
80 | features = features_dict['data']
81 | names = features_dict['names']
82 | if type(feat_idx[0]) is int:
83 | ii = np.array(feat_idx)
84 | else:
85 | ii = np.array([np.nonzero(names == f)[0][0] for f in feat_idx])
86 | return _cluster(features[:, ii], names[:, ii])
87 |
88 |
89 | def _cluster(data, names=None, markersize=1):
90 | fig_cluster = figure(figsize=(6, 6))
91 | ax_cluster = fig_cluster.add_subplot(111,
92 | xlim=(-0.1, 1.1),
93 | ylim=(-0.1, 1.1),
94 | autoscale_on=True)
95 | lman = LassoManager(ax_cluster, data, names, markersize=markersize)
96 |
97 | while lman.ind is None:
98 | time.sleep(.01)
99 | fig_cluster.canvas.flush_events()
100 |
101 | n_spikes = data.shape[0]
102 | clust_idx = np.zeros(n_spikes, dtype='int16')
103 | clust_idx[lman.ind] = 1
104 | return clust_idx
105 |
--------------------------------------------------------------------------------
/tests/test_beans.py:
--------------------------------------------------------------------------------
1 | from spike_beans import base
2 | from nose.tools import ok_, raises
3 | from nose import with_setup
4 |
5 |
6 | def setup():
7 | "set up test fixtures"
8 | pass
9 |
10 |
11 | def teardown():
12 | "tear down test fixtures"
13 | base.features = base.FeatureBroker()
14 |
15 |
16 | class Dummy(base.Component):
17 | con = base.RequiredFeature('Data', base.HasAttributes('data'))
18 | opt_con = base.OptionalFeature('OptionalData', base.HasAttributes('data'))
19 |
20 | def __init__(self):
21 | self.data = 0
22 | super(Dummy, self).__init__()
23 |
24 | def get_data(self):
25 | return self.con.data
26 |
27 | def get_optional_data(self):
28 | if self.opt_con:
29 | return self.opt_con.data
30 |
31 | def _update(self):
32 | self.data += 1
33 |
34 | class DummyDataWithZeroDivision(object):
35 | @property
36 | def data(self):
37 | data = 1/0
38 | return data
39 |
40 | class DummyTwoWay(Dummy):
41 | con2 = base.RequiredFeature('Data2', base.HasAttributes('get_data'))
42 |
43 | def get_data(self):
44 | return self.con.data + self.con2.get_data()
45 |
46 |
47 | class DummyDataProvider(base.Component):
48 | data = "some data"
49 |
50 |
51 | class NixProvider(base.Component):
52 | pass
53 |
54 |
55 | @with_setup(setup, teardown)
56 | def test_dependency_resolution():
57 | base.features.Provide('Data', DummyDataProvider)
58 | comp = Dummy()
59 | ok_(comp.get_data() == 'some data')
60 |
61 |
62 | @with_setup(setup, teardown)
63 | def test_diamond_dependency():
64 | base.features.Provide("Data", DummyDataProvider())
65 | base.features.Provide("Data2", Dummy())
66 | out = DummyTwoWay()
67 | data = out.get_data()
68 | base.features['Data'].update()
69 | print out.data
70 | ok_(out.data == 1)
71 |
72 |
73 | @raises(AssertionError)
74 | @with_setup(setup, teardown)
75 | def test_missing_attribute():
76 | base.features.Provide('Data', NixProvider)
77 | comp = Dummy()
78 | data = comp.get_data()
79 |
80 |
81 | @raises(AttributeError)
82 | @with_setup(setup, teardown)
83 | def test_missing_dependency():
84 | comp = Dummy()
85 | data = comp.get_data()
86 |
87 |
88 | @with_setup(setup, teardown)
89 | def test_on_change():
90 | base.features.Provide('Data', DummyDataProvider())
91 | comp = Dummy()
92 | comp.get_data()
93 | base.features['Data'].update()
94 | ok_(comp.data)
95 |
96 | @raises(ZeroDivisionError)
97 | @with_setup(setup, teardown)
98 | def test_hasattribute_exceptions():
99 | '''test whether HasAttributes lets exceptions through (other than AttributeError)'''
100 | c = DummyDataWithZeroDivision()
101 | base.features.Provide('Data', c)
102 | comp = Dummy()
103 | data = comp.get_data()
104 | assert True
105 |
106 | @with_setup(setup, teardown)
107 | def test_register_new_feature_by_setitem():
108 | base.features['Data']=DummyDataProvider()
109 | comp = Dummy()
110 | ok_(comp.get_data() == 'some data')
111 |
112 | @with_setup(setup, teardown)
113 | def test_register_new_feature_by_register():
114 | dep_comp = DummyDataProvider()
115 | added_comp = base.register("Data", dep_comp)
116 | comp = Dummy()
117 | ok_(comp.get_data() == 'some data')
118 | assert dep_comp is added_comp
119 |
120 | @with_setup(setup, teardown)
121 | def test_missing_optional_dependency():
122 | required_dep = base.register("Data", DummyDataProvider())
123 | comp = Dummy()
124 | ok_(comp.opt_con is None)
125 |
126 | @raises(ZeroDivisionError)
127 | @with_setup(setup, teardown)
128 | def test_hasattribute_exceptions_for_optional_deps():
129 | '''test whether HasAttributes raises exceptions for optional features'''
130 | required_dep = base.register("Data", DummyDataProvider())
131 | optional_dep = base.register('OptionalData', DummyDataWithZeroDivision())
132 | comp = Dummy()
133 | data = comp.get_optional_data()
134 | assert True
135 |
136 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | from spike_beans import base
2 | import numpy as np
3 |
4 | spike_dur = 5.
5 | spike_amp = 100.
6 | FS = 25E3
7 | period = 100
8 | n_spikes = 100
9 |
10 | class DummySignalSource(base.Component):
11 |
12 | def __init__(self):
13 |
14 | self.period = period
15 | self.n_spikes = n_spikes
16 | self.f_filter = None
17 | super(DummySignalSource, self).__init__()
18 |
19 | self._generate_data()
20 |
21 | def _generate_data(self):
22 | n_pts = int(self.n_spikes*self.period/1000.*FS)
23 | sp_idx = (np.arange(1,self.n_spikes-1)*self.period*FS/1000).astype(int)
24 | spikes = np.zeros(n_pts)[np.newaxis,:]
25 | spikes[0,sp_idx]=spike_amp
26 |
27 | n = int(spike_dur/1000.*FS) #spike length
28 | spikes[0,:] = np.convolve(spikes[0,:], np.ones(n), 'full')[:n_pts]
29 | self.spt = (sp_idx+0.5)*1000./FS
30 | self.FS = FS
31 | self._spikes = spikes
32 |
33 | def read_signal(self):
34 |
35 | #in milisecs
36 |
37 | spk_data ={"data":self._spikes,"n_contacts":1, "FS":self.FS}
38 | return spk_data
39 |
40 | def _update(self):
41 | self.period = period*2
42 | self.n_spikes = n_spikes/2
43 | self._generate_data()
44 |
45 | signal = property(read_signal)
46 |
47 | class DummySpikeDetector(base.Component):
48 | def __init__(self):
49 |
50 | self.threshold = 500
51 | self.type = 'min'
52 | self.contact = 0
53 | self.sp_win = [-0.6, 0.8]
54 | super(DummySpikeDetector, self).__init__()
55 | self._generate_data()
56 |
57 | def _generate_data(self):
58 | n_pts = int(n_spikes*period/1000.*FS)
59 | sp_idx = (np.arange(1,n_spikes-1)*period*FS/1000).astype(int)
60 | spt = (sp_idx+0.5)*1000./FS
61 | self._spt_data = {'data':spt}
62 |
63 | def read_events(self):
64 | return self._spt_data
65 |
66 | events = property(read_events)
67 |
68 | class DummyLabelSource(base.Component):
69 | def __init__(self):
70 | self.labels = np.random.randint(0,5, n_spikes-2)
71 |
72 | class DummyFeatureExtractor(base.Component):
73 |
74 | def __init__(self):
75 | n_feats=2
76 | features = np.vstack((
77 | np.zeros((n_spikes, n_feats)),
78 | np.ones((n_spikes, n_feats))
79 | ))
80 | names = ["Fet{0}".format(i) for i in range(n_feats)]
81 |
82 | self._features = {"data": features, "names":names}
83 |
84 | super(DummyFeatureExtractor, self).__init__()
85 |
86 | def read_features(self):
87 |
88 | return self._features
89 |
90 | def add_feature(self, name):
91 | ''' adds random feature with specifid name '''
92 |
93 | self._features['data'] = np.hstack((self._features['data'], np.random.randn(n_spikes * 2, 1)))
94 | self._features['names'].append(name)
95 |
96 | def add_spikes(self, num = 10):
97 | '''appends `num` random values to each feature'''
98 |
99 | n_features = self._features['data'].shape[1]
100 | self._features['data'] = np.vstack((self._features['data'], np.random.randn(num, n_features)))
101 |
102 | features = property(read_features)
103 |
104 | class DummySpikeSource(base.Component):
105 |
106 | def __init__(self):
107 | n_pts = 100
108 | spike_shape = np.zeros(n_pts)
109 | spike_shape[n_pts/2] = 1.
110 | data = spike_shape[:,np.newaxis, np.newaxis]*np.ones(n_spikes-2)[np.newaxis,:,np.newaxis]
111 | self._sp_waves = {'data':data, 'time':np.ones(n_pts)*1000./FS}
112 | def read_spikes(self):
113 |
114 | return self._sp_waves
115 |
116 | spikes = property(read_spikes)
117 |
118 | class RandomFeatures(base.Component):
119 | def read_features(self):
120 | n_feats=2
121 | features = np.random.randn(n_spikes, n_feats)
122 | names = ["Fet{0}".format(i) for i in range(n_feats)]
123 | return {"data": features, "names":names}
124 | features = property(read_features)
125 |
126 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask_theme_support.py:
--------------------------------------------------------------------------------
1 | # flasky extensions. flasky pygments style based on tango style
2 | from pygments.style import Style
3 | from pygments.token import Keyword, Name, Comment, String, Error, \
4 | Number, Operator, Generic, Whitespace, Punctuation, Other, Literal
5 |
6 |
7 | class FlaskyStyle(Style):
8 | background_color = "#f8f8f8"
9 | default_style = ""
10 |
11 | styles = {
12 | # No corresponding class for the following:
13 | #Text: "", # class: ''
14 | Whitespace: "underline #f8f8f8", # class: 'w'
15 | Error: "#a40000 border:#ef2929", # class: 'err'
16 | Other: "#000000", # class 'x'
17 |
18 | Comment: "italic #8f5902", # class: 'c'
19 | Comment.Preproc: "noitalic", # class: 'cp'
20 |
21 | Keyword: "bold #004461", # class: 'k'
22 | Keyword.Constant: "bold #004461", # class: 'kc'
23 | Keyword.Declaration: "bold #004461", # class: 'kd'
24 | Keyword.Namespace: "bold #004461", # class: 'kn'
25 | Keyword.Pseudo: "bold #004461", # class: 'kp'
26 | Keyword.Reserved: "bold #004461", # class: 'kr'
27 | Keyword.Type: "bold #004461", # class: 'kt'
28 |
29 | Operator: "#582800", # class: 'o'
30 | Operator.Word: "bold #004461", # class: 'ow' - like keywords
31 |
32 | Punctuation: "bold #000000", # class: 'p'
33 |
34 | # because special names such as Name.Class, Name.Function, etc.
35 | # are not recognized as such later in the parsing, we choose them
36 | # to look the same as ordinary variables.
37 | Name: "#000000", # class: 'n'
38 | Name.Attribute: "#c4a000", # class: 'na' - to be revised
39 | Name.Builtin: "#004461", # class: 'nb'
40 | Name.Builtin.Pseudo: "#3465a4", # class: 'bp'
41 | Name.Class: "#000000", # class: 'nc' - to be revised
42 | Name.Constant: "#000000", # class: 'no' - to be revised
43 | Name.Decorator: "#888", # class: 'nd' - to be revised
44 | Name.Entity: "#ce5c00", # class: 'ni'
45 | Name.Exception: "bold #cc0000", # class: 'ne'
46 | Name.Function: "#000000", # class: 'nf'
47 | Name.Property: "#000000", # class: 'py'
48 | Name.Label: "#f57900", # class: 'nl'
49 | Name.Namespace: "#000000", # class: 'nn' - to be revised
50 | Name.Other: "#000000", # class: 'nx'
51 | Name.Tag: "bold #004461", # class: 'nt' - like a keyword
52 | Name.Variable: "#000000", # class: 'nv' - to be revised
53 | Name.Variable.Class: "#000000", # class: 'vc' - to be revised
54 | Name.Variable.Global: "#000000", # class: 'vg' - to be revised
55 | Name.Variable.Instance: "#000000", # class: 'vi' - to be revised
56 |
57 | Number: "#990000", # class: 'm'
58 |
59 | Literal: "#000000", # class: 'l'
60 | Literal.Date: "#000000", # class: 'ld'
61 |
62 | String: "#4e9a06", # class: 's'
63 | String.Backtick: "#4e9a06", # class: 'sb'
64 | String.Char: "#4e9a06", # class: 'sc'
65 | String.Doc: "italic #8f5902", # class: 'sd' - like a comment
66 | String.Double: "#4e9a06", # class: 's2'
67 | String.Escape: "#4e9a06", # class: 'se'
68 | String.Heredoc: "#4e9a06", # class: 'sh'
69 | String.Interpol: "#4e9a06", # class: 'si'
70 | String.Other: "#4e9a06", # class: 'sx'
71 | String.Regex: "#4e9a06", # class: 'sr'
72 | String.Single: "#4e9a06", # class: 's1'
73 | String.Symbol: "#4e9a06", # class: 'ss'
74 |
75 | Generic: "#000000", # class: 'g'
76 | Generic.Deleted: "#a40000", # class: 'gd'
77 | Generic.Emph: "italic #000000", # class: 'ge'
78 | Generic.Error: "#ef2929", # class: 'gr'
79 | Generic.Heading: "bold #000080", # class: 'gh'
80 | Generic.Inserted: "#00A000", # class: 'gi'
81 | Generic.Output: "#888", # class: 'go'
82 | Generic.Prompt: "#745334", # class: 'gp'
83 | Generic.Strong: "bold #000000", # class: 'gs'
84 | Generic.Subheading: "bold #800080", # class: 'gu'
85 | Generic.Traceback: "bold #a40000", # class: 'gt'
86 | }
87 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = build
9 |
10 | # Internal variables.
11 | PAPEROPT_a4 = -D latex_paper_size=a4
12 | PAPEROPT_letter = -D latex_paper_size=letter
13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source
14 |
15 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest
16 |
17 | help:
18 | @echo "Please use \`make ' where is one of"
19 | @echo " html to make standalone HTML files"
20 | @echo " dirhtml to make HTML files named index.html in directories"
21 | @echo " singlehtml to make a single large HTML file"
22 | @echo " pickle to make pickle files"
23 | @echo " json to make JSON files"
24 | @echo " htmlhelp to make HTML files and a HTML help project"
25 | @echo " qthelp to make HTML files and a qthelp project"
26 | @echo " devhelp to make HTML files and a Devhelp project"
27 | @echo " epub to make an epub"
28 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
29 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
30 | @echo " text to make text files"
31 | @echo " man to make manual pages"
32 | @echo " changes to make an overview of all changed/added/deprecated items"
33 | @echo " linkcheck to check all external links for integrity"
34 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
35 |
36 | clean:
37 | -rm -rf $(BUILDDIR)/*
38 |
39 | html:
40 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
41 | @echo
42 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
43 |
44 | dirhtml:
45 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
46 | @echo
47 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
48 |
49 | singlehtml:
50 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
51 | @echo
52 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
53 |
54 | pickle:
55 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
56 | @echo
57 | @echo "Build finished; now you can process the pickle files."
58 |
59 | json:
60 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
61 | @echo
62 | @echo "Build finished; now you can process the JSON files."
63 |
64 | htmlhelp:
65 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
66 | @echo
67 | @echo "Build finished; now you can run HTML Help Workshop with the" \
68 | ".hhp project file in $(BUILDDIR)/htmlhelp."
69 |
70 | qthelp:
71 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
72 | @echo
73 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
74 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
75 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SortSpikes.qhcp"
76 | @echo "To view the help file:"
77 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SortSpikes.qhc"
78 |
79 | devhelp:
80 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
81 | @echo
82 | @echo "Build finished."
83 | @echo "To view the help file:"
84 | @echo "# mkdir -p $$HOME/.local/share/devhelp/SortSpikes"
85 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SortSpikes"
86 | @echo "# devhelp"
87 |
88 | epub:
89 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
90 | @echo
91 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
92 |
93 | latex:
94 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
95 | @echo
96 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
97 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
98 | "(use \`make latexpdf' here to do that automatically)."
99 |
100 | latexpdf:
101 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
102 | @echo "Running LaTeX files through pdflatex..."
103 | make -C $(BUILDDIR)/latex all-pdf
104 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
105 |
106 | text:
107 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
108 | @echo
109 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
110 |
111 | man:
112 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
113 | @echo
114 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
115 |
116 | changes:
117 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
118 | @echo
119 | @echo "The overview file is in $(BUILDDIR)/changes."
120 |
121 | linkcheck:
122 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
123 | @echo
124 | @echo "Link check complete; look for any errors in the above output " \
125 | "or in $(BUILDDIR)/linkcheck/output.txt."
126 |
127 | doctest:
128 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
129 | @echo "Testing of doctests in the sources finished, look at the " \
130 | "results in $(BUILDDIR)/doctest/output.txt."
131 |
--------------------------------------------------------------------------------
/src/spike_analysis/basic.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | def gettrains(spt, stim, win, binsz):
5 | i = np.searchsorted(stim, spt)
6 | spt2 = spt - stim[i - 1]
7 | bin = np.arange(win[0], win[1], binsz)
8 | bin = np.concatenate((bin, [bin[-1] + binsz, np.inf]))
9 | j = np.searchsorted(bin, spt2)
10 | npoints = len(bin) - 1
11 | ntrials = len(stim)
12 | trains = np.zeros((npoints, ntrials))
13 | trains[j - 1, i - 1] = 1
14 | return trains[:-1, :]
15 |
16 | def SortSpikes(spt, stim, win=None):
17 | """Given spike and stimuli times return 2D array with spike trains.
18 | If win is given only spikes occuring in the time window are
19 | returned.
20 | """
21 | i = np.searchsorted(stim, spt)
22 | spt2 = spt - stim[i - 1]
23 | if win:
24 | corrected = filter(lambda x: win[1] > x[0] >= win[0], zip(spt2, i))
25 | spt2 = np.array([x[0] for x in corrected])
26 | i = np.array([x[1] for x in corrected])
27 | return [spt2[i == j] for j in xrange(1, len(stim) + 1)]
28 |
29 | def plotraster(spt, stim, win=[0, 30], ntrials=None, ax=None, height=1.0):
30 | """Creates raster plots of spike trains:
31 |
32 | spt - times of spike occurance,
33 | stim - stimulus times
34 | win - range of time axis
35 | ntrials - number of trials to plot (if None plot all)
36 | """
37 | if not ntrials:
38 | ntrials = len(stim) - 1
39 |
40 | if not ax:
41 | ax = plt.gca()
42 |
43 | spt2 = spt[spt < stim[ntrials]].copy()
44 | i = np.searchsorted(stim[:ntrials], spt2)
45 | spt2 -= stim[i - 1]
46 |
47 | plt.vlines(spt2, i, i + height)
48 | plt.xlim(win)
49 | plt.ylim((1, ntrials))
50 | plt.xlabel('time (ms)')
51 | plt.ylabel('trials')
52 |
53 | def plottrains(trains, win=[0, 30], ntrials=None, height=1.0):
54 | print "Deprecation: Please use plotRasterTrains insted"
55 | plotRasterTrains(trains, win, ntrials, height)
56 |
57 | def plotRasterTrains(trains, win=[0, 30], ntrials=None, height=1.0):
58 | """Creates raster plots of spike trains:
59 |
60 | spt - times of spike occurance,
61 | stim - stimulus times
62 | win - range of time axis
63 | ntrials - number of trials to plot (if None plot all)
64 | """
65 | if ntrials:
66 | trains = trains[:ntrials]
67 |
68 | ax = plt.gca()
69 | ax.set_xlim(win)
70 | ax.set_ylim((1, ntrials))
71 |
72 | lines = [plt.vlines(sp, np.ones(len(sp)) * i,
73 | np.ones(len(sp)) * (i + height))
74 | for i, sp in enumerate(trains) if len(sp)]
75 | plt.xlabel('time (ms)')
76 | plt.ylabel('trials')
77 | return lines
78 |
79 | def BinTrains(trains, win, tsamp=0.25):
80 | """Convert a list of spike trains into a binary sequence"""
81 | bins = np.arange(win[0], win[1], tsamp)
82 | trains = [np.histogram(spt, bins)[0][1:] for spt in trains]
83 | return bins[1:-1], np.array(trains).T
84 |
85 | def plotPSTH(spt, stim, win=[0, 30], bin=0.25, ax=None,
86 | rate=False, **kwargs):
87 | """Plot peri-stimulus time histogram (PSTH)"""
88 | i = np.searchsorted(stim + win[0], spt)
89 | spt2 = spt - stim[i - 1]
90 | bins = np.arange(win[0], win[1], bin)
91 | psth, bins = np.histogram(spt2, bins)
92 |
93 | if not ax:
94 | ax = plt.gca()
95 |
96 | if rate:
97 | psth = psth * 1.0 / len(stim) / bin * 1000.0
98 | ax.set_ylabel('firing rate (Hz)')
99 | else:
100 | ax.set_ylabel('number of spikes')
101 |
102 | lines = ax.plot(bins[:-1], psth, **kwargs)
103 | ax.set_xlabel('time (ms)')
104 | return lines
105 |
106 | def plotTrainsPSTH(trains, win, bin=0.25, rate=False, **kwargs):
107 | """Plot peri-stimulus time histogram (PSTH) from a list of spike times
108 | (trains)"""
109 | ax = plt.gca()
110 | time, binned = BinTrains(trains, win, bin)
111 | psth = np.mean(binned, 1) / bin * 1000
112 | plt.plot(time, psth, **kwargs)
113 | if rate:
114 | psth = psth * 1.0 / len(trains) / bin * 1000.0
115 | ax.set_ylabel('firing rate (Hz)')
116 | else:
117 | ax.set_ylabel('number of spikes')
118 | ax.set_xlabel('time (ms)')
119 | ax.set_ylabel('firing rate (Hz)')
120 |
121 | def CalcTrainsPSTH(trains, win, bin=0.25):
122 | time, binned = BinTrains(trains, win, bin)
123 | psth = np.mean(binned, 1) / bin * 1000
124 | return time, psth
125 |
126 | def CalcPSTH(spt, stim, win=[0, 30], bin=0.25, ax=None, norm=False, **kwargs):
127 | """Calculate peri-stimulus time histogram (PSTH).
128 | Output:
129 | -- psth - spike counts
130 | -- bins - bins edges"""
131 | i = np.searchsorted(stim + win[0], spt)
132 | spt2 = spt - stim[i - 1]
133 | bins = np.arange(win[0], win[1], bin)
134 | psth, bins = np.histogram(spt2, bins)
135 | if norm:
136 | psth = psth * 1000.0 / (len(stim) * bin)
137 | return psth[1:], bins[1:-1]
138 |
139 | def plotPSTHBar(spt, stim, win=[0, 30], bin=0.25, **kwargs):
140 | """Plot peri-stimulus time histogram (PSTH)"""
141 | i = np.searchsorted(stim + win[0], spt)
142 | spt2 = spt - stim[i - 1]
143 | bins = np.arange(win[0], win[1], bin)
144 | psth, bins = np.histogram(spt2, bins)
145 | psth = psth * 1.0 / len(stim) / bin * 1000.0
146 | plt.gca().bar(bins[:-1], psth, bin, **kwargs)
147 | plt.xlabel('time (ms)')
148 | plt.ylabel('firing rate (Hz)')
149 |
--------------------------------------------------------------------------------
/src/spike_sort/core/filters.py:
--------------------------------------------------------------------------------
1 | import tables
2 | import tempfile
3 | import numpy as np
4 | from scipy import signal
5 | import os
6 | import atexit
7 | _open_files = {}
8 |
9 |
10 | class ZeroPhaseFilter(object):
11 | """IIR Filter with zero phase delay"""
12 | def __init__(self, ftype, fband, tw=200., stop=20):
13 | self.gstop = stop
14 | self.gpass = 1
15 | self.fband = fband
16 | self.tw = tw
17 | self.ftype = ftype
18 | self._coefs_cache = {}
19 |
20 | def _design_filter(self, FS):
21 |
22 | if not FS in self._coefs_cache:
23 | wp = np.array(self.fband)
24 | ws = wp + np.array([-self.tw, self.tw])
25 | wp, ws = wp * 2.0 / FS, ws * 2.0 / FS
26 | b, a = signal.iirdesign(wp=wp,
27 | ws=ws,
28 | gstop=self.gstop,
29 | gpass=self.gpass,
30 | ftype=self.ftype)
31 | self._coefs_cache[FS] = (b, a)
32 | else:
33 | b, a = self._coefs_cache[FS]
34 | return b, a
35 |
36 | def __call__(self, x, FS):
37 | b, a = self._design_filter(FS)
38 | return signal.filtfilt(b, a, x)
39 |
40 |
41 | class FilterFir(object):
42 | """FIR filter with zero phase delay
43 |
44 | Attributes
45 | ----------
46 | f_pass : float
47 | normalised low-cutoff frequency
48 |
49 | f_stop : float
50 | normalised high-cutoff frequency
51 |
52 | order : int
53 | filter order
54 |
55 | """
56 | def __init__(self, f_pass, f_stop, order):
57 | self._coefs_cache = {}
58 | self.fp = f_pass
59 | self.fs = f_stop
60 | self.order = order
61 |
62 | def _design_filter(self, FS):
63 | if not FS in self._coefs_cache:
64 | bands = [0, min(self.fs, self.fp), max(self.fs, self.fp), FS / 2]
65 | gains = [int(self.fp < self.fs), int(self.fp > self.fs)]
66 | b, a = signal.remez(self.order, bands, gains, Hz=FS), [1]
67 | self._coefs_cache[FS] = (b, a)
68 | else:
69 | b, a = self._coefs_cache[FS]
70 | return b, a
71 |
72 | def __call__(self, x, FS):
73 | b, a = self._design_filter(FS)
74 | return signal.filtfilt(b, a, x)
75 |
76 |
77 | class Filter(object):
78 | def __init__(self, fpass, fstop, gpass=1, gstop=10, ftype='butter'):
79 | self.ftype = ftype
80 | self.fp = np.asarray(fpass)
81 | self.fs = np.asarray(fstop)
82 | self._coefs_cache = {}
83 | self.gstop = gstop
84 | self.gpass = gpass
85 |
86 | def _design_filter(self, FS):
87 | if not FS in self._coefs_cache:
88 | wp, ws = self.fp * 2 / FS, self.fs * 2 / FS
89 | b, a = signal.iirdesign(wp=wp,
90 | ws=ws,
91 | gstop=self.gstop,
92 | gpass=self.gpass,
93 | ftype=self.ftype)
94 | self._coefs_cache[FS] = (b, a)
95 | else:
96 | b, a = self._coefs_cache[FS]
97 | return b, a
98 |
99 | def __call__(self, x, FS):
100 | b, a = self._design_filter(FS)
101 | return signal.filtfilt(b, a, x)
102 |
103 |
104 | def filter_proxy(spikes, filter_obj, chunksize=1E6):
105 | """Proxy object to read filtered data
106 |
107 | Parameters
108 | ----------
109 | spikes : dict
110 | unfiltered raw recording
111 | filter_object : object
112 | Filter to filter the data
113 | chunksize : int
114 | size of segments in which data is filtered
115 |
116 | Returns
117 | -------
118 | sp_dict : dict
119 | filtered recordings
120 | """
121 | data = spikes['data']
122 | sp_dict = spikes.copy()
123 |
124 | if filter_obj is None:
125 | return spikes
126 |
127 | filename = tempfile.mktemp(suffix='.h5')
128 | atom = tables.Atom.from_dtype(np.dtype('float64'))
129 | shape = data.shape
130 | h5f = tables.openFile(filename, 'w')
131 | carray = h5f.createCArray('/', 'test', atom, shape)
132 |
133 | _open_files[filename] = h5f
134 |
135 | chunksize = int(chunksize)
136 | n_chunks = int(np.ceil(shape[1] * 1.0 / chunksize))
137 | for i in range(shape[0]):
138 | for j in range(n_chunks):
139 | stop = int(np.min(((j + 1) * chunksize, shape[1])))
140 | carray[i, j * chunksize:stop] = filter_obj(
141 | data[i, j * chunksize:stop], sp_dict['FS'])
142 | sp_dict['data'] = carray
143 | return sp_dict
144 |
145 |
146 | def fltLinearIIR(signal, fpass, fstop, gpass=1, gstop=10, ftype='butter'):
147 | """An IIR acausal linear filter. Works through
148 | spike_sort.core.filters.filter_proxy method
149 |
150 | Parameters
151 | ----------
152 | signal : dict
153 | input [raw] signal
154 | fpass, fstop : float
155 | Passband and stopband edge frequencies [Hz]
156 | For more details see scipy.signal.iirdesign
157 | gpass : float
158 | The maximum loss in the passband (dB).
159 | gstop : float
160 | The minimum attenuation in the stopband (dB).
161 | ftype : str, optional
162 | The type of IIR filter to design:
163 |
164 | - elliptic : 'ellip'
165 | - Butterworth : 'butter',
166 | - Chebyshev I : 'cheby1',
167 | - Chebyshev II: 'cheby2',
168 | - Bessel : 'bessel'
169 | """
170 | filt = Filter(fpass, fstop, gpass, gstop, ftype)
171 | return filter_proxy(signal, filt)
172 |
173 |
174 | def clean_after_exit():
175 | for fname, fid in _open_files.items():
176 | fid.close()
177 | try:
178 | os.remove(fname)
179 | except OSError:
180 | pass
181 | _open_files.clear()
182 |
183 | atexit.register(clean_after_exit)
184 |
--------------------------------------------------------------------------------
/src/spike_sort/ui/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 | from matplotlib.collections import LineCollection
7 |
8 | # used by other modules which import plotting.py. Do not remove!
9 | from matplotlib.pyplot import show, figure, close
10 |
11 | import spike_sort
12 | cmap = plt.cm.jet
13 |
14 | import _mpl_helpers # performance boosters
15 |
16 |
17 | def label_color(labels):
18 | """Map labels to number range [0, 1]"""
19 |
20 | num_labels = np.linspace(0, 1., len(labels))
21 | mapper = dict(zip(labels, num_labels))
22 |
23 | @np.vectorize
24 | def map_func(lab):
25 | return mapper[lab]
26 |
27 | def color_func(lab):
28 | return cmap(map_func(lab))
29 | return color_func
30 |
31 |
32 | def plot_spikes(spikes, clust_idx=None, show_cells='all', **kwargs):
33 | """Plot Spike waveshapes
34 |
35 | Parameters
36 | ----------
37 | spike_data : dict
38 | clust_idx : sequence
39 | sequence of the length equal to the number of spikes; labels of
40 | clusters to which spikes belong
41 | show_cells : list or 'all'
42 | list of identifiers of clusters (cells) to plot
43 | plot_avg: bool
44 | if True plot waveform averages
45 |
46 | Returns
47 | -------
48 | lines_segments : object
49 | matplotlib line collection of spike waveshapes
50 | """
51 |
52 | if clust_idx is None:
53 | spikegraph(spikes, **kwargs)
54 | else:
55 | spikes_cell = spike_sort.extract.split_cells(spikes, clust_idx)
56 |
57 | if show_cells == 'all':
58 | labs = spikes_cell.keys()
59 | else:
60 | labs = show_cells
61 |
62 | color_func = label_color(spikes_cell.keys())
63 | for l in labs:
64 | spikegraph(spikes_cell[l], color_func(l), **kwargs)
65 |
66 |
67 | def spikegraph(spike_data, color='k', alpha=0.2, n_spikes='all',
68 | contacts='all', plot_avg=True, fig=None):
69 |
70 | spikes = spike_data['data']
71 | time = spike_data['time']
72 |
73 | if contacts == 'all':
74 | contacts = np.arange(spikes.shape[2])
75 |
76 | n_pts = len(time)
77 |
78 | if not n_spikes == 'all':
79 | sample_idx = np.argsort(np.random.randn(spikes.shape[1]))[:n_spikes]
80 | spikes = spikes[:, sample_idx, :]
81 | n_spikes = spikes.shape[1]
82 | if fig is None:
83 | fig = plt.gcf()
84 | line_segments = []
85 | for i, contact_id in enumerate(contacts):
86 | ax = fig.add_subplot(2, 2, i + 1)
87 | # ax.set_xlim(time.min(), time.max())
88 | # ax.set_ylim(spikes.min(), spikes.max())
89 |
90 | segs = np.zeros((n_spikes, n_pts, 2))
91 | segs[:, :, 0] = time[np.newaxis, :]
92 | segs[:, :, 1] = spikes[:, :, contact_id].T
93 | collection = LineCollection(segs, colors=color,
94 | alpha=alpha)
95 | line_segments.append(collection)
96 | ax.add_collection(collection, autolim=True)
97 |
98 | if plot_avg:
99 | spikes_mean = spikes[:, :, i].mean(1)
100 | ax.plot(time, spikes_mean, color='w', lw=3)
101 | ax.plot(time, spikes_mean, color=color, lw=2)
102 | ax.autoscale_view(tight=True)
103 |
104 | return line_segments
105 |
106 |
107 | def plot_features(features, clust_idx=None, show_cells='all', **kwargs):
108 | """Plot features and their histograms
109 |
110 | Parameters
111 | ----------
112 | features_dict : dict
113 | features data structure
114 | clust_idx : array or None
115 | array of size (n_spikes,) containing indices of clusters to which
116 | each spike was classfied
117 | show_cells : list or 'all'
118 | list of identifiers of clusters (cells) to plot
119 |
120 | """
121 |
122 | if clust_idx is None:
123 | featuresgraph(features, **kwargs)
124 | else:
125 | features_cell = spike_sort.features.split_cells(features, clust_idx)
126 |
127 | if show_cells == 'all':
128 | labs = features_cell.keys()
129 | else:
130 | labs = show_cells
131 |
132 | color_func = label_color(features_cell.keys())
133 | for l in labs:
134 | featuresgraph(features_cell[l], color_func(l), **kwargs)
135 |
136 |
137 | def featuresgraph(features_dict, color='k', size=1, datarange=None, fig=None, n_spikes='all'):
138 |
139 | features = features_dict['data']
140 | names = features_dict['names']
141 |
142 | _, n_feats = features.shape
143 | if fig is None:
144 | fig = plt.gcf()
145 | axes = [[fig.add_subplot(n_feats, n_feats, i * n_feats + j + 1, projection='thin')
146 | for i in range(n_feats)] for j in range(n_feats)]
147 |
148 | if not n_spikes == 'all':
149 | sample_idx = np.argsort(np.random.randn(features.shape[0]))[:n_spikes]
150 | features = features[sample_idx, :]
151 |
152 | for i in range(n_feats):
153 | for j in range(n_feats):
154 | ax = axes[i][j]
155 | if i != j:
156 | ax.plot(features[:, i],
157 | features[:, j], ".",
158 | color=color, markersize=size)
159 | if datarange:
160 | ax.set_xlim(datarange)
161 |
162 | else:
163 | ax.set_frame_on(False)
164 | n, bins = np.histogram(features[:, i], 20, datarange, normed=True)
165 | ax.plot(bins[:-1], n, '-', color=color, drawstyle='steps')
166 | if datarange:
167 | ax.set_xlim(datarange)
168 | ax.set_xticks([])
169 | ax.set_yticks([])
170 | ax.set_xlabel(names[i])
171 | ax.set_ylabel(names[j])
172 | ax.xaxis.set_label_position("top")
173 | ax.xaxis.label.set_visible(False)
174 | ax.yaxis.label.set_visible(False)
175 |
176 | for i in range(n_feats):
177 | ax = axes[i][0]
178 | ax.xaxis.label.set_visible(True)
179 | ax = axes[0][i]
180 | ax.yaxis.label.set_visible(True)
181 |
182 |
183 | def legend(labels, colors=None, ax=None):
184 |
185 | if ax is None:
186 | ax = plt.gca()
187 | if colors is None:
188 | color_func = label_color(labels)
189 | colors = [color_func(i) for i in labels]
190 |
191 | ax.set_frame_on(False)
192 | n_classes = len(labels)
193 | x, y = np.zeros(n_classes) + 0.4, 0.1 * np.arange(n_classes)
194 | ax.scatter(x, y, c=colors, marker='s', edgecolors="none", s=100)
195 | ax.set_xlim([0, 1])
196 | ax.set_xticks([])
197 | ax.set_yticks([])
198 |
199 | for i, l in enumerate(labels):
200 | ax.text(x[i] + 0.1, y[i], "Cell {0}".format(l), va='center', ha='left',
201 | transform=ax.transData)
202 |
--------------------------------------------------------------------------------
/src/spike_sort/stats/diptst/diptst.f:
--------------------------------------------------------------------------------
1 | SUBROUTINE DIPTST1(X,N,DIP,XL,XU,IFAULT,GCM,LCM,MN,MJ,DDX,DDXSGN)
2 | C
3 | C ALGORITHM AS 217 APPL. STATIST. (1985) VOL.34, NO.3
4 | C
5 | C Does the dip calculation for an ordered vector X using the
6 | C greatest convex minorant and the least concave majorant, skipping
7 | C through the data using the change points of these distributions.
8 | C It returns the dip statistic 'DIP' and the modal interval
9 | C (XL, XU).
10 | C
11 | C MODIFICATIONS SEP 2 2002 BY F. MECHLER TO FIX PROBLEMS WITH
12 | C UNIMODAL (INCLUDING MONOTONIC) INPUT
13 | C
14 | REAL X(N)
15 | INTEGER MN(N), MJ(N), LCM(N), GCM(N), HIGH
16 | REAL ZERO, HALF, ONE
17 | C NEXT TWO LINES ARE ADDED
18 | REAL DDX(N)
19 | INTEGER DDXSGN(N), POSK, NEGK
20 | DATA ZERO/0.0/, HALF/0.5/, ONE/1.0/
21 | C
22 | IFAULT = 1
23 | IF (N .LE. 0) RETURN
24 | IFAULT = 0
25 | C
26 | C Check if N = 1
27 | C
28 | IF (N .EQ. 1) GO TO 4
29 | C
30 | C Check that X is sorted
31 | C
32 | IFAULT = 2
33 | DO 3 K = 2, N
34 | IF (X(K) .LT. X(K-1)) RETURN
35 | 3 CONTINUE
36 | IFAULT = 0
37 | C
38 | C Check for all values of X identical,
39 | C and for 1 < N < 4.
40 | C
41 | IF (X(N) .GT. X(1) .AND. N .GE. 4) GO TO 505
42 | 4 XL = X(1)
43 | XU = X(N)
44 | DIP = ZERO
45 | RETURN
46 | C The code amendment below is intended to be inseted above the line marked "5" in the original FORTRAN code
47 | C The amendment checks the condition whether the input X is perfectly unimodal
48 | C Hartigan's original DIPTST algorithm did not check for this condition
49 | C and DIPTST runs into an infinite cycle for a unimodal input
50 | C The condition that the input is unimodal is equivalent to having
51 | C at most 1 sign change in the second derivative of the input p.d.f.
52 | C In MATLAB syntax, we check the flips in the function xsign=-sign(diff(1./diff(x)))=-sign(diff(diff(x)));
53 | C with DDXSGN=xsign in the fortran code below
54 | 505 NEGK=0
55 | POSK=0
56 | DO 104 K = 3,N
57 | DDX(K) = X(K)+X(K-2)-2*X(K-1)
58 | IF (DDX(K) .LT. 0) DDXSGN(K) = 1
59 | IF (DDX(K) .EQ. 0) DDXSGN(K) = 0
60 | IF (DDX(K) .GT. 0) DDXSGN(K) = -1
61 | IF (DDXSGN(K) .GT. 0) POSK = K
62 | IF ((DDXSGN(K) .LT. 0) .AND. (NEGK .EQ. 0)) NEGK = K
63 | 104 CONTINUE
64 |
65 | C The condition check below examines whether the greatest position with a positive second derivative
66 | C is smaller than the smallest position with a negative second derivative
67 | C The boolean check gets it right even if
68 | C the unimodal p.d.f. has its mode in the very first or last point of the input
69 |
70 | IF ((POSK .GT. NEGK) .AND. (NEGK .GT. 0)) GOTO 5
71 | XL=X(1)
72 | XU=X(N)
73 | DIP=0
74 | IFAULT=5
75 | RETURN
76 | C
77 | C LOW contains the index of the current estimate of the lower end
78 | C of the modal interval, HIGH contains the index for the upper end.
79 | C
80 | 5 FN = FLOAT(N)
81 | LOW = 1
82 | HIGH = N
83 | DIP = ONE / FN
84 | XL = X(LOW)
85 | XU = X(HIGH)
86 | C
87 | C Establish the indices over which combination is necessary for the
88 | C convex minorant fit.
89 | C
90 | MN(1) = 1
91 | DO 28 J = 2, N
92 | MN(J) = J - 1
93 | 25 MNJ = MN(J)
94 | MNMNJ = MN(MNJ)
95 | A = FLOAT(MNJ - MNMNJ)
96 | B = FLOAT(J - MNJ)
97 | IF (MNJ .EQ. 1 .OR. (X(J) - X(MNJ))*A .LT. (X(MNJ) - X(MNMNJ))
98 | + *B) GO TO 28
99 | MN(J) = MNMNJ
100 | GO TO 25
101 | 28 CONTINUE
102 | C
103 | C Establish the indices over which combination is necessary for the
104 | C concave majorant fit.
105 | C
106 | MJ(N) = N
107 | NA = N - 1
108 | DO 34 JK = 1, NA
109 | K = N - JK
110 | MJ(K) = K + 1
111 | 32 MJK = MJ(K)
112 | MJMJK = MJ(MJK)
113 | A = FLOAT(MJK - MJMJK)
114 | B = FLOAT(K - MJK)
115 | IF (MJK .EQ. N .OR. (X(K) - X(MJK))*A .LT. (X(MJK) - X(MJMJK))
116 | + *B) GO TO 34
117 | MJ(K) = MJMJK
118 | GO TO 32
119 | 34 CONTINUE
120 | C
121 | C Start the cycling.
122 | C Collect the change points for the GCM from HIGH to LOW.
123 | C
124 | 40 IC = 1
125 | GCM(1) = HIGH
126 | 42 IGCM1 = GCM(IC)
127 | IC = IC + 1
128 | GCM(IC) = MN(IGCM1)
129 | IF (GCM(IC) .GT. LOW) GO TO 42
130 | ICX = IC
131 | C
132 | C Collect the change points for the LCM from LOW to HIGH.
133 | C
134 | IC = 1
135 | LCM(1) = LOW
136 | 44 LCM1 = LCM(IC)
137 | IC = IC + 1
138 | LCM(IC) = MJ(LCM1)
139 | IF (LCM(IC) .LT. HIGH) GO TO 44
140 | ICV = IC
141 | C
142 | C ICX, IX, IG are counters for the convex minorant,
143 | C ICV, IV, IH are counters for the concave majorant.
144 | C
145 | IG = ICX
146 | IH = ICV
147 | C
148 | C Find the largest distance greater than 'DIP' between the GCM and
149 | C the LCM from LOW to HIGH.
150 | C
151 | IX = ICX - 1
152 | IV = 2
153 | D = ZERO
154 | IF (ICX .NE. 2 .OR. ICV .NE. 2) GO TO 50
155 | D = ONE / FN
156 | GO TO 65
157 | 50 IGCMX = GCM(IX)
158 | LCMIV = LCM(IV)
159 | IF (IGCMX .GT. LCMIV) GO TO 55
160 | C
161 | C If the next point of either the GCM or LCM is from the LCM,
162 | C calculate the distance here.
163 | C
164 | LCMIV1 = LCM(IV - 1)
165 | A = FLOAT(LCMIV - LCMIV1)
166 | B = FLOAT(IGCMX - LCMIV1 - 1)
167 | DX = (X(IGCMX) - X(LCMIV1))*A / (FN*(X(LCMIV) - X(LCMIV1)))
168 | + - B / FN
169 | IX = IX - 1
170 | IF (DX .LT. D) GO TO 60
171 | D = DX
172 | IG = IX + 1
173 | IH = IV
174 | GO TO 60
175 | C
176 | C If the next point of either the GCM or LCM is from the GCM,
177 | C calculate the distance here.
178 | C
179 | 55 LCMIV = LCM(IV)
180 | IGCM = GCM(IX)
181 | IGCM1 = GCM(IX + 1)
182 | A = FLOAT(LCMIV - IGCM1 + 1)
183 | B = FLOAT(IGCM - IGCM1)
184 | DX = A / FN - ((X(LCMIV) - X(IGCM1))*B) / (FN * (X(IGCM)
185 | + - X(IGCM1)))
186 | IV = IV + 1
187 | IF (DX .LT. D) GO TO 60
188 | D = DX
189 | IG = IX + 1
190 | IH = IV - 1
191 | 60 IF (IX .LT. 1) IX = 1
192 | IF (IV .GT. ICV) IV = ICV
193 | IF (GCM(IX) .NE. LCM(IV)) GO TO 50
194 | 65 IF (D .LT. DIP) GO TO 100
195 | C
196 | C Calculate the DIPs for the current LOW and HIGH.
197 | C
198 | C The DIP for the convex minorant.
199 | C
200 | DL = ZERO
201 | IF (IG .EQ. ICX) GO TO 80
202 | ICXA = ICX - 1
203 | DO 76 J = IG, ICXA
204 | TEMP = ONE / FN
205 | JB = GCM(J + 1)
206 | JE = GCM(J)
207 | IF (JE - JB .LE. 1) GO TO 74
208 | IF (X(JE) .EQ. X(JB)) GO TO 74
209 | A = FLOAT(JE - JB)
210 | CONST = A / (FN * (X(JE) - X(JB)))
211 | DO 72 JR = JB, JE
212 | B = FLOAT(JR - JB + 1)
213 | T = B / FN - (X(JR) - X(JB))*CONST
214 | IF (T .GT. TEMP) TEMP = T
215 | 72 CONTINUE
216 | 74 IF (DL .LT. TEMP) DL = TEMP
217 | 76 CONTINUE
218 | C
219 | C The DIP for the concave majorant.
220 | C
221 | 80 DU = ZERO
222 | IF (IH .EQ. ICV) GO TO 90
223 | ICVA = ICV - 1
224 | DO 88 K = IH, ICVA
225 | TEMP = ONE / FN
226 | KB = LCM(K)
227 | KE = LCM(K + 1)
228 | IF (KE - KB .LE. 1) GO TO 86
229 | IF (X(KE) .EQ. X(KB)) GO TO 86
230 | A = FLOAT(KE - KB)
231 | CONST = A / (FN * (X(KE) - X(KB)))
232 | DO 84 KR = KB, KE
233 | B = FLOAT(KR - KB - 1)
234 | T = (X(KR) - X(KB))*CONST - B / FN
235 | IF (T .GT. TEMP) TEMP = T
236 | 84 CONTINUE
237 | 86 IF (DU .LT. TEMP) DU = TEMP
238 | 88 CONTINUE
239 | C
240 | C Determine the current maximum.
241 | C
242 | 90 DIPNEW = DL
243 | IF (DU .GT. DL) DIPNEW = DU
244 | IF (DIP .LT. DIPNEW) DIP = DIPNEW
245 | LOW = GCM(IG)
246 | HIGH = LCM(IH)
247 | C
248 | C Recycle
249 | C
250 | GO TO 40
251 | C
252 | 100 DIP = HALF * DIP
253 | XL = X(LOW)
254 | XU = X(HIGH)
255 | C
256 | RETURN
257 | END
258 |
259 |
--------------------------------------------------------------------------------
/src/spike_sort/core/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import numpy as np
5 | import extract
6 | import cluster
7 | import warnings
8 |
9 |
10 | def deprecation(message):
11 | warnings.warn(message, DeprecationWarning, stacklevel=2)
12 |
13 |
14 | def snr_spike(spike_waves, scale=5.):
15 | """Estimate signal-to-noise ratio (SNR) as a ratio of
16 | peak-to-peak amplitude of an average spike to the std. deviation
17 | of residuals
18 |
19 | Parameters
20 | ----------
21 | spike_waves : dict
22 |
23 | Returns
24 | -------
25 | snr : float
26 | signal to noise ratio
27 | """
28 | sp_data = spike_waves['data']
29 | avg_spike = sp_data.mean(axis=1)
30 | peak_to_peak = np.ptp(avg_spike, axis=None)
31 | residuals = sp_data - avg_spike[:, np.newaxis]
32 | noise_std = np.sqrt(residuals.var())
33 | snr = peak_to_peak / (noise_std * scale)
34 | return snr
35 |
36 |
37 | def snr_clust(spike_waves, noise_waves):
38 | """Calculate signal-to-noise ratio.
39 |
40 | Comparing average P2P amplitude of spike cluster to noise cluster
41 | (randomly selected signal segments)
42 |
43 | Parameters
44 | ----------
45 | spike_waves : dict
46 | noise_waves : dict
47 |
48 | Returns
49 | -------
50 | snr : float
51 | signal-to-noise ratio
52 | """
53 | _calc_p2p = lambda data: np.ptp(data, axis=0).mean()
54 |
55 | sp_data = spike_waves['data']
56 | avg_p2p_spk = _calc_p2p(sp_data)
57 |
58 | noise_data = noise_waves['data']
59 | avg_p2p_ns = _calc_p2p(noise_data)
60 |
61 | snr = avg_p2p_spk / avg_p2p_ns
62 |
63 | return snr
64 |
65 |
66 | def extract_noise_cluster(sp, spt, sp_win, type="positive"):
67 |
68 | deprecation("extract_noise_cluster deprecated. Use"
69 | " detect_noise and extract_spikes instead.")
70 | spt_noise = detect_noise(sp, spt, sp_win, type)
71 | sp_waves = extract.extract_spikes(sp, spt_noise, sp_win)
72 |
73 | return sp_waves
74 |
75 |
76 | def rand_sample_spt(spt, max_spikes):
77 | n_spikes = len(spt['data'])
78 | spt_data = spt['data']
79 | spt_new = spt.copy()
80 | if max_spikes and n_spikes > max_spikes:
81 | i = np.random.rand(n_spikes).argsort()[:max_spikes]
82 | spt_new['data'] = spt_data[i]
83 | return spt_new
84 |
85 |
86 | def detect_noise(sp, spt, sp_win, type="positive", max_spikes=None,
87 | resample=1):
88 | """Find noisy spikes"""
89 |
90 | spike_waves = extract.extract_spikes(sp, spt, sp_win)
91 |
92 | if type == "positive":
93 | threshold = calc_noise_threshold(spike_waves, 1)
94 | spt_noise = extract.detect_spikes(sp, threshold, 'rising')
95 | spt_noise = rand_sample_spt(spt_noise, max_spikes)
96 | spt_noise = extract.remove_spikes(spt_noise, spt, sp_win)
97 | spt_noise = extract.align_spikes(sp, spt_noise, sp_win, 'max',
98 | resample=resample)
99 | else:
100 | threshold = calc_noise_threshold(spike_waves, -1)
101 | spt_noise = extract.detect_spikes(sp, threshold, 'falling')
102 | spt_noise = rand_sample_spt(spt_noise, max_spikes)
103 | spt_noise = extract.remove_spikes(spt_noise, spt, sp_win)
104 | spt_noise = extract.align_spikes(sp, spt_noise, sp_win, 'min',
105 | resample=resample)
106 | return spt_noise
107 |
108 |
109 | def calc_noise_threshold(spike_waves, sign=1, frac_spikes=0.02, frac_max=0.5):
110 | """ Find threshold to extract noise cluster.
111 |
112 | According to algorithm described in Joshua et al. (2007)
113 |
114 | Parameters
115 | ----------
116 | spike_waves : dict
117 | waveshapes of spikes from the identified single unit (extracted
118 | with extract.extrac_spikes)
119 |
120 | sign : int
121 | sign should be negative for negative-going spikes and positive for
122 | postitive-going spikes. Note that only sign of this number is
123 | taken into account.
124 |
125 | frac_spikes : float, optional
126 | fraction of largest (smallest) spikes to calculate the threshold
127 | from (default 0.02)
128 |
129 | frac_max : float
130 | fraction of the average peak amplitude to use as a treshold
131 |
132 |
133 | Returns
134 | -------
135 | threshold : float
136 | threshold to obtain a noise cluster"""
137 |
138 | gain = np.sign(sign)
139 |
140 | peak_amp = np.max(gain*spike_waves['data'], 0)
141 | frac_spikes = 0.02
142 | frac_max = 0.5
143 | peak_amp.sort()
144 | threshold = frac_max*np.mean(peak_amp[:int(frac_spikes*len(peak_amp))])
145 | threshold *= gain
146 |
147 | return threshold
148 |
149 |
150 | def isolation_score(sp, spt, sp_win, spike_type='positive', lam=10.,
151 | max_spikes=None):
152 | "calculate spike isolation score from raw data and spike times"
153 |
154 | spike_waves = extract.extract_spikes(sp, spt, sp_win)
155 | spt_noise = detect_noise(sp, spt, sp_win, spike_type)
156 | noise_waves = extract.extract_spikes(sp, spt_noise, sp_win)
157 |
158 | iso_score = calc_isolation_score(spike_waves, noise_waves,
159 | spike_type,
160 | lam=lam,
161 | max_spikes=max_spikes)
162 |
163 | return iso_score
164 |
165 |
166 | def _iso_score_dist(dist, lam, n_spikes):
167 |
168 | """Calculate isolation score from a distance matrix
169 |
170 | Parameters
171 | ----------
172 | dist : array
173 | NxM matrix, where N is number of spikes and M is number of all
174 | events (spikes + noise)
175 |
176 | lam : float
177 | lambda parameter
178 |
179 | n_spikes : int
180 | number of spikes (N)
181 |
182 | Returns
183 | -------
184 | isolation_score : float
185 | """
186 |
187 | distSS = dist[:, :n_spikes]
188 | distSN = dist[:, n_spikes:]
189 |
190 | d0 = distSS.mean()
191 | expSS = np.exp(-distSS*lam*1./d0)
192 | expSN = np.exp(-distSN*lam*1./d0)
193 |
194 | sumSS = np.sum(expSS - np.eye(n_spikes), 1)
195 | sumSN = np.sum(expSN, 1)
196 |
197 | correctProbS = sumSS / (sumSS + sumSN)
198 | isolation_score = correctProbS.mean()
199 |
200 | return isolation_score
201 |
202 |
203 | def calc_isolation_score(spike_waves, noise_waves, spike_type='positive',
204 | lam=10., max_spikes=None):
205 | """Calculate isolation index according to Joshua et al. (2007)
206 |
207 | Parameters
208 | ----------
209 | spike_waves : dict
210 | noise_waves : dict
211 | sp_win : list or tuple
212 | window used for spike extraction
213 | spike_type : {'positive', 'negative'}
214 | indicates if the spikes occuring at time points given by spt are
215 | positive or negative going
216 | lambda : float
217 | determines the "softness" of clusters
218 |
219 | Returns
220 | -------
221 | isolation_score : float
222 | a value from the range [0,1] indicating the quality of sorting
223 | (1=ideal isolation of spikes)
224 | """
225 |
226 | #Memory issue: sample spikes if too many
227 | if max_spikes is not None:
228 | if spike_waves['data'].shape[1] > max_spikes:
229 | i = np.random.rand(max_spikes).argsort()
230 | spike_waves = spike_waves.copy()
231 | spike_waves['data'] = spike_waves['data'][:, i]
232 | if noise_waves['data'].shape[1] > max_spikes:
233 | i = np.random.rand(max_spikes).argsort()
234 | noise_waves = noise_waves.copy()
235 | noise_waves['data'] = noise_waves['data'][:, i]
236 |
237 | n_spikes = spike_waves['data'].shape[1]
238 |
239 | #calculate distance between spikes and all other events
240 | all_waves = {'data': np.concatenate((spike_waves['data'],
241 | noise_waves['data']), 1)}
242 | dist_matrix = cluster.dist_euclidean(spike_waves, all_waves)
243 | #d_0 = dist_matrix[:,:n_spikes].mean()
244 |
245 | isolation_score = _iso_score_dist(dist_matrix, lam,
246 | n_spikes)
247 |
248 | return isolation_score
249 |
--------------------------------------------------------------------------------
/src/spike_beans/base.py:
--------------------------------------------------------------------------------
1 | """
2 | The intial version of this code was adapted from a recipe by Zoran Isailovski
3 | (published under PSF License).
4 |
5 | http://code.activestate.com/recipes/413268-dependency-injection-the-python-way/
6 | """
7 |
8 | import logging
9 |
10 |
11 | class FeatureBroker(object):
12 | def __init__(self, allowReplace=False):
13 | self.providers = {}
14 | self.allowReplace = allowReplace
15 |
16 | def Provide(self, feature, provider, *args, **kwargs):
17 | if not self.allowReplace:
18 | assert feature not in self.providers, \
19 | "Duplicate feature: %r" % feature
20 |
21 | if callable(provider):
22 | call = lambda: provider(*args, **kwargs)
23 | else:
24 | call = lambda: provider
25 | self.providers[feature] = call
26 |
27 | def __getitem__(self, feature):
28 | if not feature in self.providers:
29 | raise AttributeError("Unknown feature named %r" % feature)
30 | else:
31 | return self.providers[feature]()
32 |
33 | def __setitem__(self, feature, component):
34 | self.Provide(feature, component)
35 |
36 | def __contains__(self, feature):
37 | return feature in self.providers
38 |
39 |
40 | features = FeatureBroker()
41 |
42 |
43 | def register(feature, component):
44 | """register `component` as providing `feature`"""
45 | features[feature] = component
46 | return component
47 |
48 | ## Representation of Required Features and Feature Assertions
49 |
50 | # Some basic assertions to test the suitability of injected features
51 |
52 | def NoAssertion():
53 | def test(obj):
54 | return True
55 | return test
56 |
57 |
58 | def IsInstanceOf(*classes):
59 | def test(obj):
60 | return isinstance(obj, classes)
61 | return test
62 |
63 |
64 | def HasAttributes(*attributes):
65 | def test(obj):
66 | for attr_name in attributes:
67 | try:
68 | getattr(obj, attr_name)
69 | except AttributeError:
70 | return False
71 | return True
72 | return test
73 |
74 |
75 | def HasMethods(*methods):
76 | def test(obj):
77 | for each in methods:
78 | try:
79 | attr = getattr(obj, each)
80 | except AttributeError:
81 | return False
82 | if not callable(attr):
83 | return False
84 | return True
85 | return test
86 |
87 | # An attribute descriptor to "declare" required features
88 |
89 | class DataAttribute(object):
90 | """A data descriptor that sets and returns values
91 | normally and notifies on value changed.
92 | """
93 |
94 | def __init__(self, initval=None, name='var'):
95 | self.val = initval
96 | self.name = name
97 |
98 | def __get__(self, obj, objtype):
99 | return self.val
100 |
101 | def __set__(self, obj, val):
102 | self.val = val
103 | for handler in obj.observers:
104 | handler()
105 |
106 |
107 | class RequiredFeature(object):
108 | """Descriptor class for required dependencies. Implements dependency
109 | injection."""
110 | def __init__(self, feature, assertion=NoAssertion(),
111 | alt_name="_alternative_"):
112 | """
113 | Parameters
114 | ----------
115 | feature : string
116 | name of the associated dependency
117 | assertion : function
118 | additional tests for the associated dependency
119 | alt_name : string
120 | in case of renaming the dependency, the variable named
121 | 'alt_name'+'feature' should contain the new name of the dependency
122 | """
123 |
124 | self.feature = feature
125 | self.alt_name = alt_name
126 | self.assertion = assertion
127 | self.result = None
128 |
129 | def __get__(self, callee, T):
130 | self.result = self.Request(callee)
131 | return self.result # <-- will request the feature upon first call
132 |
133 | def __set__(self, instance, value):
134 | '''Rename the feature'''
135 | if isinstance(value, str):
136 | logging.info("changed %s to %s" % (self.feature, value))
137 | setattr(instance, self.alt_name + self.feature, value)
138 | else:
139 | raise TypeError("can't change the feature name to non-string type")
140 |
141 | def __getattr__(self, name):
142 | assert name == 'result', \
143 | "Unexpected attribute request other then 'result'"
144 | return self.result
145 |
146 | def Request(self, callee):
147 | fet_name = self.feature
148 | if hasattr(callee, self.alt_name + self.feature):
149 | fet_name = getattr(callee, self.alt_name + self.feature)
150 | obj = features[fet_name]
151 |
152 | try:
153 | obj.register_handler(callee)
154 | except AttributeError:
155 | pass
156 |
157 | isComponentCorrect = self.assertion(obj)
158 | assert isComponentCorrect, \
159 | "The value %r of %r does not match the specified criteria" \
160 | % (obj, self.feature)
161 | return obj
162 |
163 | class OptionalFeature(RequiredFeature):
164 | """Descriptor class for optional dependencies. Implements dependency
165 | injection. Acquires None if the dependency is not satisfied"""
166 | def Request(self, callee):
167 | fet_name = self.feature
168 | if hasattr(callee, self.alt_name + self.feature):
169 | fet_name = getattr(callee, self.alt_name + self.feature)
170 |
171 | if not fet_name in features:
172 | return None
173 |
174 | return super(OptionalFeature, self).Request(callee)
175 |
176 | class Component(object):
177 | "Symbolic base class for components"
178 | def __init__(self):
179 | self.observers = []
180 |
181 | @staticmethod
182 | def _rm_duplicate_deps(deps):
183 | new_deps = []
184 | for i, d in enumerate(deps):
185 | if not d in deps[i + 1:]:
186 | new_deps.append(d)
187 | return new_deps
188 |
189 | def get_dependencies(self):
190 | deps = [o.get_dependencies() for o in self.observers]
191 | deps = sum(deps, self.observers)
192 | deps = Component._rm_duplicate_deps(deps)
193 | return deps
194 |
195 | def register_handler(self, handler):
196 | if handler not in self.observers:
197 | self.observers.append(handler)
198 |
199 | def unregister_handler(self, handler):
200 | if handler in self.observers:
201 | self.observers.remove(handler)
202 |
203 | def notify_observers(self):
204 | for dep in self.get_dependencies():
205 | dep._update()
206 |
207 | def _update(self):
208 | pass
209 |
210 | def update(self):
211 | self._update()
212 | self.notify_observers()
213 |
214 |
215 | class dictproperty(object):
216 | """implements collection properties with dictionary-like access.
217 | Copied and modified from a recipe by Ed Swierk
218 | published under PSF license
219 | `_
220 | """
221 |
222 | class _proxy(object):
223 | def __init__(self, obj, fget, fset, fdel):
224 | self._obj = obj
225 | self._fget = fget
226 | self._fset = fset
227 | self._fdel = fdel
228 |
229 | def __getitem__(self, key):
230 | if self._fget is None:
231 | raise TypeError("can't read item")
232 | return self._fget(self._obj, key)
233 |
234 | def __setitem__(self, key, value):
235 | if self._fset is None:
236 | raise TypeError("can't set item")
237 | self._fset(self._obj, key, value)
238 |
239 | def __delitem__(self, key):
240 | if self._fdel is None:
241 | raise TypeError("can't delete item")
242 | self._fdel(self._obj, key)
243 |
244 | def __init__(self, fget=None, fset=None, fdel=None, doc=None):
245 | self._fget = fget
246 | self._fset = fset
247 | self._fdel = fdel
248 | self.__doc__ = doc
249 |
250 | def __get__(self, obj, objtype=None):
251 | if obj is None:
252 | return self
253 | return self._proxy(obj, self._fget, self._fset, self._fdel)
254 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # SpikeSort documentation build configuration file, created by
4 | # sphinx-quickstart on Thu Jan 27 14:09:31 2011.
5 | #
6 | # This file is execfile()d with the current directory set to its containing dir.
7 | #
8 | # Note that not all possible configuration values are present in this
9 | # autogenerated file.
10 | #
11 | # All configuration values have a default; values that are commented out
12 | # serve to show the default.
13 |
14 | import sys, os
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | sys.path.append(os.path.abspath('../../src/'))
20 |
21 | # -- General configuration -----------------------------------------------------
22 |
23 | # If your documentation needs a minimal Sphinx version, state it here.
24 | #needs_sphinx = '1.0'
25 |
26 | # Add any Sphinx extension module names here, as strings. They can be extensions
27 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
28 | extensions = ['sphinx.ext.autodoc',
29 | 'sphinx.ext.doctest',
30 | 'sphinx.ext.coverage',
31 | 'sphinx.ext.pngmath',
32 | 'sphinx.ext.viewcode',
33 | 'sphinx.ext.autosummary',
34 | 'matplotlib.sphinxext.only_directives',
35 | 'matplotlib.sphinxext.plot_directive',
36 | 'numpydoc'
37 | ]
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # The suffix of source filenames.
43 | source_suffix = '.rst'
44 |
45 | # The encoding of source files.
46 | #source_encoding = 'utf-8-sig'
47 |
48 | # The master toctree document.
49 | master_doc = 'index'
50 |
51 | #automatically generate autosummary stub files
52 | #autosummary_generate = True
53 |
54 | # General information about the project.
55 | project = u'SpikeSort'
56 | copyright = u'2011, Bartosz Telenczuk, Dmytro Bielievstsov'
57 |
58 | # The version info for the project you're documenting, acts as replacement for
59 | # |version| and |release|, also used in various other places throughout the
60 | # built documents.
61 | #
62 | # The short X.Y version.
63 | version = '0.13'
64 | # The full version, including alpha/beta/rc tags.
65 | release = '0.13dev'
66 |
67 | # The language for content autogenerated by Sphinx. Refer to documentation
68 | # for a list of supported languages.
69 | #language = None
70 |
71 | # There are two options for replacing |today|: either, you set today to some
72 | # non-false value, then it is used:
73 | #today = ''
74 | # Else, today_fmt is used as the format for a strftime call.
75 | #today_fmt = '%B %d, %Y'
76 |
77 | # List of patterns, relative to source directory, that match files and
78 | # directories to ignore when looking for source files.
79 | exclude_patterns = []
80 |
81 | # The reST default role (used for this markup: `text`) to use for all documents.
82 | #default_role = None
83 |
84 | # If true, '()' will be appended to :func: etc. cross-reference text.
85 | #add_function_parentheses = True
86 |
87 | # If true, the current module name will be prepended to all description
88 | # unit titles (such as .. function::).
89 | #add_module_names = True
90 |
91 | # If true, sectionauthor and moduleauthor directives will be shown in the
92 | # output. They are ignored by default.
93 | #show_authors = False
94 |
95 | # The name of the Pygments (syntax highlighting) style to use.
96 | pygments_style = 'sphinx'
97 |
98 | # A list of ignored prefixes for module index sorting.
99 | #modindex_common_prefix = []
100 |
101 |
102 | # -- Options for HTML output ---------------------------------------------------
103 |
104 | # The theme to use for HTML and HTML Help pages. See the documentation for
105 | # a list of builtin themes.
106 | html_theme = 'flask'
107 |
108 | # Theme options are theme-specific and customize the look and feel of a theme
109 | # further. For a list of options available for each theme, see the
110 | # documentation.
111 | html_theme_options = {'index_logo': 'logo.png',
112 | 'index_logo_height': '150px'}
113 |
114 | # Add any paths that contain custom themes here, relative to this directory.
115 | html_theme_path = ['_themes']
116 |
117 | # The name for this set of Sphinx documents. If None, it defaults to
118 | # " v documentation".
119 | #html_title = None
120 |
121 | # A shorter title for the navigation bar. Default is the same as html_title.
122 | #html_short_title = None
123 |
124 | # The name of an image file (relative to this directory) to place at the top
125 | # of the sidebar.
126 | #html_logo = None
127 |
128 | # The name of an image file (within the static path) to use as favicon of the
129 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
130 | # pixels large.
131 | #html_favicon = None
132 |
133 | # Add any paths that contain custom static files (such as style sheets) here,
134 | # relative to this directory. They are copied after the builtin static files,
135 | # so a file named "default.css" will overwrite the builtin "default.css".
136 | html_static_path = ['_static']
137 |
138 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
139 | # using the given strftime format.
140 | #html_last_updated_fmt = '%b %d, %Y'
141 |
142 | # If true, SmartyPants will be used to convert quotes and dashes to
143 | # typographically correct entities.
144 | #html_use_smartypants = True
145 |
146 | # Custom sidebar templates, maps document names to template names.
147 | #html_sidebars = {}
148 |
149 | # Additional templates that should be rendered to pages, maps page names to
150 | # template names.
151 | #html_additional_pages = {}
152 |
153 | # If false, no module index is generated.
154 | #html_domain_indices = True
155 |
156 | # If false, no index is generated.
157 | #html_use_index = True
158 |
159 | # If true, the index is split into individual pages for each letter.
160 | #html_split_index = False
161 |
162 | # If true, links to the reST sources are added to the pages.
163 | #html_show_sourcelink = True
164 |
165 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
166 | #html_show_sphinx = True
167 |
168 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
169 | #html_show_copyright = True
170 |
171 | # If true, an OpenSearch description file will be output, and all pages will
172 | # contain a tag referring to it. The value of this option must be the
173 | # base URL from which the finished HTML is served.
174 | #html_use_opensearch = ''
175 |
176 | # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml").
177 | #html_file_suffix = ''
178 |
179 | # Output file base name for HTML help builder.
180 | htmlhelp_basename = 'SpikeSortsdoc'
181 |
182 |
183 | # -- Options for LaTeX output --------------------------------------------------
184 |
185 | # The paper size ('letter' or 'a4').
186 | #latex_paper_size = 'letter'
187 |
188 | # The font size ('10pt', '11pt' or '12pt').
189 | #latex_font_size = '10pt'
190 |
191 | # Grouping the document tree into LaTeX files. List of tuples
192 | # (source start file, target name, title, author, documentclass [howto/manual]).
193 | latex_documents = [
194 | ('index', 'SpikeSort.tex', u'SpikeSort Documentation',
195 | u'Bartosz Telenczuk', 'manual'),
196 | ]
197 |
198 | # The name of an image file (relative to this directory) to place at the top of
199 | # the title page.
200 | #latex_logo = None
201 |
202 | # For "manual" documents, if this is true, then toplevel headings are parts,
203 | # not chapters.
204 | #latex_use_parts = False
205 |
206 | # If true, show page references after internal links.
207 | #latex_show_pagerefs = False
208 |
209 | # If true, show URL addresses after external links.
210 | #latex_show_urls = False
211 |
212 | # Additional stuff for the LaTeX preamble.
213 | #latex_preamble = ''
214 |
215 | # Documents to append as an appendix to all manuals.
216 | #latex_appendices = []
217 |
218 | # If false, no module index is generated.
219 | #latex_domain_indices = True
220 |
221 |
222 | # -- Options for manual page output --------------------------------------------
223 |
224 | # One entry per manual page. List of tuples
225 | # (source start file, name, description, authors, manual section).
226 | man_pages = [
227 | ('index', 'sortspikes', u'SpikeSorts Documentation',
228 | [u'Bartosz Telenczuk'], 1)
229 | ]
230 |
--------------------------------------------------------------------------------
/docs/source/_themes/flask/static/flasky.css_t:
--------------------------------------------------------------------------------
1 | /*
2 | * flasky.css_t
3 | * ~~~~~~~~~~~~
4 | *
5 | * :copyright: Copyright 2010 by Armin Ronacher.
6 | * :license: Flask Design License, see LICENSE for details.
7 | */
8 |
9 | {% set page_width = '90%' %}
10 | {% set sidebar_width = '220px' %}
11 |
12 | @import url("basic.css");
13 |
14 | /* -- page layout ----------------------------------------------------------- */
15 |
16 | body {
17 | font-family: 'Georgia', serif;
18 | font-size: 17px;
19 | background-color: white;
20 | color: #000;
21 | margin: 0;
22 | padding: 0;
23 | }
24 |
25 | div.document {
26 | width: {{ page_width }};
27 | margin: 30px auto 0 20px;
28 | }
29 |
30 | div.documentwrapper {
31 | float: left;
32 | width: 100%;
33 | }
34 |
35 | div.bodywrapper {
36 | margin: 0 0 0 {{ sidebar_width }};
37 | }
38 |
39 | div.sphinxsidebar {
40 | width: {{ sidebar_width }};
41 | }
42 |
43 | hr {
44 | border: 1px solid #B1B4B6;
45 | }
46 |
47 | div.body {
48 | background-color: #ffffff;
49 | color: #3E4349;
50 | padding: 0 30px 0 30px;
51 | }
52 |
53 | img.floatingflask {
54 | padding: 0 0 10px 10px;
55 | float: right;
56 | }
57 |
58 | div.footer {
59 | width: {{ page_width }};
60 | margin: 20px auto 30px auto;
61 | font-size: 14px;
62 | color: #888;
63 | text-align: right;
64 | }
65 |
66 | div.footer a {
67 | color: #888;
68 | }
69 |
70 | div.related {
71 | display:inline;
72 | }
73 |
74 | div.related h3 {
75 | display:none;
76 | }
77 |
78 | div.sphinxsidebar a {
79 | color: #444;
80 | text-decoration: none;
81 | border-bottom: 1px dotted #999;
82 | }
83 |
84 | div.sphinxsidebar a:hover {
85 | border-bottom: 1px solid #999;
86 | }
87 |
88 | div.sphinxsidebar {
89 | font-size: 14px;
90 | line-height: 1.5;
91 | }
92 |
93 | div.sphinxsidebarwrapper {
94 | padding: 18px 10px;
95 | }
96 |
97 | div.sphinxsidebarwrapper p.logo {
98 | padding: 0 0 20px 0;
99 | margin: 0;
100 | text-align: center;
101 | }
102 |
103 | div.sphinxsidebar h3,
104 | div.sphinxsidebar h4 {
105 | font-family: 'Garamond', 'Georgia', serif;
106 | color: #444;
107 | font-size: 24px;
108 | font-weight: normal;
109 | margin: 0 0 5px 0;
110 | padding: 0;
111 | }
112 |
113 | div.sphinxsidebar h4 {
114 | font-size: 20px;
115 | }
116 |
117 | div.sphinxsidebar h3 a {
118 | color: #444;
119 | }
120 |
121 | div.sphinxsidebar p.logo a,
122 | div.sphinxsidebar h3 a,
123 | div.sphinxsidebar p.logo a:hover,
124 | div.sphinxsidebar h3 a:hover {
125 | border: none;
126 | }
127 |
128 | div.sphinxsidebar p {
129 | color: #555;
130 | margin: 10px 0;
131 | }
132 |
133 | div.sphinxsidebar ul {
134 | margin: 10px 0;
135 | padding: 0;
136 | color: #000;
137 | }
138 |
139 | div.sphinxsidebar input {
140 | border: 1px solid #ccc;
141 | font-family: 'Georgia', serif;
142 | font-size: 1em;
143 | }
144 |
145 | /* -- body styles ----------------------------------------------------------- */
146 |
147 | a {
148 | color: #004B6B;
149 | text-decoration: underline;
150 | }
151 |
152 | a:hover {
153 | color: #6D4100;
154 | text-decoration: underline;
155 | }
156 |
157 | div.body h1,
158 | div.body h2,
159 | div.body h3,
160 | div.body h4,
161 | div.body h5,
162 | div.body h6 {
163 | font-family: 'Garamond', 'Georgia', serif;
164 | font-weight: normal;
165 | margin: 30px 0px 10px 0px;
166 | padding: 0;
167 | }
168 |
169 | {% if theme_index_logo %}
170 | div.indexwrapper h1 {
171 | text-indent: -999999px;
172 | background: url({{ theme_index_logo }}) no-repeat center center;
173 | height: {{ theme_index_logo_height }};
174 | }
175 | {% endif %}
176 |
177 | div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; }
178 | div.body h2 { font-size: 180%; }
179 | div.body h3 { font-size: 150%; }
180 | div.body h4 { font-size: 130%; }
181 | div.body h5 { font-size: 100%; }
182 | div.body h6 { font-size: 100%; }
183 |
184 | a.headerlink {
185 | color: #ddd;
186 | padding: 0 4px;
187 | text-decoration: none;
188 | }
189 |
190 | a.headerlink:hover {
191 | color: #444;
192 | background: #eaeaea;
193 | }
194 |
195 | div.body p, div.body dd, div.body li {
196 | line-height: 1.4em;
197 | }
198 |
199 | div.admonition {
200 | background: #fafafa;
201 | margin: 20px -30px;
202 | padding: 10px 30px;
203 | border-top: 1px solid #ccc;
204 | border-bottom: 1px solid #ccc;
205 | }
206 |
207 | div.admonition tt.xref, div.admonition a tt {
208 | border-bottom: 1px solid #fafafa;
209 | }
210 |
211 | dd div.admonition {
212 | margin-left: -60px;
213 | padding-left: 60px;
214 | }
215 |
216 | div.admonition p.admonition-title {
217 | font-family: 'Garamond', 'Georgia', serif;
218 | font-weight: normal;
219 | font-size: 24px;
220 | margin: 0 0 10px 0;
221 | padding: 0;
222 | line-height: 1;
223 | }
224 |
225 | div.admonition p.last {
226 | margin-bottom: 0;
227 | }
228 |
229 | div.highlight {
230 | background-color: white;
231 | }
232 |
233 | dt:target, .highlight {
234 | background: #FAF3E8;
235 | }
236 |
237 | div.note {
238 | background-color: #eee;
239 | border: 1px solid #ccc;
240 | }
241 |
242 | div.seealso {
243 | background-color: rgb(240,240,240);
244 | }
245 |
246 | td.field-body strong {
247 | font-weight: normal;
248 | font-style: italic;
249 | }
250 | div.topic {
251 | background-color: #eee;
252 | }
253 |
254 | p.admonition-title {
255 | display: inline;
256 | }
257 |
258 | p.admonition-title:after {
259 | content: ":";
260 | }
261 |
262 | pre, tt {
263 | font-family: 'Consolas', 'Menlo', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace;
264 | font-size: 0.9em;
265 | }
266 |
267 | img.screenshot {
268 | }
269 |
270 | tt.descname, tt.descclassname {
271 | font-size: 0.95em;
272 | }
273 |
274 | tt.descname {
275 | padding-right: 0.08em;
276 | }
277 |
278 | img.screenshot {
279 | -moz-box-shadow: 2px 2px 4px #eee;
280 | -webkit-box-shadow: 2px 2px 4px #eee;
281 | box-shadow: 2px 2px 4px #eee;
282 | }
283 |
284 | table.docutils {
285 | border: 1px solid #888;
286 | -moz-box-shadow: 2px 2px 4px #eee;
287 | -webkit-box-shadow: 2px 2px 4px #eee;
288 | box-shadow: 2px 2px 4px #eee;
289 | }
290 |
291 | table.docutils td, table.docutils th {
292 | border: 1px solid #888;
293 | padding: 0.25em 0.7em;
294 | }
295 |
296 | table.field-list, table.footnote {
297 | border: none;
298 | -moz-box-shadow: none;
299 | -webkit-box-shadow: none;
300 | box-shadow: none;
301 | }
302 |
303 | table.footnote {
304 | margin: 15px 0;
305 | width: 100%;
306 | border: 1px solid #eee;
307 | background: #fdfdfd;
308 | font-size: 0.9em;
309 | }
310 |
311 | table.footnote + table.footnote {
312 | margin-top: -15px;
313 | border-top: none;
314 | }
315 |
316 | table.field-list th {
317 | padding: 0 0.8em 0 0;
318 | width: 120px;
319 | }
320 |
321 | table.field-list td {
322 | padding: 0;
323 | }
324 |
325 | table.footnote td.label {
326 | width: 0px;
327 | padding: 0.3em 0 0.3em 0.5em;
328 | }
329 |
330 | table.footnote td {
331 | padding: 0.3em 0.5em;
332 | }
333 |
334 | dl {
335 | margin: 0;
336 | padding: 0;
337 | }
338 |
339 | dl dd {
340 | margin-left: 30px;
341 | }
342 |
343 | dl.method, dl.attribute {
344 | border-top: 1px solid rgb(170,170,170);
345 | }
346 |
347 | dl.class, dl.function {
348 | border-top: 2px solid rgb(136,136,136);
349 | }
350 |
351 | blockquote {
352 | margin: 0 0 0 30px;
353 | padding: 0;
354 | }
355 |
356 | ul, ol {
357 | margin: 10px 0 10px 30px;
358 | padding: 0;
359 | }
360 |
361 | pre {
362 | background: #eee;
363 | padding: 7px 30px;
364 | margin: 15px -30px;
365 | line-height: 1.3em;
366 | }
367 |
368 | dl pre, blockquote pre, li pre {
369 | margin-left: -60px;
370 | padding-left: 60px;
371 | }
372 |
373 | dl dl pre {
374 | margin-left: -90px;
375 | padding-left: 90px;
376 | }
377 |
378 | tt {
379 | background-color: #ecf0f3;
380 | color: #222;
381 | /* padding: 1px 2px; */
382 | }
383 |
384 | tt.xref, a tt {
385 | background-color: #FBFBFB;
386 | border-bottom: 1px solid white;
387 | }
388 |
389 | a.reference {
390 | text-decoration: none;
391 | border-bottom: 1px dotted #004B6B;
392 | }
393 |
394 | a.reference:hover {
395 | border-bottom: 1px solid #6D4100;
396 | }
397 |
398 | a.footnote-reference {
399 | text-decoration: none;
400 | font-size: 0.7em;
401 | vertical-align: top;
402 | border-bottom: 1px dotted #004B6B;
403 | }
404 |
405 | a.footnote-reference:hover {
406 | border-bottom: 1px solid #6D4100;
407 | }
408 |
409 | a:hover tt {
410 | background: #EEE;
411 | }
412 |
--------------------------------------------------------------------------------
/src/spike_sort/core/cluster.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | import numpy as np
5 |
6 | from spike_sort.ui import manual_sort
7 | from spike_sort.core.features import requires
8 |
9 | #optional scikits.learn imports
10 | try:
11 | #import scikits.learn >= 0.9
12 | from sklearn import cluster as skcluster
13 | from sklearn import mixture
14 | from sklearn import decomposition
15 | from sklearn import neighbors
16 | except ImportError:
17 | try:
18 | #import scikits.learn < 0.9
19 | from scikits.learn import cluster as skcluster
20 | from scikits.learn import mixture
21 | from scikits.learn import decomposition
22 | from scikits.learn import neighbors
23 | except ImportError:
24 | pass
25 |
26 | @requires(skcluster, "scikits.learn must be installed to use spectral")
27 | def spectral(data, n_clusters=2, affinity='rbf'):
28 |
29 | sp = skcluster.SpectralClustering(k=n_clusters, affinity=affinity)
30 | sp.fit(data)
31 | labels = sp.labels_
32 | return labels
33 |
34 | @requires(skcluster, "scikits.learn must be installed to use dbsca")
35 | def dbscan(data, eps=0.3, min_samples=10):
36 | """DBScan clustering
37 |
38 | Parameters
39 | ----------
40 | data : float array
41 | features array
42 |
43 | Returns
44 | -------
45 | cl : int array
46 | cluster indicies
47 |
48 | Notes
49 | -----
50 | This function requires scikits-learn
51 | """
52 |
53 | db = skcluster.DBSCAN(eps=eps, min_samples=min_samples).fit(data)
54 | labels = db.labels_
55 | return labels
56 |
57 | @requires(skcluster, "scikits.learn must be installed to use mean_shift")
58 | def mean_shift(data, bandwith=None, n_samples=500, quantile=0.3):
59 | if bandwith is None:
60 | bandwidth = skcluster.estimate_bandwidth(data,
61 | quantile=quantile,
62 | n_samples=n_samples)
63 |
64 | ms = skcluster.MeanShift(bandwidth=bandwidth).fit(data)
65 | labels = ms.labels_
66 | return labels
67 |
68 | @requires(skcluster, "scikits.learn must be installed to use k_means_plus")
69 | def k_means_plus(data, K=2, whiten=False):
70 | """k means with smart initialization.
71 |
72 | Notes
73 | -----
74 | This function requires scikits-learn
75 |
76 | See Also
77 | --------
78 | kmeans
79 |
80 | """
81 | if whiten:
82 | pca = decomposition.PCA(whiten=True).fit(data)
83 | data = pca.transform(data)
84 |
85 | clusters = skcluster.k_means(data, n_clusters=K)[1]
86 |
87 | return clusters
88 |
89 |
90 | def gmm(data, k=2, cvtype='full'):
91 | """Cluster based on gaussian mixture models
92 |
93 | Parameters
94 | ----------
95 | data : dict
96 | features structure
97 | k : int
98 | number of clusters
99 |
100 | Returns
101 | -------
102 | cl : int array
103 | cluster indicies
104 |
105 | Notes
106 | -----
107 | This function requires scikits-learn
108 |
109 | """
110 |
111 | try:
112 | #scikits.learn 0.8
113 | clf = mixture.GMM(n_states=k, cvtype=cvtype)
114 | except TypeError:
115 | try:
116 | clf = mixture.GMM(n_components=k, cvtype=cvtype)
117 | except TypeError:
118 | #scikits.learn 0.11
119 | clf = mixture.GMM(n_components=k, covariance_type=cvtype)
120 | except NameError:
121 | raise NotImplementedError(
122 | "scikits.learn must be installed to use gmm")
123 |
124 | clf.fit(data)
125 | cl = clf.predict(data)
126 | return cl
127 |
128 |
129 | def manual(data, n_spikes='all', *args, **kwargs):
130 | """Sort spikes manually by cluster cutting
131 |
132 | Opens a new window in which you can draw cluster of arbitrary
133 | shape.
134 |
135 | Notes
136 | -----
137 | Only two first features are plotted
138 | """
139 | if n_spikes=='all':
140 | return manual_sort._cluster(data[:, :2], **kwargs)
141 | else:
142 | idx = np.argsort(np.random.rand(data.shape[0]))[:n_spikes]
143 | labels_subsampled = manual_sort._cluster(data[idx, :2], **kwargs)
144 | try:
145 | neigh = neighbors.KNeighborsClassifier(15)
146 | except NameError:
147 | raise NotImplementedError(
148 | "scikits.learn must be installed to use subsampling")
149 | neigh.fit(data[idx, :2], labels_subsampled)
150 | return neigh.predict(data[:, :2])
151 |
152 |
153 |
154 |
155 | def none(data):
156 | """Do nothing"""
157 | return np.zeros(data.shape[0], dtype='int16')
158 |
159 |
160 | def _metric_euclidean(data1, data2):
161 | n_pts1, n_dims1 = data1.shape
162 | n_pts2, n_dims2 = data2.shape
163 | if not n_dims1 == n_dims2:
164 | raise TypeError("data1 and data2 must have the same number of columns")
165 | delta = np.zeros((n_pts1, n_pts2), 'd')
166 | for d in xrange(n_dims1):
167 | _data1 = data1[:, d]
168 | _data2 = data2[:, d]
169 | _delta = np.subtract.outer(_data1, _data2) ** 2
170 | delta += _delta
171 | return np.sqrt(delta)
172 |
173 |
174 | def dist_euclidean(spike_waves1, spike_waves2=None):
175 | """Given spike_waves calculate pairwise Euclidean distance between
176 | them"""
177 |
178 | sp_data1 = np.concatenate(spike_waves1['data'], 1)
179 |
180 | if spike_waves2 is None:
181 | sp_data2 = sp_data1
182 | else:
183 | sp_data2 = np.concatenate(spike_waves2['data'], 1)
184 | d = _metric_euclidean(sp_data1, sp_data2)
185 |
186 | return d
187 |
188 |
189 | def cluster(method, features, *args, **kwargs):
190 | """Automatically cluster spikes using K means algorithm
191 |
192 | Parameters
193 | ----------
194 | features : dict
195 | spike features datastructure
196 | n_clusters : int
197 | number of clusters to identify
198 | args, kwargs :
199 | optional arguments that are passed to the clustering algorithm
200 |
201 | Returns
202 | -------
203 | labels : array
204 | array of cluster (unit) label - one for each cell
205 |
206 | Examples
207 | --------
208 | Create a sample feature dataset and use k-means clustering to find
209 | groups of spikes (units)
210 |
211 | >>> import spike_sort
212 | >>> import numpy as np
213 | >>> np.random.seed(1234) #k_means uses random initialization
214 | >>> features = {'data':np.array([[0.,0.],
215 | ... [0, 1.],
216 | ... [0, 0.9],
217 | ... [0.1,0]])}
218 | >>> labels = spike_sort.cluster.cluster('k_means', features, 2)
219 | >>> print labels
220 | [0 1 1 0]
221 | """
222 | try:
223 | cluster_func = eval(method)
224 | except NameError:
225 | raise NotImplementedError(
226 | "clustering method %s is not implemented" % method)
227 |
228 | data = features['data']
229 | mask = features.get('is_valid')
230 | if mask is not None:
231 | valid_data = data[mask, :]
232 | cl = cluster_func(valid_data, *args, **kwargs)
233 | labels = np.zeros(data.shape[0], dtype='int') - 1
234 | labels[mask] = cl
235 | else:
236 | labels = cluster_func(data, *args, **kwargs)
237 | return labels
238 |
239 |
240 | def k_means(features, K=2):
241 | """Perform K means clustering
242 |
243 | Parameters
244 | ----------
245 | data : dict
246 | data vectors (n,m) where n is the number of datapoints and m is
247 | the number of variables
248 | K : int
249 | number of distinct clusters to identify
250 |
251 | Returns
252 | -------
253 | partition : array
254 | vector of cluster labels (ints) for each datapoint from `data`
255 | """
256 | n_dim = features.shape[1]
257 | centers = np.random.rand(K, n_dim)
258 | centers_new = np.random.rand(K, n_dim)
259 | partition = np.zeros(features.shape[0], dtype=np.int)
260 | while not (centers_new == centers).all():
261 | centers = centers_new.copy()
262 |
263 | distances = (centers[:, np.newaxis, :] - features)
264 | distances *= distances
265 | distances = distances.sum(axis=2)
266 | partition = distances.argmin(axis=0)
267 |
268 | for i in range(K):
269 | if np.sum(partition == i) > 0:
270 | centers_new[i, :] = features[partition == i, :].mean(0)
271 | return partition
272 |
273 |
274 | def split_cells(spt_dict, idx, which='all'):
275 | """return the spike times belonging to the cluster and the rest"""
276 |
277 | if which == 'all':
278 | classes = np.unique(idx)
279 | else:
280 | classes = which
281 | spt = spt_dict['data']
282 | spt_dicts = dict([(cl, {'data': spt[idx == cl]}) for cl in classes])
283 | return spt_dicts
284 |
--------------------------------------------------------------------------------
/tests/test_io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import filecmp
3 | import json
4 | import glob
5 | import tempfile
6 |
7 | from nose.tools import ok_, eq_, raises
8 | import tables
9 | import numpy as np
10 |
11 | from spike_sort.io.filters import BakerlabFilter, PyTablesFilter
12 | from spike_sort.io import export
13 | from spike_sort.io import neo_filters
14 |
15 | class TestNeo:
16 |
17 | def setUp(self):
18 | path = os.path.dirname(os.path.abspath(__file__))
19 | self.samples_dir = os.path.join(path, 'samples')
20 | self.file_names = {'Axon': 'axonio.abf'}
21 |
22 | def test_read_abf(self):
23 | fname = self.file_names['Axon']
24 | file_path = os.path.join(self.samples_dir, fname)
25 | abf = neo_filters.AxonFilter(file_path)
26 | sp = abf.read_sp()
27 | assert len(np.array(sp['data']))>0
28 | assert sp['FS'] == 125000
29 |
30 | def test_read_abf_via_component(self):
31 | fname = self.file_names['Axon']
32 | file_path = os.path.join(self.samples_dir, fname)
33 | source = neo_filters.NeoSource(file_path)
34 | sp = source.read_sp()
35 | assert len(np.array(sp['data']))>0
36 |
37 | class TestHDF(object):
38 | def setUp(self):
39 | self.data = np.random.randint(1000, size=(4, 100))
40 | self.spt = np.random.randint(0, 100, (10,)) / 200.0
41 | self.el_node = '/Subject/Session/Electrode'
42 | self.fname = 'test.h5'
43 | self.cell_node = self.el_node + '/cell'
44 | self.h5f = tables.openFile(self.fname, 'a')
45 |
46 | self.spt.sort()
47 | atom = tables.Atom.from_dtype(self.data.dtype)
48 | shape = self.data.shape
49 | filter = tables.Filters(complevel=0, complib='zlib')
50 | new_array = self.h5f.createCArray(self.el_node, "raw", atom, shape,
51 | filters=filter,
52 | createparents=True)
53 | self.sampfreq = 5.0E3
54 | new_array.attrs['sampfreq'] = self.sampfreq
55 | new_array[:] = self.data
56 | self.h5f.createArray(self.el_node, "cell", self.spt,
57 | title="",
58 | createparents="True")
59 | self.h5f.close()
60 |
61 | def tearDown(self):
62 | self.filter.close()
63 | os.unlink(self.fname)
64 |
65 | def test_write(self):
66 | sp_dict = {'data': self.data, 'FS': self.sampfreq}
67 | spt_dict = {'data': self.spt}
68 | self.filter = PyTablesFilter("test2.h5")
69 | self.filter.write_sp(sp_dict, self.el_node + "/raw")
70 | self.filter.write_spt(spt_dict, self.cell_node)
71 | self.filter.close()
72 | exit_code = os.system('h5diff ' + self.fname + ' test2.h5')
73 | os.unlink("test2.h5")
74 | ok_(exit_code == 0)
75 |
76 | def test_read_sp(self):
77 | self.filter = PyTablesFilter(self.fname)
78 | sp = self.filter.read_sp(self.el_node)
79 | ok_((sp['data'][:] == self.data).all())
80 |
81 | def test_read_sp_attr(self):
82 | #check n_contacts attribute
83 | self.filter = PyTablesFilter(self.fname)
84 | sp = self.filter.read_sp(self.el_node)
85 | n_contacts = sp['n_contacts']
86 | ok_(n_contacts == self.data.shape[0])
87 |
88 | def test_read_spt(self):
89 | self.filter = PyTablesFilter(self.fname)
90 | spt = self.filter.read_spt(self.cell_node)
91 | ok_((spt['data'] == self.spt).all())
92 |
93 |
94 | class TestBakerlab(object):
95 | def setup(self):
96 | file_descr = {"fspike": "{ses_id}{el_id}.sp",
97 | "cell": "{ses_id}{el_id}{cell_id}.spt",
98 | "dirname": ".",
99 | "FS": 5.E3,
100 | "n_contacts": 1}
101 | self.el_node = '/Test/s32test01/el1'
102 | self.cell_node = self.el_node + '/cell1'
103 | self.data = np.random.randint(-1000, 1000, (100,))
104 | self.spt_data = np.random.randint(0, 100, (10,)) / 200.0
105 | self.spt_metadata = {'element1': 5, 'element2': 10}
106 | self.conf_file = 'test.conf'
107 | self.fname = "32test011.sp"
108 | self.spt_fname = "32test0111.spt"
109 | self.spt_log_fname = "32test0111.log"
110 |
111 | with open(self.conf_file, 'w') as fp:
112 | json.dump(file_descr, fp)
113 |
114 | self.data.astype(np.int16).tofile(self.fname)
115 | (self.spt_data * 200).astype(np.int32).tofile(self.spt_fname)
116 |
117 | with open(self.spt_log_fname, 'w') as lf:
118 | json.dump(self.spt_metadata, lf)
119 |
120 | def tearDown(self):
121 | os.unlink(self.conf_file)
122 | os.unlink(self.fname)
123 | os.unlink(self.spt_fname)
124 | os.unlink(self.spt_log_fname)
125 |
126 | def test_write_spt(self):
127 | cell_node_tmp = '/Test/s32test01/el2/cell1'
128 | spt_dict = {'data': self.spt_data}
129 | filter = BakerlabFilter(self.conf_file)
130 | filter.write_spt(spt_dict, cell_node_tmp)
131 | files_eq = filecmp.cmp(self.spt_fname, "32test0121.spt", shallow=0)
132 | os.unlink("32test0121.spt")
133 | ok_(files_eq)
134 |
135 | def test_write_spt_metadata(self):
136 | cell_node_tmp = '/Test/s32test01/el2/cell1'
137 | spt_dict = {'data': self.spt_data,
138 | 'metadata': self.spt_metadata}
139 | filter = BakerlabFilter(self.conf_file)
140 | filter.write_spt(spt_dict, cell_node_tmp, overwrite=True)
141 | files_eq = filecmp.cmp(self.spt_log_fname, "32test0121.log", shallow=0)
142 | os.unlink("32test0121.spt")
143 | os.unlink("32test0121.log")
144 | ok_(files_eq)
145 |
146 | @raises(IOError)
147 | def test_writespt_overwrite_exc(self):
148 | cell_node_tmp = '/Test/s32test01/el1/cell1'
149 | spt_dict = {'data': self.spt_data}
150 | filter = BakerlabFilter(self.conf_file)
151 | filter.write_spt(spt_dict, cell_node_tmp)
152 |
153 | def test_read_spt(self):
154 | filter = BakerlabFilter(self.conf_file)
155 | sp = filter.read_spt(self.cell_node)
156 | read_data = sp['data']
157 | ok_((np.abs(read_data - self.spt_data) <= 1 / 200.0).all())
158 |
159 | def test_write_sp(self):
160 | el_node_tmp = '/Test/s32test01/el2'
161 | sp_dict = {'data': self.data[np.newaxis, :]}
162 | filter = BakerlabFilter(self.conf_file)
163 | filter.write_sp(sp_dict, el_node_tmp)
164 | files_eq = filecmp.cmp("32test011.sp", "32test012.sp", shallow=0)
165 | os.unlink("32test012.sp")
166 | ok_(files_eq)
167 |
168 | def test_write_multichan(self):
169 | n_contacts = 4
170 | data = np.repeat(self.data[np.newaxis, :], n_contacts, 0)
171 | sp_dict = {'data': data}
172 | with open(self.conf_file, 'r+') as fid:
173 | file_desc = json.load(fid)
174 | file_desc['n_contacts'] = 4
175 | file_desc["fspike"] = "test{contact_id}.sp"
176 | fid.seek(0)
177 | json.dump(file_desc, fid)
178 | filter = BakerlabFilter(self.conf_file)
179 | filter.write_sp(sp_dict, self.el_node)
180 | all_chan_files = glob.glob("test?.sp")
181 | [os.unlink(p) for p in all_chan_files]
182 | eq_(len(all_chan_files), n_contacts)
183 |
184 | def test_read_sp(self):
185 | filter = BakerlabFilter(self.conf_file)
186 | sp = filter.read_sp(self.el_node)
187 | read_data = sp['data'][0, :]
188 | print read_data.shape
189 | ok_((np.abs(read_data - self.data) <= 1 / 200.0).all())
190 |
191 | def test_sp_shape(self):
192 | with open(self.conf_file, 'r+') as fid:
193 | file_desc = json.load(fid)
194 | file_desc['n_contacts'] = 4
195 | fid.seek(0)
196 | json.dump(file_desc, fid)
197 | filter = BakerlabFilter(self.conf_file)
198 | sp = filter.read_sp(self.el_node)
199 | data = sp['data']
200 | ok_(data.shape == (4, len(self.data)))
201 |
202 |
203 | class TestExport(object):
204 | def test_export_cells(self):
205 | n_cells = 4
206 | self.spt_data = np.random.randint(0, 10000, (100, n_cells))
207 | self.spt_data.sort(0)
208 | self.cells_dict = dict([(i, {"data": self.spt_data[:, i]})
209 | for i in range(n_cells)])
210 | tempdir = tempfile.mkdtemp()
211 | fname = os.path.join(tempdir, "test.h5")
212 | ptfilter = PyTablesFilter(fname)
213 | tmpl = "/Subject/Session/Electrode/Cell{cell_id}"
214 | export.export_cells(ptfilter, tmpl, self.cells_dict)
215 | test = []
216 | for i in range(n_cells):
217 | spt_dict = ptfilter.read_spt(tmpl.format(cell_id=i))
218 | test.append((spt_dict['data'] == self.spt_data[:, i]).all())
219 | test = np.array(test)
220 | ptfilter.close()
221 |
222 | os.unlink(fname)
223 | os.rmdir(tempdir)
224 |
225 | ok_(test.all())
226 |
--------------------------------------------------------------------------------
/docs/source/datastructures.rst:
--------------------------------------------------------------------------------
1 | .. testsetup::
2 |
3 | import numpy
4 | numpy.random.seed(1221)
5 |
6 |
7 | Data Structures
8 | ===============
9 |
10 | .. _raw_recording:
11 |
12 | To achieve best compatibility with external libraries most of the data
13 | structures are standard Python *dictionaries*, with at least one key -- `data`. The
14 | `data` key contains the actual data in an array-like object (such as
15 | NumPy array). Other attributes provide metadata that are required by
16 | some methods.
17 |
18 |
19 | Raw recording
20 | -------------
21 |
22 | Raw electrophysiological data sampled at equally spaced time points. It
23 | can contain multiple channels, but all of them need to be of the same
24 | sampling frequency and duration (for example, multiple contacts of
25 | a tetrode).
26 |
27 | The following keys are defined:
28 |
29 | :data: *array*, required
30 |
31 | array-like object (for example :py:class:`numpy.ndarray`) of
32 | dimensions (N_channels, N_samples)
33 |
34 | :FS: *int*, required
35 |
36 | sampling frequency in Hz
37 |
38 | :n_contacts: *int*, required
39 |
40 | number of channels (tetrode contacts). It is equal to the size
41 | of the first dimension of `data`.
42 |
43 | .. note::
44 |
45 | You may read/write the data with your own functions, but to make the
46 | interface with the SpikeSort a bit cleaner, you might also want to
47 | define your custom IO filters (see :ref:`io_filters`)
48 |
49 | .. rubric:: Example
50 |
51 | We will read the raw tetrode data from :ref:`tutorial_data` using
52 | standard :py:class:`~spike_sort.io.filters.PyTablesFilter`:
53 |
54 | >>> from spike_sort.io.filters import PyTablesFilter
55 | >>> io_filter = PyTablesFilter('../data/tutorial.h5')
56 | >>> raw_data = io_filter.read_sp('/SubjectA/session01/el1')
57 | >>> print(raw_data.keys()) # print all keys
58 | ['n_contacts', 'FS', 'data']
59 | >>> shape = raw_data['data'].shape # check size
60 | >>> print "{0} channels, {1} samples".format(*shape)
61 | 4 channels, 23512500 samples
62 | >>> print(raw_data['FS']) # check sampling frequency
63 | 25000
64 |
65 |
66 | .. _spike_times:
67 |
68 | Spike times
69 | -----------
70 |
71 | A sequence of (sorted) time readings at which spikes were generated
72 | (or other discrete events happened).
73 |
74 | The data is store in a dictionary with following keys:
75 |
76 | :data: *array*, required
77 |
78 | one-dimensional array-like object with event times (in
79 | milliseconds)
80 |
81 | :is_valid: *array*, optional
82 |
83 | boolean area of the same size as `data` -- if an element is False
84 | the event of the same index is masked (or invalid)
85 |
86 | .. note::
87 |
88 | You may read/write the data with your own functions, but to make the
89 | interface with the SpikeSort a bit cleaner, you might also want to
90 | define your custom IO filters (see :ref:`io_filters`)
91 |
92 | .. rubric:: Example
93 |
94 | :py:func:`spike_sort.core.extract.detect_spikes` is one of functions
95 | which takes the raw recordings and returns spike times dictionary:
96 |
97 |
98 | >>> import numpy as np
99 | >>> raw_dict = {
100 | ... 'data': np.array([[0,1,0,0,0,1]]),
101 | ... 'FS' : 10000,
102 | ... 'n_contacts': 1
103 | ... }
104 | >>> from spike_sort.core.extract import detect_spikes
105 | >>> spt_dict = detect_spikes(raw_dict, thresh=0.8)
106 | >>> print(spt_dict.keys())
107 | ['thresh', 'contact', 'data']
108 | >>> print('Spike times (ms): {0}'.format(spt_dict['data']))
109 | Spike times (ms): [ 0. 0.4]
110 |
111 |
112 | Note that in addition to the required data key,
113 | :py:func:`~spike_sort.core.extract.detect_spikes`
114 | appends some extrcontact a attributes: :py:attr:`thresh` (detection threshold)
115 | and :py:attr:`contact` (contact on which spikes were detected). These
116 | attributes are ignored by other methods.
117 |
118 | .. _spike_wave:
119 |
120 | Spike waveforms
121 | ---------------
122 |
123 | Spike waveform structure contains waveforms of extracted spikes. It may be
124 | any mapping data structure (usually a dictionary) with following keys:
125 |
126 | :data: *array*, required
127 |
128 | three-dimensional array-like object of size (N_points, N_spikes,
129 | N_contacts), where:
130 |
131 | * `N_points` -- the number of data points in a single waveform,
132 | * `N_spikes` -- the total number of spikes and
133 | * `N_contacts` -- the number of independent channels (for example 4 in a
134 | tetrode)
135 |
136 | :time: *array*, required
137 |
138 | Timeline of the spike waveshapes (in miliseconds). It must be of
139 | the same size as the first dimension of data (`n_pts`).
140 |
141 | :FS: *int*, optional
142 |
143 | Sampling frequency.
144 |
145 |
146 | :n_contacts: *int*, optional
147 |
148 | Number of independent channels with spike
149 | waveshapes (see also :ref:`raw_recording`).
150 |
151 | :is_valid: *array*, optional
152 |
153 | boolean area of the size of second dimension of `data` (N_spikes) -- if an element is False
154 | the spike with the same index is masked (or invalid)
155 |
156 |
157 | .. rubric:: Example
158 |
159 | Spike waveforms can be extracted from raw recordings (see :ref:`raw_recording`)
160 | given a sequence of spike times (see :ref:`spike_times`) by means of
161 | :py:func:`spike_sort.core.extract.extract_spikes` function:
162 |
163 | >>> from spike_sort.core.extract import extract_spikes
164 | >>> raw_dict = {
165 | ... 'data': np.array([[0,1,1,0,0,0,1,-1,0,0, 0]]),
166 | ... 'FS' : 10000,
167 | ... 'n_contacts': 1
168 | ... } # raw signal
169 | >>> spt_dict = {
170 | ... 'data': np.array([0.15, 0.65, 1])}
171 | ... } # timestamps of three spikes
172 | >>> sp_win = [0, 0.4] # window in which spikes should be extracted
173 | >>> waves_dict = extract_spikes(raw_dict, spt_dict, sp_win)
174 |
175 | Now let us investigate the returned spike waveforms structure:
176 |
177 | * keys:
178 |
179 | >>> print waves_dict.keys()
180 | ['is_valid', 'FS', 'data', 'time']
181 |
182 | * data array shape:
183 |
184 | >>> print(waves_dict['data'].shape)
185 | (4, 3, 1)
186 |
187 | * extracted spikes:
188 |
189 | >>> print(waves_dict['data'][:,:,0].T) # data contains three spikes
190 | [[ 1. 1. 0. 0.]
191 | [ 1. -1. 0. 0.]
192 | [ 0. 0. 0. 0.]]
193 | >>> print(waves_dict['time']) # defined over 4 time points
194 | [ 0. 0.1 0.2 0.3]
195 |
196 | * and potential invalid (truncated spikes):
197 |
198 | >>> print(waves_dict['is_valid']) # last spike is invalid (truncated)
199 | [ True True False]
200 |
201 | Note that the :py:attr:`is_valid` element of truncated spike is
202 | :py:data:`False`.
203 |
204 | .. _spike_features:
205 |
206 | Spike features
207 | --------------
208 |
209 | This data structure contains features calculated from spike waveforms
210 | using one of the methods defined in :py:mod:`spike_sort.core.features` module
211 | (one of the :py:func:`fet*` functions, see :ref:`features_doc`).
212 |
213 | The spike features dictionary consits of following keys:
214 |
215 | :data: *array*, required
216 |
217 | two-dimensional array of size (N_spikes, N_features) that contains
218 | the actual feature values
219 |
220 | :names: *list of str*, required
221 |
222 | list of length N_features containing feature labels
223 |
224 | :is_valid: *array*, optional
225 |
226 | boolean area of of length N_spikes; if an element is False
227 | the spike with the same index is masked (or invalid, see also
228 | :ref:`spike_wave`)
229 |
230 |
231 | .. rubric:: Example
232 |
233 | Let us try to calculate peak-to-peak amplitude from some spikes
234 | extracted in :ref:`spike_wave`:
235 |
236 | >>> from spike_sort.core.features import fetP2P
237 | >>> print(waves_dict['data'].shape) # 3 spikes, 4 data points each
238 | (4, 3, 1)
239 | >>> feature_dict = fetP2P(waves_dict)
240 | >>> print(feature_dict.keys())
241 | ['is_valid', 'data', 'names']
242 | >>> print(feature_dict['data'].shape)
243 | (3, 1)
244 |
245 | Then we have one feature for 3 spikes. Let check whether the peak-to-peak amplitudes
246 | are correctly calculated:
247 |
248 | >>> print(feature_dict['data'])
249 | [[ 1.]
250 | [ 2.]
251 | [ 0.]]
252 |
253 | as expected (compare with example above). There is only one
254 | peak-to-peak (`P2P`) feature on a single channel (`Ch0`) and its name
255 | is:
256 |
257 | >>> print(feature_dict['names'])
258 | ['Ch0:P2P']
259 |
260 | The mask array is inherited from :py:data:`waves_dict`:
261 |
262 | >>> print(feature_dict['is_valid'])
263 | [ True True False]
264 |
265 | .. _spike_labels:
266 |
267 | Spike labels
268 | ------------
269 |
270 | Spike labels are the identifiers of a cell (unit) each spike was
271 | classified to. Spike labels are **not** dictionaries, but arrays of
272 | integers -- one cluster index per spike.
273 |
274 | .. rubric:: Example
275 |
276 | Let us try to cluster the spikes described by `Sample` feature using
277 | K-means with K=2:
278 |
279 | >>> from spike_sort.core.cluster import cluster
280 | >>> feature_dict = {
281 | ... 'data' : np.array([[1],[-1], [1]]),
282 | ... 'names' : ['Sample']
283 | ... }
284 | >>> labels = cluster('k_means', feature_dict, 2)
285 | >>> print(labels)
286 | [1 0 1]
287 |
288 | As expected :py:data:`labels` is an array describing two clusters: 0 and 1.
289 |
290 |
--------------------------------------------------------------------------------
/tests/test_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import spike_sort as ss
3 |
4 | from nose.tools import ok_, eq_, raises
5 | from numpy.testing import assert_array_almost_equal as almost_equal
6 |
7 |
8 | class TestFeatures(object):
9 | def setup(self):
10 | self.gain = 1000
11 |
12 | n_pts = 100
13 | FS = 5E3
14 | time = np.arange(n_pts, dtype=np.float32) / n_pts - 0.5
15 | cells = self.gain * np.vstack(
16 | (np.sin(time * 2 * np.pi) / 2.0,
17 | np.abs(time) * 2 - 1
18 | )
19 | )
20 | self.cells = cells.astype(np.int)
21 |
22 | #define data interface (dictionary)
23 | self.spikes_dict = {"data": self.cells.T[:, :, np.newaxis],
24 | "time": time, "FS": FS}
25 |
26 | def test_fetP2P(self):
27 | spikes_dict = self.spikes_dict.copy()
28 |
29 | n_spikes = 200
30 | amps = np.random.randint(1, 100, n_spikes)
31 | amps = amps[:, np.newaxis]
32 | spikes = amps * self.cells[0, :]
33 | spikes_dict['data'] = spikes.T[:, :, np.newaxis]
34 | p2p = ss.features.fetP2P(spikes_dict)
35 |
36 | ok_((p2p['data'] == amps * self.gain).all())
37 |
38 | def test_PCA(self):
39 | n_dim = 2
40 | n_obs = 100
41 | raw_data = np.random.randn(n_dim, n_obs)
42 | mixing = np.array([[-1, 1], [1, 1]])
43 |
44 | mix_data = np.dot(mixing, raw_data)
45 | mix_data = mix_data - mix_data.mean(1)[:, np.newaxis]
46 |
47 | evals, evecs, score = ss.features.PCA(mix_data, n_dim)
48 |
49 | proj_cov = np.cov(score)
50 | error = np.mean((proj_cov - np.eye(n_dim)) ** 2)
51 | ok_(error < 0.01)
52 |
53 | def test_fetPCA(self):
54 | spikes_dict = self.spikes_dict.copy()
55 | n_spikes = 200
56 | n_cells = 2
57 | amp_var_fact = 0.4
58 | amp_var = 1 + amp_var_fact * np.random.rand(n_spikes, 1)
59 | _amps = np.random.rand(n_spikes) > 0.5
60 | amps = amp_var * np.vstack((_amps >= 0.5, _amps < 0.5)).T
61 |
62 | spikes = np.dot(amps, self.cells).T
63 | spikes = spikes.astype(np.float32)
64 | spikes_dict['data'] = spikes[:, :, np.newaxis]
65 |
66 | pcs = ss.features.fetPCA(spikes_dict, ncomps=1)
67 | pcs = pcs['data']
68 | compare = ~np.logical_xor(pcs[:, 0].astype(int) + 1, _amps)
69 | correct = np.sum(compare)
70 | eq_(n_spikes, correct)
71 |
72 | def test_fetPCA_multichannel_labels(self):
73 | # tests whether features are properly labeled, based on the linearity
74 | # of PCA
75 |
76 | spike1, spike2 = self.cells
77 | ch0_spikes = np.vstack((spike1, spike2, spike1 + spike2)).T
78 | ch1_spikes = np.vstack((spike1*2, spike2/2., spike1*2 - spike2/2.)).T
79 |
80 | spikes = np.empty((ch0_spikes.shape[0], ch0_spikes.shape[1], 2))
81 | spikes[:, :, 0] = ch0_spikes
82 | spikes[:, :, 1] = ch1_spikes
83 | spikes_dict = {'data' : spikes}
84 |
85 | features = ss.features.fetPCA(spikes_dict, ncomps=3)
86 | names = features['names']
87 |
88 | ch0_feats = np.empty((3, 3))
89 | ch1_feats = np.empty((3, 3))
90 | # sort features by channel using labels
91 | for fidx in range(3):
92 | i0 = names.index('Ch%d:PC%d' % (0, fidx))
93 | i1 = names.index('Ch%d:PC%d' % (1, fidx))
94 | ch0_feats[fidx, :] = features['data'][:, i0]
95 | ch1_feats[fidx, :] = features['data'][:, i1]
96 |
97 | # spike combinations on channel 0
98 | almost_equal(ch0_feats[:, 2], ch0_feats[:, 0] + ch0_feats[:, 1])
99 | # spike combinations on channel 1
100 | almost_equal(ch1_feats[:, 2], ch1_feats[:, 0] - ch1_feats[:, 1])
101 |
102 | def test_getSpProjection(self):
103 | spikes_dict = self.spikes_dict.copy()
104 | cells = spikes_dict['data']
105 | spikes_dict['data'] = np.repeat(cells, 10, 1)
106 | labels = np.repeat([0, 1], 10, 0)
107 |
108 | feat = ss.features.fetSpProjection(spikes_dict, labels)
109 | ok_(((feat['data'][:, 0] > 0.5) == labels).all())
110 |
111 | def test_fetMarkers(self):
112 | spikes_dict = self.spikes_dict.copy()
113 |
114 | n_spikes = 200
115 | n_pts = self.cells.shape[1]
116 | amps = np.random.randint(1, 100, n_spikes)
117 | amps = amps[:, np.newaxis]
118 | spikes = amps * self.cells[0, :]
119 | spikes_dict['data'] = spikes.T[:, :, np.newaxis]
120 | time = spikes_dict['time']
121 | indices = np.array([0, n_pts/2, n_pts-1])
122 | values = ss.features.fetMarkers(spikes_dict, time[indices])
123 |
124 | ok_((values['data'] == spikes[:, indices]).all())
125 |
126 | def test_WT(self):
127 | "simple test for linearity of wavelet transform"
128 | spike1, spike2 = self.cells
129 | spike3 = 0.1 * spike1 + 0.7 * spike2
130 | spikes = np.vstack((spike1, spike2, spike3)).T
131 | spikes = spikes[:, :, np.newaxis] # WT only accepts 3D arrays
132 |
133 | wavelet = 'db3'
134 | wt = ss.features.WT(spikes, wavelet)
135 | wt1, wt2, wt3 = wt.squeeze().T
136 |
137 | almost_equal(wt3, 0.1 * wt1 + 0.7 * wt2)
138 |
139 | def test_fetWT_multichannel_labels(self):
140 | # tests whether features are properly labeled, based on linearity of
141 | # Wavelet Transform
142 |
143 | spike1, spike2 = self.cells
144 | ch0_spikes = np.vstack((spike1, spike2)).T
145 | ch1_spikes = np.vstack((spike1 + spike2, spike1 - spike2)).T
146 |
147 | spikes = np.empty((ch0_spikes.shape[0], ch0_spikes.shape[1], 2))
148 | spikes[:, :, 0] = ch0_spikes
149 | spikes[:, :, 1] = ch1_spikes
150 | spikes_dict = {'data' : spikes}
151 |
152 | features = ss.features.fetWT(spikes_dict, 3, wavelet='haar', select_method=None)
153 | names = features['names']
154 |
155 | ch0_feats = np.empty((3, 2))
156 | ch1_feats = np.empty((3, 2))
157 | # sort features by channel using labels
158 | for fidx in range(3):
159 | i0 = names.index('Ch%d:haarWC%d' % (0, fidx))
160 | i1 = names.index('Ch%d:haarWC%d' % (1, fidx))
161 | ch0_feats[fidx, :] = features['data'][:, i0]
162 | ch1_feats[fidx, :] = features['data'][:, i1]
163 |
164 | almost_equal(ch1_feats[:, 0], ch0_feats[:, 0] + ch0_feats[:, 1])
165 | almost_equal(ch1_feats[:, 1], ch0_feats[:, 0] - ch0_feats[:, 1])
166 |
167 | def test_fetWT_math(self):
168 | n_samples = 256
169 |
170 | # upsampled haar wavelet
171 | spike1 = np.hstack((np.ones(n_samples / 2), -1 * np.ones(n_samples / 2)))
172 | # upsampled haar scaling function
173 | spike2 = np.ones(n_samples)
174 |
175 | spikes = np.vstack((spike1, spike2)).T
176 | spikes = spikes[:, :, np.newaxis]
177 | spikes_dict = {'data' : spikes}
178 |
179 | features = ss.features.fetWT(spikes_dict, n_samples, wavelet='haar', select_method=None)
180 | idx = np.nonzero(features['data']) # nonzero indices
181 |
182 | # if nonzero elements are ONLY at (0,1) and (1,0),
183 | # this should be eye(2)
184 | eye = np.fliplr(np.vstack(idx))
185 |
186 | ok_((eye == np.eye(2)).all())
187 |
188 | def test_fetWT_selection(self):
189 | n_samples = 30
190 | n_channels = 2
191 | n_spikes = 50
192 | n_features = 10
193 | methods = [None, 'std', 'std_r', 'ks', 'dip', 'ksPCA', 'dipPCA']
194 |
195 | spikes = np.random.randn(n_samples, n_spikes, n_channels)
196 | spikes_dict = {'data' : spikes}
197 |
198 | shapes = [(n_spikes, n_features * n_channels)]
199 |
200 | for met in methods:
201 | wt = ss.features.fetWT(spikes_dict, n_features, wavelet='haar', select_method=met)
202 | shapes.append(wt['data'].shape)
203 |
204 | equal = lambda x, y: x == y and y or False
205 | success = bool(reduce(equal, shapes)) # returned shapes for all methods are correct
206 |
207 | ok_(success)
208 |
209 | def test_add_mask_decorator(self):
210 | spikes_dict = {'data': np.zeros((10, 2)),
211 | 'is_valid': np.zeros(2,)}
212 |
213 | fetIdentity = lambda x: {'data': x['data'], 'names': 'Identity'}
214 | deco_fet = ss.features.add_mask(fetIdentity)
215 | features = deco_fet(spikes_dict)
216 |
217 | ok_((features['is_valid'] == spikes_dict['is_valid']).all())
218 |
219 | def test_combine_features_without_mask(self):
220 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1']}
221 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2']}
222 | combined = ss.features.combine((feature1, feature2))
223 | ok_('is_valid' not in combined)
224 |
225 | def test_combine_features_with_one_mask(self):
226 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1']}
227 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2'],
228 | 'is_valid': np.ones(5, dtype=np.bool)}
229 | combined = ss.features.combine((feature1, feature2))
230 | ok_((combined['is_valid'] == feature2['is_valid']).all())
231 |
232 | def test_combine_features_with_different_masks(self):
233 | mask1 = np.ones(5, dtype=np.bool)
234 | mask1[-1] = False
235 | mask2 = mask1.copy()
236 | mask2[:2] = False
237 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1'],
238 | 'is_valid': mask1}
239 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2'],
240 | 'is_valid': mask2}
241 | combined = ss.features.combine((feature1, feature2))
242 | ok_((combined['is_valid'] == (mask1 & mask2)).all())
243 |
244 | def test_add_method_suffix(self):
245 | names = ['methA', 'methB', 'methA_1', 'methC', 'methA_2', 'methD']
246 |
247 | test1 = ss.features._add_method_suffix('methE', names) == 'methE'
248 | test2 = ss.features._add_method_suffix('methB', names) == 'methB_1'
249 | test3 = ss.features._add_method_suffix('methA', names) == 'methA_3'
250 |
251 | ok_(test1 and test2 and test3)
252 |
--------------------------------------------------------------------------------
/docs/source/tutorials/tutorial_manual.rst:
--------------------------------------------------------------------------------
1 | .. _lowlevel_tutorial:
2 |
3 | Using low-level interface
4 | ==========================
5 |
6 | .. testsetup::
7 |
8 | import numpy
9 | numpy.random.seed(1221)
10 | import os, shutil, tempfile
11 |
12 | temp_path = tempfile.mkdtemp()
13 | data_path = os.path.join(temp_path, 'data')
14 | os.mkdir(data_path)
15 | shutil.copyfile('../data/tutorial.h5', os.path.join(data_path, 'tutorial.h5'))
16 | os.chdir(temp_path)
17 |
18 | .. testcleanup::
19 |
20 | shutil.rmtree(temp_path)
21 |
22 |
23 | In this tutorial we will go deeper into the lower-level interface of
24 | SpikeSort: :py:mod:`spike_sort.core`. This interface is a bit more
25 | complex than the SpikeBeans (see :ref:`beans_tutorial`), but it offers
26 | more flexibility and allows you to embbedd SpikeSort in your own
27 | programs.
28 |
29 | To start this tutorial you will need:
30 |
31 | * working installation of SpikeSort
32 |
33 | * the sample :ref:`tutorial_data`
34 |
35 |
36 | 1. Read data
37 | ------------
38 |
39 |
40 | We will assume that you downloaded the sample data file :file:`tutorial.h5` and saved it to the :file:`data`
41 | directory.
42 |
43 | You can load this file using one of I/O fiters from
44 | :py:mod:`spike_sort.io.filter` module:
45 |
46 | .. doctest::
47 |
48 | >>> from spike_sort.io.filters import PyTablesFilter
49 | >>> dataset = '/SubjectA/session01/el1'
50 | >>> io_filter = PyTablesFilter('data/tutorial.h5')
51 | >>> raw = io_filter.read_sp(dataset)
52 |
53 | :py:data:`raw` is a dictionary which contains the raw data (in this case it is
54 | a pytables compressed array) under :py:attr:`data`
55 | key:
56 |
57 | .. doctest::
58 |
59 | >>> print raw['data']
60 | /SubjectA/session01/el1/raw (CArray(4, 23512500)) ''
61 |
62 | The size of the data is 23512500 samples in 4 independent channels (`contacts`
63 | in the tetrode).
64 |
65 | .. note::
66 |
67 | HDF5 are organised hierarchically and may contain multiple
68 | datasets. You can access the datasets via simple paths - in this
69 | case `/SubjectA/session01/el1` which means dataset of SubjectA
70 | recorded in session01 from el1
71 |
72 | 2. Detect spikes
73 | ----------------
74 |
75 |
76 | The first step of spike sorting is spike detection. It is usually done by
77 | thresholding the raw recordings. Let us use an automatic threshold on
78 | 4th contact i.e. index 3 (channel indexing always starts with 0!):
79 |
80 | .. doctest::
81 |
82 | >>> from spike_sort import extract
83 | >>> spt = extract.detect_spikes(raw, contact=3, thresh='auto')
84 |
85 | Let us see now how many events were detected:
86 |
87 | .. doctest::
88 |
89 | >>> print len(spt['data'])
90 | 16293
91 |
92 | We should make sure that all events are aligned to the same point of reference,
93 | for example, the maximum amplitude. To this end we first define a window
94 | around which spikes should be centered and then recalculate aligned event times:
95 |
96 | .. doctest::
97 |
98 | >>> sp_win = [-0.2, 0.8]
99 | >>> spt = extract.align_spikes(raw, spt, sp_win, type="max",
100 | ... resample=10)
101 |
102 | ``resample`` is optional - it enables upsampling (in this case 10-fold)
103 | of the original waveforms to obtain better resolution of event times.
104 |
105 | After spike detection and alignment we can finally extract the spike waveforms:
106 |
107 | .. doctest::
108 |
109 | >>> sp_waves = extract.extract_spikes(raw, spt, sp_win)
110 |
111 | The resulting structure is a dictionary whose :py:attr:`data` key is an array
112 | containing the spike waveshapes. Note that the array is three-dimensional and
113 | sizes of its dimensions reflect:
114 |
115 | * 1st dimmension: number of samples in each waveform,
116 | * 2nd: number of spikes,
117 | * 3rd: number of contacts
118 |
119 | .. doctest::
120 |
121 | >>> print sp_waves['data'].shape
122 | (25, 15537, 4)
123 |
124 | In practice, you do not to take care of such details. However, it is always
125 | a good idea to take a look at the obtained waveforms.
126 | :py:mod:`spike_sort.ui.plotting` module contains various functions which will
127 | help you to visualize the data. To plot waveshapes you can use :py:func:`plot_spikes` function from this module:
128 |
129 | .. doctest::
130 |
131 | >>> from spike_sort.ui import plotting
132 | >>> plotting.plot_spikes(sp_waves, n_spikes=200)
133 |
134 | .. figure:: images_manual/tutorial_spikes.png
135 |
136 | It is apparent from the plot that the spike waveforms of a few different cells
137 | and also some artifacts were detected. In order to separate these activities,
138 | in the next step we will perform *spike clustering*.
139 |
140 | 3. Calculate features
141 | ---------------------
142 |
143 | Before we can cluster spikes, we should calculate some characteristic features
144 | that may be used to differentiate between the waveshapes. Module
145 | :py:mod:`~spike_sort.core.features` defines several of such features, for example
146 | peak-to-peak amplitude (:py:func:`fetP2P`) and projections on principal
147 | components (:py:func:`fetPCA`). Now, we will calculate peak-to-peak amplitudes
148 | and PC projections on each of the contact, and then combine them into a single
149 | object:
150 |
151 | .. doctest::
152 |
153 | >>> from spike_sort import features
154 | >>> sp_feats = features.combine(
155 | ... (
156 | ... features.fetP2P(sp_waves),
157 | ... features.fetPCA(sp_waves)
158 | ... )
159 | ... )
160 |
161 | To help the user identify the features, abbreviated
162 | labels are assigned to all features:
163 |
164 | .. doctest::
165 |
166 | >>> print sp_feats['names']
167 | ['Ch0:P2P' 'Ch1:P2P' 'Ch2:P2P' 'Ch3:P2P' 'Ch0:PC0' 'Ch1:PC0' 'Ch2:PC0'
168 | 'Ch3:PC0' 'Ch0:PC1' 'Ch1:PC1' 'Ch2:PC1' 'Ch3:PC1']
169 |
170 | For examples feature ``Ch0:P2P`` denotes peak-to-peak amplitude in contact
171 | (channel) 0.
172 |
173 | Let us plot the two-dimensional
174 | projections of the feature space and histograms of features:
175 |
176 | .. doctest::
177 |
178 | >>> plotting.plot_features(sp_feats)
179 |
180 | .. figure:: images_manual/tutorial_features.png
181 |
182 | 4. Cluster spikes
183 | -----------------
184 |
185 | Finally, based on the calculated features we can perform spike clustering. This
186 | step is a little bit more complex and the best settings have to be identified
187 | using trial-and-error procedure.
188 |
189 | There are several automatic, semi-automatic and manual methods for clustering.
190 | They performance and accuracy depends to large degree on a particular dataset
191 | and recording setup. In SpikeSort you can choose from several available methods,
192 | whose names are given as the first argument of :py:func:`~spike_sort.core.cluster.cluster`
193 | method.
194 |
195 | We will start with an automatic clustering :py:func:`~spike_sort.core.cluster.gmm` , which requires only the feature object :py:data:`sp_feats` and number of clusters to identify.
196 | It attempts to find a mixture of gaussian distributions which approximates best the
197 | distribution of spike features (gaussian mixture model).
198 | Since we do not know, how many cells were picked up by the electrode we guess
199 | an initial number of clusters, which we can modify later on:
200 |
201 | .. doctest::
202 |
203 | >>> from spike_sort import cluster
204 | >>> clust_idx = cluster.cluster("gmm",sp_feats,4)
205 |
206 | The resulting data is just assigning a number (cluster index) to each spike from
207 | the feature array :py:data:`sp_feats`.
208 |
209 | You can use the plotting module to draw the
210 | feature vectors with color reflecting groups to which each spike was assigned:
211 |
212 | .. doctest::
213 |
214 | >>> plotting.plot_features(sp_feats, clust_idx)
215 |
216 | .. figure:: images_manual/tutorial_clusters.png
217 |
218 | or you can see the spike waveshapes:
219 |
220 | .. doctest::
221 |
222 | >>> plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200)
223 | >>> plotting.show()
224 |
225 | .. figure:: images_manual/tutorial_cells.png
226 |
227 | If you are not satisfied with the results or you think you might do better,
228 | you can also try manual sorting using cluster cutting method::
229 |
230 | >>> from spike_sort.ui import manual_sort
231 | >>> cluster_idx = manual_sort.show(features, sp_waves,
232 | ... ['Ch0:P2P','Ch3:P2P'],
233 | ... show_spikes=True)
234 |
235 | This function will open a window in which you can draw clusters of arbitrary
236 | shapes, but beware: you can draw only on two dimensional plane so that you
237 | are limited to only two features (``Ch0:P2P`` and ``Ch3:P2P`` in this case)!
238 |
239 | 5. Export data
240 | --------------
241 |
242 | Once you are done with spike sorting, you can export the results to a file.
243 | To this end you can use the same :py:mod:`~spike_sort.io.filters` module we used
244 | for reading. Here, we will save the spike times of a selected cell
245 | back to the file we read the data from.
246 |
247 | First, we need to extract the spike times
248 | of the discriminated cells:
249 |
250 | .. doctest::
251 |
252 | >>> spt_clust = cluster.split_cells(spt, clust_idx)
253 |
254 | It will create a dictionary whose keys are the cell labels pointing
255 | to spike times of the specific cell. For example, to extract spike
256 | times of cell 0:
257 |
258 | .. doctest::
259 |
260 | >>> print spt_clust[0]
261 | {'data': array([ 5.68152000e+02, 1.56978000e+03, 2.23985200e+03,
262 | ...
263 | 9.24276876e+05, 9.33539168e+05])}
264 |
265 |
266 | Then we may export them to the datafile:
267 |
268 | .. doctest::
269 |
270 | >>> from spike_sort.io import export
271 | >>> cell_template = dataset + '/cell{cell_id}'
272 | >>> export.export_cells(io_filter, cell_template, spt_clust, overwrite=True)
273 |
274 | This will create a new node in :file:`tutorial.h5` containing spike times of
275 | the discriminated cell ``/SubjectA/session01/el1/cell{1-4}``,
276 | which you can use for further analysis.
277 |
278 | Do not forget to close the I/O filter at the end of your analysis:
279 |
280 | .. doctest::
281 |
282 | >>> io_filter.close()
283 |
284 | Good luck!!!
285 |
286 |
287 |
--------------------------------------------------------------------------------
/docs/source/intro.rst:
--------------------------------------------------------------------------------
1 | Introduction
2 | ============
3 |
4 | What is spike sorting?
5 | ----------------------
6 |
7 | *Spike sorting is a class of techniques used in the analysis of
8 | electrophysiological data. Spike sorting algorithms use the
9 | shape(s) of waveforms collected with one or more electrodes in the
10 | brain to distinguish the activity of one or more neurons from
11 | background electrical noise.*
12 |
13 | Source: Spike sorting, Wikipedia_
14 |
15 | Spike sorting usually consists of the following steps [Quiroga2004]_:
16 |
17 | 1. Spike detection (*detect*)
18 |
19 | Spikes are very rapid and often sparse events, so that they appear
20 | only a tiny fraction of the recordings. To achieve a sort of
21 | compression and easy the subsequent analysis, the times of
22 | occurrence of putative spikes are first identified in the continuous
23 | by means of thresholding. This method return only events that cross
24 | a specified threshold (selected based on some data statistics, such
25 | as standard deviation, or visually).
26 |
27 | #. Spike waveform extraction (*extract*)
28 |
29 | Based on the series of spike times identified in the previous step,
30 | we now may proceed to extract the spike waveforms by taking a small
31 | segment of the signal around each spike time. Such segments are
32 | usually automatically aligned to a specific feature of the
33 | waveform, for example maximum or minimum.
34 |
35 | #. Feature extraction (*feature*)
36 |
37 | This is one of the most important steps in which the silent
38 | features of the spikes such as peak-to-peak amplitude or spike
39 | width are calculated based on spike waveshapes. The features should
40 | be preferably low-dimensional and should well differentiate spikes
41 | of different cells and noise.
42 |
43 | #. Clustering (*cluster*)
44 |
45 | At the heart of spike sorting is the clustering that uses
46 | automatic, semi-automatic or manual methods to identify groups of
47 | spikes belonging to the same cell (often called unit). The
48 | procedure is usually applied in n-dimensional space of spike
49 | features, where each feature is a single dimension. Since it is
50 | very difficult to do that visually for more than 2 features (on a
51 | plane), one usually resorts to different clustring algorithms, such
52 | as K-means or Gaussian Mixture Models that can handle large number
53 | of dimension.
54 |
55 |
56 | #. Sorting evaluation (*evaluate*)
57 |
58 | After sorting is done and we have determined spike times of
59 | different units, we have to make sure that the quality of sorting
60 | is sufficient for further analysis. There are multiple visual and
61 | statistical methods that help in this evaluations, such as
62 | inter-spike intervals histograms, waveforms overlays,
63 | cross-corellograms etc. [Hill2011]_.
64 |
65 | SpikeSort is a comprehensive library whose goal is to accompany you trough
66 | the entire process of spike sorting - from detection to evaluation.
67 | It is not fully automatic or on-line sorting program. Although we
68 | include lots of functions, that help to automatize some of the
69 | repetitive task, good sorting will always require human supervision.
70 |
71 | Design goals
72 | ------------
73 |
74 | There are several of spike sorting programs on the market both commercial and free
75 | and open source. We started this projects to offer an alternative that
76 | would be free, scriptable and preferably in our favourite language:
77 | Python. However, we did not try to re-invent the
78 | wheel and whenever we could we leveraged the established libraries.
79 |
80 | We had a few design goals when we worked on SpikeSort;
81 |
82 | * modular
83 |
84 | Spike sorting is modular - there are several steps and different
85 | techniques that can and
86 | should be mixed-and-matched to adjust the process to the data we
87 | try to analyze. Spike sorting library should be modular as well to allow for
88 | much flexibility. This is achieved by composing the library of
89 | independent components that can be easily inserted into or deleted
90 | from the spike sorting workflow.
91 |
92 | * customizable
93 |
94 | No two data sets are the same. Different experimental protocols,
95 | different acquisition systems, different neural systems result in
96 | different properties of the data and thus require different methods.
97 | Therefore, a good spike sorting library should allow for easy and
98 | flexible customizations of the algorithms used at each stage of spike
99 | sorting and their parameters.
100 |
101 | * easy-to-use
102 |
103 | Flexibility usually comes at price: the usability. Nevertheless, a
104 | transparent design can allow complex systems to be user-friendly.
105 | The interface should allow to focus on the data and not on the
106 | peculiarities of specific software solutions.
107 |
108 | * fast
109 |
110 | In practice, one will try to discriminate thousands of spikes of
111 | tens of different neurons. Any performance optimizations will save
112 | you precious minutes (hours or even days) for the task you are most
113 | interested in: decoding what the cells actually do.
114 |
115 | * compatible with standard libraries (NumPy, SciPy, matplotlib)
116 |
117 | There is no need to reinvent the wheel. Python developers provide
118 | hundreds of optimized, well-tested and widely-used libraries with
119 | great community support. Why not use them? Moreover, any data coming
120 | from the spike sorting libraries should be easily pluggable into
121 | third-party analysis routines and databases.
122 |
123 | Although still much work is required to meet all the goals, we kept
124 | them all in mind while designing the SpikeSort. As a result, SpikeSort
125 | is already an usable and powerful framework that will help you to
126 | get most of your data.
127 |
128 | Why Python?
129 | -----------
130 |
131 | Python is a interpreted and very dynamic language with huge very
132 | enthusiastic community. It grew to be the de-facto standard in
133 | computational science [Langtangen2009]_ and it rapidly gains momentum in
134 | experimental disciplines [Hanke2009]_. Python is also completely free and available
135 | for multiple platforms - it means that you can run it at home, at work
136 | or give it to your students with no additional costs. Last but not
137 | least Python can be easily interfaced with other languages making it a
138 | ''glueing'' language that can make two independent libraries to
139 | communicate.
140 |
141 | Installation
142 | ------------
143 |
144 | You can download the most recent release of SpikeSort from github::
145 |
146 | git clone git://github.com/btel/SpikeSort.git
147 |
148 | In order to install SortSpike you need following libraries:
149 |
150 | * python 2.6 or 2.7
151 | * setuptools
152 | * scipy
153 | * numpy
154 | * pytables
155 | * matplotlib (only for plotting)
156 |
157 | Optional dependencies are:
158 |
159 | * scikits.learn - clustering algorithms
160 | * neurotools - spike train analysis
161 | * ipython - enhanced python shell
162 |
163 | If some of the python packages are not available on your system you
164 | can install them with easy-install::
165 |
166 | easy_install numpy scipy pytables matplotlib
167 |
168 | .. note::
169 |
170 | If you are not familiar with Python packaging system we recommend
171 | you installing a complete Python distribution from a company called
172 | Enthought: `EPD `_
173 | (there are free academic licenses). Installers for Windows, MacOSX
174 | and Linux are available.
175 |
176 | If you have the above libraries you can install SpikeSort simply
177 | issuing the command::
178 |
179 | python setup.py install
180 |
181 | If you prefer to install it in your home directory you may try::
182 |
183 | python setup.py install --user
184 |
185 | but remember to add :file:`$HOME/.local/lib/python2.6/site-packages` to your python
186 | path.
187 |
188 | After a successful installation you can run the supplied tests::
189 |
190 | python setup.py nosetests
191 |
192 | If you don't have all the optional dependencies, be prepared for some
193 | tests errors.
194 |
195 | Examples
196 | --------
197 |
198 | In :file:`examples/sorting` subdirectory you will find some sample scripts,
199 | which use SpikeSort for spike sorting
200 |
201 | * :file:`cluster_manual.py` - sort spikes by manual cluster cutting
202 | * :file:`cluster_auto.py` - automatically cluster with GMM (Gaussian
203 | Mixture Models) algorithm (see our tutorial
204 | :ref:`lowlevel_tutorial`)
205 | * :file:`cluster_beans.py` - run full stack spike-sorting evnvironment
206 | and show spikes in a spike browser (see our tutorial :ref:`beans_tutorial`)
207 |
208 | In order to run these examples, you need to download :ref:`tutorial_data` and define an environment variable ``DATAPATH``::
209 |
210 | export DATAPATH=/path/to/data/directory
211 |
212 | where ``/path/to/data/directory`` points to the directory where you
213 | downloaded the data file.
214 |
215 | Once you have the tutorial data, you may run above script, for example::
216 |
217 | python -i cluster_auto.py
218 |
219 | .. note::
220 |
221 | The ``-i`` in python command will leave a Python interpreter open
222 | for interactive exploration - read more in our tutorials.
223 |
224 | Similar software
225 | ----------------
226 |
227 | There are a few open source packages for spike sorting that have
228 | different design and use case:
229 |
230 | * `spyke `_
231 |
232 | * `OpenElectrophy `_
233 |
234 | * `spikepy `_
235 |
236 | * `Klusters `_
237 |
238 |
239 |
240 |
241 | References
242 | ----------
243 |
244 | .. _tutorial.h5: https://github.com/btel/SpikeSort/releases/download/v0.12/tutorial.h5
245 |
246 | .. _Wikipedia: http://en.wikipedia.org/wiki/Spike_sorting
247 |
248 | .. [Quiroga2004] Quiroga, RQ, Z. Nadasdy, Y. Ben-Shaul, and others. *“Unsupervised Spike Detection and Sorting with Wavelets and Superparamagnetic Clustering.”* Neural Computation **16**, no. 8 (2004):1661. ``_
249 |
250 | .. [Hill2011] Hill, Daniel N, Samar B Mehta, and David Kleinfeld. *“Quality Metrics to Accompany Spike Sorting of Extracellular Signals.”* The Journal of Neuroscience **31**, no. 24 (2011): 8699-8705. ``_
251 |
252 | .. [Hanke2009] Hanke, Michael, Yaroslav O. Halchenko, Per B. Sederberg, Emanuele Olivetti, Ingo Fründ, Jochem W. Rieger, Christoph S. Herrmann, James V. Haxby, Stephen José Hanson, and Stefan Pollmann. *“PyMVPA: a Unifying Approach to the Analysis of Neuroscientific Data.”* Frontiers in Neuroinformatics **3** (2009): 3.
253 |
254 | .. [Langtangen2009] Langtangen, Hans Petter. *Python Scripting for Computational Science*. 3rd ed. Springer, 2009.
255 |
256 |
--------------------------------------------------------------------------------
/src/spike_sort/core/extract.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding=utf-8
3 |
4 | from warnings import warn
5 | import operator
6 |
7 | from scipy import interpolate
8 | import numpy as np
9 |
10 |
11 | def split_cells(spikes, idx, which='all'):
12 | """Return the spike features splitted into separate cells
13 | """
14 |
15 | if which == 'all':
16 | classes = np.unique(idx)
17 | else:
18 | classes = which
19 |
20 | data = spikes['data']
21 | time = spikes['time']
22 | spikes_dict = dict([(cl, {'data': data[:, idx == cl, :], 'time': time})
23 | for cl in classes])
24 | return spikes_dict
25 |
26 |
27 | def remove_spikes(spt_dict, remove_dict, tolerance):
28 | """Remove spikes with given spike times from the spike time
29 | structure """
30 | spt_data = spt_dict['data']
31 | spt_remove = remove_dict['data']
32 |
33 | mn, mx = tolerance
34 |
35 | for t in spt_remove:
36 | spt_data = spt_data[(spt_data > (t + mx)) | (spt_data < (t + mn))]
37 |
38 | spt_ret = spt_dict.copy()
39 | spt_ret['data'] = spt_data
40 | return spt_ret
41 |
42 |
43 | def detect_spikes(spike_data, thresh='auto', edge="rising",
44 | contact=0):
45 | r"""Detects spikes in extracellular data using amplitude thresholding.
46 |
47 | Parameters
48 | ----------
49 | spike_data : dict
50 | extracellular waveforms
51 | thresh : float or 'auto'
52 | threshold for detection. if thresh is 'auto' it will be
53 | estimated from the data.
54 | edge : {'rising', 'falling'}
55 | which edge to trigger on
56 | contact : int, optional
57 | index of tetrode contact to use for detection, defaults to
58 | first contact
59 |
60 | Returns
61 | -------
62 | spt_dict : dict
63 | dictionary with 'data' key which contains detected threshold
64 | crossing in miliseconds
65 |
66 | """
67 |
68 | sp_data = spike_data['data'][contact, :]
69 |
70 | FS = spike_data['FS']
71 |
72 | edges = ('rising', 'max', 'falling', 'min')
73 | if isinstance(thresh, basestring):
74 | if thresh == 'auto':
75 | thresh_frac = 8.0
76 | else:
77 | thresh_frac = float(thresh)
78 |
79 | thresh = thresh_frac * np.sqrt(sp_data[:10 * FS].var(dtype=np.float64))
80 | if edge in edges[2:]:
81 | thresh = -thresh
82 |
83 | if edge not in edges:
84 | raise TypeError("'edge' parameter must be 'rising' or 'falling'")
85 |
86 | op1, op2 = operator.lt, operator.gt
87 |
88 | if edge in edges[2:]:
89 | op1, op2 = op2, op1
90 |
91 | i, = np.where(op1(sp_data[:-1], thresh) & op2(sp_data[1:], thresh))
92 |
93 | spt = i * 1000.0 / FS
94 | return {'data': spt, 'thresh': thresh, 'contact': contact}
95 |
96 |
97 | def filter_spt(spike_data, spt_dict, sp_win):
98 | spt = spt_dict['data']
99 | sp_data = spike_data['data']
100 | FS = spike_data['FS']
101 |
102 | try:
103 | n_pts = sp_data.shape[1]
104 | except IndexError:
105 | n_pts = len(sp_data)
106 | max_time = n_pts * 1000.0 / FS
107 |
108 | t_min = np.max((-sp_win[0], 0))
109 | t_max = np.min((max_time, max_time - sp_win[1]))
110 | idx, = np.nonzero((spt >= t_min) & (spt <= t_max))
111 | return idx
112 |
113 |
114 | def extract_spikes(spike_data, spt_dict, sp_win,
115 | resample=1, contacts='all'):
116 | """Extract spikes from recording.
117 |
118 | Parameters
119 | ----------
120 | spike_data : dict
121 | extracellular data (see :ref:`raw_recording`)
122 | spt : dict
123 | spike times structure (see :ref:`spike_times`)
124 | sp_win : list of int
125 | temporal extent of the wave shape
126 |
127 | Returns
128 | -------
129 | wavedict : dict
130 | spike waveforms structure (see :ref:`spike_wave`)
131 |
132 |
133 | """
134 | sp_data = spike_data['data']
135 | n_contacts = spike_data['n_contacts']
136 |
137 | if contacts == "all":
138 | contacts = np.arange(n_contacts)
139 | elif isinstance(contacts, int):
140 | contacts = np.array([contacts])
141 | else:
142 | contacts = np.asarray(contacts)
143 |
144 | FS = spike_data['FS']
145 | spt = spt_dict['data']
146 | idx = np.arange(len(spt))
147 | inner_idx = filter_spt(spike_data, spt_dict, sp_win)
148 | outer_idx = idx[~np.in1d(idx, inner_idx)]
149 |
150 | indices = (spt / 1000.0 * FS).astype(np.int32)
151 | win = (np.asarray(sp_win) / 1000.0 * FS).astype(np.int32)
152 | time = np.arange(win[1] - win[0]) * 1000.0 / FS + sp_win[0]
153 | n_contacts, n_pts = sp_data.shape
154 |
155 | # auxiliary function to find a valid spike window within data range
156 | minmax = lambda x: np.max([np.min([n_pts, x]), 0])
157 | spWave = np.zeros((len(time), len(spt), len(contacts)),
158 | dtype=np.float32)
159 |
160 | for i in inner_idx:
161 | sp = indices[i]
162 | spWave[:, i, :] = np.atleast_2d(sp_data[contacts,
163 | sp + win[0]:sp + win[1]]).T
164 |
165 | for i in outer_idx:
166 | sp = indices[i]
167 | l, r = map(minmax, sp + win)
168 | if l != r:
169 | spWave[(l - sp) - win[0]:(r - sp) - win[0], i, :] = \
170 | sp_data[contacts, l:r].T
171 |
172 | wavedict = {"data": spWave, "time": time, "FS": FS}
173 |
174 | if len(idx) != len(inner_idx):
175 | is_valid = np.zeros(len(spt), dtype=np.bool)
176 | is_valid[inner_idx] = True
177 | wavedict['is_valid'] = is_valid
178 |
179 | if resample != 1:
180 | warn("resample argument is deprecated."
181 | "Please update your code to use function"
182 | "resample_spikes", DeprecationWarning)
183 | wavedict = resample_spikes(wavedict, FS * resample)
184 | return wavedict
185 |
186 |
187 | def resample_spikes(spikes_dict, FS_new):
188 | """Upsample spike waveforms using spline interpolation"""
189 |
190 | sp_waves = spikes_dict['data']
191 | time = spikes_dict['time']
192 | FS = spikes_dict['FS']
193 |
194 | resamp_time = np.arange(time[0], time[-1], 1000.0 / FS_new)
195 | n_pts, n_spikes, n_contacts = sp_waves.shape
196 |
197 | spike_resamp = np.empty((len(resamp_time), n_spikes, n_contacts))
198 |
199 | for i in range(n_spikes):
200 | for contact in range(n_contacts):
201 | tck = interpolate.splrep(time, sp_waves[:, i, contact], s=0)
202 | spike_resamp[:, i, contact] = interpolate.splev(resamp_time,
203 | tck, der=0)
204 |
205 | return {"data": spike_resamp, "time": resamp_time, "FS": FS}
206 |
207 |
208 | def align_spikes(spike_data, spt_dict, sp_win, type="max", resample=1,
209 | contact=0, remove=True):
210 | """Aligns spike waves and returns corrected spike times
211 |
212 | Parameters
213 | ----------
214 | spike_data : dict
215 | spt_dict : dict
216 | sp_win : list of int
217 | type : {'max', 'min'}, optional
218 | resample : int, optional
219 | contact : int, optional
220 | remove : bool, optiona
221 |
222 | Returns
223 | -------
224 | ret_dict : dict
225 | spike times of aligned spikes
226 |
227 | """
228 |
229 | tol = 0.1
230 |
231 | if (sp_win[0] > -tol) or (sp_win[1] < tol):
232 | warn('You are using very short sp_win. '
233 | 'This may lead to alignment problems.')
234 |
235 | spt = spt_dict['data'].copy()
236 |
237 | idx_align = np.arange(len(spt))
238 |
239 | #go in a loop until all spikes are correctly aligned
240 | iter_id = 0
241 | while len(idx_align) > 0:
242 | spt_align = {'data': spt[idx_align]}
243 | spt_inbound = filter_spt(spike_data, spt_align, sp_win)
244 | idx_align = idx_align[spt_inbound]
245 | sp_waves_dict = extract_spikes(spike_data, spt_align, sp_win,
246 | resample=resample, contacts=contact)
247 |
248 | sp_waves = sp_waves_dict['data'][:, spt_inbound, 0]
249 | time = sp_waves_dict['time']
250 |
251 | if type == "max":
252 | i = sp_waves.argmax(0)
253 | elif type == "min":
254 | i = sp_waves.argmin(0)
255 |
256 | #move spike markers
257 | shift = time[i]
258 | spt[idx_align] += shift
259 |
260 | #if spike maximum/minimum was at the edge we have to extract it at the
261 | # new marker and repeat the alignment
262 |
263 | idx_align = idx_align[(shift < (sp_win[0] + tol)) |
264 | (shift > (sp_win[1] - tol))]
265 | iter_id += 1
266 |
267 | ret_dict = {'data': spt}
268 |
269 | if remove:
270 | #remove double spikes
271 | FS = spike_data['FS']
272 | ret_dict = remove_doubles(ret_dict, 1000.0 / FS)
273 |
274 | return ret_dict
275 |
276 |
277 | def remove_doubles(spt_dict, tol):
278 | new_dict = spt_dict.copy()
279 | spt = spt_dict['data']
280 |
281 | isi = np.diff(spt)
282 | intisi = (isi/tol).astype('int')
283 |
284 | if len(spt) > 0:
285 | spt = spt[np.concatenate(([True], intisi > 1))]
286 |
287 | new_dict['data'] = spt
288 | return new_dict
289 |
290 |
291 | def merge_spikes(spike_waves1, spike_waves2):
292 | """Merges two sets of spike waves
293 |
294 | Parameters
295 | ----------
296 |
297 | spike_waves1 : dict
298 | spike_waves2 : dict
299 | spike wavefroms to merge; both spike wave sets must be defined
300 | within the same time window and with the same sampling
301 | frequency
302 |
303 | Returns
304 | -------
305 |
306 | spike_waves : dict
307 | merged spike waveshapes
308 |
309 | clust_idx : array
310 | labels denoting to which set the given spike originally belonged to
311 | """
312 |
313 | sp_data1 = spike_waves1['data']
314 | sp_data2 = spike_waves2['data']
315 |
316 | sp_data = np.hstack((sp_data1, sp_data2))
317 | spike_waves = spike_waves1.copy()
318 | spike_waves['data'] = sp_data
319 |
320 | clust_idx = np.concatenate((np.ones(sp_data1.shape[1]),
321 | np.zeros(sp_data2.shape[1])))
322 |
323 | return spike_waves, clust_idx
324 |
325 |
326 | def merge_spiketimes(spt1, spt2, sort=True):
327 | """Merges two sets of spike times
328 |
329 | Parameters
330 | ----------
331 | spt1 : dict
332 | spt2 : dict
333 | sort : bool, optional
334 | False if you don't want to be the spike times sorted.
335 |
336 | Returns
337 | -------
338 | spt : dict
339 | dictionary with merged spike time arrrays under data key
340 | clust_idx : array
341 | labels denoting to which set the given spike originally belonged
342 | to
343 |
344 | """
345 |
346 | spt_data1 = spt1['data']
347 | spt_data2 = spt2['data']
348 |
349 | spt_data = np.concatenate((spt_data1, spt_data2))
350 | i = spt_data.argsort()
351 | spt_data = spt_data[i]
352 |
353 | clust_idx = np.concatenate((np.ones(spt_data1.shape[0]),
354 | np.zeros(spt_data2.shape[0])))
355 | clust_idx = clust_idx[i]
356 | spt_dict = {"data": spt_data}
357 |
358 | return spt_dict, clust_idx
359 |
--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import spike_sort as ss
3 |
4 | from nose.tools import ok_, eq_, raises
5 | from numpy.testing import assert_array_almost_equal as almost_equal
6 |
7 | import os
8 |
9 | class TestFilter(object):
10 | def __init__(self):
11 | self.n_spikes = 100
12 |
13 | self.FS = 25E3
14 | self.period = 1000.0 / self.FS * 100
15 | self.time = np.arange(0, self.period * self.n_spikes, 1000.0 / self.FS)
16 | self.spikes = np.sin(2 * np.pi / self.period * self.time)[np.newaxis, :]
17 | self.spk_data = {"data": self.spikes, "n_contacts": 1, "FS": self.FS}
18 |
19 | def test_filter_proxy(self):
20 | sp_freq = 1000.0 / self.period
21 | filter = ss.filters.Filter(sp_freq * 0.5, sp_freq * 0.4, 1, 10, 'ellip')
22 | spk_filt = ss.filters.filter_proxy(self.spk_data, filter)
23 | ok_(self.spk_data['data'].shape == spk_filt['data'].shape)
24 |
25 | def test_remove_proxy_files_after_exit(self):
26 | filter_func = lambda x, f: x
27 | spk_filt = ss.filters.filter_proxy(self.spk_data, filter_func)
28 | fname = spk_filt['data']._v_file.filename
29 | ss.filters.clean_after_exit()
30 | assert not os.path.isfile(fname)
31 |
32 |
33 | def test_LiearIIR_detect(self):
34 | n_spikes = self.n_spikes
35 | period = self.period
36 | threshold = 0.5
37 | sp_freq = 1000.0 / period
38 |
39 | self.spk_data['data'] += 2
40 | spk_filt = ss.filters.fltLinearIIR(self.spk_data, sp_freq * 0.5, sp_freq * 0.4, 1, 10, 'ellip')
41 |
42 | spt = ss.extract.detect_spikes(spk_filt, thresh=threshold)
43 | ok_(len(spt['data']) == n_spikes)
44 |
45 | class TestExtract(object):
46 | def __init__(self):
47 | self.n_spikes = 100
48 |
49 | self.FS = 25E3
50 | self.period = 1000.0 / self.FS * 100
51 | self.time = np.arange(0, self.period * self.n_spikes, 1000.0 / self.FS)
52 | self.spikes = np.sin(2 * np.pi / self.period * self.time)[np.newaxis, :]
53 | self.spk_data = {"data": self.spikes, "n_contacts": 1, "FS": self.FS}
54 |
55 | def test_detect(self):
56 | n_spikes = self.n_spikes
57 | period = self.period
58 | FS = self.FS
59 | time = self.time
60 | spikes = self.spikes
61 | threshold = 0.5
62 | crossings_real = period / 12.0 + np.arange(n_spikes) * period
63 | spt = ss.extract.detect_spikes(self.spk_data, thresh=threshold)
64 | ok_((np.abs(spt['data'] - crossings_real) <= 1000.0 / FS).all())
65 |
66 | def test_align(self):
67 | #check whether spikes are correctly aligned to maxima
68 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes))
69 | thr_crossings = self.period * (1 / 6.0 + np.arange(self.n_spikes))
70 | spt_dict = {"data": thr_crossings}
71 | sp_win = [-self.period / 6.0, self.period / 3.0]
72 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win)
73 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all())
74 |
75 | def test_align_short_win(self):
76 | #test spike alignment with windows shorter than total spike duration
77 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes))
78 | thr_crossings = self.period * (1 / 6.0 + np.arange(self.n_spikes))
79 | spt_dict = {"data": thr_crossings}
80 | sp_win = [-self.period / 24.0, self.period / 12.0]
81 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win)
82 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all())
83 |
84 | def test_align_edge(self):
85 | #???
86 | spikes = np.sin(2 * np.pi / self.period * self.time + np.pi / 2.0)[np.newaxis, :]
87 | maxima_idx = self.period * (np.arange(1, self.n_spikes + 1))
88 | thr_crossings = self.period * (-1 / 6.0 + np.arange(1, self.n_spikes + 1))
89 | spt_dict = {"data": thr_crossings}
90 | sp_win = [-self.period / 24.0, self.period / 12.0]
91 | spk_data = {"data": spikes, "n_contacts": 1, "FS": self.FS}
92 | spt = ss.extract.align_spikes(spk_data, spt_dict, sp_win)
93 | last = spt['data'][-1]
94 | ok_((last >= (self.time[-1] - sp_win[1])) & (last <= self.time[-1]))
95 |
96 | def test_align_double_spikes(self):
97 | #double detections of the same spike should be removed
98 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes))
99 | thr_crossings = self.period * (1 / 6.0 + np.arange(0, self.n_spikes, 0.5))
100 | spt_dict = {"data": thr_crossings}
101 | sp_win = [-self.period / 24.0, self.period / 12.0]
102 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win)
103 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all())
104 |
105 | def test_remove_doubles_roundoff(self):
106 | # remove_doubles should account for slight variations around
107 | # the 'tolerance' which may occur due to some round-off errors
108 | # during spike detection/alignment. This won't affect any useful
109 | # information, because the tolerance is one sample large in this
110 | # test.
111 |
112 | tol = 1000.0 / self.FS # [ms], one sample
113 | data = [1.0, 1.0 + tol + tol * 0.01]
114 | spt_dict = {"data": np.array(data)}
115 | clean_spt_dict = ss.extract.remove_doubles(spt_dict, tol)
116 | ok_(len(clean_spt_dict['data']) == 1) # duplicate removed
117 |
118 | def test_extract(self):
119 | zero_crossing = self.period * np.arange(self.n_spikes)
120 | zero_crossing += 1000.0 / self.FS / 2.0 # move by half a sample to avoid round-off errors
121 | spt_dict = {"data": zero_crossing}
122 | sp_win = [0, self.period]
123 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win)
124 | ref_sp = np.sin(2 * np.pi / self.period * sp_waves['time'])
125 | ok_((np.abs(ref_sp[:, np.newaxis] - sp_waves['data'][:, :, 0]) < 1E-6).all())
126 | #ok_((np.abs(sp_waves['data'][:,:,0].mean(1)-ref_sp)<2*1000*np.pi/(self.FS*self.period)).all())
127 | #ok_(np.abs(np.sum(sp_waves['data'][:,:,0].mean(1)-ref_sp))<1E-6)
128 |
129 | #def test_extract_resample_deprecation(self):
130 | # zero_crossing = self.period*np.arange(self.n_spikes)
131 | # spt_dict = {"data":zero_crossing}
132 | # sp_win = [0, self.period]
133 | # with warnings.catch_warnings(True) as w:
134 | # sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win,
135 | # resample=2.)
136 | # ok_(len(w)>=1)
137 |
138 | def test_extract_and_resample(self):
139 | zero_crossing = self.period * np.arange(self.n_spikes)
140 | zero_crossing += 1000.0 / self.FS / 2.0 # move by half a sample to avoid round-off errors
141 | spt_dict = {"data": zero_crossing}
142 | sp_win = [0, self.period]
143 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win)
144 | sp_resamp = ss.extract.resample_spikes(sp_waves, self.FS * 2)
145 | ref_sp = np.sin(2 * np.pi / self.period * sp_resamp['time'])
146 | ok_((np.abs(ref_sp[:, np.newaxis] - sp_resamp['data'][:, :, 0]) < 1E-6).all())
147 |
148 | def test_mask_of_truncated_spikes(self):
149 | zero_crossing = self.period * np.arange(self.n_spikes + 1)
150 | spt_dict = {"data": zero_crossing}
151 | sp_win = [0, self.period]
152 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win)
153 | correct_mask = np.ones(self.n_spikes + 1, np.bool)
154 | correct_mask[-1] = False
155 | ok_((sp_waves['is_valid'] == correct_mask).all())
156 | #ok_(np.abs(np.sum(sp_waves['data'][:,:,0].mean(1)-ref_sp))<1E-6)
157 |
158 | def test_extract_truncated_spike_end(self):
159 | zero_crossing = np.array([self.period * (self.n_spikes - 0.5)])
160 | spt_dict = {"data": zero_crossing}
161 | sp_win = [0, self.period]
162 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win)
163 | ref_sp = -np.sin(2 * np.pi / self.period * sp_waves['time'])
164 | ref_sp[len(ref_sp) / 2:] = 0
165 | almost_equal(sp_waves['data'][:, 0, 0], ref_sp)
166 |
167 | def test_extract_truncated_spike_end(self):
168 | zero_crossing = np.array([-self.period * 0.5])
169 | spt_dict = {"data": zero_crossing}
170 | sp_win = [0, self.period]
171 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win)
172 | ref_sp = -np.sin(2 * np.pi / self.period * sp_waves['time'])
173 | ref_sp[:len(ref_sp) / 2] = 0
174 | almost_equal(sp_waves['data'][:, 0, 0], ref_sp)
175 |
176 | def test_filter_spt(self):
177 | #out of band spikes should be removed
178 | zero_crossing = self.period * (np.arange(self.n_spikes))
179 | spt_dict = {"data": zero_crossing}
180 | sp_win = [0, self.period]
181 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win)
182 | ok_(len(spt_filt) == self.n_spikes)
183 |
184 | def test_filter_spt_shorten_left(self):
185 | #remove out-of-band spikes from the beginning of the train
186 | zero_crossing = self.period * (np.arange(self.n_spikes))
187 | spt_dict = {"data": zero_crossing}
188 | sp_win = [-self.period / 8, self.period / 8.0]
189 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win)
190 | ok_(len(spt_filt) == (self.n_spikes - 1))
191 |
192 | def test_filter_spt_shorten_right(self):
193 | #remove out-of-band spikes from the end of the train
194 | zero_crossing = self.period * (np.arange(self.n_spikes))
195 | spt_dict = {"data": zero_crossing}
196 | sp_win = [0, self.period + 1000.0 / self.FS]
197 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win)
198 | ok_(len(spt_filt) == (self.n_spikes - 1))
199 |
200 | class TestCluster(object):
201 | """test clustering algorithms"""
202 |
203 | def _cmp_bin_partitions(self, cl1, cl2):
204 | return (~np.logical_xor(cl1, cl2)).all() or (np.logical_xor(cl1, cl2)).all()
205 |
206 | def setup(self):
207 |
208 | self.K = 2
209 |
210 | n_dim = 2
211 | pts_in_clust = 100
212 | np.random.seed(1234)
213 | data = np.vstack((np.random.rand(pts_in_clust, n_dim),
214 | np.random.rand(pts_in_clust, n_dim) + 2 * np.ones(n_dim)))
215 | self.labels = np.concatenate((np.zeros(pts_in_clust, dtype=int),
216 | np.ones(pts_in_clust, dtype=int)))
217 | feature_labels = ["feat%d" % i for i in range(n_dim)]
218 | self.features = {"data": data, "names": feature_labels}
219 |
220 | def test_k_means(self):
221 | """test own k-means algorithm"""
222 |
223 | cl = ss.cluster.cluster('k_means', self.features, self.K)
224 | ok_(self._cmp_bin_partitions(cl, self.labels))
225 |
226 | def test_k_means_plus(self):
227 | """test scikits k-means plus algorithm"""
228 |
229 | cl = ss.cluster.cluster('k_means_plus', self.features, self.K)
230 | ok_(self._cmp_bin_partitions(cl, self.labels))
231 |
232 | def test_gmm(self):
233 | """test gmm clustering algorithm"""
234 |
235 | cl = ss.cluster.cluster('gmm', self.features, self.K)
236 | ok_(self._cmp_bin_partitions(cl, self.labels))
237 |
238 | def test_random(self):
239 | cl = np.random.rand(len(self.labels)) > 0.5
240 | ok_(~self._cmp_bin_partitions(cl, self.labels))
241 |
242 | @raises(NotImplementedError)
243 | def test_method_notimplemented(self):
244 | cl = ss.cluster.cluster("notimplemented", self.features)
245 |
--------------------------------------------------------------------------------