├── tests ├── __init__.py ├── samples │ └── axonio.abf ├── test_stats.py ├── test_beans.py ├── utils.py ├── test_io.py ├── test_features.py └── test_core.py ├── src ├── spike_analysis │ ├── __init__.py │ ├── xcorr.py │ ├── io_tools.py │ ├── dashboard.py │ └── basic.py ├── spike_beans │ ├── __init__.py │ └── base.py └── spike_sort │ ├── io │ ├── __init__.py │ ├── export.py │ └── neo_filters.py │ ├── ui │ ├── __init__.py │ ├── _mpl_helpers.py │ ├── zoomer.py │ ├── manual_sort.py │ └── plotting.py │ ├── stats │ ├── __init__.py │ ├── diptst │ │ ├── diptst.pyf │ │ └── diptst.f │ └── tests.py │ ├── __init__.py │ └── core │ ├── __init__.py │ ├── filters.py │ ├── evaluate.py │ ├── cluster.py │ └── extract.py ├── setup.cfg ├── docs ├── source │ ├── _static │ │ └── logo.png │ ├── tutorials │ │ ├── images_beans │ │ │ ├── waves.png │ │ │ ├── features.png │ │ │ ├── browser_zoom.png │ │ │ ├── browser_nozoom.png │ │ │ ├── waves_2_deleted.png │ │ │ └── waves_one_left.png │ │ ├── images_manual │ │ │ ├── tutorial_cells.png │ │ │ ├── tutorial_spikes.png │ │ │ ├── tutorial_clusters.png │ │ │ └── tutorial_features.png │ │ ├── index.rst │ │ ├── tutorial_io.rst │ │ └── tutorial_manual.rst │ ├── modules │ │ ├── components.rst │ │ ├── evaluate.rst │ │ ├── ui.rst │ │ ├── extract.rst │ │ ├── cluster.rst │ │ ├── features.rst │ │ └── io.rst │ ├── _themes │ │ ├── flask │ │ │ ├── theme.conf │ │ │ ├── layout.html │ │ │ ├── relations.html │ │ │ └── static │ │ │ │ ├── small_flask.css │ │ │ │ └── flasky.css_t │ │ ├── README │ │ ├── LICENSE │ │ └── flask_theme_support.py │ ├── pyplots │ │ ├── tutorial_spikes.py │ │ ├── tutorial_features.py │ │ ├── tutorial_cells.py │ │ └── tutorial_clusters.py │ ├── index.rst │ ├── datafiles.rst │ ├── conf.py │ ├── datastructures.rst │ └── intro.rst ├── README └── Makefile ├── requirements.txt ├── .gitignore ├── readthedocs.yml ├── data ├── gollum.inf ├── gollum_export.inf └── gen_tutorial_data.py ├── README ├── examples ├── analysis │ ├── cell_dashboard.py │ └── cell_xcorr.py └── sorting │ ├── browse_data.py │ ├── read_axon.py │ ├── cluster_auto.py │ ├── cluster_manual.py │ └── cluster_beans.py ├── .travis.yml ├── conda-environment.yml ├── LICENSE └── setup.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spike_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spike_beans/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/spike_sort/io/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spike_sort/ui/__init__.py: -------------------------------------------------------------------------------- 1 | from plotting import * -------------------------------------------------------------------------------- /src/spike_sort/stats/__init__.py: -------------------------------------------------------------------------------- 1 | from tests import * 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | verbosity=2 3 | with-doctest=1 4 | tests=tests 5 | 6 | -------------------------------------------------------------------------------- /tests/samples/axonio.abf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/tests/samples/axonio.abf -------------------------------------------------------------------------------- /docs/source/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/_static/logo.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/waves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/features.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.4 2 | matplotlib >= 1.1.0 3 | tables >= 2.3.0 4 | scipy >= 0.9.0 5 | neo>=0.2.0 6 | scikits.learn 7 | pywavelets 8 | -------------------------------------------------------------------------------- /src/spike_sort/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | from core import * 5 | 6 | from ui import plotting 7 | import io 8 | -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/browser_zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/browser_zoom.png -------------------------------------------------------------------------------- /src/spike_sort/core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | __all__ = ["extract", "features", "filters", "cluster", "evaluate"] 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *.pyc 4 | *.so 5 | docs/build/ 6 | build/ 7 | /dist 8 | /src/SpikeSort.egg-info/**/* 9 | /src/SpikeSort.egg-info 10 | -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/browser_nozoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/browser_nozoom.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/waves_2_deleted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves_2_deleted.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_beans/waves_one_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_beans/waves_one_left.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_manual/tutorial_cells.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_cells.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_manual/tutorial_spikes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_spikes.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_manual/tutorial_clusters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_clusters.png -------------------------------------------------------------------------------- /docs/source/tutorials/images_manual/tutorial_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btel/SpikeSort/HEAD/docs/source/tutorials/images_manual/tutorial_features.png -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | formats: 2 | - none 3 | 4 | conda: 5 | file: conda-environment.yml 6 | 7 | python: 8 | version: 2 9 | setup_py_install: true 10 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. _tutorials-section: 2 | 3 | Tutorials 4 | ========= 5 | 6 | .. toctree:: 7 | 8 | tutorial_beans 9 | tutorial_manual 10 | tutorial_io 11 | -------------------------------------------------------------------------------- /docs/source/modules/components.rst: -------------------------------------------------------------------------------- 1 | Components (:mod:`spike_beans.components`) 2 | ========================================== 3 | 4 | .. automodule:: spike_beans.components 5 | :show-inheritance: 6 | 7 | -------------------------------------------------------------------------------- /docs/source/_themes/flask/theme.conf: -------------------------------------------------------------------------------- 1 | [theme] 2 | inherit = basic 3 | stylesheet = flasky.css 4 | pygments_style = flask_theme_support.FlaskyStyle 5 | 6 | [options] 7 | index_logo = '' 8 | index_logo_height = 120px 9 | touch_icon = 10 | -------------------------------------------------------------------------------- /data/gollum.inf: -------------------------------------------------------------------------------- 1 | { 2 | "fspike":"{subject}/{ses_id}/{ses_id}-{el_id}{contact_id}.sp", 3 | "cell": "{subject}/spt/{ses_id}-{el_id}-{cell_id}.spt", 4 | "stim": "{subject}/{ses_id}/{ses_id}-10.spt", 5 | "n_contacts":4, 6 | "dirname":"{DATAPATH}", 7 | "FS":25000 8 | } 9 | -------------------------------------------------------------------------------- /data/gollum_export.inf: -------------------------------------------------------------------------------- 1 | { 2 | "fspike":"{subject}/{ses_id}/{ses_id}-{el_id}{contact_id}.sp", 3 | "cell": "spike_sort/{ses_id}-{el_id}-{cell_id}.spt", 4 | "stim": "{subject}/{ses_id}/{ses_id}-10.spt", 5 | "n_contacts":4, 6 | "dirname":"{DATAPATH}", 7 | "FS":25000 8 | } 9 | -------------------------------------------------------------------------------- /data/gen_tutorial_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | from spike_sort.io.filters import PyTablesFilter, BakerlabFilter 5 | 6 | in_dataset = "/Gollum/s5gollum01/el3" 7 | out_dataset = "/SubjectA/session01/el1/raw" 8 | 9 | in_filter = BakerlabFilter("gollum.inf") 10 | out_filter = PyTablesFilter("tutorial.h5") 11 | 12 | sp = in_filter.read_sp(in_dataset) 13 | out_filter.write_sp(sp, out_dataset) 14 | 15 | in_filter.close() 16 | out_filter.close() 17 | -------------------------------------------------------------------------------- /docs/source/_themes/flask/layout.html: -------------------------------------------------------------------------------- 1 | {%- extends "basic/layout.html" %} 2 | 3 | {% block sidebarlogo %} {% endblock %} 5 | {% block header %} 6 | {{ super() }} 7 | {% endblock %} 8 | {% block relbar2 %} {% endblock %} 9 | 10 | {%- block footer %} 11 | 15 | 16 | {%- endblock %} 17 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | Spike sorting library implemented in Python/NumPy/PyTables 2 | ---------------------------------------------------------- 3 | 4 | Project website: spikesort.org 5 | 6 | Requirements: 7 | 8 | * Python >= 2.6 9 | * PyTables 10 | * NumPy 11 | * matplotlib 12 | 13 | Optional: 14 | 15 | * scikits.learn -- clustering algorithms 16 | * neurotools -- spike train analysis 17 | 18 | Test dependencies: 19 | 20 | * all the above 21 | * hdf5-tools 22 | 23 | To see the library in actions see examples folder. 24 | -------------------------------------------------------------------------------- /docs/source/modules/evaluate.rst: -------------------------------------------------------------------------------- 1 | Evaluate (:mod:`spike_sort.evaluate`) 2 | ===================================== 3 | 4 | .. currentmodule:: spike_sort.core.evaluate 5 | 6 | 7 | Utility functions to evaluate the quality of spike sorting 8 | 9 | .. autosummary:: 10 | 11 | snr_spike 12 | snr_clust 13 | detect_noise 14 | calc_noise_threshold 15 | isolation_score 16 | calc_isolation_score 17 | 18 | 19 | Reference 20 | --------- 21 | 22 | .. automodule:: spike_sort.core.evaluate 23 | :members: 24 | -------------------------------------------------------------------------------- /examples/analysis/cell_dashboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import os 5 | import matplotlib.pyplot as plt 6 | 7 | from spike_analysis import dashboard 8 | from spike_sort.io.filters import BakerlabFilter 9 | 10 | if __name__ == "__main__": 11 | cell = "/Gollum/s39gollum03/el1/cell1" 12 | path = os.path.join(os.pardir, os.pardir, 'data', 'gollum_export.inf') 13 | filt = BakerlabFilter(path) 14 | 15 | dashboard.show_cell(filt, cell) 16 | 17 | plt.show() 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /docs/source/modules/ui.rst: -------------------------------------------------------------------------------- 1 | User Interface (:mod:`spike_sort.ui`) 2 | ===================================== 3 | 4 | .. currentmodule:: spike_sort.ui 5 | 6 | Plotting (:mod:`spike_sort.ui.plotting`) 7 | ---------------------------------------- 8 | 9 | This module provides basic plotting capabilities using matplotlib 10 | library. 11 | 12 | .. autosummary:: 13 | 14 | ~plotting.plot_spikes 15 | ~plotting.plot_features 16 | 17 | 18 | Reference 19 | --------- 20 | 21 | .. automodule:: spike_sort.ui.plotting 22 | :members: 23 | :undoc-members: 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | virtualenv: 5 | system_site_packages: true 6 | before_install: 7 | - sudo apt-get update -qq 8 | - sudo apt-get install libhdf5-serial-dev python-numpy python-scipy libatlas-dev liblapack-dev gfortran python-tk python-matplotlib hdf5-tools python-tables cython tk-dev 9 | # command to install dependencies 10 | install: 11 | - "pip install numexpr cython --use-mirrors" 12 | - "pip install -r requirements.txt --use-mirrors" 13 | - "pip install ." 14 | # command to run tests 15 | script: nosetests -P 16 | -------------------------------------------------------------------------------- /docs/README: -------------------------------------------------------------------------------- 1 | The documentation was created using Sphinx. You may create a HTML 2 | version of the documentation using: 3 | 4 | make html 5 | 6 | Alternatively, you can call your sphinx-build tool manually running the 7 | following command from docs directory: 8 | 9 | sphinx-build-2.6 -b html -d build/doctrees source build/html 10 | 11 | Note: In order to compile figures in tutorials you need to download 12 | tutorial.h5 file and copy it to spike_sort/data directory. In 13 | addition, matplotlib has to be installed at the time of compilation 14 | (it is also a dependency of spike_sort). 15 | -------------------------------------------------------------------------------- /docs/source/_themes/flask/relations.html: -------------------------------------------------------------------------------- 1 |

Related Topics

2 | 20 | -------------------------------------------------------------------------------- /docs/source/modules/extract.rst: -------------------------------------------------------------------------------- 1 | Extract (:mod:`spike_sort.core.extract`) 2 | ======================================== 3 | 4 | This module is a collection of functions for pre-processing of raw 5 | recordings (filtering and upsampling), detecting spikes and extracting 6 | spikes based on detected spike times. 7 | 8 | .. currentmodule:: spike_sort.core.extract 9 | 10 | .. autosummary:: 11 | 12 | align_spikes 13 | detect_spikes 14 | extract_spikes 15 | filter_proxy 16 | merge_spikes 17 | merge_spiketimes 18 | remove_spikes 19 | resample_spikes 20 | split_cells 21 | 22 | Reference 23 | --------- 24 | 25 | .. automodule:: spike_sort.core.extract 26 | :members: 27 | :undoc-members: 28 | -------------------------------------------------------------------------------- /examples/analysis/cell_xcorr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import os 5 | import itertools 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from spike_sort.io.filters import BakerlabFilter 10 | from spike_analysis import io_tools, xcorr 11 | 12 | cell_pattern = "/Gollum/s4gollum*/el*/cell*" 13 | path = os.path.join(os.pardir, os.pardir, 'data', 'gollum.inf') 14 | filt = BakerlabFilter(path) 15 | 16 | if __name__ == "__main__": 17 | cell_nodes = io_tools.list_cells(filt, cell_pattern) 18 | cells = [io_tools.read_dataset(filt, node) for node in cell_nodes] 19 | for data, name in itertools.izip(cells, cell_nodes): 20 | data['dataset'] = name 21 | xcorr.show_xcorr(cells) 22 | plt.show() 23 | 24 | 25 | -------------------------------------------------------------------------------- /examples/sorting/browse_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | """ 5 | Simple raw data browser. 6 | 7 | Keyboard shortcuts: 8 | 9 | +/- - zoom in/out 10 | """ 11 | 12 | import spike_sort as sort 13 | from spike_sort.io.filters import PyTablesFilter 14 | from spike_sort.ui import spike_browser 15 | import os 16 | 17 | DATAPATH = os.environ['DATAPATH'] 18 | 19 | if __name__ == "__main__": 20 | dataset = "/SubjectA/session01/el1" 21 | data_fname = os.path.join(DATAPATH, "tutorial.h5") 22 | 23 | io_filter = PyTablesFilter(data_fname) 24 | sp = io_filter.read_sp(dataset) 25 | spt = sort.extract.detect_spikes(sp, contact=3, thresh='auto') 26 | 27 | spike_browser.browse_data_tk(sp, spt, win=50) 28 | -------------------------------------------------------------------------------- /docs/source/pyplots/tutorial_spikes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from spike_sort.io.filters import PyTablesFilter 5 | from spike_sort import extract 6 | from spike_sort import features 7 | from spike_sort import cluster 8 | from spike_sort.ui import plotting 9 | import os 10 | 11 | dataset = '/SubjectA/session01/el1' 12 | datapath = '../../../data/tutorial.h5' 13 | 14 | io_filter = PyTablesFilter(datapath) 15 | raw = io_filter.read_sp(dataset) 16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto') 17 | 18 | sp_win = [-0.2, 0.8] 19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10) 20 | sp_waves = extract.extract_spikes(raw, spt, sp_win) 21 | plotting.plot_spikes(sp_waves, n_spikes=200) 22 | plotting.show() 23 | io_filter.close() 24 | -------------------------------------------------------------------------------- /docs/source/tutorials/tutorial_io.rst: -------------------------------------------------------------------------------- 1 | .. _io_tutorial: 2 | 3 | Reading custom data formats 4 | =========================== 5 | 6 | .. testsetup:: 7 | 8 | .. testcleanup:: 9 | 10 | SpikeSort provides a flexible interface with various data sources via 11 | so call input-output `Filters`. The code available for download already 12 | contains a few filters that allow for reading (and in some cases writing) 13 | of several data formats, including: 14 | 15 | * raw binary data, 16 | * HDF5 files, 17 | * proprietary data formats (such as Axon `.abf`) supported by 18 | `Neo `_ library. 19 | 20 | In this tutorial we show you how to use available filters to read and plot data 21 | from `.abf` files and next we will convience you how easy it is to add support 22 | for custom formats by defining your own `Filter`. 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/spike_sort/io/export.py: -------------------------------------------------------------------------------- 1 | def export_cells(io_filter, node_templ, spike_times, overwrite=False): 2 | """Export discriminated spike times of all cells to a file. 3 | 4 | Parameters 5 | ---------- 6 | io_filter : object, 7 | read/write filter object (see :py:mod:`spike_sort.io.filters`) 8 | node_templ : string 9 | string identifing the dataset name. It will be passed to 10 | IOFilters.write_spt method. It can contain the 11 | `{cell_id}` placeholder that will be substituted by cell 12 | identifier. 13 | spt_dict : dict 14 | dictionary in which keys are the cell IDs and values are spike 15 | times structures 16 | """ 17 | for cell_id, spt_cell in spike_times.items(): 18 | dataset = node_templ.format(cell_id=cell_id) 19 | io_filter.write_spt(spt_cell, dataset, overwrite=overwrite) 20 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. SortSpikes documentation master file, created by 2 | sphinx-quickstart on Thu Jan 27 14:09:31 2011. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to SpikesSort's documentation! 7 | ====================================== 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | intro 15 | tutorials/index 16 | datastructures 17 | datafiles 18 | 19 | Modules: 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | modules/extract 25 | modules/features 26 | modules/cluster 27 | modules/evaluate 28 | modules/io 29 | modules/ui 30 | 31 | modules/components 32 | 33 | 34 | 35 | Indices and tables 36 | ================== 37 | 38 | * :ref:`genindex` 39 | * :ref:`modindex` 40 | * :ref:`search` 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /docs/source/pyplots/tutorial_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from spike_sort.io.filters import PyTablesFilter 5 | from spike_sort import extract 6 | from spike_sort import features 7 | from spike_sort import cluster 8 | from spike_sort.ui import plotting 9 | import os 10 | 11 | dataset = '/SubjectA/session01/el1' 12 | datapath = '../../../data/tutorial.h5' 13 | 14 | io_filter = PyTablesFilter(datapath) 15 | raw = io_filter.read_sp(dataset) 16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto') 17 | 18 | sp_win = [-0.2, 0.8] 19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10) 20 | sp_waves = extract.extract_spikes(raw, spt, sp_win) 21 | sp_feats = features.combine( 22 | ( 23 | features.fetP2P(sp_waves), 24 | features.fetPCA(sp_waves) 25 | ) 26 | ) 27 | 28 | plotting.plot_features(sp_feats) 29 | plotting.show() 30 | -------------------------------------------------------------------------------- /docs/source/datafiles.rst: -------------------------------------------------------------------------------- 1 | Sample data files 2 | ================= 3 | 4 | 5 | These files are provided only for demonstration purposes. Authors do 6 | not hold any responsibility for using these file for other purposes 7 | than testing and learning the functions of SpikeSort. Unless stated 8 | otherwise, use of these datasets in publications is allowed only with 9 | written consent of the authors. 10 | 11 | 12 | .. _tutorial_data: 13 | 14 | Tutorial data 15 | ------------- 16 | 17 | `tutorial.h5 `_ 18 | 19 | This is a single dataset of extracellular spikes recorded with a 20 | tetrode in stimulus-driven paradigm. For more information see 21 | [Telenczuk2011]_. 22 | 23 | .. [Telenczuk2011] Telenczuk, Bartosz, Stuart N Baker, Andreas V M Herz, and Gabriel Curio. *“High-frequency EEG Covaries with Spike Burst Patterns Detected in Cortical Neurons.”* Journal of Neurophysiology **105**, no. 6 (2011): 2951–2959. 24 | 25 | -------------------------------------------------------------------------------- /docs/source/pyplots/tutorial_cells.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from spike_sort.io.filters import PyTablesFilter 5 | from spike_sort import extract 6 | from spike_sort import features 7 | from spike_sort import cluster 8 | from spike_sort.ui import plotting 9 | import os 10 | 11 | dataset = '/SubjectA/session01/el1' 12 | datapath = '../../../data/tutorial.h5' 13 | 14 | io_filter = PyTablesFilter(datapath) 15 | raw = io_filter.read_sp(dataset) 16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto') 17 | 18 | sp_win = [-0.2, 0.8] 19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10) 20 | sp_waves = extract.extract_spikes(raw, spt, sp_win) 21 | sp_feats = features.combine( 22 | ( 23 | features.fetP2P(sp_waves), 24 | features.fetPCA(sp_waves) 25 | ) 26 | ) 27 | clust_idx = cluster.cluster("gmm",sp_feats,4) 28 | plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200) 29 | plotting.show() 30 | io_filter.close() 31 | -------------------------------------------------------------------------------- /docs/source/pyplots/tutorial_clusters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from spike_sort.io.filters import PyTablesFilter 5 | from spike_sort import extract 6 | from spike_sort import features 7 | from spike_sort import cluster 8 | from spike_sort.ui import plotting 9 | import os 10 | 11 | dataset = '/SubjectA/session01/el1' 12 | datapath = '../../../data/tutorial.h5' 13 | 14 | io_filter = PyTablesFilter(datapath) 15 | raw = io_filter.read_sp(dataset) 16 | spt = extract.detect_spikes(raw, contact=3, thresh='auto') 17 | 18 | sp_win = [-0.2, 0.8] 19 | spt = extract.align_spikes(raw, spt, sp_win, type="max", resample=10) 20 | sp_waves = extract.extract_spikes(raw, spt, sp_win) 21 | sp_feats = features.combine( 22 | ( 23 | features.fetP2P(sp_waves), 24 | features.fetPCA(sp_waves) 25 | ) 26 | ) 27 | 28 | clust_idx = cluster.cluster("gmm",sp_feats,4) 29 | plotting.plot_features(sp_feats, clust_idx) 30 | plotting.show() 31 | io_filter.close() 32 | -------------------------------------------------------------------------------- /docs/source/modules/cluster.rst: -------------------------------------------------------------------------------- 1 | Cluster (:mod:`spike_sort.cluster`) 2 | =================================== 3 | 4 | .. currentmodule:: spike_sort.core.cluster 5 | 6 | Module with clustering algorithms. 7 | 8 | Utility functions 9 | ----------------- 10 | 11 | Spike sorting is usually done with the 12 | :py:func:`~spike_sort.core.cluster.cluster` function which takes as an 13 | argument one of the clustering methods (as a string). 14 | 15 | Others functions help to manipulate the results: 16 | 17 | .. autosummary:: 18 | 19 | cluster 20 | split_cells 21 | 22 | 23 | 24 | Clustering methods 25 | ------------------ 26 | 27 | Several different clustering methods are defined in the module. Each 28 | method should take at least one argument -- the features structure. 29 | 30 | .. autosummary:: 31 | 32 | k_means_plus 33 | gmm 34 | manual 35 | none 36 | k_means 37 | 38 | 39 | Reference 40 | --------- 41 | 42 | .. automodule:: spike_sort.core.cluster 43 | :members: 44 | -------------------------------------------------------------------------------- /examples/sorting/read_axon.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | matplotlib.use("TkAgg") 5 | matplotlib.interactive(True) 6 | 7 | from spike_beans import components, base 8 | from spike_sort.io import neo_filters 9 | 10 | #################################### 11 | # Adjust these fields for your needs 12 | 13 | sp_win = [-0.6, 0.8] 14 | 15 | url = 'https://portal.g-node.org/neo/axon/File_axon_1.abf' 16 | path = 'file_axon.abf' 17 | 18 | import urllib 19 | urllib.urlretrieve(url, path) 20 | 21 | io = neo_filters.NeoSource(path) 22 | 23 | base.register("SignalSource", io) 24 | base.register("SpikeMarkerSource", 25 | components.SpikeDetector(contact=0, 26 | thresh='auto', 27 | type='max', 28 | sp_win=sp_win, 29 | resample=1, 30 | align=True)) 31 | base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) 32 | 33 | browser = components.SpikeBrowser() 34 | 35 | browser.show() 36 | -------------------------------------------------------------------------------- /tests/test_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spike_sort import stats 3 | 4 | from nose.tools import ok_ 5 | 6 | class TestStats(object): 7 | def setup(self): 8 | n_samples = 1000 9 | 10 | # unimodal 11 | normal = np.random.randn(n_samples) 12 | self.normal = (normal - normal.mean()) / normal.std() 13 | uniform = np.random.uniform(size = n_samples) 14 | self.uniform = (uniform - uniform.mean()) / uniform.std() 15 | laplace = np.random.laplace(size = n_samples) 16 | self.laplace = (laplace - laplace.mean()) / laplace.std() 17 | 18 | # multimodal 19 | multimodal = np.hstack((self.normal, self.uniform + 4, self.laplace + 8)) 20 | np.random.shuffle(multimodal) 21 | multimodal = multimodal[:n_samples] 22 | self.multimodal = (multimodal - multimodal.mean()) / multimodal.std() 23 | 24 | def test_multimodality_detection(self): 25 | data = [self.normal, self.uniform, self.laplace] 26 | tests = [stats.dip, stats.ks] 27 | 28 | detected = [test(self.multimodal) > test(dist) for dist in data for test in tests] 29 | 30 | ok_(np.array(detected).all()) 31 | -------------------------------------------------------------------------------- /src/spike_sort/stats/diptst/diptst.pyf: -------------------------------------------------------------------------------- 1 | ! -*- f90 -*- 2 | ! Note: the context of this file is case sensitive. 3 | 4 | python module _diptst ! in 5 | interface ! in :_diptst 6 | subroutine diptst1(x,n,dip,xl,xu,ifault,gcm,lcm,mn,mj,ddx,ddxsgn) ! in :_diptst:diptst.f 7 | real dimension(n),intent(in) :: x 8 | integer optional,check(len(x)>=n),intent(hide),depend(x) :: n=len(x) 9 | real intent(out) :: dip 10 | real intent(out) :: xl 11 | real intent(out) :: xu 12 | integer intent(out) :: ifault 13 | integer intent(hide,cache),dimension(n),depend(n) :: gcm 14 | integer intent(hide,cache),dimension(n),depend(n) :: lcm 15 | integer intent(hide,cache),dimension(n),depend(n) :: mn 16 | integer intent(hide,cache),dimension(n),depend(n) :: mj 17 | real intent(hide,cache),dimension(n),depend(n) :: ddx 18 | integer intent(hide,cache),dimension(n),depend(n) :: ddxsgn 19 | end subroutine diptst1 20 | end interface 21 | end python module _diptst 22 | 23 | ! This file was auto-generated with f2py (version:2). 24 | ! See http://cens.ioc.ee/projects/f2py2e/ 25 | -------------------------------------------------------------------------------- /conda-environment.yml: -------------------------------------------------------------------------------- 1 | name: spikesort-rtd 2 | dependencies: 3 | - alabaster=0.7.7=py27_0 4 | - babel=2.2.0=py27_0 5 | - cairo=1.12.18=6 6 | - cycler=0.10.0=py27_0 7 | - docutils=0.12=py27_0 8 | - fontconfig=2.11.1=5 9 | - freetype=2.5.5=0 10 | - hdf5=1.8.15.1=2 11 | - jinja2=2.8=py27_0 12 | - libgfortran=1.0=0 13 | - libpng=1.6.17=0 14 | - libxml2=2.9.2=0 15 | - markupsafe=0.23=py27_0 16 | - matplotlib=1.5.1=np110py27_0 17 | - mkl=11.3.1=0 18 | - numexpr=2.4.6=np110py27_1 19 | - numpy=1.10.4=py27_0 20 | - numpydoc=0.5=py27_1 21 | - openssl=1.0.2g=0 22 | - pip=8.0.3=py27_0 23 | - pixman=0.32.6=0 24 | - pycairo=1.10.0=py27_0 25 | - pygments=2.1.1=py27_0 26 | - pyparsing=2.0.3=py27_0 27 | - pyqt=4.11.4=py27_1 28 | - pytables=3.2.2=np110py27_0 29 | - python=2.7.11=0 30 | - python-dateutil=2.4.2=py27_0 31 | - pytz=2015.7=py27_0 32 | - qt=4.8.7=1 33 | - readline=6.2=2 34 | - scikit-learn=0.17.1=np110py27_0 35 | - scipy=0.17.0=np110py27_1 36 | - setuptools=20.1.1=py27_0 37 | - sip=4.16.9=py27_0 38 | - six=1.10.0=py27_0 39 | - snowballstemmer=1.2.1=py27_0 40 | - sphinx=1.3.5=py27_0 41 | - sphinx_rtd_theme=0.1.9=py27_0 42 | - sqlite=3.9.2=0 43 | - tk=8.5.18=0 44 | - wheel=0.29.0=py27_0 45 | - zlib=1.2.8=0 46 | - pip: 47 | - pywavelets==0.4.0 48 | -------------------------------------------------------------------------------- /docs/source/_themes/README: -------------------------------------------------------------------------------- 1 | Flask Sphinx Styles 2 | =================== 3 | 4 | This repository contains sphinx styles for Flask and Flask related 5 | projects. To use this style in your Sphinx documentation, follow 6 | this guide: 7 | 8 | 1. put this folder as _themes into your docs folder. Alternatively 9 | you can also use git submodules to check out the contents there. 10 | 2. add this to your conf.py: 11 | 12 | sys.path.append(os.path.abspath('_themes')) 13 | html_theme_path = ['_themes'] 14 | html_theme = 'flask' 15 | 16 | The following themes exist: 17 | 18 | - 'flask' - the standard flask documentation theme for large 19 | projects 20 | - 'flask_small' - small one-page theme. Intended to be used by 21 | very small addon libraries for flask. 22 | 23 | The following options exist for the flask_small theme: 24 | 25 | [options] 26 | index_logo = '' filename of a picture in _static 27 | to be used as replacement for the 28 | h1 in the index.rst file. 29 | index_logo_height = 120px height of the index logo 30 | github_fork = '' repository name on github for the 31 | "fork me" badge 32 | -------------------------------------------------------------------------------- /docs/source/_themes/flask/static/small_flask.css: -------------------------------------------------------------------------------- 1 | /* 2 | * small_flask.css_t 3 | * ~~~~~~~~~~~~~~~~~ 4 | * 5 | * :copyright: Copyright 2010 by Armin Ronacher. 6 | * :license: Flask Design License, see LICENSE for details. 7 | */ 8 | 9 | body { 10 | margin: 0; 11 | padding: 20px 30px; 12 | } 13 | 14 | div.documentwrapper { 15 | float: none; 16 | background: white; 17 | } 18 | 19 | div.sphinxsidebar { 20 | display: block; 21 | float: none; 22 | width: 102.5%; 23 | margin: 50px -30px -20px -30px; 24 | padding: 10px 20px; 25 | background: #333; 26 | color: white; 27 | } 28 | 29 | div.sphinxsidebar h3, div.sphinxsidebar h4, div.sphinxsidebar p, 30 | div.sphinxsidebar h3 a { 31 | color: white; 32 | } 33 | 34 | div.sphinxsidebar a { 35 | color: #aaa; 36 | } 37 | 38 | div.sphinxsidebar p.logo { 39 | display: none; 40 | } 41 | 42 | div.document { 43 | width: 100%; 44 | margin: 0; 45 | } 46 | 47 | div.related { 48 | display: block; 49 | margin: 0; 50 | padding: 10px 0 20px 0; 51 | } 52 | 53 | div.related ul, 54 | div.related ul li { 55 | margin: 0; 56 | padding: 0; 57 | } 58 | 59 | div.footer { 60 | display: none; 61 | } 62 | 63 | div.bodywrapper { 64 | margin: 0; 65 | } 66 | 67 | div.body { 68 | min-height: 0; 69 | padding: 0; 70 | } 71 | -------------------------------------------------------------------------------- /docs/source/modules/features.rst: -------------------------------------------------------------------------------- 1 | Calculate features (:mod:`spike_sort.features`) 2 | =============================================== 3 | 4 | .. currentmodule:: spike_sort.core.features 5 | 6 | Provides functions to calculate spike waveforms features. 7 | 8 | .. _features_doc: 9 | 10 | Features 11 | -------- 12 | 13 | Functions starting with `fet` implement various features calculated 14 | from the spike waveshapes. They have usually one required argument 15 | :ref:`spike_wave` structure and some can have optional arguments (see 16 | below) 17 | 18 | Each of the function returns a mapping structure (dictionary) with the following keys: 19 | 20 | * `data` -- an array of shape (n_spikes x n_features) 21 | * `names` -- a list of length n_features with feature labels 22 | 23 | The following features are implemented: 24 | 25 | .. autosummary:: 26 | 27 | fetPCA 28 | fetP2P 29 | fetSpIdx 30 | fetSpTime 31 | fetSpProjection 32 | 33 | Tools 34 | ----- 35 | 36 | This module provides a few tools to facilitate working with features 37 | data structure: 38 | 39 | .. autosummary:: 40 | 41 | split_cells 42 | select 43 | combine 44 | normalize 45 | 46 | Auxiliary 47 | --------- 48 | .. autosummary:: 49 | 50 | PCA 51 | add_mask 52 | 53 | 54 | Reference 55 | --------- 56 | 57 | .. automodule:: spike_sort.core.features 58 | :members: 59 | -------------------------------------------------------------------------------- /examples/sorting/cluster_auto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | """ 5 | Based on raw recordings detect spikes, calculate features and do automatic 6 | clustering with gaussian mixture models. 7 | """ 8 | 9 | import os 10 | 11 | import spike_sort as sort 12 | from spike_sort.io.filters import PyTablesFilter 13 | import spike_sort.ui.manual_sort 14 | 15 | DATAPATH = os.environ['DATAPATH'] 16 | 17 | if __name__ == "__main__": 18 | h5_fname = os.path.join(DATAPATH, "tutorial.h5") 19 | h5filter = PyTablesFilter(h5_fname, 'a') 20 | 21 | dataset = "/SubjectA/session01/el1" 22 | sp_win = [-0.2, 0.8] 23 | 24 | sp = h5filter.read_sp(dataset) 25 | spt = sort.extract.detect_spikes(sp, contact=3, thresh='auto') 26 | 27 | spt = sort.extract.align_spikes(sp, spt, sp_win, type="max", resample=10) 28 | sp_waves = sort.extract.extract_spikes(sp, spt, sp_win) 29 | features = sort.features.combine( 30 | (sort.features.fetP2P(sp_waves), 31 | sort.features.fetPCA(sp_waves)), 32 | norm=True 33 | ) 34 | 35 | 36 | clust_idx = sort.cluster.cluster("gmm",features,4) 37 | 38 | spike_sort.ui.plotting.plot_features(features, clust_idx) 39 | spike_sort.ui.plotting.figure() 40 | spike_sort.ui.plotting.plot_spikes(sp_waves, clust_idx,n_spikes=200) 41 | 42 | 43 | spike_sort.ui.plotting.show() 44 | h5filter.close() 45 | -------------------------------------------------------------------------------- /src/spike_sort/ui/_mpl_helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | from matplotlib.axes import Axes 5 | from matplotlib.ticker import NullLocator 6 | from matplotlib.projections import register_projection 7 | import matplotlib.axis as maxis 8 | 9 | class NoTicksXAxis(maxis.XAxis): 10 | def reset_ticks(self): 11 | self._lastNumMajorTicks = 1 12 | self._lastNumMinorTicks = 1 13 | def set_clip_path(self, clippath, transform=None): 14 | pass 15 | 16 | class NoTicksYAxis(maxis.YAxis): 17 | def reset_ticks(self): 18 | self._lastNumMajorTicks = 1 19 | self._lastNumMinorTicks = 1 20 | def set_clip_path(self, clippath, transform=None): 21 | pass 22 | 23 | class ThinAxes(Axes): 24 | """Thin axes without spines and ticks to accelerate axes creation""" 25 | 26 | name = 'thin' 27 | 28 | def _init_axis(self): 29 | self.xaxis = NoTicksXAxis(self) 30 | self.yaxis = NoTicksYAxis(self) 31 | 32 | def cla(self): 33 | """ 34 | Override to set up some reasonable defaults. 35 | """ 36 | Axes.cla(self) 37 | self.xaxis.set_minor_locator(NullLocator()) 38 | self.yaxis.set_minor_locator(NullLocator()) 39 | self.xaxis.set_major_locator(NullLocator()) 40 | self.yaxis.set_major_locator(NullLocator()) 41 | 42 | def _gen_axes_spines(self): 43 | return {} 44 | 45 | 46 | # Now register the projection with matplotlib so the user can select 47 | # it. 48 | register_projection(ThinAxes) 49 | -------------------------------------------------------------------------------- /docs/source/modules/io.rst: -------------------------------------------------------------------------------- 1 | File I/O (:mod:`spike_sort.io`) 2 | =============================== 3 | 4 | .. currentmodule:: spike_sort.io 5 | 6 | Functions for reading and writing datafiles. 7 | 8 | .. _io_filters: 9 | 10 | Read/Write Filters (:mod:`spike_sort.io.filters`) 11 | ------------------------------------------------- 12 | 13 | Filters are basic backends for read/write operations. They offer following 14 | methods: 15 | 16 | * :py:func:`read_spt` -- read event times (such as spike times) 17 | * :py:func:`read_sp` -- read raw spike waveforms 18 | * :py:func:`write_spt` -- write spike times 19 | * :py:func:`write_sp` -- write raw spike waveforms 20 | 21 | The `read_*` methods take usually one argument (`datapath`), but it is 22 | not required. The `write_*` methods take `datapath` and the data to be 23 | written. 24 | 25 | If you want to read/write you custom data format, it is enough that 26 | you implement a class with these functions. 27 | 28 | The following filters are implemented: 29 | 30 | .. autosummary:: 31 | 32 | filters.BakerlabFilter 33 | filters.PyTablesFilter 34 | 35 | 36 | Export tools (:mod:`spike_sort.io.export`) 37 | -------------------------------------------- 38 | 39 | These tolls take one of the `io.filters` as an argument and export 40 | data to the file using `write_spt` or `write_sp` methods. 41 | 42 | .. autosummary:: 43 | 44 | export.export_cells 45 | 46 | Reference 47 | --------- 48 | 49 | .. automodule:: spike_sort.io.filters 50 | :members: 51 | 52 | .. automodule:: spike_sort.io.export 53 | :members: 54 | 55 | -------------------------------------------------------------------------------- /src/spike_analysis/xcorr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | def raise_exception(*args, **kwargs): 8 | raise NotImplementedError("This function requires NeuroTools") 9 | 10 | try: 11 | from NeuroTools.analysis import crosscorrelate 12 | except ImportError: 13 | crosscorrelate = raise_exception 14 | 15 | def show_xcorr(cells): 16 | n = len(cells) 17 | maxlag = 3 18 | bins = np.arange(-maxlag, maxlag, 0.01) 19 | ax = None 20 | for i in range(n): 21 | for j in range(i, n): 22 | if i != j: 23 | ax = plt.subplot(n, n, i + j * n + 1, sharey=ax) 24 | 25 | crosscorrelate(cells[i]['spt'], cells[j]['spt'], maxlag, 26 | display=ax, 27 | kwargs={"bins": bins}) 28 | plt.axvline(0, color='k') 29 | 30 | if ax: 31 | ax.set_ylabel("") 32 | ax.set_yticks([]) 33 | ax.set_xticks([]) 34 | ax.set_xlabel("") 35 | else: 36 | ax_label = plt.subplot(n, n, i + j * n + 1, frameon=False) 37 | ax_label.set_xticks([]) 38 | ax_label.set_yticks([]) 39 | ax_label.text(0,0, cells[j]['dataset'], 40 | transform=ax_label.transAxes, 41 | rotation=45, ha='left', va='bottom') 42 | 43 | ax.set_xticks((-maxlag, maxlag)) 44 | ax.set_yticks(ax.get_ylim()) 45 | -------------------------------------------------------------------------------- /examples/sorting/cluster_manual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | """ 5 | Based on raw recordings detect spikes, calculate features and do 6 | clustering by means of manual cluster-cutting. 7 | """ 8 | 9 | import os 10 | 11 | import matplotlib 12 | matplotlib.use("TkAgg") 13 | matplotlib.interactive(True) 14 | 15 | import spike_sort as sort 16 | from spike_sort.io.filters import PyTablesFilter 17 | 18 | DATAPATH = os.environ['DATAPATH'] 19 | 20 | if __name__ == "__main__": 21 | h5_fname = os.path.join(DATAPATH, "tutorial.h5") 22 | h5filter = PyTablesFilter(h5_fname, 'r') 23 | 24 | dataset = "/SubjectA/session01/el1" 25 | sp_win = [-0.2, 0.8] 26 | 27 | sp = h5filter.read_sp(dataset) 28 | spt = sort.extract.detect_spikes(sp, contact=3, thresh=300) 29 | 30 | spt = sort.extract.align_spikes(sp, spt, sp_win, type="max", resample=10) 31 | sp_waves = sort.extract.extract_spikes(sp, spt, sp_win) 32 | features = sort.features.combine( 33 | ( 34 | sort.features.fetSpIdx(sp_waves), 35 | sort.features.fetP2P(sp_waves), 36 | sort.features.fetPCA(sp_waves)), 37 | norm=True 38 | ) 39 | 40 | clust_idx = sort.ui.manual_sort.manual_sort(features, 41 | ['Ch0:P2P', 'Ch3:P2P']) 42 | 43 | clust, rest = sort.cluster.split_cells(spt, clust_idx, [1, 0]) 44 | 45 | sort.ui.plotting.figure() 46 | sort.ui.plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200) 47 | 48 | raw_input('Press any key to exit...') 49 | 50 | h5filter.close() 51 | -------------------------------------------------------------------------------- /src/spike_analysis/io_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | import glob, re 4 | import os 5 | 6 | def read_dataset(filter, dataset): 7 | spt = filter.read_spt(dataset) 8 | stim_node = "/".join(dataset.split('/')[:-1]+['stim']) 9 | stim = filter.read_spt(stim_node) 10 | return {'spt': spt['data'], 'stim': stim['data'], 'ev': []} 11 | 12 | def list_cells(filter, dataset): 13 | """List all cells which fit the pattern given in dataset. Dataset can contain 14 | wildcards. 15 | 16 | Example: 17 | 18 | dataset = "/Subject/sSession01/el*/cell*" 19 | """ 20 | regexp = "^/(?P[a-zA-z\*]+)/s(?P.+)/el(?P[0-9\*]+)/?(?P[a-zA-Z]+)?(?P[0-9\*]+)?$" 21 | 22 | conf = filter.conf_dict 23 | fpath = (conf['dirname'].format(DATAPATH=os.environ['DATAPATH'])+conf['cell']) 24 | rec_wildcard = re.match(regexp, dataset).groupdict() 25 | fname = fpath.format(**rec_wildcard) 26 | 27 | files = glob.glob(fname) 28 | 29 | rec_regexp = {"subject": "(?P[a-zA-z\*]+)", 30 | "ses_id": "(?P.+)", 31 | "el_id": "(?P[0-9\*]+)", 32 | "cell_id": "(?P[0-9\*]+)"} 33 | node_fmt = "/{subject}/s{ses_id}/el{el_id}/cell{cell_id}" 34 | 35 | f_regexp = fpath.format(**rec_regexp) 36 | pattern = re.compile(f_regexp) 37 | nodes = [] 38 | for f in files: 39 | dataset_match = pattern.match(f) 40 | rec = rec_wildcard.copy() 41 | rec.update(dataset_match.groupdict()) 42 | nodes.append(node_fmt.format(**rec)) 43 | 44 | return nodes 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012, Bartosz Telenczuk, Dmytro Bielievtsov 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | 28 | Third-party code 29 | ---------------- 30 | 31 | Initial version of the third-party code in `spike_beans.base` library was 32 | created by Zoran Isailovski and Ed Swierk and publised under permissive PSF 33 | license. For details see notes in the module. 34 | -------------------------------------------------------------------------------- /docs/source/_themes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010 by Armin Ronacher. 2 | 3 | Some rights reserved. 4 | 5 | Redistribution and use in source and binary forms of the theme, with or 6 | without modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | * Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above 13 | copyright notice, this list of conditions and the following 14 | disclaimer in the documentation and/or other materials provided 15 | with the distribution. 16 | 17 | * The names of the contributors may not be used to endorse or 18 | promote products derived from this software without specific 19 | prior written permission. 20 | 21 | We kindly ask you to only use these themes in an unmodified manner just 22 | for Flask and Flask-related products, not for unrelated projects. If you 23 | like the visual style and want to use it for your own projects, please 24 | consider making some larger changes to the themes (such as changing 25 | font faces, sizes, colors or margins). 26 | 27 | THIS THEME IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 28 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 30 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 31 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 32 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 33 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 34 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 35 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 36 | ARISING IN ANY WAY OUT OF THE USE OF THIS THEME, EVEN IF ADVISED OF THE 37 | POSSIBILITY OF SUCH DAMAGE. 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import setuptools 5 | 6 | from numpy.distutils.core import setup, Extension 7 | 8 | ext_modules = [] 9 | 10 | import os 11 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 12 | 13 | if not on_rtd: 14 | diptst_ext = Extension(name = 'spike_sort.stats._diptst', 15 | sources = ['src/spike_sort/stats/diptst/diptst.f', 16 | 'src/spike_sort/stats/diptst/diptst.pyf']) 17 | ext_modules.append(diptst_ext) 18 | 19 | 20 | setup(name='SpikeSort', 21 | version='0.13', 22 | description='Python Spike Sorting Package', 23 | long_description="""SpikeSort is a flexible spike sorting framework 24 | implemented entirely in Python based on widely-used packages such as numpy, 25 | PyTables and matplotlib. It features manual and automatic clustering, many 26 | data formats and it is memory-efficient.""", 27 | author='Bartosz Telenczuk and Dmytro Bielievtsov', 28 | author_email='bartosz.telenczuk@gmail.com', 29 | url='http://spikesort.org', 30 | ext_modules = ext_modules, 31 | classifiers = [ 32 | "Development Status :: 4 - Beta", 33 | "Environment :: Console", 34 | "Environment :: X11 Applications", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: BSD License", 37 | "Operating System :: OS Independent", 38 | "Programming Language :: Python :: 2.7" 39 | ], 40 | packages=['spike_sort', 41 | 'spike_sort.core', 42 | 'spike_sort.stats', 43 | 'spike_sort.ui', 44 | 'spike_sort.io', 45 | 'spike_beans', 46 | 'spike_analysis', 47 | ], 48 | package_dir = {"": "src"}, 49 | install_requires=[ 50 | 'matplotlib', 51 | 'tables', 52 | 'numpy >= 1.4.1', 53 | 'scipy', 54 | 'PyWavelets' 55 | ] 56 | ) 57 | 58 | 59 | -------------------------------------------------------------------------------- /examples/sorting/cluster_beans.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | matplotlib.use("TkAgg") 5 | matplotlib.interactive(True) 6 | 7 | from spike_beans import components, base 8 | 9 | #################################### 10 | # Adjust these fields for your needs 11 | 12 | # data source 13 | hdf5filename = 'tutorial.h5' 14 | dataset = "/SubjectA/session01/el1" 15 | 16 | # spike detection/extraction properties 17 | contact = 3 18 | detection_type = "max" 19 | thresh = "auto" 20 | filter_freq = (800.0, 100.0) 21 | 22 | sp_win = [-0.6, 0.8] 23 | 24 | path = filter(None, os.environ['DATAPATH'].split(os.sep)) + [hdf5filename] 25 | hdf5file = os.path.join(os.sep, *path) 26 | 27 | io = components.PyTablesSource(hdf5file, dataset) 28 | io_filter = components.FilterStack() 29 | 30 | base.register("RawSource", io) 31 | base.register("EventsOutput", io) 32 | base.register("SignalSource", io_filter) 33 | base.register("SpikeMarkerSource", 34 | components.SpikeDetector(contact=contact, 35 | thresh=thresh, 36 | type=detection_type, 37 | sp_win=sp_win, 38 | resample=1, 39 | align=True)) 40 | base.register("SpikeSource", components.SpikeExtractor(sp_win=sp_win)) 41 | base.register("FeatureSource", components.FeatureExtractor()) 42 | base.register("LabelSource", components.ClusterAnalyzer("gmm", 4)) 43 | 44 | browser = components.SpikeBrowser() 45 | feature_plot = components.PlotFeaturesTimeline() 46 | wave_plot = components.PlotSpikes() 47 | legend = components.Legend() 48 | export = components.ExportCells() 49 | 50 | ############################################################# 51 | # Add filters here: 52 | base.features["SignalSource"].add_filter("LinearIIR", *filter_freq) 53 | 54 | # Add the features here: 55 | base.features["FeatureSource"].add_feature("P2P") 56 | base.features["FeatureSource"].add_feature("PCA", ncomps=2) 57 | 58 | ############################################################# 59 | # Run the analysis (this can take a while) 60 | browser.update() 61 | -------------------------------------------------------------------------------- /src/spike_sort/io/neo_filters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import spike_sort as sort 5 | from spike_beans import components 6 | from spike_beans.components import GenericSource 7 | try: 8 | import neo 9 | except ImportError: 10 | raise ImportError, "To use the extra data filters you have to install Neo package" 11 | import numpy as np 12 | import os 13 | 14 | class AxonFilter(object): 15 | """Read Axon .abf files 16 | 17 | Parameters 18 | ---------- 19 | fname : str 20 | path to file 21 | 22 | electrodes : (optional) list of ints 23 | list of electrode indices to use 24 | 25 | """ 26 | 27 | def __init__(self, fname, electrodes=None): 28 | self.reader = neo.io.AxonIO(fname) 29 | self.block = self.reader.read_block() 30 | self.electrodes = electrodes 31 | 32 | def read_sp(self, dataset=None): 33 | electrodes = self.electrodes 34 | analogsignals = self.block.segments[0].analogsignals 35 | if electrodes is not None: 36 | analogsignals = [analogsignals[i] for i in electrodes] 37 | sp_raw = np.array(analogsignals) 38 | FS = float(analogsignals[0].sampling_rate.magnitude) 39 | n_contacts, _ = sp_raw.shape 40 | return {"data": sp_raw, "FS": FS, "n_contacts": n_contacts} 41 | 42 | def write_sp(self): 43 | raise NotImplementedError, "Writing to Axon files not yet implemented" 44 | def write_spt(self, spt_dict, dataset, overwrite=False): 45 | raise NotImplementedError, "Writing spikes in Axon format not yet implemented" 46 | 47 | class NeoSource(components.GenericSource): 48 | def __init__(self, fname, electrodes=None, overwrite=False): 49 | GenericSource.__init__(self, '', overwrite) 50 | 51 | root, ext = os.path.splitext(fname) 52 | ext = ext.lower() 53 | if ext == '.abf': 54 | self.io_filter = AxonFilter(fname, electrodes) 55 | else: 56 | raise IOError, "Format {0} not recognised".format(ext) 57 | self.read_sp = self.io_filter.read_sp 58 | self.write_sp = self.io_filter.write_sp 59 | self.write_spt = self.io_filter.write_spt 60 | -------------------------------------------------------------------------------- /src/spike_sort/ui/zoomer.py: -------------------------------------------------------------------------------- 1 | class Zoomer(object): 2 | '''Allows to zoom subplots''' 3 | def __init__(self, plt, fig): 4 | self.axlist = [] 5 | 6 | self.zoomed_state = {'geometry': (1, 1, 1), 7 | 'xlabel_visible': True, 8 | 'ylabel_visible': True} 9 | 10 | self.old_state = {'geometry': [1, 1, 1], 11 | 'xlabel_visible': True, 12 | 'ylabel_visible': True} 13 | 14 | plt.connect('key_press_event', self.zoom) 15 | self.fig = fig 16 | 17 | def zoom(self, event): 18 | axis = event.inaxes 19 | 20 | if axis is None or event.key != 'z': 21 | return 22 | 23 | if axis.get_geometry() == self.zoomed_state['geometry']: 24 | zoomed = True 25 | else: 26 | zoomed = False 27 | 28 | if not zoomed: 29 | # saving previous state 30 | self.old_state['geometry'] = axis.get_geometry() 31 | self.old_state['xlabel_visible'] = axis.xaxis.label.get_visible() 32 | self.old_state['ylabel_visible'] = axis.yaxis.label.get_visible() 33 | 34 | # removing old axes 35 | self.axlist = list(self.fig.get_axes()) 36 | for ax in self.axlist: 37 | self.fig.delaxes(ax) 38 | 39 | # modifying state (zooming) 40 | axis.change_geometry(*self.zoomed_state['geometry']) 41 | axis.xaxis.label.set_visible(self.zoomed_state['xlabel_visible']) 42 | axis.yaxis.label.set_visible(self.zoomed_state['ylabel_visible']) 43 | self.fig.add_axes(axis) 44 | self.fig.show() 45 | 46 | else: 47 | # removing current axes 48 | self.fig.delaxes(axis) 49 | 50 | # bringing the old state back 51 | axis.change_geometry(*self.old_state['geometry']) 52 | axis.xaxis.label.set_visible(self.old_state['xlabel_visible']) 53 | axis.yaxis.label.set_visible(self.old_state['ylabel_visible']) 54 | 55 | # adding old axes 56 | for ax in self.axlist: 57 | self.fig.add_axes(ax) 58 | 59 | self.fig.show() 60 | -------------------------------------------------------------------------------- /src/spike_analysis/dashboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | from scipy import stats 6 | import basic 7 | 8 | import matplotlib.pyplot as plt 9 | from io_tools import read_dataset 10 | 11 | def plot_psth(ax, dataset, **kwargs): 12 | spt = dataset['spt'] 13 | stim = dataset['stim'] 14 | ev = dataset['ev'] 15 | 16 | basic.plotPSTH(spt, stim,ax=ax, **kwargs) 17 | ymin, ymax = plt.ylim() 18 | if len(ev)>0: 19 | plt.vlines(ev, ymin, ymax) 20 | 21 | ax.text(0.95, 0.9,"total n/o spikes: %d" % (len(spt),), 22 | transform=ax.transAxes, 23 | ha='right') 24 | 25 | def plot_isi(ax, dataset, win=[0,5], bin=0.1, color='k'): 26 | spt = dataset['spt'] 27 | stim = dataset['stim'] 28 | 29 | isi = np.diff(spt) 30 | intvs = np.arange(win[0], win[1], bin) 31 | counts, bins = np.histogram(isi, intvs) 32 | mode,n = stats.mode(isi) 33 | 34 | ax.set_xlim(win) 35 | ax.set_ylim(0, counts.max()) 36 | ax.plot(intvs[:-1], counts, color=color, drawstyle='steps-post') 37 | ax.axvline(mode, color='k') 38 | 39 | ax.text(0.95, 0.9,"mode: %.2f ms" % (mode,), 40 | transform=ax.transAxes, 41 | ha='right') 42 | ax.set_xlabel("interval (ms)") 43 | ax.set_ylabel("count") 44 | 45 | def plot_trains(ax, dataset, **kwargs): 46 | spt = dataset['spt'] 47 | stim = dataset['stim'] 48 | ev = dataset['ev'] 49 | 50 | basic.plotraster(spt, stim,ax=ax, **kwargs) 51 | ymin, ymax = plt.ylim() 52 | if len(ev)>0: 53 | plt.vlines(ev, ymin, ymax) 54 | 55 | def plot_nspikes(ax, dataset, win=[0,30], color="k"): 56 | spt = dataset['spt'] 57 | stim = dataset['stim'] 58 | 59 | trains = basic.SortSpikes(spt, stim, win) 60 | n_spks = np.array([len(t) for t in trains]) 61 | count, bins = np.histogram(n_spks, np.arange(10)) 62 | ax.bar(bins[:-1]-0.5, count, color=color) 63 | ax.set_xlim((-1,10)) 64 | 65 | burst_frac = np.mean(n_spks>1) 66 | ax.text(0.95, 0.9,"%d %% bursts" % (burst_frac*100,), 67 | transform=ax.transAxes, 68 | ha='right') 69 | ax.set_xlabel("no. spikes") 70 | ax.set_ylabel("count") 71 | 72 | 73 | def plot_dataset(dataset, fig=None, **kwargs): 74 | 75 | if not fig: 76 | fig = plt.gcf() 77 | plt.subplots_adjust(hspace=0.3) 78 | ax1=fig.add_subplot(2,2,1) 79 | plot_psth(ax1, dataset, **kwargs) 80 | plt.title("PSTH") 81 | ax2= fig.add_subplot(2,2,2) 82 | plot_isi(ax2, dataset, **kwargs) 83 | plt.title("ISIH") 84 | ax3=fig.add_subplot(2,2,3) 85 | plot_trains(ax3,dataset) 86 | plt.title("raster") 87 | ax4= fig.add_subplot(2,2,4) 88 | plot_nspikes(ax4, dataset, **kwargs) 89 | plt.title("burst order") 90 | 91 | def show_cell(filter, cell): 92 | 93 | dataset = read_dataset(filter, cell) 94 | plot_dataset(dataset) 95 | -------------------------------------------------------------------------------- /src/spike_sort/stats/tests.py: -------------------------------------------------------------------------------- 1 | try: 2 | import _diptst 3 | except ImportError: 4 | _diptst = None 5 | 6 | 7 | import numpy as np 8 | from scipy import stats as st 9 | 10 | 11 | def multidimensional(func1d): 12 | """apply 1d function along specified axis 13 | """ 14 | def _decorated(data, axis=0): 15 | if data.ndim <= 1: 16 | return func1d(data) 17 | else: 18 | return np.apply_along_axis(func1d, axis, data) 19 | _decorated.__doc__ = func1d.__doc__ 20 | 21 | return _decorated 22 | 23 | 24 | def unsqueeze(data, axis): 25 | """inserts new axis to data at position `axis`. 26 | 27 | This is very useful when one wants to do operations which support 28 | broadcasting without using np.newaxis every time. 29 | 30 | Parameters 31 | ---------- 32 | data : array 33 | input array 34 | axis : int 35 | axis to be inserted 36 | 37 | Returns 38 | ------- 39 | out : array 40 | array which has data.ndim+1 dimensions. Additional dimension 41 | has length 1 42 | 43 | Example 44 | ------- 45 | >>> data = [[1,2,3], [4,5,6]] 46 | >>> m = np.mean(data, 1) 47 | >>> data -= unsqueeze(m, 1) 48 | >>> data.mean(1) 49 | array([ 0., 0.]) 50 | 51 | """ 52 | shape = data.shape 53 | shape = np.insert(shape, axis, 1) 54 | return np.reshape(data, shape) 55 | 56 | 57 | def std_r(data, axis=0): 58 | """Computes robust estimate of standard deviation 59 | (Quiroga et al, 2004) 60 | 61 | Parameters 62 | ---------- 63 | data : array 64 | input data array 65 | axis : int 66 | 67 | Returns 68 | ------- 69 | data : array 70 | 71 | """ 72 | median = unsqueeze(np.median(data, axis), axis) 73 | std_r = np.median(np.abs(data - median), axis)/0.6745 74 | return std_r 75 | 76 | 77 | @multidimensional 78 | def dip(data): 79 | """Computes DIP statistic (Hartigan & Hartigan 1985) 80 | 81 | Parameters 82 | ---------- 83 | data : array 84 | input data array 85 | axis : int 86 | axis along which to compute dip 87 | 88 | Returns 89 | ------- 90 | data : float or array 91 | DIP statistic. If the input data is flat, returns float 92 | """ 93 | if not _diptst: 94 | raise NotImplemented, "module unavailable" 95 | sdata = np.sort(data) 96 | return _diptst.diptst1(sdata)[0] 97 | 98 | 99 | @multidimensional 100 | def ks(data): 101 | """Computes Kolmogorov-Smirnov statistic (Lilliefors modification) 102 | 103 | Parameters 104 | ---------- 105 | data : array 106 | (n_vecs) input data array 107 | axis : int 108 | axis along which to compute ks 109 | 110 | Returns 111 | ------- 112 | data : array or float 113 | KS statistic. If the input data is flat, returns float 114 | """ 115 | mr = np.median(data) 116 | stdr = std_r(data) 117 | 118 | # avoid zero-variance 119 | if stdr == 0: 120 | return 0. 121 | 122 | return st.kstest(data, st.norm(loc=mr, scale=stdr).cdf)[0] 123 | 124 | def std(*args, **kwargs): 125 | return np.std(*args, **kwargs) 126 | -------------------------------------------------------------------------------- /src/spike_sort/ui/manual_sort.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from matplotlib.widgets import Lasso 4 | from matplotlib.path import Path 5 | from matplotlib.colors import colorConverter 6 | from matplotlib.collections import RegularPolyCollection # , LineCollection 7 | 8 | from matplotlib.pyplot import figure 9 | from numpy import nonzero 10 | 11 | import numpy as np 12 | 13 | 14 | class LassoManager(object): 15 | def __init__(self, ax, data, labels=None, color_on='r', color_off='k', markersize=1): 16 | self.axes = ax 17 | self.canvas = ax.figure.canvas 18 | self.data = data 19 | self.call_list = [] 20 | 21 | self.Nxy = data.shape[0] 22 | self.color_on = colorConverter.to_rgba(color_on) 23 | self.color_off = colorConverter.to_rgba(color_off) 24 | 25 | facecolors = [self.color_on for _ in range(self.Nxy)] 26 | fig = ax.figure 27 | self.collection = RegularPolyCollection( 28 | fig.dpi, 6, sizes=(markersize,), 29 | facecolors=facecolors, 30 | edgecolors=facecolors, 31 | offsets=data, 32 | transOffset=ax.transData) 33 | 34 | ax.add_collection(self.collection, autolim=True) 35 | ax.autoscale_view() 36 | 37 | if labels is not None: 38 | ax.set_xlabel(labels[0]) 39 | ax.set_ylabel(labels[1]) 40 | self.cid = self.canvas.mpl_connect('button_press_event', self.onpress) 41 | self.ind = None 42 | self.canvas.draw() 43 | 44 | def register(self, callback_func): 45 | self.call_list.append(callback_func) 46 | 47 | def callback(self, verts): 48 | facecolors = self.collection.get_facecolors() 49 | edgecolors = self.collection.get_edgecolors() 50 | ind = nonzero(Path(verts).contains_points(self.data))[0] 51 | for i in range(self.Nxy): 52 | if i in ind: 53 | facecolors[i] = self.color_on 54 | edgecolors[i] = self.color_on 55 | else: 56 | facecolors[i] = self.color_off 57 | edgecolors[i] = self.color_off 58 | 59 | self.canvas.draw_idle() 60 | self.canvas.widgetlock.release(self.lasso) 61 | del self.lasso 62 | self.ind = ind 63 | 64 | for func in self.call_list: 65 | func(ind) 66 | 67 | def onpress(self, event): 68 | if self.canvas.widgetlock.locked(): 69 | return 70 | if event.inaxes is None: 71 | return 72 | self.lasso = Lasso(event.inaxes, (event.xdata, event.ydata), 73 | self.callback) 74 | # acquire a lock on the widget drawing 75 | self.canvas.widgetlock(self.lasso) 76 | 77 | 78 | def manual_sort(features_dict, feat_idx): 79 | 80 | features = features_dict['data'] 81 | names = features_dict['names'] 82 | if type(feat_idx[0]) is int: 83 | ii = np.array(feat_idx) 84 | else: 85 | ii = np.array([np.nonzero(names == f)[0][0] for f in feat_idx]) 86 | return _cluster(features[:, ii], names[:, ii]) 87 | 88 | 89 | def _cluster(data, names=None, markersize=1): 90 | fig_cluster = figure(figsize=(6, 6)) 91 | ax_cluster = fig_cluster.add_subplot(111, 92 | xlim=(-0.1, 1.1), 93 | ylim=(-0.1, 1.1), 94 | autoscale_on=True) 95 | lman = LassoManager(ax_cluster, data, names, markersize=markersize) 96 | 97 | while lman.ind is None: 98 | time.sleep(.01) 99 | fig_cluster.canvas.flush_events() 100 | 101 | n_spikes = data.shape[0] 102 | clust_idx = np.zeros(n_spikes, dtype='int16') 103 | clust_idx[lman.ind] = 1 104 | return clust_idx 105 | -------------------------------------------------------------------------------- /tests/test_beans.py: -------------------------------------------------------------------------------- 1 | from spike_beans import base 2 | from nose.tools import ok_, raises 3 | from nose import with_setup 4 | 5 | 6 | def setup(): 7 | "set up test fixtures" 8 | pass 9 | 10 | 11 | def teardown(): 12 | "tear down test fixtures" 13 | base.features = base.FeatureBroker() 14 | 15 | 16 | class Dummy(base.Component): 17 | con = base.RequiredFeature('Data', base.HasAttributes('data')) 18 | opt_con = base.OptionalFeature('OptionalData', base.HasAttributes('data')) 19 | 20 | def __init__(self): 21 | self.data = 0 22 | super(Dummy, self).__init__() 23 | 24 | def get_data(self): 25 | return self.con.data 26 | 27 | def get_optional_data(self): 28 | if self.opt_con: 29 | return self.opt_con.data 30 | 31 | def _update(self): 32 | self.data += 1 33 | 34 | class DummyDataWithZeroDivision(object): 35 | @property 36 | def data(self): 37 | data = 1/0 38 | return data 39 | 40 | class DummyTwoWay(Dummy): 41 | con2 = base.RequiredFeature('Data2', base.HasAttributes('get_data')) 42 | 43 | def get_data(self): 44 | return self.con.data + self.con2.get_data() 45 | 46 | 47 | class DummyDataProvider(base.Component): 48 | data = "some data" 49 | 50 | 51 | class NixProvider(base.Component): 52 | pass 53 | 54 | 55 | @with_setup(setup, teardown) 56 | def test_dependency_resolution(): 57 | base.features.Provide('Data', DummyDataProvider) 58 | comp = Dummy() 59 | ok_(comp.get_data() == 'some data') 60 | 61 | 62 | @with_setup(setup, teardown) 63 | def test_diamond_dependency(): 64 | base.features.Provide("Data", DummyDataProvider()) 65 | base.features.Provide("Data2", Dummy()) 66 | out = DummyTwoWay() 67 | data = out.get_data() 68 | base.features['Data'].update() 69 | print out.data 70 | ok_(out.data == 1) 71 | 72 | 73 | @raises(AssertionError) 74 | @with_setup(setup, teardown) 75 | def test_missing_attribute(): 76 | base.features.Provide('Data', NixProvider) 77 | comp = Dummy() 78 | data = comp.get_data() 79 | 80 | 81 | @raises(AttributeError) 82 | @with_setup(setup, teardown) 83 | def test_missing_dependency(): 84 | comp = Dummy() 85 | data = comp.get_data() 86 | 87 | 88 | @with_setup(setup, teardown) 89 | def test_on_change(): 90 | base.features.Provide('Data', DummyDataProvider()) 91 | comp = Dummy() 92 | comp.get_data() 93 | base.features['Data'].update() 94 | ok_(comp.data) 95 | 96 | @raises(ZeroDivisionError) 97 | @with_setup(setup, teardown) 98 | def test_hasattribute_exceptions(): 99 | '''test whether HasAttributes lets exceptions through (other than AttributeError)''' 100 | c = DummyDataWithZeroDivision() 101 | base.features.Provide('Data', c) 102 | comp = Dummy() 103 | data = comp.get_data() 104 | assert True 105 | 106 | @with_setup(setup, teardown) 107 | def test_register_new_feature_by_setitem(): 108 | base.features['Data']=DummyDataProvider() 109 | comp = Dummy() 110 | ok_(comp.get_data() == 'some data') 111 | 112 | @with_setup(setup, teardown) 113 | def test_register_new_feature_by_register(): 114 | dep_comp = DummyDataProvider() 115 | added_comp = base.register("Data", dep_comp) 116 | comp = Dummy() 117 | ok_(comp.get_data() == 'some data') 118 | assert dep_comp is added_comp 119 | 120 | @with_setup(setup, teardown) 121 | def test_missing_optional_dependency(): 122 | required_dep = base.register("Data", DummyDataProvider()) 123 | comp = Dummy() 124 | ok_(comp.opt_con is None) 125 | 126 | @raises(ZeroDivisionError) 127 | @with_setup(setup, teardown) 128 | def test_hasattribute_exceptions_for_optional_deps(): 129 | '''test whether HasAttributes raises exceptions for optional features''' 130 | required_dep = base.register("Data", DummyDataProvider()) 131 | optional_dep = base.register('OptionalData', DummyDataWithZeroDivision()) 132 | comp = Dummy() 133 | data = comp.get_optional_data() 134 | assert True 135 | 136 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from spike_beans import base 2 | import numpy as np 3 | 4 | spike_dur = 5. 5 | spike_amp = 100. 6 | FS = 25E3 7 | period = 100 8 | n_spikes = 100 9 | 10 | class DummySignalSource(base.Component): 11 | 12 | def __init__(self): 13 | 14 | self.period = period 15 | self.n_spikes = n_spikes 16 | self.f_filter = None 17 | super(DummySignalSource, self).__init__() 18 | 19 | self._generate_data() 20 | 21 | def _generate_data(self): 22 | n_pts = int(self.n_spikes*self.period/1000.*FS) 23 | sp_idx = (np.arange(1,self.n_spikes-1)*self.period*FS/1000).astype(int) 24 | spikes = np.zeros(n_pts)[np.newaxis,:] 25 | spikes[0,sp_idx]=spike_amp 26 | 27 | n = int(spike_dur/1000.*FS) #spike length 28 | spikes[0,:] = np.convolve(spikes[0,:], np.ones(n), 'full')[:n_pts] 29 | self.spt = (sp_idx+0.5)*1000./FS 30 | self.FS = FS 31 | self._spikes = spikes 32 | 33 | def read_signal(self): 34 | 35 | #in milisecs 36 | 37 | spk_data ={"data":self._spikes,"n_contacts":1, "FS":self.FS} 38 | return spk_data 39 | 40 | def _update(self): 41 | self.period = period*2 42 | self.n_spikes = n_spikes/2 43 | self._generate_data() 44 | 45 | signal = property(read_signal) 46 | 47 | class DummySpikeDetector(base.Component): 48 | def __init__(self): 49 | 50 | self.threshold = 500 51 | self.type = 'min' 52 | self.contact = 0 53 | self.sp_win = [-0.6, 0.8] 54 | super(DummySpikeDetector, self).__init__() 55 | self._generate_data() 56 | 57 | def _generate_data(self): 58 | n_pts = int(n_spikes*period/1000.*FS) 59 | sp_idx = (np.arange(1,n_spikes-1)*period*FS/1000).astype(int) 60 | spt = (sp_idx+0.5)*1000./FS 61 | self._spt_data = {'data':spt} 62 | 63 | def read_events(self): 64 | return self._spt_data 65 | 66 | events = property(read_events) 67 | 68 | class DummyLabelSource(base.Component): 69 | def __init__(self): 70 | self.labels = np.random.randint(0,5, n_spikes-2) 71 | 72 | class DummyFeatureExtractor(base.Component): 73 | 74 | def __init__(self): 75 | n_feats=2 76 | features = np.vstack(( 77 | np.zeros((n_spikes, n_feats)), 78 | np.ones((n_spikes, n_feats)) 79 | )) 80 | names = ["Fet{0}".format(i) for i in range(n_feats)] 81 | 82 | self._features = {"data": features, "names":names} 83 | 84 | super(DummyFeatureExtractor, self).__init__() 85 | 86 | def read_features(self): 87 | 88 | return self._features 89 | 90 | def add_feature(self, name): 91 | ''' adds random feature with specifid name ''' 92 | 93 | self._features['data'] = np.hstack((self._features['data'], np.random.randn(n_spikes * 2, 1))) 94 | self._features['names'].append(name) 95 | 96 | def add_spikes(self, num = 10): 97 | '''appends `num` random values to each feature''' 98 | 99 | n_features = self._features['data'].shape[1] 100 | self._features['data'] = np.vstack((self._features['data'], np.random.randn(num, n_features))) 101 | 102 | features = property(read_features) 103 | 104 | class DummySpikeSource(base.Component): 105 | 106 | def __init__(self): 107 | n_pts = 100 108 | spike_shape = np.zeros(n_pts) 109 | spike_shape[n_pts/2] = 1. 110 | data = spike_shape[:,np.newaxis, np.newaxis]*np.ones(n_spikes-2)[np.newaxis,:,np.newaxis] 111 | self._sp_waves = {'data':data, 'time':np.ones(n_pts)*1000./FS} 112 | def read_spikes(self): 113 | 114 | return self._sp_waves 115 | 116 | spikes = property(read_spikes) 117 | 118 | class RandomFeatures(base.Component): 119 | def read_features(self): 120 | n_feats=2 121 | features = np.random.randn(n_spikes, n_feats) 122 | names = ["Fet{0}".format(i) for i in range(n_feats)] 123 | return {"data": features, "names":names} 124 | features = property(read_features) 125 | 126 | -------------------------------------------------------------------------------- /docs/source/_themes/flask_theme_support.py: -------------------------------------------------------------------------------- 1 | # flasky extensions. flasky pygments style based on tango style 2 | from pygments.style import Style 3 | from pygments.token import Keyword, Name, Comment, String, Error, \ 4 | Number, Operator, Generic, Whitespace, Punctuation, Other, Literal 5 | 6 | 7 | class FlaskyStyle(Style): 8 | background_color = "#f8f8f8" 9 | default_style = "" 10 | 11 | styles = { 12 | # No corresponding class for the following: 13 | #Text: "", # class: '' 14 | Whitespace: "underline #f8f8f8", # class: 'w' 15 | Error: "#a40000 border:#ef2929", # class: 'err' 16 | Other: "#000000", # class 'x' 17 | 18 | Comment: "italic #8f5902", # class: 'c' 19 | Comment.Preproc: "noitalic", # class: 'cp' 20 | 21 | Keyword: "bold #004461", # class: 'k' 22 | Keyword.Constant: "bold #004461", # class: 'kc' 23 | Keyword.Declaration: "bold #004461", # class: 'kd' 24 | Keyword.Namespace: "bold #004461", # class: 'kn' 25 | Keyword.Pseudo: "bold #004461", # class: 'kp' 26 | Keyword.Reserved: "bold #004461", # class: 'kr' 27 | Keyword.Type: "bold #004461", # class: 'kt' 28 | 29 | Operator: "#582800", # class: 'o' 30 | Operator.Word: "bold #004461", # class: 'ow' - like keywords 31 | 32 | Punctuation: "bold #000000", # class: 'p' 33 | 34 | # because special names such as Name.Class, Name.Function, etc. 35 | # are not recognized as such later in the parsing, we choose them 36 | # to look the same as ordinary variables. 37 | Name: "#000000", # class: 'n' 38 | Name.Attribute: "#c4a000", # class: 'na' - to be revised 39 | Name.Builtin: "#004461", # class: 'nb' 40 | Name.Builtin.Pseudo: "#3465a4", # class: 'bp' 41 | Name.Class: "#000000", # class: 'nc' - to be revised 42 | Name.Constant: "#000000", # class: 'no' - to be revised 43 | Name.Decorator: "#888", # class: 'nd' - to be revised 44 | Name.Entity: "#ce5c00", # class: 'ni' 45 | Name.Exception: "bold #cc0000", # class: 'ne' 46 | Name.Function: "#000000", # class: 'nf' 47 | Name.Property: "#000000", # class: 'py' 48 | Name.Label: "#f57900", # class: 'nl' 49 | Name.Namespace: "#000000", # class: 'nn' - to be revised 50 | Name.Other: "#000000", # class: 'nx' 51 | Name.Tag: "bold #004461", # class: 'nt' - like a keyword 52 | Name.Variable: "#000000", # class: 'nv' - to be revised 53 | Name.Variable.Class: "#000000", # class: 'vc' - to be revised 54 | Name.Variable.Global: "#000000", # class: 'vg' - to be revised 55 | Name.Variable.Instance: "#000000", # class: 'vi' - to be revised 56 | 57 | Number: "#990000", # class: 'm' 58 | 59 | Literal: "#000000", # class: 'l' 60 | Literal.Date: "#000000", # class: 'ld' 61 | 62 | String: "#4e9a06", # class: 's' 63 | String.Backtick: "#4e9a06", # class: 'sb' 64 | String.Char: "#4e9a06", # class: 'sc' 65 | String.Doc: "italic #8f5902", # class: 'sd' - like a comment 66 | String.Double: "#4e9a06", # class: 's2' 67 | String.Escape: "#4e9a06", # class: 'se' 68 | String.Heredoc: "#4e9a06", # class: 'sh' 69 | String.Interpol: "#4e9a06", # class: 'si' 70 | String.Other: "#4e9a06", # class: 'sx' 71 | String.Regex: "#4e9a06", # class: 'sr' 72 | String.Single: "#4e9a06", # class: 's1' 73 | String.Symbol: "#4e9a06", # class: 'ss' 74 | 75 | Generic: "#000000", # class: 'g' 76 | Generic.Deleted: "#a40000", # class: 'gd' 77 | Generic.Emph: "italic #000000", # class: 'ge' 78 | Generic.Error: "#ef2929", # class: 'gr' 79 | Generic.Heading: "bold #000080", # class: 'gh' 80 | Generic.Inserted: "#00A000", # class: 'gi' 81 | Generic.Output: "#888", # class: 'go' 82 | Generic.Prompt: "#745334", # class: 'gp' 83 | Generic.Strong: "bold #000000", # class: 'gs' 84 | Generic.Subheading: "bold #800080", # class: 'gu' 85 | Generic.Traceback: "bold #a40000", # class: 'gt' 86 | } 87 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 14 | 15 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest 16 | 17 | help: 18 | @echo "Please use \`make ' where is one of" 19 | @echo " html to make standalone HTML files" 20 | @echo " dirhtml to make HTML files named index.html in directories" 21 | @echo " singlehtml to make a single large HTML file" 22 | @echo " pickle to make pickle files" 23 | @echo " json to make JSON files" 24 | @echo " htmlhelp to make HTML files and a HTML help project" 25 | @echo " qthelp to make HTML files and a qthelp project" 26 | @echo " devhelp to make HTML files and a Devhelp project" 27 | @echo " epub to make an epub" 28 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 29 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 30 | @echo " text to make text files" 31 | @echo " man to make manual pages" 32 | @echo " changes to make an overview of all changed/added/deprecated items" 33 | @echo " linkcheck to check all external links for integrity" 34 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 35 | 36 | clean: 37 | -rm -rf $(BUILDDIR)/* 38 | 39 | html: 40 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 41 | @echo 42 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 43 | 44 | dirhtml: 45 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 46 | @echo 47 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 48 | 49 | singlehtml: 50 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 51 | @echo 52 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 53 | 54 | pickle: 55 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 56 | @echo 57 | @echo "Build finished; now you can process the pickle files." 58 | 59 | json: 60 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 61 | @echo 62 | @echo "Build finished; now you can process the JSON files." 63 | 64 | htmlhelp: 65 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 66 | @echo 67 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 68 | ".hhp project file in $(BUILDDIR)/htmlhelp." 69 | 70 | qthelp: 71 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 72 | @echo 73 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 74 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 75 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SortSpikes.qhcp" 76 | @echo "To view the help file:" 77 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SortSpikes.qhc" 78 | 79 | devhelp: 80 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 81 | @echo 82 | @echo "Build finished." 83 | @echo "To view the help file:" 84 | @echo "# mkdir -p $$HOME/.local/share/devhelp/SortSpikes" 85 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SortSpikes" 86 | @echo "# devhelp" 87 | 88 | epub: 89 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 90 | @echo 91 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 92 | 93 | latex: 94 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 95 | @echo 96 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 97 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 98 | "(use \`make latexpdf' here to do that automatically)." 99 | 100 | latexpdf: 101 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 102 | @echo "Running LaTeX files through pdflatex..." 103 | make -C $(BUILDDIR)/latex all-pdf 104 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 105 | 106 | text: 107 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 108 | @echo 109 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 110 | 111 | man: 112 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 113 | @echo 114 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 115 | 116 | changes: 117 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 118 | @echo 119 | @echo "The overview file is in $(BUILDDIR)/changes." 120 | 121 | linkcheck: 122 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 123 | @echo 124 | @echo "Link check complete; look for any errors in the above output " \ 125 | "or in $(BUILDDIR)/linkcheck/output.txt." 126 | 127 | doctest: 128 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 129 | @echo "Testing of doctests in the sources finished, look at the " \ 130 | "results in $(BUILDDIR)/doctest/output.txt." 131 | -------------------------------------------------------------------------------- /src/spike_analysis/basic.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def gettrains(spt, stim, win, binsz): 5 | i = np.searchsorted(stim, spt) 6 | spt2 = spt - stim[i - 1] 7 | bin = np.arange(win[0], win[1], binsz) 8 | bin = np.concatenate((bin, [bin[-1] + binsz, np.inf])) 9 | j = np.searchsorted(bin, spt2) 10 | npoints = len(bin) - 1 11 | ntrials = len(stim) 12 | trains = np.zeros((npoints, ntrials)) 13 | trains[j - 1, i - 1] = 1 14 | return trains[:-1, :] 15 | 16 | def SortSpikes(spt, stim, win=None): 17 | """Given spike and stimuli times return 2D array with spike trains. 18 | If win is given only spikes occuring in the time window are 19 | returned. 20 | """ 21 | i = np.searchsorted(stim, spt) 22 | spt2 = spt - stim[i - 1] 23 | if win: 24 | corrected = filter(lambda x: win[1] > x[0] >= win[0], zip(spt2, i)) 25 | spt2 = np.array([x[0] for x in corrected]) 26 | i = np.array([x[1] for x in corrected]) 27 | return [spt2[i == j] for j in xrange(1, len(stim) + 1)] 28 | 29 | def plotraster(spt, stim, win=[0, 30], ntrials=None, ax=None, height=1.0): 30 | """Creates raster plots of spike trains: 31 | 32 | spt - times of spike occurance, 33 | stim - stimulus times 34 | win - range of time axis 35 | ntrials - number of trials to plot (if None plot all) 36 | """ 37 | if not ntrials: 38 | ntrials = len(stim) - 1 39 | 40 | if not ax: 41 | ax = plt.gca() 42 | 43 | spt2 = spt[spt < stim[ntrials]].copy() 44 | i = np.searchsorted(stim[:ntrials], spt2) 45 | spt2 -= stim[i - 1] 46 | 47 | plt.vlines(spt2, i, i + height) 48 | plt.xlim(win) 49 | plt.ylim((1, ntrials)) 50 | plt.xlabel('time (ms)') 51 | plt.ylabel('trials') 52 | 53 | def plottrains(trains, win=[0, 30], ntrials=None, height=1.0): 54 | print "Deprecation: Please use plotRasterTrains insted" 55 | plotRasterTrains(trains, win, ntrials, height) 56 | 57 | def plotRasterTrains(trains, win=[0, 30], ntrials=None, height=1.0): 58 | """Creates raster plots of spike trains: 59 | 60 | spt - times of spike occurance, 61 | stim - stimulus times 62 | win - range of time axis 63 | ntrials - number of trials to plot (if None plot all) 64 | """ 65 | if ntrials: 66 | trains = trains[:ntrials] 67 | 68 | ax = plt.gca() 69 | ax.set_xlim(win) 70 | ax.set_ylim((1, ntrials)) 71 | 72 | lines = [plt.vlines(sp, np.ones(len(sp)) * i, 73 | np.ones(len(sp)) * (i + height)) 74 | for i, sp in enumerate(trains) if len(sp)] 75 | plt.xlabel('time (ms)') 76 | plt.ylabel('trials') 77 | return lines 78 | 79 | def BinTrains(trains, win, tsamp=0.25): 80 | """Convert a list of spike trains into a binary sequence""" 81 | bins = np.arange(win[0], win[1], tsamp) 82 | trains = [np.histogram(spt, bins)[0][1:] for spt in trains] 83 | return bins[1:-1], np.array(trains).T 84 | 85 | def plotPSTH(spt, stim, win=[0, 30], bin=0.25, ax=None, 86 | rate=False, **kwargs): 87 | """Plot peri-stimulus time histogram (PSTH)""" 88 | i = np.searchsorted(stim + win[0], spt) 89 | spt2 = spt - stim[i - 1] 90 | bins = np.arange(win[0], win[1], bin) 91 | psth, bins = np.histogram(spt2, bins) 92 | 93 | if not ax: 94 | ax = plt.gca() 95 | 96 | if rate: 97 | psth = psth * 1.0 / len(stim) / bin * 1000.0 98 | ax.set_ylabel('firing rate (Hz)') 99 | else: 100 | ax.set_ylabel('number of spikes') 101 | 102 | lines = ax.plot(bins[:-1], psth, **kwargs) 103 | ax.set_xlabel('time (ms)') 104 | return lines 105 | 106 | def plotTrainsPSTH(trains, win, bin=0.25, rate=False, **kwargs): 107 | """Plot peri-stimulus time histogram (PSTH) from a list of spike times 108 | (trains)""" 109 | ax = plt.gca() 110 | time, binned = BinTrains(trains, win, bin) 111 | psth = np.mean(binned, 1) / bin * 1000 112 | plt.plot(time, psth, **kwargs) 113 | if rate: 114 | psth = psth * 1.0 / len(trains) / bin * 1000.0 115 | ax.set_ylabel('firing rate (Hz)') 116 | else: 117 | ax.set_ylabel('number of spikes') 118 | ax.set_xlabel('time (ms)') 119 | ax.set_ylabel('firing rate (Hz)') 120 | 121 | def CalcTrainsPSTH(trains, win, bin=0.25): 122 | time, binned = BinTrains(trains, win, bin) 123 | psth = np.mean(binned, 1) / bin * 1000 124 | return time, psth 125 | 126 | def CalcPSTH(spt, stim, win=[0, 30], bin=0.25, ax=None, norm=False, **kwargs): 127 | """Calculate peri-stimulus time histogram (PSTH). 128 | Output: 129 | -- psth - spike counts 130 | -- bins - bins edges""" 131 | i = np.searchsorted(stim + win[0], spt) 132 | spt2 = spt - stim[i - 1] 133 | bins = np.arange(win[0], win[1], bin) 134 | psth, bins = np.histogram(spt2, bins) 135 | if norm: 136 | psth = psth * 1000.0 / (len(stim) * bin) 137 | return psth[1:], bins[1:-1] 138 | 139 | def plotPSTHBar(spt, stim, win=[0, 30], bin=0.25, **kwargs): 140 | """Plot peri-stimulus time histogram (PSTH)""" 141 | i = np.searchsorted(stim + win[0], spt) 142 | spt2 = spt - stim[i - 1] 143 | bins = np.arange(win[0], win[1], bin) 144 | psth, bins = np.histogram(spt2, bins) 145 | psth = psth * 1.0 / len(stim) / bin * 1000.0 146 | plt.gca().bar(bins[:-1], psth, bin, **kwargs) 147 | plt.xlabel('time (ms)') 148 | plt.ylabel('firing rate (Hz)') 149 | -------------------------------------------------------------------------------- /src/spike_sort/core/filters.py: -------------------------------------------------------------------------------- 1 | import tables 2 | import tempfile 3 | import numpy as np 4 | from scipy import signal 5 | import os 6 | import atexit 7 | _open_files = {} 8 | 9 | 10 | class ZeroPhaseFilter(object): 11 | """IIR Filter with zero phase delay""" 12 | def __init__(self, ftype, fband, tw=200., stop=20): 13 | self.gstop = stop 14 | self.gpass = 1 15 | self.fband = fband 16 | self.tw = tw 17 | self.ftype = ftype 18 | self._coefs_cache = {} 19 | 20 | def _design_filter(self, FS): 21 | 22 | if not FS in self._coefs_cache: 23 | wp = np.array(self.fband) 24 | ws = wp + np.array([-self.tw, self.tw]) 25 | wp, ws = wp * 2.0 / FS, ws * 2.0 / FS 26 | b, a = signal.iirdesign(wp=wp, 27 | ws=ws, 28 | gstop=self.gstop, 29 | gpass=self.gpass, 30 | ftype=self.ftype) 31 | self._coefs_cache[FS] = (b, a) 32 | else: 33 | b, a = self._coefs_cache[FS] 34 | return b, a 35 | 36 | def __call__(self, x, FS): 37 | b, a = self._design_filter(FS) 38 | return signal.filtfilt(b, a, x) 39 | 40 | 41 | class FilterFir(object): 42 | """FIR filter with zero phase delay 43 | 44 | Attributes 45 | ---------- 46 | f_pass : float 47 | normalised low-cutoff frequency 48 | 49 | f_stop : float 50 | normalised high-cutoff frequency 51 | 52 | order : int 53 | filter order 54 | 55 | """ 56 | def __init__(self, f_pass, f_stop, order): 57 | self._coefs_cache = {} 58 | self.fp = f_pass 59 | self.fs = f_stop 60 | self.order = order 61 | 62 | def _design_filter(self, FS): 63 | if not FS in self._coefs_cache: 64 | bands = [0, min(self.fs, self.fp), max(self.fs, self.fp), FS / 2] 65 | gains = [int(self.fp < self.fs), int(self.fp > self.fs)] 66 | b, a = signal.remez(self.order, bands, gains, Hz=FS), [1] 67 | self._coefs_cache[FS] = (b, a) 68 | else: 69 | b, a = self._coefs_cache[FS] 70 | return b, a 71 | 72 | def __call__(self, x, FS): 73 | b, a = self._design_filter(FS) 74 | return signal.filtfilt(b, a, x) 75 | 76 | 77 | class Filter(object): 78 | def __init__(self, fpass, fstop, gpass=1, gstop=10, ftype='butter'): 79 | self.ftype = ftype 80 | self.fp = np.asarray(fpass) 81 | self.fs = np.asarray(fstop) 82 | self._coefs_cache = {} 83 | self.gstop = gstop 84 | self.gpass = gpass 85 | 86 | def _design_filter(self, FS): 87 | if not FS in self._coefs_cache: 88 | wp, ws = self.fp * 2 / FS, self.fs * 2 / FS 89 | b, a = signal.iirdesign(wp=wp, 90 | ws=ws, 91 | gstop=self.gstop, 92 | gpass=self.gpass, 93 | ftype=self.ftype) 94 | self._coefs_cache[FS] = (b, a) 95 | else: 96 | b, a = self._coefs_cache[FS] 97 | return b, a 98 | 99 | def __call__(self, x, FS): 100 | b, a = self._design_filter(FS) 101 | return signal.filtfilt(b, a, x) 102 | 103 | 104 | def filter_proxy(spikes, filter_obj, chunksize=1E6): 105 | """Proxy object to read filtered data 106 | 107 | Parameters 108 | ---------- 109 | spikes : dict 110 | unfiltered raw recording 111 | filter_object : object 112 | Filter to filter the data 113 | chunksize : int 114 | size of segments in which data is filtered 115 | 116 | Returns 117 | ------- 118 | sp_dict : dict 119 | filtered recordings 120 | """ 121 | data = spikes['data'] 122 | sp_dict = spikes.copy() 123 | 124 | if filter_obj is None: 125 | return spikes 126 | 127 | filename = tempfile.mktemp(suffix='.h5') 128 | atom = tables.Atom.from_dtype(np.dtype('float64')) 129 | shape = data.shape 130 | h5f = tables.openFile(filename, 'w') 131 | carray = h5f.createCArray('/', 'test', atom, shape) 132 | 133 | _open_files[filename] = h5f 134 | 135 | chunksize = int(chunksize) 136 | n_chunks = int(np.ceil(shape[1] * 1.0 / chunksize)) 137 | for i in range(shape[0]): 138 | for j in range(n_chunks): 139 | stop = int(np.min(((j + 1) * chunksize, shape[1]))) 140 | carray[i, j * chunksize:stop] = filter_obj( 141 | data[i, j * chunksize:stop], sp_dict['FS']) 142 | sp_dict['data'] = carray 143 | return sp_dict 144 | 145 | 146 | def fltLinearIIR(signal, fpass, fstop, gpass=1, gstop=10, ftype='butter'): 147 | """An IIR acausal linear filter. Works through 148 | spike_sort.core.filters.filter_proxy method 149 | 150 | Parameters 151 | ---------- 152 | signal : dict 153 | input [raw] signal 154 | fpass, fstop : float 155 | Passband and stopband edge frequencies [Hz] 156 | For more details see scipy.signal.iirdesign 157 | gpass : float 158 | The maximum loss in the passband (dB). 159 | gstop : float 160 | The minimum attenuation in the stopband (dB). 161 | ftype : str, optional 162 | The type of IIR filter to design: 163 | 164 | - elliptic : 'ellip' 165 | - Butterworth : 'butter', 166 | - Chebyshev I : 'cheby1', 167 | - Chebyshev II: 'cheby2', 168 | - Bessel : 'bessel' 169 | """ 170 | filt = Filter(fpass, fstop, gpass, gstop, ftype) 171 | return filter_proxy(signal, filt) 172 | 173 | 174 | def clean_after_exit(): 175 | for fname, fid in _open_files.items(): 176 | fid.close() 177 | try: 178 | os.remove(fname) 179 | except OSError: 180 | pass 181 | _open_files.clear() 182 | 183 | atexit.register(clean_after_exit) 184 | -------------------------------------------------------------------------------- /src/spike_sort/ui/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from matplotlib.collections import LineCollection 7 | 8 | # used by other modules which import plotting.py. Do not remove! 9 | from matplotlib.pyplot import show, figure, close 10 | 11 | import spike_sort 12 | cmap = plt.cm.jet 13 | 14 | import _mpl_helpers # performance boosters 15 | 16 | 17 | def label_color(labels): 18 | """Map labels to number range [0, 1]""" 19 | 20 | num_labels = np.linspace(0, 1., len(labels)) 21 | mapper = dict(zip(labels, num_labels)) 22 | 23 | @np.vectorize 24 | def map_func(lab): 25 | return mapper[lab] 26 | 27 | def color_func(lab): 28 | return cmap(map_func(lab)) 29 | return color_func 30 | 31 | 32 | def plot_spikes(spikes, clust_idx=None, show_cells='all', **kwargs): 33 | """Plot Spike waveshapes 34 | 35 | Parameters 36 | ---------- 37 | spike_data : dict 38 | clust_idx : sequence 39 | sequence of the length equal to the number of spikes; labels of 40 | clusters to which spikes belong 41 | show_cells : list or 'all' 42 | list of identifiers of clusters (cells) to plot 43 | plot_avg: bool 44 | if True plot waveform averages 45 | 46 | Returns 47 | ------- 48 | lines_segments : object 49 | matplotlib line collection of spike waveshapes 50 | """ 51 | 52 | if clust_idx is None: 53 | spikegraph(spikes, **kwargs) 54 | else: 55 | spikes_cell = spike_sort.extract.split_cells(spikes, clust_idx) 56 | 57 | if show_cells == 'all': 58 | labs = spikes_cell.keys() 59 | else: 60 | labs = show_cells 61 | 62 | color_func = label_color(spikes_cell.keys()) 63 | for l in labs: 64 | spikegraph(spikes_cell[l], color_func(l), **kwargs) 65 | 66 | 67 | def spikegraph(spike_data, color='k', alpha=0.2, n_spikes='all', 68 | contacts='all', plot_avg=True, fig=None): 69 | 70 | spikes = spike_data['data'] 71 | time = spike_data['time'] 72 | 73 | if contacts == 'all': 74 | contacts = np.arange(spikes.shape[2]) 75 | 76 | n_pts = len(time) 77 | 78 | if not n_spikes == 'all': 79 | sample_idx = np.argsort(np.random.randn(spikes.shape[1]))[:n_spikes] 80 | spikes = spikes[:, sample_idx, :] 81 | n_spikes = spikes.shape[1] 82 | if fig is None: 83 | fig = plt.gcf() 84 | line_segments = [] 85 | for i, contact_id in enumerate(contacts): 86 | ax = fig.add_subplot(2, 2, i + 1) 87 | # ax.set_xlim(time.min(), time.max()) 88 | # ax.set_ylim(spikes.min(), spikes.max()) 89 | 90 | segs = np.zeros((n_spikes, n_pts, 2)) 91 | segs[:, :, 0] = time[np.newaxis, :] 92 | segs[:, :, 1] = spikes[:, :, contact_id].T 93 | collection = LineCollection(segs, colors=color, 94 | alpha=alpha) 95 | line_segments.append(collection) 96 | ax.add_collection(collection, autolim=True) 97 | 98 | if plot_avg: 99 | spikes_mean = spikes[:, :, i].mean(1) 100 | ax.plot(time, spikes_mean, color='w', lw=3) 101 | ax.plot(time, spikes_mean, color=color, lw=2) 102 | ax.autoscale_view(tight=True) 103 | 104 | return line_segments 105 | 106 | 107 | def plot_features(features, clust_idx=None, show_cells='all', **kwargs): 108 | """Plot features and their histograms 109 | 110 | Parameters 111 | ---------- 112 | features_dict : dict 113 | features data structure 114 | clust_idx : array or None 115 | array of size (n_spikes,) containing indices of clusters to which 116 | each spike was classfied 117 | show_cells : list or 'all' 118 | list of identifiers of clusters (cells) to plot 119 | 120 | """ 121 | 122 | if clust_idx is None: 123 | featuresgraph(features, **kwargs) 124 | else: 125 | features_cell = spike_sort.features.split_cells(features, clust_idx) 126 | 127 | if show_cells == 'all': 128 | labs = features_cell.keys() 129 | else: 130 | labs = show_cells 131 | 132 | color_func = label_color(features_cell.keys()) 133 | for l in labs: 134 | featuresgraph(features_cell[l], color_func(l), **kwargs) 135 | 136 | 137 | def featuresgraph(features_dict, color='k', size=1, datarange=None, fig=None, n_spikes='all'): 138 | 139 | features = features_dict['data'] 140 | names = features_dict['names'] 141 | 142 | _, n_feats = features.shape 143 | if fig is None: 144 | fig = plt.gcf() 145 | axes = [[fig.add_subplot(n_feats, n_feats, i * n_feats + j + 1, projection='thin') 146 | for i in range(n_feats)] for j in range(n_feats)] 147 | 148 | if not n_spikes == 'all': 149 | sample_idx = np.argsort(np.random.randn(features.shape[0]))[:n_spikes] 150 | features = features[sample_idx, :] 151 | 152 | for i in range(n_feats): 153 | for j in range(n_feats): 154 | ax = axes[i][j] 155 | if i != j: 156 | ax.plot(features[:, i], 157 | features[:, j], ".", 158 | color=color, markersize=size) 159 | if datarange: 160 | ax.set_xlim(datarange) 161 | 162 | else: 163 | ax.set_frame_on(False) 164 | n, bins = np.histogram(features[:, i], 20, datarange, normed=True) 165 | ax.plot(bins[:-1], n, '-', color=color, drawstyle='steps') 166 | if datarange: 167 | ax.set_xlim(datarange) 168 | ax.set_xticks([]) 169 | ax.set_yticks([]) 170 | ax.set_xlabel(names[i]) 171 | ax.set_ylabel(names[j]) 172 | ax.xaxis.set_label_position("top") 173 | ax.xaxis.label.set_visible(False) 174 | ax.yaxis.label.set_visible(False) 175 | 176 | for i in range(n_feats): 177 | ax = axes[i][0] 178 | ax.xaxis.label.set_visible(True) 179 | ax = axes[0][i] 180 | ax.yaxis.label.set_visible(True) 181 | 182 | 183 | def legend(labels, colors=None, ax=None): 184 | 185 | if ax is None: 186 | ax = plt.gca() 187 | if colors is None: 188 | color_func = label_color(labels) 189 | colors = [color_func(i) for i in labels] 190 | 191 | ax.set_frame_on(False) 192 | n_classes = len(labels) 193 | x, y = np.zeros(n_classes) + 0.4, 0.1 * np.arange(n_classes) 194 | ax.scatter(x, y, c=colors, marker='s', edgecolors="none", s=100) 195 | ax.set_xlim([0, 1]) 196 | ax.set_xticks([]) 197 | ax.set_yticks([]) 198 | 199 | for i, l in enumerate(labels): 200 | ax.text(x[i] + 0.1, y[i], "Cell {0}".format(l), va='center', ha='left', 201 | transform=ax.transData) 202 | -------------------------------------------------------------------------------- /src/spike_sort/stats/diptst/diptst.f: -------------------------------------------------------------------------------- 1 | SUBROUTINE DIPTST1(X,N,DIP,XL,XU,IFAULT,GCM,LCM,MN,MJ,DDX,DDXSGN) 2 | C 3 | C ALGORITHM AS 217 APPL. STATIST. (1985) VOL.34, NO.3 4 | C 5 | C Does the dip calculation for an ordered vector X using the 6 | C greatest convex minorant and the least concave majorant, skipping 7 | C through the data using the change points of these distributions. 8 | C It returns the dip statistic 'DIP' and the modal interval 9 | C (XL, XU). 10 | C 11 | C MODIFICATIONS SEP 2 2002 BY F. MECHLER TO FIX PROBLEMS WITH 12 | C UNIMODAL (INCLUDING MONOTONIC) INPUT 13 | C 14 | REAL X(N) 15 | INTEGER MN(N), MJ(N), LCM(N), GCM(N), HIGH 16 | REAL ZERO, HALF, ONE 17 | C NEXT TWO LINES ARE ADDED 18 | REAL DDX(N) 19 | INTEGER DDXSGN(N), POSK, NEGK 20 | DATA ZERO/0.0/, HALF/0.5/, ONE/1.0/ 21 | C 22 | IFAULT = 1 23 | IF (N .LE. 0) RETURN 24 | IFAULT = 0 25 | C 26 | C Check if N = 1 27 | C 28 | IF (N .EQ. 1) GO TO 4 29 | C 30 | C Check that X is sorted 31 | C 32 | IFAULT = 2 33 | DO 3 K = 2, N 34 | IF (X(K) .LT. X(K-1)) RETURN 35 | 3 CONTINUE 36 | IFAULT = 0 37 | C 38 | C Check for all values of X identical, 39 | C and for 1 < N < 4. 40 | C 41 | IF (X(N) .GT. X(1) .AND. N .GE. 4) GO TO 505 42 | 4 XL = X(1) 43 | XU = X(N) 44 | DIP = ZERO 45 | RETURN 46 | C The code amendment below is intended to be inseted above the line marked "5" in the original FORTRAN code 47 | C The amendment checks the condition whether the input X is perfectly unimodal 48 | C Hartigan's original DIPTST algorithm did not check for this condition 49 | C and DIPTST runs into an infinite cycle for a unimodal input 50 | C The condition that the input is unimodal is equivalent to having 51 | C at most 1 sign change in the second derivative of the input p.d.f. 52 | C In MATLAB syntax, we check the flips in the function xsign=-sign(diff(1./diff(x)))=-sign(diff(diff(x))); 53 | C with DDXSGN=xsign in the fortran code below 54 | 505 NEGK=0 55 | POSK=0 56 | DO 104 K = 3,N 57 | DDX(K) = X(K)+X(K-2)-2*X(K-1) 58 | IF (DDX(K) .LT. 0) DDXSGN(K) = 1 59 | IF (DDX(K) .EQ. 0) DDXSGN(K) = 0 60 | IF (DDX(K) .GT. 0) DDXSGN(K) = -1 61 | IF (DDXSGN(K) .GT. 0) POSK = K 62 | IF ((DDXSGN(K) .LT. 0) .AND. (NEGK .EQ. 0)) NEGK = K 63 | 104 CONTINUE 64 | 65 | C The condition check below examines whether the greatest position with a positive second derivative 66 | C is smaller than the smallest position with a negative second derivative 67 | C The boolean check gets it right even if 68 | C the unimodal p.d.f. has its mode in the very first or last point of the input 69 | 70 | IF ((POSK .GT. NEGK) .AND. (NEGK .GT. 0)) GOTO 5 71 | XL=X(1) 72 | XU=X(N) 73 | DIP=0 74 | IFAULT=5 75 | RETURN 76 | C 77 | C LOW contains the index of the current estimate of the lower end 78 | C of the modal interval, HIGH contains the index for the upper end. 79 | C 80 | 5 FN = FLOAT(N) 81 | LOW = 1 82 | HIGH = N 83 | DIP = ONE / FN 84 | XL = X(LOW) 85 | XU = X(HIGH) 86 | C 87 | C Establish the indices over which combination is necessary for the 88 | C convex minorant fit. 89 | C 90 | MN(1) = 1 91 | DO 28 J = 2, N 92 | MN(J) = J - 1 93 | 25 MNJ = MN(J) 94 | MNMNJ = MN(MNJ) 95 | A = FLOAT(MNJ - MNMNJ) 96 | B = FLOAT(J - MNJ) 97 | IF (MNJ .EQ. 1 .OR. (X(J) - X(MNJ))*A .LT. (X(MNJ) - X(MNMNJ)) 98 | + *B) GO TO 28 99 | MN(J) = MNMNJ 100 | GO TO 25 101 | 28 CONTINUE 102 | C 103 | C Establish the indices over which combination is necessary for the 104 | C concave majorant fit. 105 | C 106 | MJ(N) = N 107 | NA = N - 1 108 | DO 34 JK = 1, NA 109 | K = N - JK 110 | MJ(K) = K + 1 111 | 32 MJK = MJ(K) 112 | MJMJK = MJ(MJK) 113 | A = FLOAT(MJK - MJMJK) 114 | B = FLOAT(K - MJK) 115 | IF (MJK .EQ. N .OR. (X(K) - X(MJK))*A .LT. (X(MJK) - X(MJMJK)) 116 | + *B) GO TO 34 117 | MJ(K) = MJMJK 118 | GO TO 32 119 | 34 CONTINUE 120 | C 121 | C Start the cycling. 122 | C Collect the change points for the GCM from HIGH to LOW. 123 | C 124 | 40 IC = 1 125 | GCM(1) = HIGH 126 | 42 IGCM1 = GCM(IC) 127 | IC = IC + 1 128 | GCM(IC) = MN(IGCM1) 129 | IF (GCM(IC) .GT. LOW) GO TO 42 130 | ICX = IC 131 | C 132 | C Collect the change points for the LCM from LOW to HIGH. 133 | C 134 | IC = 1 135 | LCM(1) = LOW 136 | 44 LCM1 = LCM(IC) 137 | IC = IC + 1 138 | LCM(IC) = MJ(LCM1) 139 | IF (LCM(IC) .LT. HIGH) GO TO 44 140 | ICV = IC 141 | C 142 | C ICX, IX, IG are counters for the convex minorant, 143 | C ICV, IV, IH are counters for the concave majorant. 144 | C 145 | IG = ICX 146 | IH = ICV 147 | C 148 | C Find the largest distance greater than 'DIP' between the GCM and 149 | C the LCM from LOW to HIGH. 150 | C 151 | IX = ICX - 1 152 | IV = 2 153 | D = ZERO 154 | IF (ICX .NE. 2 .OR. ICV .NE. 2) GO TO 50 155 | D = ONE / FN 156 | GO TO 65 157 | 50 IGCMX = GCM(IX) 158 | LCMIV = LCM(IV) 159 | IF (IGCMX .GT. LCMIV) GO TO 55 160 | C 161 | C If the next point of either the GCM or LCM is from the LCM, 162 | C calculate the distance here. 163 | C 164 | LCMIV1 = LCM(IV - 1) 165 | A = FLOAT(LCMIV - LCMIV1) 166 | B = FLOAT(IGCMX - LCMIV1 - 1) 167 | DX = (X(IGCMX) - X(LCMIV1))*A / (FN*(X(LCMIV) - X(LCMIV1))) 168 | + - B / FN 169 | IX = IX - 1 170 | IF (DX .LT. D) GO TO 60 171 | D = DX 172 | IG = IX + 1 173 | IH = IV 174 | GO TO 60 175 | C 176 | C If the next point of either the GCM or LCM is from the GCM, 177 | C calculate the distance here. 178 | C 179 | 55 LCMIV = LCM(IV) 180 | IGCM = GCM(IX) 181 | IGCM1 = GCM(IX + 1) 182 | A = FLOAT(LCMIV - IGCM1 + 1) 183 | B = FLOAT(IGCM - IGCM1) 184 | DX = A / FN - ((X(LCMIV) - X(IGCM1))*B) / (FN * (X(IGCM) 185 | + - X(IGCM1))) 186 | IV = IV + 1 187 | IF (DX .LT. D) GO TO 60 188 | D = DX 189 | IG = IX + 1 190 | IH = IV - 1 191 | 60 IF (IX .LT. 1) IX = 1 192 | IF (IV .GT. ICV) IV = ICV 193 | IF (GCM(IX) .NE. LCM(IV)) GO TO 50 194 | 65 IF (D .LT. DIP) GO TO 100 195 | C 196 | C Calculate the DIPs for the current LOW and HIGH. 197 | C 198 | C The DIP for the convex minorant. 199 | C 200 | DL = ZERO 201 | IF (IG .EQ. ICX) GO TO 80 202 | ICXA = ICX - 1 203 | DO 76 J = IG, ICXA 204 | TEMP = ONE / FN 205 | JB = GCM(J + 1) 206 | JE = GCM(J) 207 | IF (JE - JB .LE. 1) GO TO 74 208 | IF (X(JE) .EQ. X(JB)) GO TO 74 209 | A = FLOAT(JE - JB) 210 | CONST = A / (FN * (X(JE) - X(JB))) 211 | DO 72 JR = JB, JE 212 | B = FLOAT(JR - JB + 1) 213 | T = B / FN - (X(JR) - X(JB))*CONST 214 | IF (T .GT. TEMP) TEMP = T 215 | 72 CONTINUE 216 | 74 IF (DL .LT. TEMP) DL = TEMP 217 | 76 CONTINUE 218 | C 219 | C The DIP for the concave majorant. 220 | C 221 | 80 DU = ZERO 222 | IF (IH .EQ. ICV) GO TO 90 223 | ICVA = ICV - 1 224 | DO 88 K = IH, ICVA 225 | TEMP = ONE / FN 226 | KB = LCM(K) 227 | KE = LCM(K + 1) 228 | IF (KE - KB .LE. 1) GO TO 86 229 | IF (X(KE) .EQ. X(KB)) GO TO 86 230 | A = FLOAT(KE - KB) 231 | CONST = A / (FN * (X(KE) - X(KB))) 232 | DO 84 KR = KB, KE 233 | B = FLOAT(KR - KB - 1) 234 | T = (X(KR) - X(KB))*CONST - B / FN 235 | IF (T .GT. TEMP) TEMP = T 236 | 84 CONTINUE 237 | 86 IF (DU .LT. TEMP) DU = TEMP 238 | 88 CONTINUE 239 | C 240 | C Determine the current maximum. 241 | C 242 | 90 DIPNEW = DL 243 | IF (DU .GT. DL) DIPNEW = DU 244 | IF (DIP .LT. DIPNEW) DIP = DIPNEW 245 | LOW = GCM(IG) 246 | HIGH = LCM(IH) 247 | C 248 | C Recycle 249 | C 250 | GO TO 40 251 | C 252 | 100 DIP = HALF * DIP 253 | XL = X(LOW) 254 | XU = X(HIGH) 255 | C 256 | RETURN 257 | END 258 | 259 | -------------------------------------------------------------------------------- /src/spike_sort/core/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | import extract 6 | import cluster 7 | import warnings 8 | 9 | 10 | def deprecation(message): 11 | warnings.warn(message, DeprecationWarning, stacklevel=2) 12 | 13 | 14 | def snr_spike(spike_waves, scale=5.): 15 | """Estimate signal-to-noise ratio (SNR) as a ratio of 16 | peak-to-peak amplitude of an average spike to the std. deviation 17 | of residuals 18 | 19 | Parameters 20 | ---------- 21 | spike_waves : dict 22 | 23 | Returns 24 | ------- 25 | snr : float 26 | signal to noise ratio 27 | """ 28 | sp_data = spike_waves['data'] 29 | avg_spike = sp_data.mean(axis=1) 30 | peak_to_peak = np.ptp(avg_spike, axis=None) 31 | residuals = sp_data - avg_spike[:, np.newaxis] 32 | noise_std = np.sqrt(residuals.var()) 33 | snr = peak_to_peak / (noise_std * scale) 34 | return snr 35 | 36 | 37 | def snr_clust(spike_waves, noise_waves): 38 | """Calculate signal-to-noise ratio. 39 | 40 | Comparing average P2P amplitude of spike cluster to noise cluster 41 | (randomly selected signal segments) 42 | 43 | Parameters 44 | ---------- 45 | spike_waves : dict 46 | noise_waves : dict 47 | 48 | Returns 49 | ------- 50 | snr : float 51 | signal-to-noise ratio 52 | """ 53 | _calc_p2p = lambda data: np.ptp(data, axis=0).mean() 54 | 55 | sp_data = spike_waves['data'] 56 | avg_p2p_spk = _calc_p2p(sp_data) 57 | 58 | noise_data = noise_waves['data'] 59 | avg_p2p_ns = _calc_p2p(noise_data) 60 | 61 | snr = avg_p2p_spk / avg_p2p_ns 62 | 63 | return snr 64 | 65 | 66 | def extract_noise_cluster(sp, spt, sp_win, type="positive"): 67 | 68 | deprecation("extract_noise_cluster deprecated. Use" 69 | " detect_noise and extract_spikes instead.") 70 | spt_noise = detect_noise(sp, spt, sp_win, type) 71 | sp_waves = extract.extract_spikes(sp, spt_noise, sp_win) 72 | 73 | return sp_waves 74 | 75 | 76 | def rand_sample_spt(spt, max_spikes): 77 | n_spikes = len(spt['data']) 78 | spt_data = spt['data'] 79 | spt_new = spt.copy() 80 | if max_spikes and n_spikes > max_spikes: 81 | i = np.random.rand(n_spikes).argsort()[:max_spikes] 82 | spt_new['data'] = spt_data[i] 83 | return spt_new 84 | 85 | 86 | def detect_noise(sp, spt, sp_win, type="positive", max_spikes=None, 87 | resample=1): 88 | """Find noisy spikes""" 89 | 90 | spike_waves = extract.extract_spikes(sp, spt, sp_win) 91 | 92 | if type == "positive": 93 | threshold = calc_noise_threshold(spike_waves, 1) 94 | spt_noise = extract.detect_spikes(sp, threshold, 'rising') 95 | spt_noise = rand_sample_spt(spt_noise, max_spikes) 96 | spt_noise = extract.remove_spikes(spt_noise, spt, sp_win) 97 | spt_noise = extract.align_spikes(sp, spt_noise, sp_win, 'max', 98 | resample=resample) 99 | else: 100 | threshold = calc_noise_threshold(spike_waves, -1) 101 | spt_noise = extract.detect_spikes(sp, threshold, 'falling') 102 | spt_noise = rand_sample_spt(spt_noise, max_spikes) 103 | spt_noise = extract.remove_spikes(spt_noise, spt, sp_win) 104 | spt_noise = extract.align_spikes(sp, spt_noise, sp_win, 'min', 105 | resample=resample) 106 | return spt_noise 107 | 108 | 109 | def calc_noise_threshold(spike_waves, sign=1, frac_spikes=0.02, frac_max=0.5): 110 | """ Find threshold to extract noise cluster. 111 | 112 | According to algorithm described in Joshua et al. (2007) 113 | 114 | Parameters 115 | ---------- 116 | spike_waves : dict 117 | waveshapes of spikes from the identified single unit (extracted 118 | with extract.extrac_spikes) 119 | 120 | sign : int 121 | sign should be negative for negative-going spikes and positive for 122 | postitive-going spikes. Note that only sign of this number is 123 | taken into account. 124 | 125 | frac_spikes : float, optional 126 | fraction of largest (smallest) spikes to calculate the threshold 127 | from (default 0.02) 128 | 129 | frac_max : float 130 | fraction of the average peak amplitude to use as a treshold 131 | 132 | 133 | Returns 134 | ------- 135 | threshold : float 136 | threshold to obtain a noise cluster""" 137 | 138 | gain = np.sign(sign) 139 | 140 | peak_amp = np.max(gain*spike_waves['data'], 0) 141 | frac_spikes = 0.02 142 | frac_max = 0.5 143 | peak_amp.sort() 144 | threshold = frac_max*np.mean(peak_amp[:int(frac_spikes*len(peak_amp))]) 145 | threshold *= gain 146 | 147 | return threshold 148 | 149 | 150 | def isolation_score(sp, spt, sp_win, spike_type='positive', lam=10., 151 | max_spikes=None): 152 | "calculate spike isolation score from raw data and spike times" 153 | 154 | spike_waves = extract.extract_spikes(sp, spt, sp_win) 155 | spt_noise = detect_noise(sp, spt, sp_win, spike_type) 156 | noise_waves = extract.extract_spikes(sp, spt_noise, sp_win) 157 | 158 | iso_score = calc_isolation_score(spike_waves, noise_waves, 159 | spike_type, 160 | lam=lam, 161 | max_spikes=max_spikes) 162 | 163 | return iso_score 164 | 165 | 166 | def _iso_score_dist(dist, lam, n_spikes): 167 | 168 | """Calculate isolation score from a distance matrix 169 | 170 | Parameters 171 | ---------- 172 | dist : array 173 | NxM matrix, where N is number of spikes and M is number of all 174 | events (spikes + noise) 175 | 176 | lam : float 177 | lambda parameter 178 | 179 | n_spikes : int 180 | number of spikes (N) 181 | 182 | Returns 183 | ------- 184 | isolation_score : float 185 | """ 186 | 187 | distSS = dist[:, :n_spikes] 188 | distSN = dist[:, n_spikes:] 189 | 190 | d0 = distSS.mean() 191 | expSS = np.exp(-distSS*lam*1./d0) 192 | expSN = np.exp(-distSN*lam*1./d0) 193 | 194 | sumSS = np.sum(expSS - np.eye(n_spikes), 1) 195 | sumSN = np.sum(expSN, 1) 196 | 197 | correctProbS = sumSS / (sumSS + sumSN) 198 | isolation_score = correctProbS.mean() 199 | 200 | return isolation_score 201 | 202 | 203 | def calc_isolation_score(spike_waves, noise_waves, spike_type='positive', 204 | lam=10., max_spikes=None): 205 | """Calculate isolation index according to Joshua et al. (2007) 206 | 207 | Parameters 208 | ---------- 209 | spike_waves : dict 210 | noise_waves : dict 211 | sp_win : list or tuple 212 | window used for spike extraction 213 | spike_type : {'positive', 'negative'} 214 | indicates if the spikes occuring at time points given by spt are 215 | positive or negative going 216 | lambda : float 217 | determines the "softness" of clusters 218 | 219 | Returns 220 | ------- 221 | isolation_score : float 222 | a value from the range [0,1] indicating the quality of sorting 223 | (1=ideal isolation of spikes) 224 | """ 225 | 226 | #Memory issue: sample spikes if too many 227 | if max_spikes is not None: 228 | if spike_waves['data'].shape[1] > max_spikes: 229 | i = np.random.rand(max_spikes).argsort() 230 | spike_waves = spike_waves.copy() 231 | spike_waves['data'] = spike_waves['data'][:, i] 232 | if noise_waves['data'].shape[1] > max_spikes: 233 | i = np.random.rand(max_spikes).argsort() 234 | noise_waves = noise_waves.copy() 235 | noise_waves['data'] = noise_waves['data'][:, i] 236 | 237 | n_spikes = spike_waves['data'].shape[1] 238 | 239 | #calculate distance between spikes and all other events 240 | all_waves = {'data': np.concatenate((spike_waves['data'], 241 | noise_waves['data']), 1)} 242 | dist_matrix = cluster.dist_euclidean(spike_waves, all_waves) 243 | #d_0 = dist_matrix[:,:n_spikes].mean() 244 | 245 | isolation_score = _iso_score_dist(dist_matrix, lam, 246 | n_spikes) 247 | 248 | return isolation_score 249 | -------------------------------------------------------------------------------- /src/spike_beans/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | The intial version of this code was adapted from a recipe by Zoran Isailovski 3 | (published under PSF License). 4 | 5 | http://code.activestate.com/recipes/413268-dependency-injection-the-python-way/ 6 | """ 7 | 8 | import logging 9 | 10 | 11 | class FeatureBroker(object): 12 | def __init__(self, allowReplace=False): 13 | self.providers = {} 14 | self.allowReplace = allowReplace 15 | 16 | def Provide(self, feature, provider, *args, **kwargs): 17 | if not self.allowReplace: 18 | assert feature not in self.providers, \ 19 | "Duplicate feature: %r" % feature 20 | 21 | if callable(provider): 22 | call = lambda: provider(*args, **kwargs) 23 | else: 24 | call = lambda: provider 25 | self.providers[feature] = call 26 | 27 | def __getitem__(self, feature): 28 | if not feature in self.providers: 29 | raise AttributeError("Unknown feature named %r" % feature) 30 | else: 31 | return self.providers[feature]() 32 | 33 | def __setitem__(self, feature, component): 34 | self.Provide(feature, component) 35 | 36 | def __contains__(self, feature): 37 | return feature in self.providers 38 | 39 | 40 | features = FeatureBroker() 41 | 42 | 43 | def register(feature, component): 44 | """register `component` as providing `feature`""" 45 | features[feature] = component 46 | return component 47 | 48 | ## Representation of Required Features and Feature Assertions 49 | 50 | # Some basic assertions to test the suitability of injected features 51 | 52 | def NoAssertion(): 53 | def test(obj): 54 | return True 55 | return test 56 | 57 | 58 | def IsInstanceOf(*classes): 59 | def test(obj): 60 | return isinstance(obj, classes) 61 | return test 62 | 63 | 64 | def HasAttributes(*attributes): 65 | def test(obj): 66 | for attr_name in attributes: 67 | try: 68 | getattr(obj, attr_name) 69 | except AttributeError: 70 | return False 71 | return True 72 | return test 73 | 74 | 75 | def HasMethods(*methods): 76 | def test(obj): 77 | for each in methods: 78 | try: 79 | attr = getattr(obj, each) 80 | except AttributeError: 81 | return False 82 | if not callable(attr): 83 | return False 84 | return True 85 | return test 86 | 87 | # An attribute descriptor to "declare" required features 88 | 89 | class DataAttribute(object): 90 | """A data descriptor that sets and returns values 91 | normally and notifies on value changed. 92 | """ 93 | 94 | def __init__(self, initval=None, name='var'): 95 | self.val = initval 96 | self.name = name 97 | 98 | def __get__(self, obj, objtype): 99 | return self.val 100 | 101 | def __set__(self, obj, val): 102 | self.val = val 103 | for handler in obj.observers: 104 | handler() 105 | 106 | 107 | class RequiredFeature(object): 108 | """Descriptor class for required dependencies. Implements dependency 109 | injection.""" 110 | def __init__(self, feature, assertion=NoAssertion(), 111 | alt_name="_alternative_"): 112 | """ 113 | Parameters 114 | ---------- 115 | feature : string 116 | name of the associated dependency 117 | assertion : function 118 | additional tests for the associated dependency 119 | alt_name : string 120 | in case of renaming the dependency, the variable named 121 | 'alt_name'+'feature' should contain the new name of the dependency 122 | """ 123 | 124 | self.feature = feature 125 | self.alt_name = alt_name 126 | self.assertion = assertion 127 | self.result = None 128 | 129 | def __get__(self, callee, T): 130 | self.result = self.Request(callee) 131 | return self.result # <-- will request the feature upon first call 132 | 133 | def __set__(self, instance, value): 134 | '''Rename the feature''' 135 | if isinstance(value, str): 136 | logging.info("changed %s to %s" % (self.feature, value)) 137 | setattr(instance, self.alt_name + self.feature, value) 138 | else: 139 | raise TypeError("can't change the feature name to non-string type") 140 | 141 | def __getattr__(self, name): 142 | assert name == 'result', \ 143 | "Unexpected attribute request other then 'result'" 144 | return self.result 145 | 146 | def Request(self, callee): 147 | fet_name = self.feature 148 | if hasattr(callee, self.alt_name + self.feature): 149 | fet_name = getattr(callee, self.alt_name + self.feature) 150 | obj = features[fet_name] 151 | 152 | try: 153 | obj.register_handler(callee) 154 | except AttributeError: 155 | pass 156 | 157 | isComponentCorrect = self.assertion(obj) 158 | assert isComponentCorrect, \ 159 | "The value %r of %r does not match the specified criteria" \ 160 | % (obj, self.feature) 161 | return obj 162 | 163 | class OptionalFeature(RequiredFeature): 164 | """Descriptor class for optional dependencies. Implements dependency 165 | injection. Acquires None if the dependency is not satisfied""" 166 | def Request(self, callee): 167 | fet_name = self.feature 168 | if hasattr(callee, self.alt_name + self.feature): 169 | fet_name = getattr(callee, self.alt_name + self.feature) 170 | 171 | if not fet_name in features: 172 | return None 173 | 174 | return super(OptionalFeature, self).Request(callee) 175 | 176 | class Component(object): 177 | "Symbolic base class for components" 178 | def __init__(self): 179 | self.observers = [] 180 | 181 | @staticmethod 182 | def _rm_duplicate_deps(deps): 183 | new_deps = [] 184 | for i, d in enumerate(deps): 185 | if not d in deps[i + 1:]: 186 | new_deps.append(d) 187 | return new_deps 188 | 189 | def get_dependencies(self): 190 | deps = [o.get_dependencies() for o in self.observers] 191 | deps = sum(deps, self.observers) 192 | deps = Component._rm_duplicate_deps(deps) 193 | return deps 194 | 195 | def register_handler(self, handler): 196 | if handler not in self.observers: 197 | self.observers.append(handler) 198 | 199 | def unregister_handler(self, handler): 200 | if handler in self.observers: 201 | self.observers.remove(handler) 202 | 203 | def notify_observers(self): 204 | for dep in self.get_dependencies(): 205 | dep._update() 206 | 207 | def _update(self): 208 | pass 209 | 210 | def update(self): 211 | self._update() 212 | self.notify_observers() 213 | 214 | 215 | class dictproperty(object): 216 | """implements collection properties with dictionary-like access. 217 | Copied and modified from a recipe by Ed Swierk 218 | published under PSF license 219 | `_ 220 | """ 221 | 222 | class _proxy(object): 223 | def __init__(self, obj, fget, fset, fdel): 224 | self._obj = obj 225 | self._fget = fget 226 | self._fset = fset 227 | self._fdel = fdel 228 | 229 | def __getitem__(self, key): 230 | if self._fget is None: 231 | raise TypeError("can't read item") 232 | return self._fget(self._obj, key) 233 | 234 | def __setitem__(self, key, value): 235 | if self._fset is None: 236 | raise TypeError("can't set item") 237 | self._fset(self._obj, key, value) 238 | 239 | def __delitem__(self, key): 240 | if self._fdel is None: 241 | raise TypeError("can't delete item") 242 | self._fdel(self._obj, key) 243 | 244 | def __init__(self, fget=None, fset=None, fdel=None, doc=None): 245 | self._fget = fget 246 | self._fset = fset 247 | self._fdel = fdel 248 | self.__doc__ = doc 249 | 250 | def __get__(self, obj, objtype=None): 251 | if obj is None: 252 | return self 253 | return self._proxy(obj, self._fget, self._fset, self._fdel) 254 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # SpikeSort documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Jan 27 14:09:31 2011. 5 | # 6 | # This file is execfile()d with the current directory set to its containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | 14 | import sys, os 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | sys.path.append(os.path.abspath('../../src/')) 20 | 21 | # -- General configuration ----------------------------------------------------- 22 | 23 | # If your documentation needs a minimal Sphinx version, state it here. 24 | #needs_sphinx = '1.0' 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be extensions 27 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 28 | extensions = ['sphinx.ext.autodoc', 29 | 'sphinx.ext.doctest', 30 | 'sphinx.ext.coverage', 31 | 'sphinx.ext.pngmath', 32 | 'sphinx.ext.viewcode', 33 | 'sphinx.ext.autosummary', 34 | 'matplotlib.sphinxext.only_directives', 35 | 'matplotlib.sphinxext.plot_directive', 36 | 'numpydoc' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix of source filenames. 43 | source_suffix = '.rst' 44 | 45 | # The encoding of source files. 46 | #source_encoding = 'utf-8-sig' 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | #automatically generate autosummary stub files 52 | #autosummary_generate = True 53 | 54 | # General information about the project. 55 | project = u'SpikeSort' 56 | copyright = u'2011, Bartosz Telenczuk, Dmytro Bielievstsov' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = '0.13' 64 | # The full version, including alpha/beta/rc tags. 65 | release = '0.13dev' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | #language = None 70 | 71 | # There are two options for replacing |today|: either, you set today to some 72 | # non-false value, then it is used: 73 | #today = '' 74 | # Else, today_fmt is used as the format for a strftime call. 75 | #today_fmt = '%B %d, %Y' 76 | 77 | # List of patterns, relative to source directory, that match files and 78 | # directories to ignore when looking for source files. 79 | exclude_patterns = [] 80 | 81 | # The reST default role (used for this markup: `text`) to use for all documents. 82 | #default_role = None 83 | 84 | # If true, '()' will be appended to :func: etc. cross-reference text. 85 | #add_function_parentheses = True 86 | 87 | # If true, the current module name will be prepended to all description 88 | # unit titles (such as .. function::). 89 | #add_module_names = True 90 | 91 | # If true, sectionauthor and moduleauthor directives will be shown in the 92 | # output. They are ignored by default. 93 | #show_authors = False 94 | 95 | # The name of the Pygments (syntax highlighting) style to use. 96 | pygments_style = 'sphinx' 97 | 98 | # A list of ignored prefixes for module index sorting. 99 | #modindex_common_prefix = [] 100 | 101 | 102 | # -- Options for HTML output --------------------------------------------------- 103 | 104 | # The theme to use for HTML and HTML Help pages. See the documentation for 105 | # a list of builtin themes. 106 | html_theme = 'flask' 107 | 108 | # Theme options are theme-specific and customize the look and feel of a theme 109 | # further. For a list of options available for each theme, see the 110 | # documentation. 111 | html_theme_options = {'index_logo': 'logo.png', 112 | 'index_logo_height': '150px'} 113 | 114 | # Add any paths that contain custom themes here, relative to this directory. 115 | html_theme_path = ['_themes'] 116 | 117 | # The name for this set of Sphinx documents. If None, it defaults to 118 | # " v documentation". 119 | #html_title = None 120 | 121 | # A shorter title for the navigation bar. Default is the same as html_title. 122 | #html_short_title = None 123 | 124 | # The name of an image file (relative to this directory) to place at the top 125 | # of the sidebar. 126 | #html_logo = None 127 | 128 | # The name of an image file (within the static path) to use as favicon of the 129 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 130 | # pixels large. 131 | #html_favicon = None 132 | 133 | # Add any paths that contain custom static files (such as style sheets) here, 134 | # relative to this directory. They are copied after the builtin static files, 135 | # so a file named "default.css" will overwrite the builtin "default.css". 136 | html_static_path = ['_static'] 137 | 138 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 139 | # using the given strftime format. 140 | #html_last_updated_fmt = '%b %d, %Y' 141 | 142 | # If true, SmartyPants will be used to convert quotes and dashes to 143 | # typographically correct entities. 144 | #html_use_smartypants = True 145 | 146 | # Custom sidebar templates, maps document names to template names. 147 | #html_sidebars = {} 148 | 149 | # Additional templates that should be rendered to pages, maps page names to 150 | # template names. 151 | #html_additional_pages = {} 152 | 153 | # If false, no module index is generated. 154 | #html_domain_indices = True 155 | 156 | # If false, no index is generated. 157 | #html_use_index = True 158 | 159 | # If true, the index is split into individual pages for each letter. 160 | #html_split_index = False 161 | 162 | # If true, links to the reST sources are added to the pages. 163 | #html_show_sourcelink = True 164 | 165 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 166 | #html_show_sphinx = True 167 | 168 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 169 | #html_show_copyright = True 170 | 171 | # If true, an OpenSearch description file will be output, and all pages will 172 | # contain a tag referring to it. The value of this option must be the 173 | # base URL from which the finished HTML is served. 174 | #html_use_opensearch = '' 175 | 176 | # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). 177 | #html_file_suffix = '' 178 | 179 | # Output file base name for HTML help builder. 180 | htmlhelp_basename = 'SpikeSortsdoc' 181 | 182 | 183 | # -- Options for LaTeX output -------------------------------------------------- 184 | 185 | # The paper size ('letter' or 'a4'). 186 | #latex_paper_size = 'letter' 187 | 188 | # The font size ('10pt', '11pt' or '12pt'). 189 | #latex_font_size = '10pt' 190 | 191 | # Grouping the document tree into LaTeX files. List of tuples 192 | # (source start file, target name, title, author, documentclass [howto/manual]). 193 | latex_documents = [ 194 | ('index', 'SpikeSort.tex', u'SpikeSort Documentation', 195 | u'Bartosz Telenczuk', 'manual'), 196 | ] 197 | 198 | # The name of an image file (relative to this directory) to place at the top of 199 | # the title page. 200 | #latex_logo = None 201 | 202 | # For "manual" documents, if this is true, then toplevel headings are parts, 203 | # not chapters. 204 | #latex_use_parts = False 205 | 206 | # If true, show page references after internal links. 207 | #latex_show_pagerefs = False 208 | 209 | # If true, show URL addresses after external links. 210 | #latex_show_urls = False 211 | 212 | # Additional stuff for the LaTeX preamble. 213 | #latex_preamble = '' 214 | 215 | # Documents to append as an appendix to all manuals. 216 | #latex_appendices = [] 217 | 218 | # If false, no module index is generated. 219 | #latex_domain_indices = True 220 | 221 | 222 | # -- Options for manual page output -------------------------------------------- 223 | 224 | # One entry per manual page. List of tuples 225 | # (source start file, name, description, authors, manual section). 226 | man_pages = [ 227 | ('index', 'sortspikes', u'SpikeSorts Documentation', 228 | [u'Bartosz Telenczuk'], 1) 229 | ] 230 | -------------------------------------------------------------------------------- /docs/source/_themes/flask/static/flasky.css_t: -------------------------------------------------------------------------------- 1 | /* 2 | * flasky.css_t 3 | * ~~~~~~~~~~~~ 4 | * 5 | * :copyright: Copyright 2010 by Armin Ronacher. 6 | * :license: Flask Design License, see LICENSE for details. 7 | */ 8 | 9 | {% set page_width = '90%' %} 10 | {% set sidebar_width = '220px' %} 11 | 12 | @import url("basic.css"); 13 | 14 | /* -- page layout ----------------------------------------------------------- */ 15 | 16 | body { 17 | font-family: 'Georgia', serif; 18 | font-size: 17px; 19 | background-color: white; 20 | color: #000; 21 | margin: 0; 22 | padding: 0; 23 | } 24 | 25 | div.document { 26 | width: {{ page_width }}; 27 | margin: 30px auto 0 20px; 28 | } 29 | 30 | div.documentwrapper { 31 | float: left; 32 | width: 100%; 33 | } 34 | 35 | div.bodywrapper { 36 | margin: 0 0 0 {{ sidebar_width }}; 37 | } 38 | 39 | div.sphinxsidebar { 40 | width: {{ sidebar_width }}; 41 | } 42 | 43 | hr { 44 | border: 1px solid #B1B4B6; 45 | } 46 | 47 | div.body { 48 | background-color: #ffffff; 49 | color: #3E4349; 50 | padding: 0 30px 0 30px; 51 | } 52 | 53 | img.floatingflask { 54 | padding: 0 0 10px 10px; 55 | float: right; 56 | } 57 | 58 | div.footer { 59 | width: {{ page_width }}; 60 | margin: 20px auto 30px auto; 61 | font-size: 14px; 62 | color: #888; 63 | text-align: right; 64 | } 65 | 66 | div.footer a { 67 | color: #888; 68 | } 69 | 70 | div.related { 71 | display:inline; 72 | } 73 | 74 | div.related h3 { 75 | display:none; 76 | } 77 | 78 | div.sphinxsidebar a { 79 | color: #444; 80 | text-decoration: none; 81 | border-bottom: 1px dotted #999; 82 | } 83 | 84 | div.sphinxsidebar a:hover { 85 | border-bottom: 1px solid #999; 86 | } 87 | 88 | div.sphinxsidebar { 89 | font-size: 14px; 90 | line-height: 1.5; 91 | } 92 | 93 | div.sphinxsidebarwrapper { 94 | padding: 18px 10px; 95 | } 96 | 97 | div.sphinxsidebarwrapper p.logo { 98 | padding: 0 0 20px 0; 99 | margin: 0; 100 | text-align: center; 101 | } 102 | 103 | div.sphinxsidebar h3, 104 | div.sphinxsidebar h4 { 105 | font-family: 'Garamond', 'Georgia', serif; 106 | color: #444; 107 | font-size: 24px; 108 | font-weight: normal; 109 | margin: 0 0 5px 0; 110 | padding: 0; 111 | } 112 | 113 | div.sphinxsidebar h4 { 114 | font-size: 20px; 115 | } 116 | 117 | div.sphinxsidebar h3 a { 118 | color: #444; 119 | } 120 | 121 | div.sphinxsidebar p.logo a, 122 | div.sphinxsidebar h3 a, 123 | div.sphinxsidebar p.logo a:hover, 124 | div.sphinxsidebar h3 a:hover { 125 | border: none; 126 | } 127 | 128 | div.sphinxsidebar p { 129 | color: #555; 130 | margin: 10px 0; 131 | } 132 | 133 | div.sphinxsidebar ul { 134 | margin: 10px 0; 135 | padding: 0; 136 | color: #000; 137 | } 138 | 139 | div.sphinxsidebar input { 140 | border: 1px solid #ccc; 141 | font-family: 'Georgia', serif; 142 | font-size: 1em; 143 | } 144 | 145 | /* -- body styles ----------------------------------------------------------- */ 146 | 147 | a { 148 | color: #004B6B; 149 | text-decoration: underline; 150 | } 151 | 152 | a:hover { 153 | color: #6D4100; 154 | text-decoration: underline; 155 | } 156 | 157 | div.body h1, 158 | div.body h2, 159 | div.body h3, 160 | div.body h4, 161 | div.body h5, 162 | div.body h6 { 163 | font-family: 'Garamond', 'Georgia', serif; 164 | font-weight: normal; 165 | margin: 30px 0px 10px 0px; 166 | padding: 0; 167 | } 168 | 169 | {% if theme_index_logo %} 170 | div.indexwrapper h1 { 171 | text-indent: -999999px; 172 | background: url({{ theme_index_logo }}) no-repeat center center; 173 | height: {{ theme_index_logo_height }}; 174 | } 175 | {% endif %} 176 | 177 | div.body h1 { margin-top: 0; padding-top: 0; font-size: 240%; } 178 | div.body h2 { font-size: 180%; } 179 | div.body h3 { font-size: 150%; } 180 | div.body h4 { font-size: 130%; } 181 | div.body h5 { font-size: 100%; } 182 | div.body h6 { font-size: 100%; } 183 | 184 | a.headerlink { 185 | color: #ddd; 186 | padding: 0 4px; 187 | text-decoration: none; 188 | } 189 | 190 | a.headerlink:hover { 191 | color: #444; 192 | background: #eaeaea; 193 | } 194 | 195 | div.body p, div.body dd, div.body li { 196 | line-height: 1.4em; 197 | } 198 | 199 | div.admonition { 200 | background: #fafafa; 201 | margin: 20px -30px; 202 | padding: 10px 30px; 203 | border-top: 1px solid #ccc; 204 | border-bottom: 1px solid #ccc; 205 | } 206 | 207 | div.admonition tt.xref, div.admonition a tt { 208 | border-bottom: 1px solid #fafafa; 209 | } 210 | 211 | dd div.admonition { 212 | margin-left: -60px; 213 | padding-left: 60px; 214 | } 215 | 216 | div.admonition p.admonition-title { 217 | font-family: 'Garamond', 'Georgia', serif; 218 | font-weight: normal; 219 | font-size: 24px; 220 | margin: 0 0 10px 0; 221 | padding: 0; 222 | line-height: 1; 223 | } 224 | 225 | div.admonition p.last { 226 | margin-bottom: 0; 227 | } 228 | 229 | div.highlight { 230 | background-color: white; 231 | } 232 | 233 | dt:target, .highlight { 234 | background: #FAF3E8; 235 | } 236 | 237 | div.note { 238 | background-color: #eee; 239 | border: 1px solid #ccc; 240 | } 241 | 242 | div.seealso { 243 | background-color: rgb(240,240,240); 244 | } 245 | 246 | td.field-body strong { 247 | font-weight: normal; 248 | font-style: italic; 249 | } 250 | div.topic { 251 | background-color: #eee; 252 | } 253 | 254 | p.admonition-title { 255 | display: inline; 256 | } 257 | 258 | p.admonition-title:after { 259 | content: ":"; 260 | } 261 | 262 | pre, tt { 263 | font-family: 'Consolas', 'Menlo', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 264 | font-size: 0.9em; 265 | } 266 | 267 | img.screenshot { 268 | } 269 | 270 | tt.descname, tt.descclassname { 271 | font-size: 0.95em; 272 | } 273 | 274 | tt.descname { 275 | padding-right: 0.08em; 276 | } 277 | 278 | img.screenshot { 279 | -moz-box-shadow: 2px 2px 4px #eee; 280 | -webkit-box-shadow: 2px 2px 4px #eee; 281 | box-shadow: 2px 2px 4px #eee; 282 | } 283 | 284 | table.docutils { 285 | border: 1px solid #888; 286 | -moz-box-shadow: 2px 2px 4px #eee; 287 | -webkit-box-shadow: 2px 2px 4px #eee; 288 | box-shadow: 2px 2px 4px #eee; 289 | } 290 | 291 | table.docutils td, table.docutils th { 292 | border: 1px solid #888; 293 | padding: 0.25em 0.7em; 294 | } 295 | 296 | table.field-list, table.footnote { 297 | border: none; 298 | -moz-box-shadow: none; 299 | -webkit-box-shadow: none; 300 | box-shadow: none; 301 | } 302 | 303 | table.footnote { 304 | margin: 15px 0; 305 | width: 100%; 306 | border: 1px solid #eee; 307 | background: #fdfdfd; 308 | font-size: 0.9em; 309 | } 310 | 311 | table.footnote + table.footnote { 312 | margin-top: -15px; 313 | border-top: none; 314 | } 315 | 316 | table.field-list th { 317 | padding: 0 0.8em 0 0; 318 | width: 120px; 319 | } 320 | 321 | table.field-list td { 322 | padding: 0; 323 | } 324 | 325 | table.footnote td.label { 326 | width: 0px; 327 | padding: 0.3em 0 0.3em 0.5em; 328 | } 329 | 330 | table.footnote td { 331 | padding: 0.3em 0.5em; 332 | } 333 | 334 | dl { 335 | margin: 0; 336 | padding: 0; 337 | } 338 | 339 | dl dd { 340 | margin-left: 30px; 341 | } 342 | 343 | dl.method, dl.attribute { 344 | border-top: 1px solid rgb(170,170,170); 345 | } 346 | 347 | dl.class, dl.function { 348 | border-top: 2px solid rgb(136,136,136); 349 | } 350 | 351 | blockquote { 352 | margin: 0 0 0 30px; 353 | padding: 0; 354 | } 355 | 356 | ul, ol { 357 | margin: 10px 0 10px 30px; 358 | padding: 0; 359 | } 360 | 361 | pre { 362 | background: #eee; 363 | padding: 7px 30px; 364 | margin: 15px -30px; 365 | line-height: 1.3em; 366 | } 367 | 368 | dl pre, blockquote pre, li pre { 369 | margin-left: -60px; 370 | padding-left: 60px; 371 | } 372 | 373 | dl dl pre { 374 | margin-left: -90px; 375 | padding-left: 90px; 376 | } 377 | 378 | tt { 379 | background-color: #ecf0f3; 380 | color: #222; 381 | /* padding: 1px 2px; */ 382 | } 383 | 384 | tt.xref, a tt { 385 | background-color: #FBFBFB; 386 | border-bottom: 1px solid white; 387 | } 388 | 389 | a.reference { 390 | text-decoration: none; 391 | border-bottom: 1px dotted #004B6B; 392 | } 393 | 394 | a.reference:hover { 395 | border-bottom: 1px solid #6D4100; 396 | } 397 | 398 | a.footnote-reference { 399 | text-decoration: none; 400 | font-size: 0.7em; 401 | vertical-align: top; 402 | border-bottom: 1px dotted #004B6B; 403 | } 404 | 405 | a.footnote-reference:hover { 406 | border-bottom: 1px solid #6D4100; 407 | } 408 | 409 | a:hover tt { 410 | background: #EEE; 411 | } 412 | -------------------------------------------------------------------------------- /src/spike_sort/core/cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | import numpy as np 5 | 6 | from spike_sort.ui import manual_sort 7 | from spike_sort.core.features import requires 8 | 9 | #optional scikits.learn imports 10 | try: 11 | #import scikits.learn >= 0.9 12 | from sklearn import cluster as skcluster 13 | from sklearn import mixture 14 | from sklearn import decomposition 15 | from sklearn import neighbors 16 | except ImportError: 17 | try: 18 | #import scikits.learn < 0.9 19 | from scikits.learn import cluster as skcluster 20 | from scikits.learn import mixture 21 | from scikits.learn import decomposition 22 | from scikits.learn import neighbors 23 | except ImportError: 24 | pass 25 | 26 | @requires(skcluster, "scikits.learn must be installed to use spectral") 27 | def spectral(data, n_clusters=2, affinity='rbf'): 28 | 29 | sp = skcluster.SpectralClustering(k=n_clusters, affinity=affinity) 30 | sp.fit(data) 31 | labels = sp.labels_ 32 | return labels 33 | 34 | @requires(skcluster, "scikits.learn must be installed to use dbsca") 35 | def dbscan(data, eps=0.3, min_samples=10): 36 | """DBScan clustering 37 | 38 | Parameters 39 | ---------- 40 | data : float array 41 | features array 42 | 43 | Returns 44 | ------- 45 | cl : int array 46 | cluster indicies 47 | 48 | Notes 49 | ----- 50 | This function requires scikits-learn 51 | """ 52 | 53 | db = skcluster.DBSCAN(eps=eps, min_samples=min_samples).fit(data) 54 | labels = db.labels_ 55 | return labels 56 | 57 | @requires(skcluster, "scikits.learn must be installed to use mean_shift") 58 | def mean_shift(data, bandwith=None, n_samples=500, quantile=0.3): 59 | if bandwith is None: 60 | bandwidth = skcluster.estimate_bandwidth(data, 61 | quantile=quantile, 62 | n_samples=n_samples) 63 | 64 | ms = skcluster.MeanShift(bandwidth=bandwidth).fit(data) 65 | labels = ms.labels_ 66 | return labels 67 | 68 | @requires(skcluster, "scikits.learn must be installed to use k_means_plus") 69 | def k_means_plus(data, K=2, whiten=False): 70 | """k means with smart initialization. 71 | 72 | Notes 73 | ----- 74 | This function requires scikits-learn 75 | 76 | See Also 77 | -------- 78 | kmeans 79 | 80 | """ 81 | if whiten: 82 | pca = decomposition.PCA(whiten=True).fit(data) 83 | data = pca.transform(data) 84 | 85 | clusters = skcluster.k_means(data, n_clusters=K)[1] 86 | 87 | return clusters 88 | 89 | 90 | def gmm(data, k=2, cvtype='full'): 91 | """Cluster based on gaussian mixture models 92 | 93 | Parameters 94 | ---------- 95 | data : dict 96 | features structure 97 | k : int 98 | number of clusters 99 | 100 | Returns 101 | ------- 102 | cl : int array 103 | cluster indicies 104 | 105 | Notes 106 | ----- 107 | This function requires scikits-learn 108 | 109 | """ 110 | 111 | try: 112 | #scikits.learn 0.8 113 | clf = mixture.GMM(n_states=k, cvtype=cvtype) 114 | except TypeError: 115 | try: 116 | clf = mixture.GMM(n_components=k, cvtype=cvtype) 117 | except TypeError: 118 | #scikits.learn 0.11 119 | clf = mixture.GMM(n_components=k, covariance_type=cvtype) 120 | except NameError: 121 | raise NotImplementedError( 122 | "scikits.learn must be installed to use gmm") 123 | 124 | clf.fit(data) 125 | cl = clf.predict(data) 126 | return cl 127 | 128 | 129 | def manual(data, n_spikes='all', *args, **kwargs): 130 | """Sort spikes manually by cluster cutting 131 | 132 | Opens a new window in which you can draw cluster of arbitrary 133 | shape. 134 | 135 | Notes 136 | ----- 137 | Only two first features are plotted 138 | """ 139 | if n_spikes=='all': 140 | return manual_sort._cluster(data[:, :2], **kwargs) 141 | else: 142 | idx = np.argsort(np.random.rand(data.shape[0]))[:n_spikes] 143 | labels_subsampled = manual_sort._cluster(data[idx, :2], **kwargs) 144 | try: 145 | neigh = neighbors.KNeighborsClassifier(15) 146 | except NameError: 147 | raise NotImplementedError( 148 | "scikits.learn must be installed to use subsampling") 149 | neigh.fit(data[idx, :2], labels_subsampled) 150 | return neigh.predict(data[:, :2]) 151 | 152 | 153 | 154 | 155 | def none(data): 156 | """Do nothing""" 157 | return np.zeros(data.shape[0], dtype='int16') 158 | 159 | 160 | def _metric_euclidean(data1, data2): 161 | n_pts1, n_dims1 = data1.shape 162 | n_pts2, n_dims2 = data2.shape 163 | if not n_dims1 == n_dims2: 164 | raise TypeError("data1 and data2 must have the same number of columns") 165 | delta = np.zeros((n_pts1, n_pts2), 'd') 166 | for d in xrange(n_dims1): 167 | _data1 = data1[:, d] 168 | _data2 = data2[:, d] 169 | _delta = np.subtract.outer(_data1, _data2) ** 2 170 | delta += _delta 171 | return np.sqrt(delta) 172 | 173 | 174 | def dist_euclidean(spike_waves1, spike_waves2=None): 175 | """Given spike_waves calculate pairwise Euclidean distance between 176 | them""" 177 | 178 | sp_data1 = np.concatenate(spike_waves1['data'], 1) 179 | 180 | if spike_waves2 is None: 181 | sp_data2 = sp_data1 182 | else: 183 | sp_data2 = np.concatenate(spike_waves2['data'], 1) 184 | d = _metric_euclidean(sp_data1, sp_data2) 185 | 186 | return d 187 | 188 | 189 | def cluster(method, features, *args, **kwargs): 190 | """Automatically cluster spikes using K means algorithm 191 | 192 | Parameters 193 | ---------- 194 | features : dict 195 | spike features datastructure 196 | n_clusters : int 197 | number of clusters to identify 198 | args, kwargs : 199 | optional arguments that are passed to the clustering algorithm 200 | 201 | Returns 202 | ------- 203 | labels : array 204 | array of cluster (unit) label - one for each cell 205 | 206 | Examples 207 | -------- 208 | Create a sample feature dataset and use k-means clustering to find 209 | groups of spikes (units) 210 | 211 | >>> import spike_sort 212 | >>> import numpy as np 213 | >>> np.random.seed(1234) #k_means uses random initialization 214 | >>> features = {'data':np.array([[0.,0.], 215 | ... [0, 1.], 216 | ... [0, 0.9], 217 | ... [0.1,0]])} 218 | >>> labels = spike_sort.cluster.cluster('k_means', features, 2) 219 | >>> print labels 220 | [0 1 1 0] 221 | """ 222 | try: 223 | cluster_func = eval(method) 224 | except NameError: 225 | raise NotImplementedError( 226 | "clustering method %s is not implemented" % method) 227 | 228 | data = features['data'] 229 | mask = features.get('is_valid') 230 | if mask is not None: 231 | valid_data = data[mask, :] 232 | cl = cluster_func(valid_data, *args, **kwargs) 233 | labels = np.zeros(data.shape[0], dtype='int') - 1 234 | labels[mask] = cl 235 | else: 236 | labels = cluster_func(data, *args, **kwargs) 237 | return labels 238 | 239 | 240 | def k_means(features, K=2): 241 | """Perform K means clustering 242 | 243 | Parameters 244 | ---------- 245 | data : dict 246 | data vectors (n,m) where n is the number of datapoints and m is 247 | the number of variables 248 | K : int 249 | number of distinct clusters to identify 250 | 251 | Returns 252 | ------- 253 | partition : array 254 | vector of cluster labels (ints) for each datapoint from `data` 255 | """ 256 | n_dim = features.shape[1] 257 | centers = np.random.rand(K, n_dim) 258 | centers_new = np.random.rand(K, n_dim) 259 | partition = np.zeros(features.shape[0], dtype=np.int) 260 | while not (centers_new == centers).all(): 261 | centers = centers_new.copy() 262 | 263 | distances = (centers[:, np.newaxis, :] - features) 264 | distances *= distances 265 | distances = distances.sum(axis=2) 266 | partition = distances.argmin(axis=0) 267 | 268 | for i in range(K): 269 | if np.sum(partition == i) > 0: 270 | centers_new[i, :] = features[partition == i, :].mean(0) 271 | return partition 272 | 273 | 274 | def split_cells(spt_dict, idx, which='all'): 275 | """return the spike times belonging to the cluster and the rest""" 276 | 277 | if which == 'all': 278 | classes = np.unique(idx) 279 | else: 280 | classes = which 281 | spt = spt_dict['data'] 282 | spt_dicts = dict([(cl, {'data': spt[idx == cl]}) for cl in classes]) 283 | return spt_dicts 284 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import filecmp 3 | import json 4 | import glob 5 | import tempfile 6 | 7 | from nose.tools import ok_, eq_, raises 8 | import tables 9 | import numpy as np 10 | 11 | from spike_sort.io.filters import BakerlabFilter, PyTablesFilter 12 | from spike_sort.io import export 13 | from spike_sort.io import neo_filters 14 | 15 | class TestNeo: 16 | 17 | def setUp(self): 18 | path = os.path.dirname(os.path.abspath(__file__)) 19 | self.samples_dir = os.path.join(path, 'samples') 20 | self.file_names = {'Axon': 'axonio.abf'} 21 | 22 | def test_read_abf(self): 23 | fname = self.file_names['Axon'] 24 | file_path = os.path.join(self.samples_dir, fname) 25 | abf = neo_filters.AxonFilter(file_path) 26 | sp = abf.read_sp() 27 | assert len(np.array(sp['data']))>0 28 | assert sp['FS'] == 125000 29 | 30 | def test_read_abf_via_component(self): 31 | fname = self.file_names['Axon'] 32 | file_path = os.path.join(self.samples_dir, fname) 33 | source = neo_filters.NeoSource(file_path) 34 | sp = source.read_sp() 35 | assert len(np.array(sp['data']))>0 36 | 37 | class TestHDF(object): 38 | def setUp(self): 39 | self.data = np.random.randint(1000, size=(4, 100)) 40 | self.spt = np.random.randint(0, 100, (10,)) / 200.0 41 | self.el_node = '/Subject/Session/Electrode' 42 | self.fname = 'test.h5' 43 | self.cell_node = self.el_node + '/cell' 44 | self.h5f = tables.openFile(self.fname, 'a') 45 | 46 | self.spt.sort() 47 | atom = tables.Atom.from_dtype(self.data.dtype) 48 | shape = self.data.shape 49 | filter = tables.Filters(complevel=0, complib='zlib') 50 | new_array = self.h5f.createCArray(self.el_node, "raw", atom, shape, 51 | filters=filter, 52 | createparents=True) 53 | self.sampfreq = 5.0E3 54 | new_array.attrs['sampfreq'] = self.sampfreq 55 | new_array[:] = self.data 56 | self.h5f.createArray(self.el_node, "cell", self.spt, 57 | title="", 58 | createparents="True") 59 | self.h5f.close() 60 | 61 | def tearDown(self): 62 | self.filter.close() 63 | os.unlink(self.fname) 64 | 65 | def test_write(self): 66 | sp_dict = {'data': self.data, 'FS': self.sampfreq} 67 | spt_dict = {'data': self.spt} 68 | self.filter = PyTablesFilter("test2.h5") 69 | self.filter.write_sp(sp_dict, self.el_node + "/raw") 70 | self.filter.write_spt(spt_dict, self.cell_node) 71 | self.filter.close() 72 | exit_code = os.system('h5diff ' + self.fname + ' test2.h5') 73 | os.unlink("test2.h5") 74 | ok_(exit_code == 0) 75 | 76 | def test_read_sp(self): 77 | self.filter = PyTablesFilter(self.fname) 78 | sp = self.filter.read_sp(self.el_node) 79 | ok_((sp['data'][:] == self.data).all()) 80 | 81 | def test_read_sp_attr(self): 82 | #check n_contacts attribute 83 | self.filter = PyTablesFilter(self.fname) 84 | sp = self.filter.read_sp(self.el_node) 85 | n_contacts = sp['n_contacts'] 86 | ok_(n_contacts == self.data.shape[0]) 87 | 88 | def test_read_spt(self): 89 | self.filter = PyTablesFilter(self.fname) 90 | spt = self.filter.read_spt(self.cell_node) 91 | ok_((spt['data'] == self.spt).all()) 92 | 93 | 94 | class TestBakerlab(object): 95 | def setup(self): 96 | file_descr = {"fspike": "{ses_id}{el_id}.sp", 97 | "cell": "{ses_id}{el_id}{cell_id}.spt", 98 | "dirname": ".", 99 | "FS": 5.E3, 100 | "n_contacts": 1} 101 | self.el_node = '/Test/s32test01/el1' 102 | self.cell_node = self.el_node + '/cell1' 103 | self.data = np.random.randint(-1000, 1000, (100,)) 104 | self.spt_data = np.random.randint(0, 100, (10,)) / 200.0 105 | self.spt_metadata = {'element1': 5, 'element2': 10} 106 | self.conf_file = 'test.conf' 107 | self.fname = "32test011.sp" 108 | self.spt_fname = "32test0111.spt" 109 | self.spt_log_fname = "32test0111.log" 110 | 111 | with open(self.conf_file, 'w') as fp: 112 | json.dump(file_descr, fp) 113 | 114 | self.data.astype(np.int16).tofile(self.fname) 115 | (self.spt_data * 200).astype(np.int32).tofile(self.spt_fname) 116 | 117 | with open(self.spt_log_fname, 'w') as lf: 118 | json.dump(self.spt_metadata, lf) 119 | 120 | def tearDown(self): 121 | os.unlink(self.conf_file) 122 | os.unlink(self.fname) 123 | os.unlink(self.spt_fname) 124 | os.unlink(self.spt_log_fname) 125 | 126 | def test_write_spt(self): 127 | cell_node_tmp = '/Test/s32test01/el2/cell1' 128 | spt_dict = {'data': self.spt_data} 129 | filter = BakerlabFilter(self.conf_file) 130 | filter.write_spt(spt_dict, cell_node_tmp) 131 | files_eq = filecmp.cmp(self.spt_fname, "32test0121.spt", shallow=0) 132 | os.unlink("32test0121.spt") 133 | ok_(files_eq) 134 | 135 | def test_write_spt_metadata(self): 136 | cell_node_tmp = '/Test/s32test01/el2/cell1' 137 | spt_dict = {'data': self.spt_data, 138 | 'metadata': self.spt_metadata} 139 | filter = BakerlabFilter(self.conf_file) 140 | filter.write_spt(spt_dict, cell_node_tmp, overwrite=True) 141 | files_eq = filecmp.cmp(self.spt_log_fname, "32test0121.log", shallow=0) 142 | os.unlink("32test0121.spt") 143 | os.unlink("32test0121.log") 144 | ok_(files_eq) 145 | 146 | @raises(IOError) 147 | def test_writespt_overwrite_exc(self): 148 | cell_node_tmp = '/Test/s32test01/el1/cell1' 149 | spt_dict = {'data': self.spt_data} 150 | filter = BakerlabFilter(self.conf_file) 151 | filter.write_spt(spt_dict, cell_node_tmp) 152 | 153 | def test_read_spt(self): 154 | filter = BakerlabFilter(self.conf_file) 155 | sp = filter.read_spt(self.cell_node) 156 | read_data = sp['data'] 157 | ok_((np.abs(read_data - self.spt_data) <= 1 / 200.0).all()) 158 | 159 | def test_write_sp(self): 160 | el_node_tmp = '/Test/s32test01/el2' 161 | sp_dict = {'data': self.data[np.newaxis, :]} 162 | filter = BakerlabFilter(self.conf_file) 163 | filter.write_sp(sp_dict, el_node_tmp) 164 | files_eq = filecmp.cmp("32test011.sp", "32test012.sp", shallow=0) 165 | os.unlink("32test012.sp") 166 | ok_(files_eq) 167 | 168 | def test_write_multichan(self): 169 | n_contacts = 4 170 | data = np.repeat(self.data[np.newaxis, :], n_contacts, 0) 171 | sp_dict = {'data': data} 172 | with open(self.conf_file, 'r+') as fid: 173 | file_desc = json.load(fid) 174 | file_desc['n_contacts'] = 4 175 | file_desc["fspike"] = "test{contact_id}.sp" 176 | fid.seek(0) 177 | json.dump(file_desc, fid) 178 | filter = BakerlabFilter(self.conf_file) 179 | filter.write_sp(sp_dict, self.el_node) 180 | all_chan_files = glob.glob("test?.sp") 181 | [os.unlink(p) for p in all_chan_files] 182 | eq_(len(all_chan_files), n_contacts) 183 | 184 | def test_read_sp(self): 185 | filter = BakerlabFilter(self.conf_file) 186 | sp = filter.read_sp(self.el_node) 187 | read_data = sp['data'][0, :] 188 | print read_data.shape 189 | ok_((np.abs(read_data - self.data) <= 1 / 200.0).all()) 190 | 191 | def test_sp_shape(self): 192 | with open(self.conf_file, 'r+') as fid: 193 | file_desc = json.load(fid) 194 | file_desc['n_contacts'] = 4 195 | fid.seek(0) 196 | json.dump(file_desc, fid) 197 | filter = BakerlabFilter(self.conf_file) 198 | sp = filter.read_sp(self.el_node) 199 | data = sp['data'] 200 | ok_(data.shape == (4, len(self.data))) 201 | 202 | 203 | class TestExport(object): 204 | def test_export_cells(self): 205 | n_cells = 4 206 | self.spt_data = np.random.randint(0, 10000, (100, n_cells)) 207 | self.spt_data.sort(0) 208 | self.cells_dict = dict([(i, {"data": self.spt_data[:, i]}) 209 | for i in range(n_cells)]) 210 | tempdir = tempfile.mkdtemp() 211 | fname = os.path.join(tempdir, "test.h5") 212 | ptfilter = PyTablesFilter(fname) 213 | tmpl = "/Subject/Session/Electrode/Cell{cell_id}" 214 | export.export_cells(ptfilter, tmpl, self.cells_dict) 215 | test = [] 216 | for i in range(n_cells): 217 | spt_dict = ptfilter.read_spt(tmpl.format(cell_id=i)) 218 | test.append((spt_dict['data'] == self.spt_data[:, i]).all()) 219 | test = np.array(test) 220 | ptfilter.close() 221 | 222 | os.unlink(fname) 223 | os.rmdir(tempdir) 224 | 225 | ok_(test.all()) 226 | -------------------------------------------------------------------------------- /docs/source/datastructures.rst: -------------------------------------------------------------------------------- 1 | .. testsetup:: 2 | 3 | import numpy 4 | numpy.random.seed(1221) 5 | 6 | 7 | Data Structures 8 | =============== 9 | 10 | .. _raw_recording: 11 | 12 | To achieve best compatibility with external libraries most of the data 13 | structures are standard Python *dictionaries*, with at least one key -- `data`. The 14 | `data` key contains the actual data in an array-like object (such as 15 | NumPy array). Other attributes provide metadata that are required by 16 | some methods. 17 | 18 | 19 | Raw recording 20 | ------------- 21 | 22 | Raw electrophysiological data sampled at equally spaced time points. It 23 | can contain multiple channels, but all of them need to be of the same 24 | sampling frequency and duration (for example, multiple contacts of 25 | a tetrode). 26 | 27 | The following keys are defined: 28 | 29 | :data: *array*, required 30 | 31 | array-like object (for example :py:class:`numpy.ndarray`) of 32 | dimensions (N_channels, N_samples) 33 | 34 | :FS: *int*, required 35 | 36 | sampling frequency in Hz 37 | 38 | :n_contacts: *int*, required 39 | 40 | number of channels (tetrode contacts). It is equal to the size 41 | of the first dimension of `data`. 42 | 43 | .. note:: 44 | 45 | You may read/write the data with your own functions, but to make the 46 | interface with the SpikeSort a bit cleaner, you might also want to 47 | define your custom IO filters (see :ref:`io_filters`) 48 | 49 | .. rubric:: Example 50 | 51 | We will read the raw tetrode data from :ref:`tutorial_data` using 52 | standard :py:class:`~spike_sort.io.filters.PyTablesFilter`: 53 | 54 | >>> from spike_sort.io.filters import PyTablesFilter 55 | >>> io_filter = PyTablesFilter('../data/tutorial.h5') 56 | >>> raw_data = io_filter.read_sp('/SubjectA/session01/el1') 57 | >>> print(raw_data.keys()) # print all keys 58 | ['n_contacts', 'FS', 'data'] 59 | >>> shape = raw_data['data'].shape # check size 60 | >>> print "{0} channels, {1} samples".format(*shape) 61 | 4 channels, 23512500 samples 62 | >>> print(raw_data['FS']) # check sampling frequency 63 | 25000 64 | 65 | 66 | .. _spike_times: 67 | 68 | Spike times 69 | ----------- 70 | 71 | A sequence of (sorted) time readings at which spikes were generated 72 | (or other discrete events happened). 73 | 74 | The data is store in a dictionary with following keys: 75 | 76 | :data: *array*, required 77 | 78 | one-dimensional array-like object with event times (in 79 | milliseconds) 80 | 81 | :is_valid: *array*, optional 82 | 83 | boolean area of the same size as `data` -- if an element is False 84 | the event of the same index is masked (or invalid) 85 | 86 | .. note:: 87 | 88 | You may read/write the data with your own functions, but to make the 89 | interface with the SpikeSort a bit cleaner, you might also want to 90 | define your custom IO filters (see :ref:`io_filters`) 91 | 92 | .. rubric:: Example 93 | 94 | :py:func:`spike_sort.core.extract.detect_spikes` is one of functions 95 | which takes the raw recordings and returns spike times dictionary: 96 | 97 | 98 | >>> import numpy as np 99 | >>> raw_dict = { 100 | ... 'data': np.array([[0,1,0,0,0,1]]), 101 | ... 'FS' : 10000, 102 | ... 'n_contacts': 1 103 | ... } 104 | >>> from spike_sort.core.extract import detect_spikes 105 | >>> spt_dict = detect_spikes(raw_dict, thresh=0.8) 106 | >>> print(spt_dict.keys()) 107 | ['thresh', 'contact', 'data'] 108 | >>> print('Spike times (ms): {0}'.format(spt_dict['data'])) 109 | Spike times (ms): [ 0. 0.4] 110 | 111 | 112 | Note that in addition to the required data key, 113 | :py:func:`~spike_sort.core.extract.detect_spikes` 114 | appends some extrcontact a attributes: :py:attr:`thresh` (detection threshold) 115 | and :py:attr:`contact` (contact on which spikes were detected). These 116 | attributes are ignored by other methods. 117 | 118 | .. _spike_wave: 119 | 120 | Spike waveforms 121 | --------------- 122 | 123 | Spike waveform structure contains waveforms of extracted spikes. It may be 124 | any mapping data structure (usually a dictionary) with following keys: 125 | 126 | :data: *array*, required 127 | 128 | three-dimensional array-like object of size (N_points, N_spikes, 129 | N_contacts), where: 130 | 131 | * `N_points` -- the number of data points in a single waveform, 132 | * `N_spikes` -- the total number of spikes and 133 | * `N_contacts` -- the number of independent channels (for example 4 in a 134 | tetrode) 135 | 136 | :time: *array*, required 137 | 138 | Timeline of the spike waveshapes (in miliseconds). It must be of 139 | the same size as the first dimension of data (`n_pts`). 140 | 141 | :FS: *int*, optional 142 | 143 | Sampling frequency. 144 | 145 | 146 | :n_contacts: *int*, optional 147 | 148 | Number of independent channels with spike 149 | waveshapes (see also :ref:`raw_recording`). 150 | 151 | :is_valid: *array*, optional 152 | 153 | boolean area of the size of second dimension of `data` (N_spikes) -- if an element is False 154 | the spike with the same index is masked (or invalid) 155 | 156 | 157 | .. rubric:: Example 158 | 159 | Spike waveforms can be extracted from raw recordings (see :ref:`raw_recording`) 160 | given a sequence of spike times (see :ref:`spike_times`) by means of 161 | :py:func:`spike_sort.core.extract.extract_spikes` function: 162 | 163 | >>> from spike_sort.core.extract import extract_spikes 164 | >>> raw_dict = { 165 | ... 'data': np.array([[0,1,1,0,0,0,1,-1,0,0, 0]]), 166 | ... 'FS' : 10000, 167 | ... 'n_contacts': 1 168 | ... } # raw signal 169 | >>> spt_dict = { 170 | ... 'data': np.array([0.15, 0.65, 1])} 171 | ... } # timestamps of three spikes 172 | >>> sp_win = [0, 0.4] # window in which spikes should be extracted 173 | >>> waves_dict = extract_spikes(raw_dict, spt_dict, sp_win) 174 | 175 | Now let us investigate the returned spike waveforms structure: 176 | 177 | * keys: 178 | 179 | >>> print waves_dict.keys() 180 | ['is_valid', 'FS', 'data', 'time'] 181 | 182 | * data array shape: 183 | 184 | >>> print(waves_dict['data'].shape) 185 | (4, 3, 1) 186 | 187 | * extracted spikes: 188 | 189 | >>> print(waves_dict['data'][:,:,0].T) # data contains three spikes 190 | [[ 1. 1. 0. 0.] 191 | [ 1. -1. 0. 0.] 192 | [ 0. 0. 0. 0.]] 193 | >>> print(waves_dict['time']) # defined over 4 time points 194 | [ 0. 0.1 0.2 0.3] 195 | 196 | * and potential invalid (truncated spikes): 197 | 198 | >>> print(waves_dict['is_valid']) # last spike is invalid (truncated) 199 | [ True True False] 200 | 201 | Note that the :py:attr:`is_valid` element of truncated spike is 202 | :py:data:`False`. 203 | 204 | .. _spike_features: 205 | 206 | Spike features 207 | -------------- 208 | 209 | This data structure contains features calculated from spike waveforms 210 | using one of the methods defined in :py:mod:`spike_sort.core.features` module 211 | (one of the :py:func:`fet*` functions, see :ref:`features_doc`). 212 | 213 | The spike features dictionary consits of following keys: 214 | 215 | :data: *array*, required 216 | 217 | two-dimensional array of size (N_spikes, N_features) that contains 218 | the actual feature values 219 | 220 | :names: *list of str*, required 221 | 222 | list of length N_features containing feature labels 223 | 224 | :is_valid: *array*, optional 225 | 226 | boolean area of of length N_spikes; if an element is False 227 | the spike with the same index is masked (or invalid, see also 228 | :ref:`spike_wave`) 229 | 230 | 231 | .. rubric:: Example 232 | 233 | Let us try to calculate peak-to-peak amplitude from some spikes 234 | extracted in :ref:`spike_wave`: 235 | 236 | >>> from spike_sort.core.features import fetP2P 237 | >>> print(waves_dict['data'].shape) # 3 spikes, 4 data points each 238 | (4, 3, 1) 239 | >>> feature_dict = fetP2P(waves_dict) 240 | >>> print(feature_dict.keys()) 241 | ['is_valid', 'data', 'names'] 242 | >>> print(feature_dict['data'].shape) 243 | (3, 1) 244 | 245 | Then we have one feature for 3 spikes. Let check whether the peak-to-peak amplitudes 246 | are correctly calculated: 247 | 248 | >>> print(feature_dict['data']) 249 | [[ 1.] 250 | [ 2.] 251 | [ 0.]] 252 | 253 | as expected (compare with example above). There is only one 254 | peak-to-peak (`P2P`) feature on a single channel (`Ch0`) and its name 255 | is: 256 | 257 | >>> print(feature_dict['names']) 258 | ['Ch0:P2P'] 259 | 260 | The mask array is inherited from :py:data:`waves_dict`: 261 | 262 | >>> print(feature_dict['is_valid']) 263 | [ True True False] 264 | 265 | .. _spike_labels: 266 | 267 | Spike labels 268 | ------------ 269 | 270 | Spike labels are the identifiers of a cell (unit) each spike was 271 | classified to. Spike labels are **not** dictionaries, but arrays of 272 | integers -- one cluster index per spike. 273 | 274 | .. rubric:: Example 275 | 276 | Let us try to cluster the spikes described by `Sample` feature using 277 | K-means with K=2: 278 | 279 | >>> from spike_sort.core.cluster import cluster 280 | >>> feature_dict = { 281 | ... 'data' : np.array([[1],[-1], [1]]), 282 | ... 'names' : ['Sample'] 283 | ... } 284 | >>> labels = cluster('k_means', feature_dict, 2) 285 | >>> print(labels) 286 | [1 0 1] 287 | 288 | As expected :py:data:`labels` is an array describing two clusters: 0 and 1. 289 | 290 | -------------------------------------------------------------------------------- /tests/test_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spike_sort as ss 3 | 4 | from nose.tools import ok_, eq_, raises 5 | from numpy.testing import assert_array_almost_equal as almost_equal 6 | 7 | 8 | class TestFeatures(object): 9 | def setup(self): 10 | self.gain = 1000 11 | 12 | n_pts = 100 13 | FS = 5E3 14 | time = np.arange(n_pts, dtype=np.float32) / n_pts - 0.5 15 | cells = self.gain * np.vstack( 16 | (np.sin(time * 2 * np.pi) / 2.0, 17 | np.abs(time) * 2 - 1 18 | ) 19 | ) 20 | self.cells = cells.astype(np.int) 21 | 22 | #define data interface (dictionary) 23 | self.spikes_dict = {"data": self.cells.T[:, :, np.newaxis], 24 | "time": time, "FS": FS} 25 | 26 | def test_fetP2P(self): 27 | spikes_dict = self.spikes_dict.copy() 28 | 29 | n_spikes = 200 30 | amps = np.random.randint(1, 100, n_spikes) 31 | amps = amps[:, np.newaxis] 32 | spikes = amps * self.cells[0, :] 33 | spikes_dict['data'] = spikes.T[:, :, np.newaxis] 34 | p2p = ss.features.fetP2P(spikes_dict) 35 | 36 | ok_((p2p['data'] == amps * self.gain).all()) 37 | 38 | def test_PCA(self): 39 | n_dim = 2 40 | n_obs = 100 41 | raw_data = np.random.randn(n_dim, n_obs) 42 | mixing = np.array([[-1, 1], [1, 1]]) 43 | 44 | mix_data = np.dot(mixing, raw_data) 45 | mix_data = mix_data - mix_data.mean(1)[:, np.newaxis] 46 | 47 | evals, evecs, score = ss.features.PCA(mix_data, n_dim) 48 | 49 | proj_cov = np.cov(score) 50 | error = np.mean((proj_cov - np.eye(n_dim)) ** 2) 51 | ok_(error < 0.01) 52 | 53 | def test_fetPCA(self): 54 | spikes_dict = self.spikes_dict.copy() 55 | n_spikes = 200 56 | n_cells = 2 57 | amp_var_fact = 0.4 58 | amp_var = 1 + amp_var_fact * np.random.rand(n_spikes, 1) 59 | _amps = np.random.rand(n_spikes) > 0.5 60 | amps = amp_var * np.vstack((_amps >= 0.5, _amps < 0.5)).T 61 | 62 | spikes = np.dot(amps, self.cells).T 63 | spikes = spikes.astype(np.float32) 64 | spikes_dict['data'] = spikes[:, :, np.newaxis] 65 | 66 | pcs = ss.features.fetPCA(spikes_dict, ncomps=1) 67 | pcs = pcs['data'] 68 | compare = ~np.logical_xor(pcs[:, 0].astype(int) + 1, _amps) 69 | correct = np.sum(compare) 70 | eq_(n_spikes, correct) 71 | 72 | def test_fetPCA_multichannel_labels(self): 73 | # tests whether features are properly labeled, based on the linearity 74 | # of PCA 75 | 76 | spike1, spike2 = self.cells 77 | ch0_spikes = np.vstack((spike1, spike2, spike1 + spike2)).T 78 | ch1_spikes = np.vstack((spike1*2, spike2/2., spike1*2 - spike2/2.)).T 79 | 80 | spikes = np.empty((ch0_spikes.shape[0], ch0_spikes.shape[1], 2)) 81 | spikes[:, :, 0] = ch0_spikes 82 | spikes[:, :, 1] = ch1_spikes 83 | spikes_dict = {'data' : spikes} 84 | 85 | features = ss.features.fetPCA(spikes_dict, ncomps=3) 86 | names = features['names'] 87 | 88 | ch0_feats = np.empty((3, 3)) 89 | ch1_feats = np.empty((3, 3)) 90 | # sort features by channel using labels 91 | for fidx in range(3): 92 | i0 = names.index('Ch%d:PC%d' % (0, fidx)) 93 | i1 = names.index('Ch%d:PC%d' % (1, fidx)) 94 | ch0_feats[fidx, :] = features['data'][:, i0] 95 | ch1_feats[fidx, :] = features['data'][:, i1] 96 | 97 | # spike combinations on channel 0 98 | almost_equal(ch0_feats[:, 2], ch0_feats[:, 0] + ch0_feats[:, 1]) 99 | # spike combinations on channel 1 100 | almost_equal(ch1_feats[:, 2], ch1_feats[:, 0] - ch1_feats[:, 1]) 101 | 102 | def test_getSpProjection(self): 103 | spikes_dict = self.spikes_dict.copy() 104 | cells = spikes_dict['data'] 105 | spikes_dict['data'] = np.repeat(cells, 10, 1) 106 | labels = np.repeat([0, 1], 10, 0) 107 | 108 | feat = ss.features.fetSpProjection(spikes_dict, labels) 109 | ok_(((feat['data'][:, 0] > 0.5) == labels).all()) 110 | 111 | def test_fetMarkers(self): 112 | spikes_dict = self.spikes_dict.copy() 113 | 114 | n_spikes = 200 115 | n_pts = self.cells.shape[1] 116 | amps = np.random.randint(1, 100, n_spikes) 117 | amps = amps[:, np.newaxis] 118 | spikes = amps * self.cells[0, :] 119 | spikes_dict['data'] = spikes.T[:, :, np.newaxis] 120 | time = spikes_dict['time'] 121 | indices = np.array([0, n_pts/2, n_pts-1]) 122 | values = ss.features.fetMarkers(spikes_dict, time[indices]) 123 | 124 | ok_((values['data'] == spikes[:, indices]).all()) 125 | 126 | def test_WT(self): 127 | "simple test for linearity of wavelet transform" 128 | spike1, spike2 = self.cells 129 | spike3 = 0.1 * spike1 + 0.7 * spike2 130 | spikes = np.vstack((spike1, spike2, spike3)).T 131 | spikes = spikes[:, :, np.newaxis] # WT only accepts 3D arrays 132 | 133 | wavelet = 'db3' 134 | wt = ss.features.WT(spikes, wavelet) 135 | wt1, wt2, wt3 = wt.squeeze().T 136 | 137 | almost_equal(wt3, 0.1 * wt1 + 0.7 * wt2) 138 | 139 | def test_fetWT_multichannel_labels(self): 140 | # tests whether features are properly labeled, based on linearity of 141 | # Wavelet Transform 142 | 143 | spike1, spike2 = self.cells 144 | ch0_spikes = np.vstack((spike1, spike2)).T 145 | ch1_spikes = np.vstack((spike1 + spike2, spike1 - spike2)).T 146 | 147 | spikes = np.empty((ch0_spikes.shape[0], ch0_spikes.shape[1], 2)) 148 | spikes[:, :, 0] = ch0_spikes 149 | spikes[:, :, 1] = ch1_spikes 150 | spikes_dict = {'data' : spikes} 151 | 152 | features = ss.features.fetWT(spikes_dict, 3, wavelet='haar', select_method=None) 153 | names = features['names'] 154 | 155 | ch0_feats = np.empty((3, 2)) 156 | ch1_feats = np.empty((3, 2)) 157 | # sort features by channel using labels 158 | for fidx in range(3): 159 | i0 = names.index('Ch%d:haarWC%d' % (0, fidx)) 160 | i1 = names.index('Ch%d:haarWC%d' % (1, fidx)) 161 | ch0_feats[fidx, :] = features['data'][:, i0] 162 | ch1_feats[fidx, :] = features['data'][:, i1] 163 | 164 | almost_equal(ch1_feats[:, 0], ch0_feats[:, 0] + ch0_feats[:, 1]) 165 | almost_equal(ch1_feats[:, 1], ch0_feats[:, 0] - ch0_feats[:, 1]) 166 | 167 | def test_fetWT_math(self): 168 | n_samples = 256 169 | 170 | # upsampled haar wavelet 171 | spike1 = np.hstack((np.ones(n_samples / 2), -1 * np.ones(n_samples / 2))) 172 | # upsampled haar scaling function 173 | spike2 = np.ones(n_samples) 174 | 175 | spikes = np.vstack((spike1, spike2)).T 176 | spikes = spikes[:, :, np.newaxis] 177 | spikes_dict = {'data' : spikes} 178 | 179 | features = ss.features.fetWT(spikes_dict, n_samples, wavelet='haar', select_method=None) 180 | idx = np.nonzero(features['data']) # nonzero indices 181 | 182 | # if nonzero elements are ONLY at (0,1) and (1,0), 183 | # this should be eye(2) 184 | eye = np.fliplr(np.vstack(idx)) 185 | 186 | ok_((eye == np.eye(2)).all()) 187 | 188 | def test_fetWT_selection(self): 189 | n_samples = 30 190 | n_channels = 2 191 | n_spikes = 50 192 | n_features = 10 193 | methods = [None, 'std', 'std_r', 'ks', 'dip', 'ksPCA', 'dipPCA'] 194 | 195 | spikes = np.random.randn(n_samples, n_spikes, n_channels) 196 | spikes_dict = {'data' : spikes} 197 | 198 | shapes = [(n_spikes, n_features * n_channels)] 199 | 200 | for met in methods: 201 | wt = ss.features.fetWT(spikes_dict, n_features, wavelet='haar', select_method=met) 202 | shapes.append(wt['data'].shape) 203 | 204 | equal = lambda x, y: x == y and y or False 205 | success = bool(reduce(equal, shapes)) # returned shapes for all methods are correct 206 | 207 | ok_(success) 208 | 209 | def test_add_mask_decorator(self): 210 | spikes_dict = {'data': np.zeros((10, 2)), 211 | 'is_valid': np.zeros(2,)} 212 | 213 | fetIdentity = lambda x: {'data': x['data'], 'names': 'Identity'} 214 | deco_fet = ss.features.add_mask(fetIdentity) 215 | features = deco_fet(spikes_dict) 216 | 217 | ok_((features['is_valid'] == spikes_dict['is_valid']).all()) 218 | 219 | def test_combine_features_without_mask(self): 220 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1']} 221 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2']} 222 | combined = ss.features.combine((feature1, feature2)) 223 | ok_('is_valid' not in combined) 224 | 225 | def test_combine_features_with_one_mask(self): 226 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1']} 227 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2'], 228 | 'is_valid': np.ones(5, dtype=np.bool)} 229 | combined = ss.features.combine((feature1, feature2)) 230 | ok_((combined['is_valid'] == feature2['is_valid']).all()) 231 | 232 | def test_combine_features_with_different_masks(self): 233 | mask1 = np.ones(5, dtype=np.bool) 234 | mask1[-1] = False 235 | mask2 = mask1.copy() 236 | mask2[:2] = False 237 | feature1 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature1'], 238 | 'is_valid': mask1} 239 | feature2 = {'data': np.random.uniform(size=(5, 1)), 'names': ['feature2'], 240 | 'is_valid': mask2} 241 | combined = ss.features.combine((feature1, feature2)) 242 | ok_((combined['is_valid'] == (mask1 & mask2)).all()) 243 | 244 | def test_add_method_suffix(self): 245 | names = ['methA', 'methB', 'methA_1', 'methC', 'methA_2', 'methD'] 246 | 247 | test1 = ss.features._add_method_suffix('methE', names) == 'methE' 248 | test2 = ss.features._add_method_suffix('methB', names) == 'methB_1' 249 | test3 = ss.features._add_method_suffix('methA', names) == 'methA_3' 250 | 251 | ok_(test1 and test2 and test3) 252 | -------------------------------------------------------------------------------- /docs/source/tutorials/tutorial_manual.rst: -------------------------------------------------------------------------------- 1 | .. _lowlevel_tutorial: 2 | 3 | Using low-level interface 4 | ========================== 5 | 6 | .. testsetup:: 7 | 8 | import numpy 9 | numpy.random.seed(1221) 10 | import os, shutil, tempfile 11 | 12 | temp_path = tempfile.mkdtemp() 13 | data_path = os.path.join(temp_path, 'data') 14 | os.mkdir(data_path) 15 | shutil.copyfile('../data/tutorial.h5', os.path.join(data_path, 'tutorial.h5')) 16 | os.chdir(temp_path) 17 | 18 | .. testcleanup:: 19 | 20 | shutil.rmtree(temp_path) 21 | 22 | 23 | In this tutorial we will go deeper into the lower-level interface of 24 | SpikeSort: :py:mod:`spike_sort.core`. This interface is a bit more 25 | complex than the SpikeBeans (see :ref:`beans_tutorial`), but it offers 26 | more flexibility and allows you to embbedd SpikeSort in your own 27 | programs. 28 | 29 | To start this tutorial you will need: 30 | 31 | * working installation of SpikeSort 32 | 33 | * the sample :ref:`tutorial_data` 34 | 35 | 36 | 1. Read data 37 | ------------ 38 | 39 | 40 | We will assume that you downloaded the sample data file :file:`tutorial.h5` and saved it to the :file:`data` 41 | directory. 42 | 43 | You can load this file using one of I/O fiters from 44 | :py:mod:`spike_sort.io.filter` module: 45 | 46 | .. doctest:: 47 | 48 | >>> from spike_sort.io.filters import PyTablesFilter 49 | >>> dataset = '/SubjectA/session01/el1' 50 | >>> io_filter = PyTablesFilter('data/tutorial.h5') 51 | >>> raw = io_filter.read_sp(dataset) 52 | 53 | :py:data:`raw` is a dictionary which contains the raw data (in this case it is 54 | a pytables compressed array) under :py:attr:`data` 55 | key: 56 | 57 | .. doctest:: 58 | 59 | >>> print raw['data'] 60 | /SubjectA/session01/el1/raw (CArray(4, 23512500)) '' 61 | 62 | The size of the data is 23512500 samples in 4 independent channels (`contacts` 63 | in the tetrode). 64 | 65 | .. note:: 66 | 67 | HDF5 are organised hierarchically and may contain multiple 68 | datasets. You can access the datasets via simple paths - in this 69 | case `/SubjectA/session01/el1` which means dataset of SubjectA 70 | recorded in session01 from el1 71 | 72 | 2. Detect spikes 73 | ---------------- 74 | 75 | 76 | The first step of spike sorting is spike detection. It is usually done by 77 | thresholding the raw recordings. Let us use an automatic threshold on 78 | 4th contact i.e. index 3 (channel indexing always starts with 0!): 79 | 80 | .. doctest:: 81 | 82 | >>> from spike_sort import extract 83 | >>> spt = extract.detect_spikes(raw, contact=3, thresh='auto') 84 | 85 | Let us see now how many events were detected: 86 | 87 | .. doctest:: 88 | 89 | >>> print len(spt['data']) 90 | 16293 91 | 92 | We should make sure that all events are aligned to the same point of reference, 93 | for example, the maximum amplitude. To this end we first define a window 94 | around which spikes should be centered and then recalculate aligned event times: 95 | 96 | .. doctest:: 97 | 98 | >>> sp_win = [-0.2, 0.8] 99 | >>> spt = extract.align_spikes(raw, spt, sp_win, type="max", 100 | ... resample=10) 101 | 102 | ``resample`` is optional - it enables upsampling (in this case 10-fold) 103 | of the original waveforms to obtain better resolution of event times. 104 | 105 | After spike detection and alignment we can finally extract the spike waveforms: 106 | 107 | .. doctest:: 108 | 109 | >>> sp_waves = extract.extract_spikes(raw, spt, sp_win) 110 | 111 | The resulting structure is a dictionary whose :py:attr:`data` key is an array 112 | containing the spike waveshapes. Note that the array is three-dimensional and 113 | sizes of its dimensions reflect: 114 | 115 | * 1st dimmension: number of samples in each waveform, 116 | * 2nd: number of spikes, 117 | * 3rd: number of contacts 118 | 119 | .. doctest:: 120 | 121 | >>> print sp_waves['data'].shape 122 | (25, 15537, 4) 123 | 124 | In practice, you do not to take care of such details. However, it is always 125 | a good idea to take a look at the obtained waveforms. 126 | :py:mod:`spike_sort.ui.plotting` module contains various functions which will 127 | help you to visualize the data. To plot waveshapes you can use :py:func:`plot_spikes` function from this module: 128 | 129 | .. doctest:: 130 | 131 | >>> from spike_sort.ui import plotting 132 | >>> plotting.plot_spikes(sp_waves, n_spikes=200) 133 | 134 | .. figure:: images_manual/tutorial_spikes.png 135 | 136 | It is apparent from the plot that the spike waveforms of a few different cells 137 | and also some artifacts were detected. In order to separate these activities, 138 | in the next step we will perform *spike clustering*. 139 | 140 | 3. Calculate features 141 | --------------------- 142 | 143 | Before we can cluster spikes, we should calculate some characteristic features 144 | that may be used to differentiate between the waveshapes. Module 145 | :py:mod:`~spike_sort.core.features` defines several of such features, for example 146 | peak-to-peak amplitude (:py:func:`fetP2P`) and projections on principal 147 | components (:py:func:`fetPCA`). Now, we will calculate peak-to-peak amplitudes 148 | and PC projections on each of the contact, and then combine them into a single 149 | object: 150 | 151 | .. doctest:: 152 | 153 | >>> from spike_sort import features 154 | >>> sp_feats = features.combine( 155 | ... ( 156 | ... features.fetP2P(sp_waves), 157 | ... features.fetPCA(sp_waves) 158 | ... ) 159 | ... ) 160 | 161 | To help the user identify the features, abbreviated 162 | labels are assigned to all features: 163 | 164 | .. doctest:: 165 | 166 | >>> print sp_feats['names'] 167 | ['Ch0:P2P' 'Ch1:P2P' 'Ch2:P2P' 'Ch3:P2P' 'Ch0:PC0' 'Ch1:PC0' 'Ch2:PC0' 168 | 'Ch3:PC0' 'Ch0:PC1' 'Ch1:PC1' 'Ch2:PC1' 'Ch3:PC1'] 169 | 170 | For examples feature ``Ch0:P2P`` denotes peak-to-peak amplitude in contact 171 | (channel) 0. 172 | 173 | Let us plot the two-dimensional 174 | projections of the feature space and histograms of features: 175 | 176 | .. doctest:: 177 | 178 | >>> plotting.plot_features(sp_feats) 179 | 180 | .. figure:: images_manual/tutorial_features.png 181 | 182 | 4. Cluster spikes 183 | ----------------- 184 | 185 | Finally, based on the calculated features we can perform spike clustering. This 186 | step is a little bit more complex and the best settings have to be identified 187 | using trial-and-error procedure. 188 | 189 | There are several automatic, semi-automatic and manual methods for clustering. 190 | They performance and accuracy depends to large degree on a particular dataset 191 | and recording setup. In SpikeSort you can choose from several available methods, 192 | whose names are given as the first argument of :py:func:`~spike_sort.core.cluster.cluster` 193 | method. 194 | 195 | We will start with an automatic clustering :py:func:`~spike_sort.core.cluster.gmm` , which requires only the feature object :py:data:`sp_feats` and number of clusters to identify. 196 | It attempts to find a mixture of gaussian distributions which approximates best the 197 | distribution of spike features (gaussian mixture model). 198 | Since we do not know, how many cells were picked up by the electrode we guess 199 | an initial number of clusters, which we can modify later on: 200 | 201 | .. doctest:: 202 | 203 | >>> from spike_sort import cluster 204 | >>> clust_idx = cluster.cluster("gmm",sp_feats,4) 205 | 206 | The resulting data is just assigning a number (cluster index) to each spike from 207 | the feature array :py:data:`sp_feats`. 208 | 209 | You can use the plotting module to draw the 210 | feature vectors with color reflecting groups to which each spike was assigned: 211 | 212 | .. doctest:: 213 | 214 | >>> plotting.plot_features(sp_feats, clust_idx) 215 | 216 | .. figure:: images_manual/tutorial_clusters.png 217 | 218 | or you can see the spike waveshapes: 219 | 220 | .. doctest:: 221 | 222 | >>> plotting.plot_spikes(sp_waves, clust_idx, n_spikes=200) 223 | >>> plotting.show() 224 | 225 | .. figure:: images_manual/tutorial_cells.png 226 | 227 | If you are not satisfied with the results or you think you might do better, 228 | you can also try manual sorting using cluster cutting method:: 229 | 230 | >>> from spike_sort.ui import manual_sort 231 | >>> cluster_idx = manual_sort.show(features, sp_waves, 232 | ... ['Ch0:P2P','Ch3:P2P'], 233 | ... show_spikes=True) 234 | 235 | This function will open a window in which you can draw clusters of arbitrary 236 | shapes, but beware: you can draw only on two dimensional plane so that you 237 | are limited to only two features (``Ch0:P2P`` and ``Ch3:P2P`` in this case)! 238 | 239 | 5. Export data 240 | -------------- 241 | 242 | Once you are done with spike sorting, you can export the results to a file. 243 | To this end you can use the same :py:mod:`~spike_sort.io.filters` module we used 244 | for reading. Here, we will save the spike times of a selected cell 245 | back to the file we read the data from. 246 | 247 | First, we need to extract the spike times 248 | of the discriminated cells: 249 | 250 | .. doctest:: 251 | 252 | >>> spt_clust = cluster.split_cells(spt, clust_idx) 253 | 254 | It will create a dictionary whose keys are the cell labels pointing 255 | to spike times of the specific cell. For example, to extract spike 256 | times of cell 0: 257 | 258 | .. doctest:: 259 | 260 | >>> print spt_clust[0] 261 | {'data': array([ 5.68152000e+02, 1.56978000e+03, 2.23985200e+03, 262 | ... 263 | 9.24276876e+05, 9.33539168e+05])} 264 | 265 | 266 | Then we may export them to the datafile: 267 | 268 | .. doctest:: 269 | 270 | >>> from spike_sort.io import export 271 | >>> cell_template = dataset + '/cell{cell_id}' 272 | >>> export.export_cells(io_filter, cell_template, spt_clust, overwrite=True) 273 | 274 | This will create a new node in :file:`tutorial.h5` containing spike times of 275 | the discriminated cell ``/SubjectA/session01/el1/cell{1-4}``, 276 | which you can use for further analysis. 277 | 278 | Do not forget to close the I/O filter at the end of your analysis: 279 | 280 | .. doctest:: 281 | 282 | >>> io_filter.close() 283 | 284 | Good luck!!! 285 | 286 | 287 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | What is spike sorting? 5 | ---------------------- 6 | 7 | *Spike sorting is a class of techniques used in the analysis of 8 | electrophysiological data. Spike sorting algorithms use the 9 | shape(s) of waveforms collected with one or more electrodes in the 10 | brain to distinguish the activity of one or more neurons from 11 | background electrical noise.* 12 | 13 | Source: Spike sorting, Wikipedia_ 14 | 15 | Spike sorting usually consists of the following steps [Quiroga2004]_: 16 | 17 | 1. Spike detection (*detect*) 18 | 19 | Spikes are very rapid and often sparse events, so that they appear 20 | only a tiny fraction of the recordings. To achieve a sort of 21 | compression and easy the subsequent analysis, the times of 22 | occurrence of putative spikes are first identified in the continuous 23 | by means of thresholding. This method return only events that cross 24 | a specified threshold (selected based on some data statistics, such 25 | as standard deviation, or visually). 26 | 27 | #. Spike waveform extraction (*extract*) 28 | 29 | Based on the series of spike times identified in the previous step, 30 | we now may proceed to extract the spike waveforms by taking a small 31 | segment of the signal around each spike time. Such segments are 32 | usually automatically aligned to a specific feature of the 33 | waveform, for example maximum or minimum. 34 | 35 | #. Feature extraction (*feature*) 36 | 37 | This is one of the most important steps in which the silent 38 | features of the spikes such as peak-to-peak amplitude or spike 39 | width are calculated based on spike waveshapes. The features should 40 | be preferably low-dimensional and should well differentiate spikes 41 | of different cells and noise. 42 | 43 | #. Clustering (*cluster*) 44 | 45 | At the heart of spike sorting is the clustering that uses 46 | automatic, semi-automatic or manual methods to identify groups of 47 | spikes belonging to the same cell (often called unit). The 48 | procedure is usually applied in n-dimensional space of spike 49 | features, where each feature is a single dimension. Since it is 50 | very difficult to do that visually for more than 2 features (on a 51 | plane), one usually resorts to different clustring algorithms, such 52 | as K-means or Gaussian Mixture Models that can handle large number 53 | of dimension. 54 | 55 | 56 | #. Sorting evaluation (*evaluate*) 57 | 58 | After sorting is done and we have determined spike times of 59 | different units, we have to make sure that the quality of sorting 60 | is sufficient for further analysis. There are multiple visual and 61 | statistical methods that help in this evaluations, such as 62 | inter-spike intervals histograms, waveforms overlays, 63 | cross-corellograms etc. [Hill2011]_. 64 | 65 | SpikeSort is a comprehensive library whose goal is to accompany you trough 66 | the entire process of spike sorting - from detection to evaluation. 67 | It is not fully automatic or on-line sorting program. Although we 68 | include lots of functions, that help to automatize some of the 69 | repetitive task, good sorting will always require human supervision. 70 | 71 | Design goals 72 | ------------ 73 | 74 | There are several of spike sorting programs on the market both commercial and free 75 | and open source. We started this projects to offer an alternative that 76 | would be free, scriptable and preferably in our favourite language: 77 | Python. However, we did not try to re-invent the 78 | wheel and whenever we could we leveraged the established libraries. 79 | 80 | We had a few design goals when we worked on SpikeSort; 81 | 82 | * modular 83 | 84 | Spike sorting is modular - there are several steps and different 85 | techniques that can and 86 | should be mixed-and-matched to adjust the process to the data we 87 | try to analyze. Spike sorting library should be modular as well to allow for 88 | much flexibility. This is achieved by composing the library of 89 | independent components that can be easily inserted into or deleted 90 | from the spike sorting workflow. 91 | 92 | * customizable 93 | 94 | No two data sets are the same. Different experimental protocols, 95 | different acquisition systems, different neural systems result in 96 | different properties of the data and thus require different methods. 97 | Therefore, a good spike sorting library should allow for easy and 98 | flexible customizations of the algorithms used at each stage of spike 99 | sorting and their parameters. 100 | 101 | * easy-to-use 102 | 103 | Flexibility usually comes at price: the usability. Nevertheless, a 104 | transparent design can allow complex systems to be user-friendly. 105 | The interface should allow to focus on the data and not on the 106 | peculiarities of specific software solutions. 107 | 108 | * fast 109 | 110 | In practice, one will try to discriminate thousands of spikes of 111 | tens of different neurons. Any performance optimizations will save 112 | you precious minutes (hours or even days) for the task you are most 113 | interested in: decoding what the cells actually do. 114 | 115 | * compatible with standard libraries (NumPy, SciPy, matplotlib) 116 | 117 | There is no need to reinvent the wheel. Python developers provide 118 | hundreds of optimized, well-tested and widely-used libraries with 119 | great community support. Why not use them? Moreover, any data coming 120 | from the spike sorting libraries should be easily pluggable into 121 | third-party analysis routines and databases. 122 | 123 | Although still much work is required to meet all the goals, we kept 124 | them all in mind while designing the SpikeSort. As a result, SpikeSort 125 | is already an usable and powerful framework that will help you to 126 | get most of your data. 127 | 128 | Why Python? 129 | ----------- 130 | 131 | Python is a interpreted and very dynamic language with huge very 132 | enthusiastic community. It grew to be the de-facto standard in 133 | computational science [Langtangen2009]_ and it rapidly gains momentum in 134 | experimental disciplines [Hanke2009]_. Python is also completely free and available 135 | for multiple platforms - it means that you can run it at home, at work 136 | or give it to your students with no additional costs. Last but not 137 | least Python can be easily interfaced with other languages making it a 138 | ''glueing'' language that can make two independent libraries to 139 | communicate. 140 | 141 | Installation 142 | ------------ 143 | 144 | You can download the most recent release of SpikeSort from github:: 145 | 146 | git clone git://github.com/btel/SpikeSort.git 147 | 148 | In order to install SortSpike you need following libraries: 149 | 150 | * python 2.6 or 2.7 151 | * setuptools 152 | * scipy 153 | * numpy 154 | * pytables 155 | * matplotlib (only for plotting) 156 | 157 | Optional dependencies are: 158 | 159 | * scikits.learn - clustering algorithms 160 | * neurotools - spike train analysis 161 | * ipython - enhanced python shell 162 | 163 | If some of the python packages are not available on your system you 164 | can install them with easy-install:: 165 | 166 | easy_install numpy scipy pytables matplotlib 167 | 168 | .. note:: 169 | 170 | If you are not familiar with Python packaging system we recommend 171 | you installing a complete Python distribution from a company called 172 | Enthought: `EPD `_ 173 | (there are free academic licenses). Installers for Windows, MacOSX 174 | and Linux are available. 175 | 176 | If you have the above libraries you can install SpikeSort simply 177 | issuing the command:: 178 | 179 | python setup.py install 180 | 181 | If you prefer to install it in your home directory you may try:: 182 | 183 | python setup.py install --user 184 | 185 | but remember to add :file:`$HOME/.local/lib/python2.6/site-packages` to your python 186 | path. 187 | 188 | After a successful installation you can run the supplied tests:: 189 | 190 | python setup.py nosetests 191 | 192 | If you don't have all the optional dependencies, be prepared for some 193 | tests errors. 194 | 195 | Examples 196 | -------- 197 | 198 | In :file:`examples/sorting` subdirectory you will find some sample scripts, 199 | which use SpikeSort for spike sorting 200 | 201 | * :file:`cluster_manual.py` - sort spikes by manual cluster cutting 202 | * :file:`cluster_auto.py` - automatically cluster with GMM (Gaussian 203 | Mixture Models) algorithm (see our tutorial 204 | :ref:`lowlevel_tutorial`) 205 | * :file:`cluster_beans.py` - run full stack spike-sorting evnvironment 206 | and show spikes in a spike browser (see our tutorial :ref:`beans_tutorial`) 207 | 208 | In order to run these examples, you need to download :ref:`tutorial_data` and define an environment variable ``DATAPATH``:: 209 | 210 | export DATAPATH=/path/to/data/directory 211 | 212 | where ``/path/to/data/directory`` points to the directory where you 213 | downloaded the data file. 214 | 215 | Once you have the tutorial data, you may run above script, for example:: 216 | 217 | python -i cluster_auto.py 218 | 219 | .. note:: 220 | 221 | The ``-i`` in python command will leave a Python interpreter open 222 | for interactive exploration - read more in our tutorials. 223 | 224 | Similar software 225 | ---------------- 226 | 227 | There are a few open source packages for spike sorting that have 228 | different design and use case: 229 | 230 | * `spyke `_ 231 | 232 | * `OpenElectrophy `_ 233 | 234 | * `spikepy `_ 235 | 236 | * `Klusters `_ 237 | 238 | 239 | 240 | 241 | References 242 | ---------- 243 | 244 | .. _tutorial.h5: https://github.com/btel/SpikeSort/releases/download/v0.12/tutorial.h5 245 | 246 | .. _Wikipedia: http://en.wikipedia.org/wiki/Spike_sorting 247 | 248 | .. [Quiroga2004] Quiroga, RQ, Z. Nadasdy, Y. Ben-Shaul, and others. *“Unsupervised Spike Detection and Sorting with Wavelets and Superparamagnetic Clustering.”* Neural Computation **16**, no. 8 (2004):1661. ``_ 249 | 250 | .. [Hill2011] Hill, Daniel N, Samar B Mehta, and David Kleinfeld. *“Quality Metrics to Accompany Spike Sorting of Extracellular Signals.”* The Journal of Neuroscience **31**, no. 24 (2011): 8699-8705. ``_ 251 | 252 | .. [Hanke2009] Hanke, Michael, Yaroslav O. Halchenko, Per B. Sederberg, Emanuele Olivetti, Ingo Fründ, Jochem W. Rieger, Christoph S. Herrmann, James V. Haxby, Stephen José Hanson, and Stefan Pollmann. *“PyMVPA: a Unifying Approach to the Analysis of  Neuroscientific Data.”* Frontiers in Neuroinformatics **3** (2009): 3. 253 | 254 | .. [Langtangen2009] Langtangen, Hans Petter. *Python Scripting for Computational Science*. 3rd ed. Springer, 2009. 255 | 256 | -------------------------------------------------------------------------------- /src/spike_sort/core/extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding=utf-8 3 | 4 | from warnings import warn 5 | import operator 6 | 7 | from scipy import interpolate 8 | import numpy as np 9 | 10 | 11 | def split_cells(spikes, idx, which='all'): 12 | """Return the spike features splitted into separate cells 13 | """ 14 | 15 | if which == 'all': 16 | classes = np.unique(idx) 17 | else: 18 | classes = which 19 | 20 | data = spikes['data'] 21 | time = spikes['time'] 22 | spikes_dict = dict([(cl, {'data': data[:, idx == cl, :], 'time': time}) 23 | for cl in classes]) 24 | return spikes_dict 25 | 26 | 27 | def remove_spikes(spt_dict, remove_dict, tolerance): 28 | """Remove spikes with given spike times from the spike time 29 | structure """ 30 | spt_data = spt_dict['data'] 31 | spt_remove = remove_dict['data'] 32 | 33 | mn, mx = tolerance 34 | 35 | for t in spt_remove: 36 | spt_data = spt_data[(spt_data > (t + mx)) | (spt_data < (t + mn))] 37 | 38 | spt_ret = spt_dict.copy() 39 | spt_ret['data'] = spt_data 40 | return spt_ret 41 | 42 | 43 | def detect_spikes(spike_data, thresh='auto', edge="rising", 44 | contact=0): 45 | r"""Detects spikes in extracellular data using amplitude thresholding. 46 | 47 | Parameters 48 | ---------- 49 | spike_data : dict 50 | extracellular waveforms 51 | thresh : float or 'auto' 52 | threshold for detection. if thresh is 'auto' it will be 53 | estimated from the data. 54 | edge : {'rising', 'falling'} 55 | which edge to trigger on 56 | contact : int, optional 57 | index of tetrode contact to use for detection, defaults to 58 | first contact 59 | 60 | Returns 61 | ------- 62 | spt_dict : dict 63 | dictionary with 'data' key which contains detected threshold 64 | crossing in miliseconds 65 | 66 | """ 67 | 68 | sp_data = spike_data['data'][contact, :] 69 | 70 | FS = spike_data['FS'] 71 | 72 | edges = ('rising', 'max', 'falling', 'min') 73 | if isinstance(thresh, basestring): 74 | if thresh == 'auto': 75 | thresh_frac = 8.0 76 | else: 77 | thresh_frac = float(thresh) 78 | 79 | thresh = thresh_frac * np.sqrt(sp_data[:10 * FS].var(dtype=np.float64)) 80 | if edge in edges[2:]: 81 | thresh = -thresh 82 | 83 | if edge not in edges: 84 | raise TypeError("'edge' parameter must be 'rising' or 'falling'") 85 | 86 | op1, op2 = operator.lt, operator.gt 87 | 88 | if edge in edges[2:]: 89 | op1, op2 = op2, op1 90 | 91 | i, = np.where(op1(sp_data[:-1], thresh) & op2(sp_data[1:], thresh)) 92 | 93 | spt = i * 1000.0 / FS 94 | return {'data': spt, 'thresh': thresh, 'contact': contact} 95 | 96 | 97 | def filter_spt(spike_data, spt_dict, sp_win): 98 | spt = spt_dict['data'] 99 | sp_data = spike_data['data'] 100 | FS = spike_data['FS'] 101 | 102 | try: 103 | n_pts = sp_data.shape[1] 104 | except IndexError: 105 | n_pts = len(sp_data) 106 | max_time = n_pts * 1000.0 / FS 107 | 108 | t_min = np.max((-sp_win[0], 0)) 109 | t_max = np.min((max_time, max_time - sp_win[1])) 110 | idx, = np.nonzero((spt >= t_min) & (spt <= t_max)) 111 | return idx 112 | 113 | 114 | def extract_spikes(spike_data, spt_dict, sp_win, 115 | resample=1, contacts='all'): 116 | """Extract spikes from recording. 117 | 118 | Parameters 119 | ---------- 120 | spike_data : dict 121 | extracellular data (see :ref:`raw_recording`) 122 | spt : dict 123 | spike times structure (see :ref:`spike_times`) 124 | sp_win : list of int 125 | temporal extent of the wave shape 126 | 127 | Returns 128 | ------- 129 | wavedict : dict 130 | spike waveforms structure (see :ref:`spike_wave`) 131 | 132 | 133 | """ 134 | sp_data = spike_data['data'] 135 | n_contacts = spike_data['n_contacts'] 136 | 137 | if contacts == "all": 138 | contacts = np.arange(n_contacts) 139 | elif isinstance(contacts, int): 140 | contacts = np.array([contacts]) 141 | else: 142 | contacts = np.asarray(contacts) 143 | 144 | FS = spike_data['FS'] 145 | spt = spt_dict['data'] 146 | idx = np.arange(len(spt)) 147 | inner_idx = filter_spt(spike_data, spt_dict, sp_win) 148 | outer_idx = idx[~np.in1d(idx, inner_idx)] 149 | 150 | indices = (spt / 1000.0 * FS).astype(np.int32) 151 | win = (np.asarray(sp_win) / 1000.0 * FS).astype(np.int32) 152 | time = np.arange(win[1] - win[0]) * 1000.0 / FS + sp_win[0] 153 | n_contacts, n_pts = sp_data.shape 154 | 155 | # auxiliary function to find a valid spike window within data range 156 | minmax = lambda x: np.max([np.min([n_pts, x]), 0]) 157 | spWave = np.zeros((len(time), len(spt), len(contacts)), 158 | dtype=np.float32) 159 | 160 | for i in inner_idx: 161 | sp = indices[i] 162 | spWave[:, i, :] = np.atleast_2d(sp_data[contacts, 163 | sp + win[0]:sp + win[1]]).T 164 | 165 | for i in outer_idx: 166 | sp = indices[i] 167 | l, r = map(minmax, sp + win) 168 | if l != r: 169 | spWave[(l - sp) - win[0]:(r - sp) - win[0], i, :] = \ 170 | sp_data[contacts, l:r].T 171 | 172 | wavedict = {"data": spWave, "time": time, "FS": FS} 173 | 174 | if len(idx) != len(inner_idx): 175 | is_valid = np.zeros(len(spt), dtype=np.bool) 176 | is_valid[inner_idx] = True 177 | wavedict['is_valid'] = is_valid 178 | 179 | if resample != 1: 180 | warn("resample argument is deprecated." 181 | "Please update your code to use function" 182 | "resample_spikes", DeprecationWarning) 183 | wavedict = resample_spikes(wavedict, FS * resample) 184 | return wavedict 185 | 186 | 187 | def resample_spikes(spikes_dict, FS_new): 188 | """Upsample spike waveforms using spline interpolation""" 189 | 190 | sp_waves = spikes_dict['data'] 191 | time = spikes_dict['time'] 192 | FS = spikes_dict['FS'] 193 | 194 | resamp_time = np.arange(time[0], time[-1], 1000.0 / FS_new) 195 | n_pts, n_spikes, n_contacts = sp_waves.shape 196 | 197 | spike_resamp = np.empty((len(resamp_time), n_spikes, n_contacts)) 198 | 199 | for i in range(n_spikes): 200 | for contact in range(n_contacts): 201 | tck = interpolate.splrep(time, sp_waves[:, i, contact], s=0) 202 | spike_resamp[:, i, contact] = interpolate.splev(resamp_time, 203 | tck, der=0) 204 | 205 | return {"data": spike_resamp, "time": resamp_time, "FS": FS} 206 | 207 | 208 | def align_spikes(spike_data, spt_dict, sp_win, type="max", resample=1, 209 | contact=0, remove=True): 210 | """Aligns spike waves and returns corrected spike times 211 | 212 | Parameters 213 | ---------- 214 | spike_data : dict 215 | spt_dict : dict 216 | sp_win : list of int 217 | type : {'max', 'min'}, optional 218 | resample : int, optional 219 | contact : int, optional 220 | remove : bool, optiona 221 | 222 | Returns 223 | ------- 224 | ret_dict : dict 225 | spike times of aligned spikes 226 | 227 | """ 228 | 229 | tol = 0.1 230 | 231 | if (sp_win[0] > -tol) or (sp_win[1] < tol): 232 | warn('You are using very short sp_win. ' 233 | 'This may lead to alignment problems.') 234 | 235 | spt = spt_dict['data'].copy() 236 | 237 | idx_align = np.arange(len(spt)) 238 | 239 | #go in a loop until all spikes are correctly aligned 240 | iter_id = 0 241 | while len(idx_align) > 0: 242 | spt_align = {'data': spt[idx_align]} 243 | spt_inbound = filter_spt(spike_data, spt_align, sp_win) 244 | idx_align = idx_align[spt_inbound] 245 | sp_waves_dict = extract_spikes(spike_data, spt_align, sp_win, 246 | resample=resample, contacts=contact) 247 | 248 | sp_waves = sp_waves_dict['data'][:, spt_inbound, 0] 249 | time = sp_waves_dict['time'] 250 | 251 | if type == "max": 252 | i = sp_waves.argmax(0) 253 | elif type == "min": 254 | i = sp_waves.argmin(0) 255 | 256 | #move spike markers 257 | shift = time[i] 258 | spt[idx_align] += shift 259 | 260 | #if spike maximum/minimum was at the edge we have to extract it at the 261 | # new marker and repeat the alignment 262 | 263 | idx_align = idx_align[(shift < (sp_win[0] + tol)) | 264 | (shift > (sp_win[1] - tol))] 265 | iter_id += 1 266 | 267 | ret_dict = {'data': spt} 268 | 269 | if remove: 270 | #remove double spikes 271 | FS = spike_data['FS'] 272 | ret_dict = remove_doubles(ret_dict, 1000.0 / FS) 273 | 274 | return ret_dict 275 | 276 | 277 | def remove_doubles(spt_dict, tol): 278 | new_dict = spt_dict.copy() 279 | spt = spt_dict['data'] 280 | 281 | isi = np.diff(spt) 282 | intisi = (isi/tol).astype('int') 283 | 284 | if len(spt) > 0: 285 | spt = spt[np.concatenate(([True], intisi > 1))] 286 | 287 | new_dict['data'] = spt 288 | return new_dict 289 | 290 | 291 | def merge_spikes(spike_waves1, spike_waves2): 292 | """Merges two sets of spike waves 293 | 294 | Parameters 295 | ---------- 296 | 297 | spike_waves1 : dict 298 | spike_waves2 : dict 299 | spike wavefroms to merge; both spike wave sets must be defined 300 | within the same time window and with the same sampling 301 | frequency 302 | 303 | Returns 304 | ------- 305 | 306 | spike_waves : dict 307 | merged spike waveshapes 308 | 309 | clust_idx : array 310 | labels denoting to which set the given spike originally belonged to 311 | """ 312 | 313 | sp_data1 = spike_waves1['data'] 314 | sp_data2 = spike_waves2['data'] 315 | 316 | sp_data = np.hstack((sp_data1, sp_data2)) 317 | spike_waves = spike_waves1.copy() 318 | spike_waves['data'] = sp_data 319 | 320 | clust_idx = np.concatenate((np.ones(sp_data1.shape[1]), 321 | np.zeros(sp_data2.shape[1]))) 322 | 323 | return spike_waves, clust_idx 324 | 325 | 326 | def merge_spiketimes(spt1, spt2, sort=True): 327 | """Merges two sets of spike times 328 | 329 | Parameters 330 | ---------- 331 | spt1 : dict 332 | spt2 : dict 333 | sort : bool, optional 334 | False if you don't want to be the spike times sorted. 335 | 336 | Returns 337 | ------- 338 | spt : dict 339 | dictionary with merged spike time arrrays under data key 340 | clust_idx : array 341 | labels denoting to which set the given spike originally belonged 342 | to 343 | 344 | """ 345 | 346 | spt_data1 = spt1['data'] 347 | spt_data2 = spt2['data'] 348 | 349 | spt_data = np.concatenate((spt_data1, spt_data2)) 350 | i = spt_data.argsort() 351 | spt_data = spt_data[i] 352 | 353 | clust_idx = np.concatenate((np.ones(spt_data1.shape[0]), 354 | np.zeros(spt_data2.shape[0]))) 355 | clust_idx = clust_idx[i] 356 | spt_dict = {"data": spt_data} 357 | 358 | return spt_dict, clust_idx 359 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spike_sort as ss 3 | 4 | from nose.tools import ok_, eq_, raises 5 | from numpy.testing import assert_array_almost_equal as almost_equal 6 | 7 | import os 8 | 9 | class TestFilter(object): 10 | def __init__(self): 11 | self.n_spikes = 100 12 | 13 | self.FS = 25E3 14 | self.period = 1000.0 / self.FS * 100 15 | self.time = np.arange(0, self.period * self.n_spikes, 1000.0 / self.FS) 16 | self.spikes = np.sin(2 * np.pi / self.period * self.time)[np.newaxis, :] 17 | self.spk_data = {"data": self.spikes, "n_contacts": 1, "FS": self.FS} 18 | 19 | def test_filter_proxy(self): 20 | sp_freq = 1000.0 / self.period 21 | filter = ss.filters.Filter(sp_freq * 0.5, sp_freq * 0.4, 1, 10, 'ellip') 22 | spk_filt = ss.filters.filter_proxy(self.spk_data, filter) 23 | ok_(self.spk_data['data'].shape == spk_filt['data'].shape) 24 | 25 | def test_remove_proxy_files_after_exit(self): 26 | filter_func = lambda x, f: x 27 | spk_filt = ss.filters.filter_proxy(self.spk_data, filter_func) 28 | fname = spk_filt['data']._v_file.filename 29 | ss.filters.clean_after_exit() 30 | assert not os.path.isfile(fname) 31 | 32 | 33 | def test_LiearIIR_detect(self): 34 | n_spikes = self.n_spikes 35 | period = self.period 36 | threshold = 0.5 37 | sp_freq = 1000.0 / period 38 | 39 | self.spk_data['data'] += 2 40 | spk_filt = ss.filters.fltLinearIIR(self.spk_data, sp_freq * 0.5, sp_freq * 0.4, 1, 10, 'ellip') 41 | 42 | spt = ss.extract.detect_spikes(spk_filt, thresh=threshold) 43 | ok_(len(spt['data']) == n_spikes) 44 | 45 | class TestExtract(object): 46 | def __init__(self): 47 | self.n_spikes = 100 48 | 49 | self.FS = 25E3 50 | self.period = 1000.0 / self.FS * 100 51 | self.time = np.arange(0, self.period * self.n_spikes, 1000.0 / self.FS) 52 | self.spikes = np.sin(2 * np.pi / self.period * self.time)[np.newaxis, :] 53 | self.spk_data = {"data": self.spikes, "n_contacts": 1, "FS": self.FS} 54 | 55 | def test_detect(self): 56 | n_spikes = self.n_spikes 57 | period = self.period 58 | FS = self.FS 59 | time = self.time 60 | spikes = self.spikes 61 | threshold = 0.5 62 | crossings_real = period / 12.0 + np.arange(n_spikes) * period 63 | spt = ss.extract.detect_spikes(self.spk_data, thresh=threshold) 64 | ok_((np.abs(spt['data'] - crossings_real) <= 1000.0 / FS).all()) 65 | 66 | def test_align(self): 67 | #check whether spikes are correctly aligned to maxima 68 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes)) 69 | thr_crossings = self.period * (1 / 6.0 + np.arange(self.n_spikes)) 70 | spt_dict = {"data": thr_crossings} 71 | sp_win = [-self.period / 6.0, self.period / 3.0] 72 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win) 73 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all()) 74 | 75 | def test_align_short_win(self): 76 | #test spike alignment with windows shorter than total spike duration 77 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes)) 78 | thr_crossings = self.period * (1 / 6.0 + np.arange(self.n_spikes)) 79 | spt_dict = {"data": thr_crossings} 80 | sp_win = [-self.period / 24.0, self.period / 12.0] 81 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win) 82 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all()) 83 | 84 | def test_align_edge(self): 85 | #??? 86 | spikes = np.sin(2 * np.pi / self.period * self.time + np.pi / 2.0)[np.newaxis, :] 87 | maxima_idx = self.period * (np.arange(1, self.n_spikes + 1)) 88 | thr_crossings = self.period * (-1 / 6.0 + np.arange(1, self.n_spikes + 1)) 89 | spt_dict = {"data": thr_crossings} 90 | sp_win = [-self.period / 24.0, self.period / 12.0] 91 | spk_data = {"data": spikes, "n_contacts": 1, "FS": self.FS} 92 | spt = ss.extract.align_spikes(spk_data, spt_dict, sp_win) 93 | last = spt['data'][-1] 94 | ok_((last >= (self.time[-1] - sp_win[1])) & (last <= self.time[-1])) 95 | 96 | def test_align_double_spikes(self): 97 | #double detections of the same spike should be removed 98 | maxima_idx = self.period * (1 / 4.0 + np.arange(self.n_spikes)) 99 | thr_crossings = self.period * (1 / 6.0 + np.arange(0, self.n_spikes, 0.5)) 100 | spt_dict = {"data": thr_crossings} 101 | sp_win = [-self.period / 24.0, self.period / 12.0] 102 | spt = ss.extract.align_spikes(self.spk_data, spt_dict, sp_win) 103 | ok_((np.abs(spt['data'] - maxima_idx) <= 1000.0 / self.FS).all()) 104 | 105 | def test_remove_doubles_roundoff(self): 106 | # remove_doubles should account for slight variations around 107 | # the 'tolerance' which may occur due to some round-off errors 108 | # during spike detection/alignment. This won't affect any useful 109 | # information, because the tolerance is one sample large in this 110 | # test. 111 | 112 | tol = 1000.0 / self.FS # [ms], one sample 113 | data = [1.0, 1.0 + tol + tol * 0.01] 114 | spt_dict = {"data": np.array(data)} 115 | clean_spt_dict = ss.extract.remove_doubles(spt_dict, tol) 116 | ok_(len(clean_spt_dict['data']) == 1) # duplicate removed 117 | 118 | def test_extract(self): 119 | zero_crossing = self.period * np.arange(self.n_spikes) 120 | zero_crossing += 1000.0 / self.FS / 2.0 # move by half a sample to avoid round-off errors 121 | spt_dict = {"data": zero_crossing} 122 | sp_win = [0, self.period] 123 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win) 124 | ref_sp = np.sin(2 * np.pi / self.period * sp_waves['time']) 125 | ok_((np.abs(ref_sp[:, np.newaxis] - sp_waves['data'][:, :, 0]) < 1E-6).all()) 126 | #ok_((np.abs(sp_waves['data'][:,:,0].mean(1)-ref_sp)<2*1000*np.pi/(self.FS*self.period)).all()) 127 | #ok_(np.abs(np.sum(sp_waves['data'][:,:,0].mean(1)-ref_sp))<1E-6) 128 | 129 | #def test_extract_resample_deprecation(self): 130 | # zero_crossing = self.period*np.arange(self.n_spikes) 131 | # spt_dict = {"data":zero_crossing} 132 | # sp_win = [0, self.period] 133 | # with warnings.catch_warnings(True) as w: 134 | # sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win, 135 | # resample=2.) 136 | # ok_(len(w)>=1) 137 | 138 | def test_extract_and_resample(self): 139 | zero_crossing = self.period * np.arange(self.n_spikes) 140 | zero_crossing += 1000.0 / self.FS / 2.0 # move by half a sample to avoid round-off errors 141 | spt_dict = {"data": zero_crossing} 142 | sp_win = [0, self.period] 143 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win) 144 | sp_resamp = ss.extract.resample_spikes(sp_waves, self.FS * 2) 145 | ref_sp = np.sin(2 * np.pi / self.period * sp_resamp['time']) 146 | ok_((np.abs(ref_sp[:, np.newaxis] - sp_resamp['data'][:, :, 0]) < 1E-6).all()) 147 | 148 | def test_mask_of_truncated_spikes(self): 149 | zero_crossing = self.period * np.arange(self.n_spikes + 1) 150 | spt_dict = {"data": zero_crossing} 151 | sp_win = [0, self.period] 152 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win) 153 | correct_mask = np.ones(self.n_spikes + 1, np.bool) 154 | correct_mask[-1] = False 155 | ok_((sp_waves['is_valid'] == correct_mask).all()) 156 | #ok_(np.abs(np.sum(sp_waves['data'][:,:,0].mean(1)-ref_sp))<1E-6) 157 | 158 | def test_extract_truncated_spike_end(self): 159 | zero_crossing = np.array([self.period * (self.n_spikes - 0.5)]) 160 | spt_dict = {"data": zero_crossing} 161 | sp_win = [0, self.period] 162 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win) 163 | ref_sp = -np.sin(2 * np.pi / self.period * sp_waves['time']) 164 | ref_sp[len(ref_sp) / 2:] = 0 165 | almost_equal(sp_waves['data'][:, 0, 0], ref_sp) 166 | 167 | def test_extract_truncated_spike_end(self): 168 | zero_crossing = np.array([-self.period * 0.5]) 169 | spt_dict = {"data": zero_crossing} 170 | sp_win = [0, self.period] 171 | sp_waves = ss.extract.extract_spikes(self.spk_data, spt_dict, sp_win) 172 | ref_sp = -np.sin(2 * np.pi / self.period * sp_waves['time']) 173 | ref_sp[:len(ref_sp) / 2] = 0 174 | almost_equal(sp_waves['data'][:, 0, 0], ref_sp) 175 | 176 | def test_filter_spt(self): 177 | #out of band spikes should be removed 178 | zero_crossing = self.period * (np.arange(self.n_spikes)) 179 | spt_dict = {"data": zero_crossing} 180 | sp_win = [0, self.period] 181 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win) 182 | ok_(len(spt_filt) == self.n_spikes) 183 | 184 | def test_filter_spt_shorten_left(self): 185 | #remove out-of-band spikes from the beginning of the train 186 | zero_crossing = self.period * (np.arange(self.n_spikes)) 187 | spt_dict = {"data": zero_crossing} 188 | sp_win = [-self.period / 8, self.period / 8.0] 189 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win) 190 | ok_(len(spt_filt) == (self.n_spikes - 1)) 191 | 192 | def test_filter_spt_shorten_right(self): 193 | #remove out-of-band spikes from the end of the train 194 | zero_crossing = self.period * (np.arange(self.n_spikes)) 195 | spt_dict = {"data": zero_crossing} 196 | sp_win = [0, self.period + 1000.0 / self.FS] 197 | spt_filt = ss.extract.filter_spt(self.spk_data, spt_dict, sp_win) 198 | ok_(len(spt_filt) == (self.n_spikes - 1)) 199 | 200 | class TestCluster(object): 201 | """test clustering algorithms""" 202 | 203 | def _cmp_bin_partitions(self, cl1, cl2): 204 | return (~np.logical_xor(cl1, cl2)).all() or (np.logical_xor(cl1, cl2)).all() 205 | 206 | def setup(self): 207 | 208 | self.K = 2 209 | 210 | n_dim = 2 211 | pts_in_clust = 100 212 | np.random.seed(1234) 213 | data = np.vstack((np.random.rand(pts_in_clust, n_dim), 214 | np.random.rand(pts_in_clust, n_dim) + 2 * np.ones(n_dim))) 215 | self.labels = np.concatenate((np.zeros(pts_in_clust, dtype=int), 216 | np.ones(pts_in_clust, dtype=int))) 217 | feature_labels = ["feat%d" % i for i in range(n_dim)] 218 | self.features = {"data": data, "names": feature_labels} 219 | 220 | def test_k_means(self): 221 | """test own k-means algorithm""" 222 | 223 | cl = ss.cluster.cluster('k_means', self.features, self.K) 224 | ok_(self._cmp_bin_partitions(cl, self.labels)) 225 | 226 | def test_k_means_plus(self): 227 | """test scikits k-means plus algorithm""" 228 | 229 | cl = ss.cluster.cluster('k_means_plus', self.features, self.K) 230 | ok_(self._cmp_bin_partitions(cl, self.labels)) 231 | 232 | def test_gmm(self): 233 | """test gmm clustering algorithm""" 234 | 235 | cl = ss.cluster.cluster('gmm', self.features, self.K) 236 | ok_(self._cmp_bin_partitions(cl, self.labels)) 237 | 238 | def test_random(self): 239 | cl = np.random.rand(len(self.labels)) > 0.5 240 | ok_(~self._cmp_bin_partitions(cl, self.labels)) 241 | 242 | @raises(NotImplementedError) 243 | def test_method_notimplemented(self): 244 | cl = ss.cluster.cluster("notimplemented", self.features) 245 | --------------------------------------------------------------------------------