├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── spiketoolkit ├── __init__.py ├── curation ├── __init__.py └── threshold_metrics.py ├── postprocessing ├── __init__.py ├── features.py ├── postprocessing_tools.py └── utils.py ├── preprocessing ├── __init__.py ├── bandpass_filter.py ├── basepreprocessorrecording.py ├── blank_saturation.py ├── center.py ├── clip.py ├── common_reference.py ├── filterrecording.py ├── highpass_filter.py ├── mask.py ├── normalize_by_quantile.py ├── notch_filter.py ├── preprocessinglist.py ├── rectify.py ├── remove_artifacts.py ├── remove_bad_channels.py ├── resample.py ├── transform.py └── whiten.py ├── sortingcomponents ├── __init__.py └── detection.py ├── tests ├── .gitignore ├── __init__.py ├── test_curation.py ├── test_curation_extractor.py ├── test_postprocessing.py ├── test_preprocessing.py ├── test_sortingcomponents.py ├── test_validation.py └── utils.py ├── utils.py ├── validation ├── __init__.py ├── curation_list.py ├── quality_metric_classes │ ├── __init__.py │ ├── amplitude_cutoff.py │ ├── d_prime.py │ ├── drift_metric.py │ ├── firing_rate.py │ ├── isi_violation.py │ ├── isolation_distance.py │ ├── l_ratio.py │ ├── metric_data.py │ ├── nearest_neighbor.py │ ├── noise_overlap.py │ ├── num_spikes.py │ ├── parameter_dictionaries.py │ ├── presence_ratio.py │ ├── quality_metric.py │ ├── silhouette_score.py │ ├── snr.py │ └── utils │ │ ├── __init__.py │ │ ├── curationsortingextractor.py │ │ ├── thresholdcurator.py │ │ └── validation_tools.py ├── quality_metrics.py └── validation_list.py └── version.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build-and-test: 7 | name: Test on (${{ matrix.os }}) 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | os: ["ubuntu-latest", "macos-latest", "windows-latest"] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: s-weigand/setup-conda@v1 16 | with: 17 | python-version: 3.8 18 | - name: Which python 19 | run: | 20 | conda --version 21 | which python 22 | - name: Install dependencies 23 | run: | 24 | pip install https://github.com/SpikeInterface/spikeextractors/archive/master.zip 25 | pip install https://github.com/SpikeInterface/spikemetrics/archive/master.zip 26 | pip install https://github.com/SpikeInterface/spikefeatures/archive/master.zip 27 | pip install -e . 28 | pip install pytest 29 | - name: Test with pytest and build coverage report 30 | run: | 31 | pytest 32 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Test and Upload Python Package 5 | 6 | on: 7 | push: 8 | tags: 9 | - '*' 10 | 11 | jobs: 12 | deploy: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 3.8 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.8 22 | - name: Add conda to system path 23 | run: | 24 | # $CONDA is an environment variable pointing to the root of the miniconda directory 25 | echo $CONDA/bin >> $GITHUB_PATH 26 | - name: Install dependencies 27 | run: | 28 | pip install numpy>=1.20 29 | pip install pandas>=1.2 30 | pip install -e . 31 | pip install pytest 32 | pip install setuptools wheel twine 33 | - name: Test with pytest and build coverage report 34 | run: | 35 | pytest 36 | - name: Publish on PyPI 37 | env: 38 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 39 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 40 | run: | 41 | python setup.py sdist bdist_wheel 42 | twine upload dist/* 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | .vscode 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 SpikeInterface 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spiketoolkit (LEGACY) 2 | 3 | The `spiketoolkit` package has now been integrated into [spikeinterface](https://github.com/SpikeInterface/spikeinterface). 4 | 5 | This package will be maintained for a while for bug fixes only, then it will be deprecated. 6 | 7 | New features and improvements will only be implemented for the `spikeinterface` package. 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | d = {} 4 | exec(open("spiketoolkit/version.py").read(), None, d) 5 | version = d['version'] 6 | long_description = open("README.md").read() 7 | 8 | pkg_name = "spiketoolkit" 9 | 10 | setup( 11 | name=pkg_name, 12 | version=version, 13 | author="Alessio Buccino, Cole Hurwitz, Samuel Garcia, Jeremy Magland, Matthias Hennig", 14 | author_email="alessiop.buccino@gmail.com", 15 | description="Python toolkit for analysis, visualization, and comparison of spike sorting output", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/alejoe91/spiketoolkit", 19 | packages=find_packages(), 20 | package_data={}, 21 | install_requires=[ 22 | 'numpy', 23 | 'spikeextractors>=0.9.7', 24 | 'spikemetrics>=0.2.4', 25 | 'spikefeatures', 26 | 'scikit-learn', 27 | 'scipy', 28 | 'pandas', 29 | 'networkx', 30 | 'joblib' 31 | ], 32 | classifiers=( 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: Apache Software License", 35 | "Operating System :: OS Independent", 36 | ) 37 | ) 38 | -------------------------------------------------------------------------------- /spiketoolkit/__init__.py: -------------------------------------------------------------------------------- 1 | from . import postprocessing 2 | from . import preprocessing 3 | from . import validation 4 | from . import curation 5 | from . import sortingcomponents 6 | 7 | from .version import version as __version__ 8 | -------------------------------------------------------------------------------- /spiketoolkit/curation/__init__.py: -------------------------------------------------------------------------------- 1 | from .threshold_metrics import threshold_num_spikes 2 | from .threshold_metrics import threshold_firing_rates 3 | from .threshold_metrics import threshold_presence_ratios 4 | from .threshold_metrics import threshold_isi_violations 5 | from .threshold_metrics import threshold_snrs 6 | from .threshold_metrics import threshold_silhouette_scores 7 | from .threshold_metrics import threshold_d_primes 8 | from .threshold_metrics import threshold_l_ratios 9 | from .threshold_metrics import threshold_isolation_distances 10 | from .threshold_metrics import threshold_nn_metrics 11 | from .threshold_metrics import threshold_drift_metrics 12 | from .threshold_metrics import threshold_amplitude_cutoffs 13 | from .threshold_metrics import threshold_noise_overlaps 14 | from ..validation import get_validation_params as get_curation_params 15 | from ..validation.quality_metric_classes.utils.curationsortingextractor import CurationSortingExtractor 16 | 17 | -------------------------------------------------------------------------------- /spiketoolkit/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocessing_tools import get_unit_waveforms, get_unit_templates, get_unit_max_channels, get_unit_amplitudes,\ 2 | compute_unit_pca_scores, export_to_phy, set_unit_properties_by_max_channel_properties,\ 3 | compute_channel_spiking_activity, compute_unit_centers_of_mass 4 | 5 | from .features import compute_unit_template_features, get_template_features_list 6 | 7 | from .utils import get_waveforms_params, get_pca_params, get_amplitudes_params, get_common_params, \ 8 | get_postprocessing_params 9 | -------------------------------------------------------------------------------- /spiketoolkit/postprocessing/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uses the functions in SpikeInterface/spikefeatures to compute 3 | unit template features 4 | """ 5 | 6 | import pandas 7 | import spikefeatures as sf 8 | from scipy.signal import resample_poly 9 | from .postprocessing_tools import get_unit_templates, get_unit_max_channels 10 | from .utils import update_all_param_dicts_with_kwargs, select_max_channels_from_templates 11 | import numpy as np 12 | 13 | 14 | def get_template_features_list(): 15 | return sf.all_1D_features 16 | 17 | 18 | def compute_unit_template_features(recording, sorting, unit_ids=None, channel_ids=None, feature_names=None, 19 | max_channels_per_features=1, recovery_slope_window=0.7, upsampling_factor=1, 20 | invert_waveforms=False, as_dataframe=False, **kwargs): 21 | """ 22 | Use SpikeInterface/spikefeatures to compute features for the unit template. 23 | 24 | These consist of a set of 1D features: 25 | - peak to valley (peak_to_valley), time between peak and valley 26 | - halfwidth (halfwidth), width of peak at half its amplitude 27 | - peak trough ratio (peak_trough_ratio), amplitude of peak over amplitude of trough 28 | - repolarization slope (repolarization_slope), slope between trough and return to base 29 | - recovery slope (recovery_slope), slope after peak towards baseline 30 | 31 | And 2D features: 32 | - unit_spread 33 | - propagation velocity 34 | To be implemented 35 | 36 | The metrics are computed on 'negative' waveforms, if templates are saved as 37 | positive, pass keyword 'invert_waveforms'. 38 | 39 | Parameters 40 | ---------- 41 | recording: RecordingExtractor 42 | The recording extractor 43 | sorting: SortingExtractor 44 | The sorting extractor 45 | unit_ids: list 46 | List of unit ids to compute features 47 | channel_ids: list 48 | List of channels ids to compute templates on which features are computed 49 | feature_names: list 50 | List of feature names to be computed. If None, all features are computed 51 | max_channels_per_features: int 52 | Maximum number of channels to compute features on (default 1). If channel_ids is used, this parameter 53 | is ignored 54 | upsampling_factor: int 55 | Factor with which to upsample the template resolution (default 1) 56 | invert_waveforms: bool 57 | Invert templates before computing features (default False) 58 | recovery_slope_window: float 59 | Window after peak in ms wherein to compute recovery slope (default 0.7) 60 | as_dataframe: bool 61 | IfTrue, output is returned as a pandas dataframe, otherwise as a dictionary 62 | **kwargs: Keyword arguments 63 | A dictionary with default values can be retrieved with: 64 | st.postprocessing.get_waveforms_params(): 65 | grouping_property: str 66 | Property to group channels. E.g. if the recording extractor has the 'group' property and 67 | 'grouping_property' is 'group', then waveforms are computed group-wise. 68 | ms_before: float 69 | Time period in ms to cut waveforms before the spike events 70 | ms_after: float 71 | Time period in ms to cut waveforms after the spike events 72 | dtype: dtype 73 | The numpy dtype of the waveforms 74 | compute_property_from_recording: bool 75 | If True and 'grouping_property' is given, the property of each unit is assigned as the corresponding 76 | property of the recording extractor channel on which the average waveform is the largest 77 | max_channels_per_waveforms: int or None 78 | Maximum channels per waveforms to return. If None, all channels are returned 79 | n_jobs: int 80 | Number of parallel jobs (default 1) 81 | max_spikes_per_unit: int 82 | The maximum number of spikes to extract per unit 83 | memmap: bool 84 | If True, waveforms are saved as memmap object (recommended for long recordings with many channels) 85 | seed: int 86 | Random seed for extracting random waveforms 87 | save_property_or_features: bool 88 | If True (default), waveforms are saved as features of the sorting extractor object 89 | recompute_info: bool 90 | If True, waveforms are recomputed (default False) 91 | verbose: bool 92 | If True output is verbose 93 | 94 | 95 | Returns 96 | ------- 97 | features: dict or pandas.DataFrame 98 | The computed features as a dictionary or a pandas.DataFrame (if as_dataframe is True) 99 | """ 100 | 101 | # ------------------- SETUP ------------------------------ 102 | if isinstance(unit_ids, (int, np.integer)): 103 | unit_ids = [unit_ids] 104 | elif unit_ids is None: 105 | unit_ids = sorting.get_unit_ids() 106 | elif not isinstance(unit_ids, (list, np.ndarray)): 107 | raise Exception("unit_ids is is invalid") 108 | if isinstance(channel_ids, (int, np.integer)): 109 | channel_ids = [channel_ids] 110 | 111 | if channel_ids is None: 112 | channel_ids = recording.get_channel_ids() 113 | 114 | assert np.all([u in sorting.get_unit_ids() for u in unit_ids]), "Invalid unit_ids" 115 | assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Invalid channel_ids" 116 | 117 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 118 | save_property_or_features = params_dict['save_property_or_features'] 119 | 120 | if feature_names is None: 121 | feature_names = sf.all_1D_features 122 | else: 123 | bad_features = [] 124 | for m in feature_names: 125 | if m not in sf.all_1D_features: 126 | bad_features.append(m) 127 | if len(bad_features) > 0: 128 | raise ValueError(f"Improper feature names: {str(bad_features)}. The following features names can be " 129 | f"calculated: {str(sf.all_1D_features)}") 130 | 131 | templates = np.array(get_unit_templates(recording, sorting, unit_ids=unit_ids, channel_ids=channel_ids, 132 | mode='median', **kwargs)) 133 | 134 | # deal with templates with different shapes 135 | shape_0 = templates[0].shape 136 | if np.all([t.shape == shape_0 for t in templates]): 137 | same_shape = True 138 | else: 139 | same_shape = False 140 | 141 | # -------------------- PROCESS TEMPLATES ----------------------------- 142 | if upsampling_factor > 1: 143 | upsampling_factor = int(upsampling_factor) 144 | if same_shape: 145 | processed_templates = resample_poly(templates, up=upsampling_factor, down=1, axis=2) 146 | else: 147 | processed_templates = [] 148 | for temp in templates: 149 | processed_templates.append(resample_poly(temp, up=upsampling_factor, down=1, axis=1)) 150 | resampled_fs = recording.get_sampling_frequency() * upsampling_factor 151 | else: 152 | processed_templates = templates 153 | resampled_fs = recording.get_sampling_frequency() 154 | 155 | if invert_waveforms: 156 | processed_templates = -processed_templates 157 | 158 | features_dict = dict() 159 | for feat in feature_names: 160 | features_dict[feat] = [] 161 | # --------------------- COMPUTE FEATURES ------------------------------ 162 | for unit_id, unit in enumerate(unit_ids): 163 | template = processed_templates[unit_id] 164 | max_channel_idxs = select_max_channels_from_templates(template, recording, max_channels_per_features) 165 | template_channels = template[max_channel_idxs] 166 | if len(template_channels.shape) == 1: 167 | template_channels = template_channels[np.newaxis, :] 168 | feat_list = sf.calculate_features(waveforms=template_channels, 169 | sampling_frequency=resampled_fs, 170 | feature_names=feature_names, 171 | recovery_slope_window=recovery_slope_window) 172 | 173 | for feat, feat_val in feat_list.items(): 174 | features_dict[feat].append(feat_val) 175 | 176 | # ---------------------- DEAL WITH OUTPUT ------------------------- 177 | if save_property_or_features: 178 | for feat_name, feat_val in features_dict.items(): 179 | for i_u, unit in enumerate(sorting.get_unit_ids()): 180 | if len(feat_val[i_u]) == 1: 181 | feat_val[i_u] = feat_val[i_u][0] 182 | sorting.set_unit_property(unit, 183 | property_name=feat_name, 184 | value=feat_val[i_u]) 185 | if as_dataframe: 186 | features = pandas.DataFrame.from_dict(features_dict) 187 | features = features.rename(index={original_idx: unit_ids[i] for 188 | i, original_idx in enumerate(range(len(features)))}) 189 | else: 190 | features = features_dict 191 | return features 192 | -------------------------------------------------------------------------------- /spiketoolkit/postprocessing/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import spikeextractors as se 3 | import numpy as np 4 | 5 | waveforms_params_dict = OrderedDict([('grouping_property', None), ('ms_before', 3.), ('ms_after', 3.), ('dtype', None), 6 | ('compute_property_from_recording', False), 7 | ('n_jobs', None), ('max_channels_per_waveforms', None)]) 8 | 9 | amplitudes_params_dict = OrderedDict([('method', 'absolute'), ('peak', 'both'), ('frames_before', 3), 10 | ('frames_after', 3)]) 11 | 12 | pca_params_dict = OrderedDict([('n_comp', 3), ('by_electrode', True), ('max_spikes_for_pca', 5000), 13 | ('whiten', False)]) 14 | 15 | common_params_dict = OrderedDict([('max_spikes_per_unit', 300), ('recompute_info', False), 16 | ('save_property_or_features', True), ('memmap', True), ('seed', 0), 17 | ('verbose', False), ('joblib_backend', 'loky')]) 18 | 19 | 20 | def get_waveforms_params(): 21 | return waveforms_params_dict.copy() 22 | 23 | 24 | def get_amplitudes_params(): 25 | return amplitudes_params_dict.copy() 26 | 27 | 28 | def get_pca_params(): 29 | return pca_params_dict.copy() 30 | 31 | 32 | def get_common_params(): 33 | return common_params_dict.copy() 34 | 35 | 36 | def get_postprocessing_params(): 37 | ''' 38 | Returns all available keyword argument params 39 | 40 | Returns 41 | ------- 42 | all_params: dict 43 | Dictionary with all available keyword arguments for postprocessing module 44 | ''' 45 | all_params = {} 46 | all_params.update(get_waveforms_params()) 47 | all_params.update(get_amplitudes_params()) 48 | all_params.update(get_pca_params()) 49 | all_params.update(get_common_params()) 50 | 51 | return all_params 52 | 53 | 54 | def update_all_param_dicts_with_kwargs(kwargs): 55 | all_params = get_postprocessing_params() 56 | 57 | if np.any([k in all_params.keys() for k in kwargs.keys()]): 58 | for k in kwargs.keys(): 59 | if k in all_params.keys(): 60 | all_params[k] = kwargs[k] 61 | 62 | return all_params 63 | 64 | 65 | def select_max_channels_from_waveforms(wf, recording, max_channels): 66 | template = np.mean(wf, axis=0) 67 | # select based on adjacency 68 | if max_channels < recording.get_num_channels(): 69 | if 'location' in recording.get_shared_channel_property_names(): 70 | max_channel_idx = np.unravel_index(np.argmax(np.abs(template)), 71 | template.shape)[0] 72 | locs = recording.get_channel_locations() 73 | loc_max = locs[max_channel_idx] 74 | distances = [np.linalg.norm(l - loc_max) for l in locs] 75 | max_channel_idxs = np.argsort(distances)[:max_channels] 76 | else: # select based on amplitude 77 | peak_idx = np.unravel_index(np.argmax(np.abs(template)), 78 | template.shape)[1] 79 | max_channel_idxs = np.argsort(np.abs( 80 | template[:, peak_idx]))[::-1][:max_channels] 81 | else: 82 | max_channel_idxs = np.arange(recording.get_num_channels()) 83 | 84 | return max_channel_idxs 85 | 86 | 87 | def select_max_channels_from_templates(template, recording, max_channels): 88 | # select based on adjacency 89 | if max_channels < recording.get_num_channels(): 90 | if 'location' in recording.get_shared_channel_property_names(): 91 | max_channel_idx = np.unravel_index(np.argmax(np.abs(template)), 92 | template.shape)[0] 93 | locs = recording.get_channel_locations() 94 | loc_max = locs[max_channel_idx] 95 | distances = [np.linalg.norm(l - loc_max) for l in locs] 96 | max_channel_idxs = np.argsort(distances)[:max_channels] 97 | else: # select based on amplitude 98 | peak_idx = np.unravel_index(np.argmax(np.abs(template)), 99 | template.shape)[1] 100 | max_channel_idxs = np.argsort(np.abs( 101 | template[:, peak_idx]))[::-1][:max_channels] 102 | else: 103 | max_channel_idxs = np.arange(recording.get_num_channels()) 104 | 105 | return max_channel_idxs 106 | 107 | 108 | def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, max_channels_per_waveforms): 109 | if grouping_property is None: 110 | if max_channels_per_waveforms is None: 111 | n_channels = len(channel_ids) 112 | elif max_channels_per_waveforms >= len(channel_ids): 113 | n_channels = len(channel_ids) 114 | else: 115 | n_channels = max_channels_per_waveforms 116 | else: 117 | rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids) 118 | rec_groups = np.array([rec.get_channel_property(ch, grouping_property) for ch in rec.get_channel_ids()]) 119 | groups, count = np.unique(rec_groups, return_counts=True) 120 | if max_channels_per_waveforms is None: 121 | n_channels = np.max(count) 122 | elif max_channels_per_waveforms >= np.max(count): 123 | n_channels = np.max(count) 124 | else: 125 | n_channels = max_channels_per_waveforms 126 | return n_channels 127 | 128 | 129 | def extract_snippet_from_traces( 130 | traces, 131 | start_frame, 132 | end_frame, 133 | ): 134 | if (0 <= start_frame) and (end_frame <= traces.shape[1]): 135 | x = traces[:, start_frame:end_frame] 136 | else: 137 | # handle edge cases 138 | x = np.zeros((traces.shape[0], end_frame - start_frame), dtype=traces.dtype) 139 | i1 = int(max(0, start_frame)) 140 | i2 = int(min(traces.shape[1], end_frame)) 141 | x[:, (i1 - start_frame):(i2 - start_frame)] = traces[:, i1:i2] 142 | return x 143 | 144 | 145 | def get_unit_waveforms_for_chunk( 146 | recording, 147 | chunk, 148 | unit_ids, 149 | snippet_len, 150 | times_in_chunk, 151 | return_scaled=True 152 | ): 153 | # chunks are chosen small enough so that all traces can be loaded into memory 154 | traces = recording.get_traces(return_scaled=return_scaled) 155 | frame_offset = chunk['istart'] - chunk['istart_with_padding'] 156 | 157 | unit_waveforms = [] 158 | for i_unit, unit_id in enumerate(unit_ids): 159 | # find indexes in chunk 160 | if len(times_in_chunk[i_unit]) > 0: 161 | # Adjust time with padding 162 | try: 163 | snippets = [extract_snippet_from_traces(traces, 164 | start_frame=frame_offset + int(t) - snippet_len[0], 165 | end_frame=frame_offset + int(t) + snippet_len[1]) 166 | for t in times_in_chunk[i_unit] - chunk['istart']] 167 | except: 168 | raise Exception 169 | unit_waveforms.append(np.stack(snippets)) 170 | else: 171 | unit_waveforms.append(np.zeros((0, recording.get_num_channels(), 172 | snippet_len[0] + snippet_len[1]), dtype=traces.dtype)) 173 | 174 | return unit_waveforms 175 | 176 | 177 | def divide_recording_into_time_chunks(num_frames, chunk_size, padding_size): 178 | chunks = [] 179 | ii = 0 180 | while ii < num_frames: 181 | ii2 = int(min(ii + chunk_size, num_frames)) 182 | chunks.append(dict( 183 | istart=ii, 184 | iend=ii2, 185 | istart_with_padding=int(max(0, ii - padding_size)), 186 | iend_with_padding=int(min(num_frames, ii2 + padding_size)) 187 | )) 188 | ii = ii2 189 | return chunks 190 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessinglist import * 2 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/bandpass_filter.py: -------------------------------------------------------------------------------- 1 | from .filterrecording import FilterRecording 2 | import numpy as np 3 | import scipy.signal as ss 4 | from scipy import special 5 | import spikeextractors as se 6 | 7 | 8 | class BandpassFilterRecording(FilterRecording): 9 | preprocessor_name = 'BandpassFilter' 10 | 11 | def __init__(self, recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3, 12 | chunk_size=30000, cache_chunks=False, dtype=None): 13 | self._freq_min = freq_min 14 | self._freq_max = freq_max 15 | self._freq_wid = freq_wid 16 | self._type = filter_type 17 | self._order = order 18 | self._chunk_size = chunk_size 19 | 20 | if self._type == 'butter': 21 | fn = recording.get_sampling_frequency() / 2. 22 | band = np.array([self._freq_min, self._freq_max]) / fn 23 | 24 | self._b, self._a = ss.butter(self._order, band, btype='bandpass') 25 | 26 | if not np.all(np.abs(np.roots(self._a)) < 1): 27 | raise ValueError('Filter is not stable') 28 | FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks, 29 | dtype=dtype) 30 | self.is_filtered = True 31 | self._kwargs = {'recording': recording.make_serialized_dict(), 'freq_min': freq_min, 'freq_max': freq_max, 32 | 'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order, 33 | 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} 34 | 35 | def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled): 36 | padding = 3000 37 | i1 = start_frame - padding 38 | i2 = end_frame + padding 39 | padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled) 40 | filtered_padded_chunk = self._do_filter(padded_chunk) 41 | return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] 42 | 43 | def _do_filter(self, chunk): 44 | sampling_frequency = self._recording.get_sampling_frequency() 45 | M = chunk.shape[0] 46 | chunk2 = chunk 47 | # Do the actual filtering with a DFT with real input 48 | if self._type == 'fft': 49 | chunk_fft = np.fft.rfft(chunk2) 50 | kernel = _create_filter_kernel( 51 | chunk2.shape[1], 52 | sampling_frequency, 53 | self._freq_min, self._freq_max, self._freq_wid 54 | ) 55 | kernel = kernel[0:chunk_fft.shape[1]] # because this is the DFT of real data 56 | chunk_fft = chunk_fft * np.tile(kernel, (M, 1)) 57 | chunk_filtered = np.fft.irfft(chunk_fft) 58 | elif self._type == 'butter': 59 | chunk_filtered = ss.filtfilt(self._b, self._a, chunk2, axis=1) 60 | 61 | return chunk_filtered 62 | 63 | 64 | def _create_filter_kernel(N, sampling_frequency, freq_min, freq_max, freq_wid=1000): 65 | # Matches ahb's code /matlab/processors/ms_bandpass_filter.m 66 | # improved ahb, changing tanh to erf, correct -3dB pts 6/14/16 67 | T = N / sampling_frequency # total time 68 | df = 1 / T # frequency grid 69 | relwid = 3.0 # relative bottom-end roll-off width param, kills low freqs by factor 1e-5. 70 | 71 | k_inds = np.arange(0, N) 72 | k_inds = np.where(k_inds <= (N + 1) / 2, k_inds, k_inds - N) 73 | 74 | fgrid = df * k_inds 75 | absf = np.abs(fgrid) 76 | 77 | val = np.ones(fgrid.shape) 78 | if freq_min != 0: 79 | val = val * (1 + special.erf(relwid * (absf - freq_min) / freq_min)) / 2 80 | val = np.where(np.abs(k_inds) < 0.1, 0, val) # kill DC part exactly 81 | if freq_max != 0: 82 | val = val * (1 - special.erf((absf - freq_max) / freq_wid)) / 2; 83 | val = np.sqrt(val) # note sqrt of filter func to apply to spectral intensity not ampl 84 | return val 85 | 86 | 87 | def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3, 88 | chunk_size=30000, cache_chunks=False, dtype=None): 89 | ''' 90 | Performs a lazy filter on the recording extractor traces. 91 | 92 | Parameters 93 | ---------- 94 | recording: RecordingExtractor 95 | The recording extractor to be filtered. 96 | freq_min: int or float 97 | High-pass cutoff frequency. 98 | freq_max: int or float 99 | Low-pass cutoff frequency. 100 | freq_wid: int or float 101 | Width of the filter (when type is 'fft'). 102 | filter_type: str 103 | 'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses 104 | scipy butter and filtfilt functions. 105 | order: int 106 | Order of the filter (if 'butter'). 107 | chunk_size: int 108 | The chunk size to be used for the filtering. 109 | cache_chunks: bool (default False). 110 | If True then each chunk is cached in memory (in a dict) 111 | dtype: dtype 112 | The dtype of the traces 113 | 114 | Returns 115 | ------- 116 | filter_recording: BandpassFilterRecording 117 | The filtered recording extractor object 118 | ''' 119 | bpf_recording = BandpassFilterRecording( 120 | recording=recording, 121 | freq_min=freq_min, 122 | freq_max=freq_max, 123 | freq_wid=freq_wid, 124 | filter_type=filter_type, 125 | order=order, 126 | chunk_size=chunk_size, 127 | cache_chunks=cache_chunks, 128 | dtype=dtype 129 | ) 130 | return bpf_recording 131 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/basepreprocessorrecording.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | 4 | 5 | class BasePreprocessorRecordingExtractor(RecordingExtractor): 6 | installed = True # check at class level if installed or not 7 | installation_mesg = "" # err 8 | 9 | def __init__(self, recording, copy_times=True): 10 | assert isinstance(recording, RecordingExtractor), "'recording' must be a RecordingExtractor" 11 | RecordingExtractor.__init__(self) 12 | self._recording = recording 13 | self.copy_channel_properties(recording) 14 | self.copy_epochs(recording) 15 | if copy_times: 16 | self.copy_times(recording) 17 | 18 | # avoid rescaling twice 19 | self.set_channel_gains(1) 20 | self.set_channel_offsets(0) 21 | 22 | self.is_filtered = recording.is_filtered 23 | if hasattr(recording, "has_unscaled"): 24 | self.has_unscaled = recording.has_unscaled 25 | else: 26 | self.has_unscaled = False 27 | 28 | def get_channel_ids(self): 29 | return self._recording.get_channel_ids() 30 | 31 | def get_num_frames(self): 32 | return self._recording.get_num_frames() 33 | 34 | def get_sampling_frequency(self): 35 | return self._recording.get_sampling_frequency() 36 | 37 | def time_to_frame(self, times): 38 | return self._recording.time_to_frame(times) 39 | 40 | def frame_to_time(self, frames): 41 | return self._recording.frame_to_time(frames) 42 | 43 | @check_get_traces_args 44 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 45 | raise NotImplementedError 46 | 47 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/blank_saturation.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | 6 | 7 | class BlankSaturationRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'BlankSaturation' 9 | 10 | def __init__(self, recording, threshold=None, seed=0): 11 | if not isinstance(recording, RecordingExtractor): 12 | raise ValueError("'recording' must be a RecordingExtractor") 13 | BasePreprocessorRecordingExtractor.__init__(self, recording) 14 | random_data = self._get_random_data_for_scaling(seed=seed).ravel() 15 | q = np.quantile(random_data, [0.001, 0.5, 1 - 0.001]) 16 | if 2 * q[1] - q[0] - q[2] < 2 * np.min([q[1] - q[0], q[2] - q[1]]): 17 | print('Warning, narrow signal range suggests artefact-free data.') 18 | self._median = q[1] 19 | if threshold is None: 20 | if np.abs(q[1] - q[0]) > np.abs(q[1] - q[2]): 21 | self._threshold = q[0] 22 | self._lower = True 23 | else: 24 | self._threshold = q[2] 25 | self._lower = False 26 | else: 27 | self._threshold = threshold 28 | if q[1] - threshold < 0: 29 | self._lower = False 30 | else: 31 | self._lower = True 32 | self.has_unscaled = False 33 | 34 | self._kwargs = {'recording': recording.make_serialized_dict(), 'threshold': threshold, 'seed': seed} 35 | 36 | def _get_random_data_for_scaling(self, num_chunks=50, chunk_size=500, seed=0): 37 | N = self._recording.get_num_frames() 38 | random_ints = np.random.RandomState(seed=seed).randint(0, N - chunk_size, size=num_chunks) 39 | chunk_list = [] 40 | for ff in random_ints: 41 | chunk = self._recording.get_traces(start_frame=ff, 42 | end_frame=ff + chunk_size) 43 | chunk_list.append(chunk) 44 | return np.concatenate(chunk_list, axis=1) 45 | 46 | @check_get_traces_args 47 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 48 | assert return_scaled, "'blank_saturation' only supports return_scaled=True" 49 | 50 | traces = self._recording.get_traces(channel_ids=channel_ids, 51 | start_frame=start_frame, 52 | end_frame=end_frame, 53 | return_scaled=return_scaled) 54 | traces = traces.copy() 55 | if self._lower: 56 | traces[traces <= self._threshold] = self._median 57 | else: 58 | traces[traces >= self._threshold] = self._median 59 | return traces 60 | 61 | 62 | def blank_saturation(recording, threshold=None, seed=0): 63 | ''' 64 | Find and remove parts of the signal with extereme values. Some arrays 65 | may produce these when amplifiers enter saturation, typically for 66 | short periods of time. To remove these artefacts, values below or above 67 | a threshold are set to the median signal value. 68 | The threshold is either be estimated automatically, using the lower and upper 69 | 0.1 signal percentile with the largest deviation from the median, or specificed. 70 | Use this function with caution, as it may clip uncontaminated signals. A warning is 71 | printed if the data range suggests no artefacts. 72 | 73 | Parameters 74 | ---------- 75 | recording: RecordingExtractor 76 | The recording extractor to be transformed 77 | Minimum value. If `None`, clipping is not performed on lower 78 | interval edge. 79 | threshold: float or 'None' (default `None`) 80 | Threshold value (in absolute units) for saturation artifacts. 81 | If `None`, the threshold will be determined from the 0.1 signal percentile. 82 | seed: int 83 | Random seed for reproducibility 84 | Returns 85 | ------- 86 | rescaled_traces: BlankSaturationRecording 87 | The filtered traces recording extractor object 88 | ''' 89 | return BlankSaturationRecording( 90 | recording=recording, 91 | threshold=threshold, 92 | seed=seed 93 | ) 94 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/center.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from .transform import TransformRecording 3 | import numpy as np 4 | 5 | 6 | class CenterRecording(TransformRecording): 7 | preprocessor_name = 'Center' 8 | 9 | def __init__(self, recording, mode, seconds, n_snippets): 10 | if not isinstance(recording, RecordingExtractor): 11 | raise ValueError("'recording' must be a RecordingExtractor") 12 | self._scalar = 1 13 | self._mode = mode 14 | self._seconds = seconds 15 | self._n_snippets = n_snippets 16 | assert self._mode in ['mean', 'median'], "'mode' can be 'mean' or 'median'" 17 | 18 | # use n_snippets of equal duration equally distributed on the recording 19 | n_snippets = int(n_snippets) 20 | assert n_snippets > 0, "'n_snippets' must be positive" 21 | snip_len = seconds / n_snippets * recording.get_sampling_frequency() 22 | 23 | if seconds * recording.get_sampling_frequency() >= recording.get_num_frames(): 24 | traces = recording.get_traces() 25 | else: 26 | # skip initial and final part 27 | snip_start = np.linspace(snip_len // 2, recording.get_num_frames()-int(1.5*snip_len), n_snippets) 28 | traces_snippets = recording.get_snippets(reference_frames=snip_start, snippet_len=snip_len) 29 | traces_snippets = traces_snippets.swapaxes(0, 1) 30 | traces = traces_snippets.reshape((traces_snippets.shape[0], 31 | traces_snippets.shape[1] * traces_snippets.shape[2])) 32 | if self._mode == 'mean': 33 | self._offset = -np.mean(traces, axis=1) 34 | else: 35 | self._offset = -np.median(traces, axis=1) 36 | dtype = np.dtype(recording.get_dtype()).name 37 | if 'uint' in dtype: 38 | dtype = dtype[1:] 39 | TransformRecording.__init__(self, recording, scalar=self._scalar, offset=self._offset, dtype=dtype) 40 | self._kwargs = {'recording': recording.make_serialized_dict(), 'mode': mode, 'seconds': seconds, 41 | 'n_snippets': n_snippets} 42 | 43 | 44 | def center(recording, mode='median', seconds=10., n_snippets=10): 45 | ''' 46 | Removes the offset of the traces channel by channel. 47 | 48 | Parameters 49 | ---------- 50 | recording: RecordingExtractor 51 | The recording extractor to be transformed 52 | mode: str 53 | 'median' (default) or 'mean' 54 | seconds: float 55 | Number of seconds used to compute center 56 | n_snippets: int 57 | Number of snippets in which the total 'seconds' are divided spanning the recording duration 58 | 59 | Returns 60 | ------- 61 | center: CenterRecording 62 | The output recording extractor object 63 | ''' 64 | return CenterRecording(recording=recording, mode=mode, seconds=seconds, n_snippets=n_snippets) 65 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/clip.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | 5 | 6 | class ClipRecording(BasePreprocessorRecordingExtractor): 7 | preprocessor_name = 'Clip' 8 | installed = True # check at class level if installed or not 9 | installation_mesg = "" # err 10 | 11 | def __init__(self, recording, a_min=None, a_max=None): 12 | if not isinstance(recording, RecordingExtractor): 13 | raise ValueError("'recording' must be a RecordingExtractor") 14 | self._a_min = a_min 15 | self._a_max = a_max 16 | BasePreprocessorRecordingExtractor.__init__(self, recording) 17 | self.has_unscaled = False 18 | self._kwargs = {'recording': recording.make_serialized_dict(), 'a_min': a_min, 'a_max': a_max} 19 | 20 | @check_get_traces_args 21 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 22 | assert return_scaled, "'clip' only supports return_scaled=True" 23 | 24 | traces = self._recording.get_traces(channel_ids=channel_ids, 25 | start_frame=start_frame, 26 | end_frame=end_frame, 27 | return_scaled=return_scaled) 28 | if self._a_min is not None: 29 | traces[traces < self._a_min] = self._a_min 30 | if self._a_max is not None: 31 | traces[traces > self._a_max] = self._a_max 32 | return traces 33 | 34 | 35 | def clip(recording, a_min=None, a_max=None): 36 | ''' 37 | Limit the values of the data between a_min and a_max. Values exceeding the 38 | range will be set to the minimum or maximum, respectively. 39 | 40 | Parameters 41 | ---------- 42 | recording: RecordingExtractor 43 | The recording extractor to be transformed 44 | a_min: float or `None` (default `None`) 45 | Minimum value. If `None`, clipping is not performed on lower 46 | interval edge. 47 | a_max: float or `None` (default `None`) 48 | Maximum value. If `None`, clipping is not performed on upper 49 | interval edge. 50 | 51 | Returns 52 | ------- 53 | rescaled_traces: ClipTracesRecording 54 | The clipped traces recording extractor object 55 | ''' 56 | return ClipRecording( 57 | recording=recording, a_min=a_min, a_max=a_max 58 | ) 59 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/common_reference.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | import numpy as np 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | from spikeextractors.extraction_tools import check_get_traces_args 5 | 6 | from ..utils import get_closest_channels 7 | 8 | 9 | class CommonReferenceRecording(BasePreprocessorRecordingExtractor): 10 | preprocessor_name = 'CommonReference' 11 | 12 | def __init__(self, recording, reference='median', groups=None, ref_channels=None, 13 | local_radius=(30, 55), dtype=None, verbose=False): 14 | 15 | if not isinstance(recording, RecordingExtractor): 16 | raise ValueError("'recording' must be a RecordingExtractor") 17 | if reference not in ['median', 'average', 'single', 'local']: 18 | raise ValueError("'reference' must be either 'median', 'average', 'single' or 'local'") 19 | self._ref = reference 20 | self._groups = groups 21 | if self._ref == 'single': 22 | assert ref_channels is not None, "With 'single' reference, provide 'ref_channels'" 23 | if self._groups is not None: 24 | assert len(ref_channels) == len(self._groups), "'ref_channel' and 'groups' must have the " \ 25 | "same length" 26 | else: 27 | if isinstance(ref_channels, (list, np.ndarray)): 28 | assert len(ref_channels) == 1, "'ref_channel' with no 'groups' can be int or a list of one element" 29 | else: 30 | assert isinstance(ref_channels, (int, np.integer)), "'ref_channels' must be int" 31 | ref_channels = [ref_channels] 32 | elif self._ref == 'local': 33 | assert groups is None, "With 'local' CAR, the group option should not be used." 34 | closest_inds, dist = get_closest_channels(recording, recording.get_channel_ids()) 35 | 36 | self.neighbors = {} 37 | for i in range(recording.get_num_channels()): 38 | mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) 39 | self.neighbors[i] = closest_inds[i, mask] 40 | assert len(self.neighbors[i]) > 0, "No reference channels are inside the local annulus chosen for reference selection." 41 | 42 | self._ref_channel = ref_channels 43 | self._local_radius = local_radius 44 | if dtype is None: 45 | self._dtype = recording.get_dtype() 46 | else: 47 | self._dtype = dtype 48 | self.verbose = verbose 49 | BasePreprocessorRecordingExtractor.__init__(self, recording) 50 | self._kwargs = {'recording': recording.make_serialized_dict(), 'reference': reference, 'groups': groups, 51 | 'ref_channels': ref_channels, 'local_radius': local_radius, 52 | 'dtype': dtype, 'verbose': verbose} 53 | 54 | @check_get_traces_args 55 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 56 | 57 | selected_groups, selected_channels = self._create_channel_groups(channel_ids) 58 | traces = None 59 | 60 | if self._ref == 'median': 61 | if self.verbose: 62 | if self._groups is None: 63 | print('Common median reference using all channels') 64 | else: 65 | print('Common median in groups: ', selected_groups) 66 | 67 | traces = np.vstack(np.array([self._recording.get_traces(channel_ids=split_channel, 68 | start_frame=start_frame, end_frame=end_frame, 69 | return_scaled=return_scaled) 70 | - np.median(self._recording.get_traces(channel_ids=split_group, 71 | start_frame=start_frame, 72 | end_frame=end_frame, 73 | return_scaled=return_scaled), 74 | axis=0, keepdims=True) for (split_channel, split_group) in 75 | zip(selected_channels, selected_groups)])) 76 | elif self._ref == 'average': 77 | if self.verbose: 78 | if self._groups is None: 79 | print('Common average reference using all channels') 80 | else: 81 | print('Common average in groups: ', selected_groups) 82 | 83 | traces = np.vstack(np.array([self._recording.get_traces(channel_ids=split_channel, 84 | start_frame=start_frame, 85 | end_frame=end_frame, 86 | return_scaled=return_scaled) 87 | - np.mean(self._recording.get_traces(channel_ids=split_group, 88 | start_frame=start_frame, 89 | end_frame=end_frame, 90 | return_scaled=return_scaled), 91 | axis=0, keepdims=True) for (split_channel, split_group) in 92 | zip(selected_channels, selected_groups)])) 93 | elif self._ref == 'single': 94 | if self.verbose: 95 | if self._groups is None: 96 | print('Reference to channel', self._ref_channel) 97 | else: 98 | print('Reference', selected_groups, 'to channels', self._ref_channel) 99 | 100 | traces = np.vstack(np.array([self._recording.get_traces(channel_ids=split_channel, 101 | start_frame=start_frame, end_frame=end_frame, 102 | return_scaled=return_scaled) 103 | - self._recording.get_traces(channel_ids=[ref], start_frame=start_frame, 104 | end_frame=end_frame, 105 | return_scaled=return_scaled) 106 | for (split_channel, ref) in zip(selected_channels, self._ref_channel)])) 107 | 108 | elif self._ref == 'local': 109 | if self.verbose: 110 | print('Local Common average using as reference channels in a ring-shape region with radius: ' + str(self._local_radius)) 111 | traces = self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, 112 | end_frame=end_frame, 113 | return_scaled=return_scaled) \ 114 | - np.vstack(np.array([np.average( 115 | self._recording.get_traces( 116 | channel_ids=self.neighbors[self._recording.get_channel_ids().index(id)], 117 | start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled), axis=0) 118 | for id in channel_ids])) 119 | 120 | return np.array(traces).astype(self._dtype) 121 | 122 | def _create_channel_groups(self, channel_ids): 123 | selected_groups = [] 124 | selected_channels = [] 125 | if self._groups: 126 | for g in self._groups: 127 | new_chans = [] 128 | for chan in g: 129 | if chan in self._recording.get_channel_ids(): 130 | new_chans.append(chan) 131 | selected_channel_for_group = [ch for ch in channel_ids if ch in new_chans] 132 | if len(selected_channel_for_group) > 0: 133 | selected_groups.append(new_chans) 134 | selected_channels.append(selected_channel_for_group) 135 | else: 136 | selected_groups = [self._recording.get_channel_ids()] 137 | selected_channels = [channel_ids] 138 | return selected_groups, selected_channels 139 | 140 | 141 | def common_reference(recording, reference='median', groups=None, ref_channels=None, local_radius=(30, 55), dtype=None, 142 | verbose=False): 143 | ''' 144 | Re-references the recording extractor traces. 145 | 146 | Parameters 147 | ---------- 148 | recording: RecordingExtractor 149 | The recording extractor to be re-referenced 150 | reference: str 151 | 'median', 'average', 'single' or 'local' 152 | If 'median', common median reference (CMR) is implemented (the median of 153 | the selected channels is removed for each timestamp). 154 | If 'average', common average reference (CAR) is implemented (the mean of the selected channels is removed 155 | for each timestamp). 156 | If 'single', the selected channel(s) is remove from all channels. 157 | If 'local', an average CAR is implemented with only k channels selected the nearest outside of a radius around each channel 158 | groups: list 159 | List of lists containing the channels for splitting the reference. The CMR, CAR, or referencing with respect to 160 | single channels are applied group-wise. However, this is not applied for the local CAR. 161 | It is useful when dealing with different channel groups, e.g. multiple tetrodes. 162 | ref_channels: list or int 163 | If no 'groups' are specified, all channels are referenced to 'ref_channels'. If 'groups' is provided, then a 164 | list of channels to be applied to each group is expected. If 'single' reference, a list of one channel or an 165 | int is expected. 166 | local_radius: tuple(int, int) 167 | Use in the local CAR implementation as the selecting annulus (exclude radius, include radius) 168 | dtype: str 169 | dtype of the returned traces. If None, dtype is maintained 170 | verbose: bool 171 | If True, output is verbose 172 | 173 | Returns 174 | ------- 175 | referenced_recording: CommonReferenceRecording 176 | The re-referenced recording extractor object 177 | ''' 178 | return CommonReferenceRecording( 179 | recording=recording, reference=reference, groups=groups, ref_channels=ref_channels, local_radius=local_radius, 180 | dtype=dtype, verbose=verbose 181 | ) 182 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/filterrecording.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import numpy as np 3 | from .transform import TransformRecording 4 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 5 | from spikeextractors.extraction_tools import check_get_traces_args 6 | 7 | 8 | class FilterRecording(BasePreprocessorRecordingExtractor): 9 | def __init__(self, recording, chunk_size=10000, cache_chunks=False, dtype=None): 10 | self._chunk_size = chunk_size 11 | self._cache_chunks = cache_chunks 12 | if cache_chunks: 13 | self._filtered_cache_chunks = FilteredChunkCache() 14 | else: 15 | self._filtered_cache_chunks = None 16 | self._traces = None 17 | if dtype is None: 18 | dtype = str(recording.get_dtype()) 19 | if 'uint' in dtype: 20 | if 'numpy' in dtype: 21 | dtype = str(dtype).replace("", "") 22 | # drop 'numpy' 23 | dtype = dtype.split('.')[1] 24 | dtype_signed = dtype[1:] 25 | exp_idx = dtype.find('int') + 3 26 | exp = int(dtype[exp_idx:]) 27 | offset = - 2**(exp - 1) 28 | recording_base = TransformRecording(recording, offset=offset, dtype=dtype_signed) 29 | print(f"dtype converted from {dtype} to {dtype_signed} before filtering") 30 | self._dtype = dtype_signed 31 | else: 32 | self._dtype = dtype 33 | recording_base = recording 34 | BasePreprocessorRecordingExtractor.__init__(self, recording_base) 35 | 36 | # avoid filtering one sample 37 | def get_dtype(self, return_scaled=True): 38 | return self._dtype 39 | 40 | @check_get_traces_args 41 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 42 | if self._chunk_size is not None: 43 | ich1 = int(start_frame / self._chunk_size) 44 | ich2 = int((end_frame - 1) / self._chunk_size) 45 | dt = self.get_dtype() 46 | filtered_chunk = np.zeros((len(channel_ids), int(end_frame-start_frame)), dtype=dt) 47 | pos = 0 48 | for ich in range(ich1, ich2 + 1): 49 | filtered_chunk0 = self._get_filtered_chunk(ich, channel_ids, return_scaled) 50 | if ich == ich1: 51 | start0 = start_frame - ich * self._chunk_size 52 | else: 53 | start0 = 0 54 | if ich == ich2: 55 | end0 = end_frame - ich * self._chunk_size 56 | else: 57 | end0 = self._chunk_size 58 | filtered_chunk[:, pos:pos+end0-start0] = filtered_chunk0[:, start0:end0] 59 | pos += (end0-start0) 60 | else: 61 | filtered_chunk = self.filter_chunk(start_frame=start_frame, end_frame=end_frame, channel_ids=channel_ids, 62 | return_scaled=return_scaled) 63 | return filtered_chunk.astype(self._dtype) 64 | 65 | @abstractmethod 66 | def filter_chunk(self, *, start_frame, end_frame, channel_ids, return_scaled): 67 | raise NotImplementedError('filter_chunk not implemented') 68 | 69 | def _read_chunk(self, i1, i2, channel_ids, return_scaled=True): 70 | num_frames = self._recording.get_num_frames() 71 | if i1 < 0: 72 | i1b = 0 73 | else: 74 | i1b = i1 75 | if i2 > num_frames: 76 | i2b = num_frames 77 | else: 78 | i2b = i2 79 | chunk = np.zeros((len(channel_ids), i2 - i1)) 80 | chunk[:, i1b - i1:i2b - i1] = self._recording.get_traces(start_frame=i1b, end_frame=i2b, 81 | channel_ids=channel_ids, return_scaled=return_scaled) 82 | 83 | return chunk 84 | 85 | def _get_filtered_chunk(self, ind, channel_ids, return_scaled): 86 | if self._cache_chunks: 87 | code = str(ind) 88 | chunk0 = self._filtered_cache_chunks.get(code) 89 | else: 90 | chunk0 = None 91 | 92 | if chunk0 is not None: 93 | if chunk0.shape[0] == len(channel_ids): 94 | return chunk0 95 | else: 96 | channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) 97 | return chunk0[channel_idxs] 98 | 99 | start0 = ind * self._chunk_size 100 | end0 = (ind + 1) * self._chunk_size 101 | 102 | if self._cache_chunks: 103 | # filter all channels if cache_chunks is used 104 | chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=self.get_channel_ids()) 105 | self._filtered_cache_chunks.add(code, chunk1) 106 | channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) 107 | chunk1 = chunk1[channel_idxs] 108 | else: 109 | # otherwise, only filter requested channels 110 | chunk1 = self.filter_chunk(start_frame=start0, end_frame=end0, channel_ids=channel_ids, 111 | return_scaled=return_scaled) 112 | 113 | return chunk1 114 | 115 | 116 | class FilteredChunkCache: 117 | def __init__(self): 118 | self._chunks_by_code = dict() 119 | self._codes = [] 120 | self._total_size = 0 121 | self._max_size = 1024 * 1024 * 100 122 | 123 | def add(self, code, chunk): 124 | self._chunks_by_code[code] = chunk 125 | self._codes.append(code) 126 | self._total_size = self._total_size + chunk.size 127 | if self._total_size > self._max_size: 128 | ii = 0 129 | while (ii < len(self._codes)) and (self._total_size > self._max_size / 2): 130 | self._total_size = self._total_size - self._chunks_by_code[self._codes[ii]].size 131 | del self._chunks_by_code[self._codes[ii]] 132 | ii = ii + 1 133 | self._codes = self._codes[ii:] 134 | 135 | def get(self, code): 136 | if code in self._chunks_by_code: 137 | return self._chunks_by_code[code] 138 | else: 139 | return None 140 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/highpass_filter.py: -------------------------------------------------------------------------------- 1 | from .filterrecording import FilterRecording 2 | import numpy as np 3 | import scipy.signal as ss 4 | from scipy import special 5 | import spikeextractors as se 6 | 7 | 8 | class HighpassFilterRecording(FilterRecording): 9 | preprocessor_name = 'HighpassFilter' 10 | 11 | def __init__(self, recording, freq_min=300, freq_wid=1000, filter_type='butter', order=1, 12 | chunk_size=30000, cache_chunks=False, dtype=None): 13 | self._freq_min = freq_min 14 | self._freq_wid = freq_wid 15 | self._type = filter_type 16 | self._order = order 17 | self._chunk_size = chunk_size 18 | self._padding = 3000 19 | 20 | if self._type == 'butter': 21 | fn = recording.get_sampling_frequency() / 2. 22 | band = self._freq_min / fn 23 | 24 | self._b, self._a = ss.butter(self._order, band, btype='highpass') 25 | 26 | if not np.all(np.abs(np.roots(self._a)) < 1): 27 | raise ValueError('Filter is not stable') 28 | elif self._type == 'fft': 29 | self._kernel = _create_filter_kernel( 30 | self._chunk_size+2*self._padding, 31 | recording.get_sampling_frequency(), 32 | self._freq_min, self._freq_wid 33 | ) 34 | self._kernel = self._kernel[0:(chunk_size+2*self._padding)//2+1] # because this is the DFT of real data 35 | else: 36 | raise NotImplementedError('filter type {} not implemented.'.format(filter_type)) 37 | 38 | FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks, 39 | dtype=dtype) 40 | self.is_filtered = True 41 | self._kwargs = {'recording': recording.make_serialized_dict(), 'freq_min': freq_min, 42 | 'freq_wid': freq_wid, 'filter_type': filter_type, 'order': order, 43 | 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} 44 | 45 | def filter_chunk(self, *, start_frame, end_frame, channel_ids, return_scaled): 46 | padding = 3000 47 | i1 = start_frame - self._padding 48 | i2 = end_frame + self._padding 49 | padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled) 50 | filtered_padded_chunk = self._do_filter(padded_chunk) 51 | return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] 52 | 53 | def _do_filter(self, chunk): 54 | sampling_frequency = self._recording.get_sampling_frequency() 55 | M = chunk.shape[0] 56 | chunk2 = chunk 57 | # Do the actual filtering with a DFT with real input 58 | if self._type == 'fft': 59 | chunk_fft = np.fft.rfft(chunk2) 60 | chunk_fft = chunk_fft * np.tile(self._kernel, (M, 1)) 61 | chunk_filtered = np.fft.irfft(chunk_fft) 62 | elif self._type == 'butter': 63 | chunk_filtered = ss.filtfilt(self._b, self._a, chunk2, axis=1) 64 | 65 | return chunk_filtered 66 | 67 | 68 | def _create_filter_kernel(N, sampling_frequency, freq_min, freq_wid=1000): 69 | # Matches ahb's code /matlab/processors/ms_bandpass_filter.m 70 | # improved ahb, changing tanh to erf, correct -3dB pts 6/14/16 71 | T = N / sampling_frequency # total time 72 | df = 1 / T # frequency grid 73 | relwid = 3.0 # relative bottom-end roll-off width param, kills low freqs by factor 1e-5. 74 | 75 | k_inds = np.arange(0, N) 76 | k_inds = np.where(k_inds <= (N + 1) / 2, k_inds, k_inds - N) 77 | 78 | fgrid = df * k_inds 79 | absf = np.abs(fgrid) 80 | 81 | val = np.ones(fgrid.shape) 82 | if freq_min != 0: 83 | val = val * (1 + special.erf(relwid * (absf - freq_min) / freq_min)) / 2 84 | val = np.where(np.abs(k_inds) < 0.1, 0, val) # kill DC part exactly 85 | val = np.sqrt(val) # note sqrt of filter func to apply to spectral intensity not ampl 86 | return val 87 | 88 | 89 | def highpass_filter(recording, freq_min=300, freq_wid=1000, filter_type='butter', order=1, 90 | chunk_size=30000, cache_chunks=False, dtype=None): 91 | ''' 92 | Performs a lazy filter on the recording extractor traces. 93 | 94 | Parameters 95 | ---------- 96 | recording: RecordingExtractor 97 | The recording extractor to be filtered. 98 | freq_min: int or float 99 | High-pass cutoff frequency. 100 | freq_wid: int or float 101 | Width of the filter (when type is 'fft'). 102 | filter_type: str 103 | 'fft' or 'butter'. The 'fft' filter uses a kernel in the frequency domain. The 'butter' filter uses 104 | scipy butter and filtfilt functions. 105 | order: int 106 | Order of the filter (if 'butter'). 107 | chunk_size: int 108 | The chunk size to be used for the filtering. 109 | cache_chunks: bool (default False). 110 | If True then each chunk is cached in memory (in a dict) 111 | dtype: dtype 112 | The dtype of the traces 113 | 114 | Returns 115 | ------- 116 | filter_recording: HighpassFilterRecording 117 | The filtered recording extractor object 118 | ''' 119 | hp_recording = HighpassFilterRecording( 120 | recording=recording, 121 | freq_min=freq_min, 122 | freq_wid=freq_wid, 123 | filter_type=filter_type, 124 | order=order, 125 | chunk_size=chunk_size, 126 | cache_chunks=cache_chunks, 127 | dtype=dtype 128 | ) 129 | return hp_recording 130 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/mask.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | import numpy as np 4 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 5 | 6 | 7 | class MaskRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'Mask' 9 | 10 | def __init__(self, recording, bool_mask): 11 | if not isinstance(recording, RecordingExtractor): 12 | raise ValueError("'recording' must be a RecordingExtractor") 13 | self._mask = bool_mask 14 | assert len(bool_mask) == recording.get_num_frames(), "'bool_mask' should be a boolean array with length of " \ 15 | "number of frames" 16 | assert np.array(bool_mask).dtype == bool, "'bool_mask' should be a boolean array" 17 | self.is_dumpable = False 18 | BasePreprocessorRecordingExtractor.__init__(self, recording) 19 | self._kwargs = {'recording': recording.make_serialized_dict(), 'bool_mask': bool_mask} 20 | 21 | @check_get_traces_args 22 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 23 | traces = self._recording.get_traces(channel_ids=channel_ids, 24 | start_frame=start_frame, 25 | end_frame=end_frame, 26 | return_scaled=return_scaled) 27 | 28 | traces = traces.copy() # takes care of memmap objects 29 | traces[:, ~self._mask[start_frame:end_frame]] = 0.0 30 | return traces 31 | 32 | 33 | def mask(recording, bool_mask): 34 | ''' 35 | Apply a boolean mask to the recording, where False elements of the mask cause the associated recording frames to 36 | be set to 0 37 | 38 | Parameters 39 | ---------- 40 | recording: RecordingExtractor 41 | The recording extractor to be transformed 42 | bool_mask: list or numpy array 43 | Boolean values of the same length as the recording 44 | 45 | Returns 46 | ------- 47 | masked_traces: MaskTracesRecording 48 | The masked traces recording extractor object 49 | ''' 50 | return MaskRecording( 51 | recording=recording, bool_mask=bool_mask 52 | ) 53 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/normalize_by_quantile.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | import numpy as np 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | from spikeextractors.extraction_tools import check_get_traces_args 5 | 6 | 7 | class NormalizeByQuantileRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'NormalizeByQuantile' 9 | 10 | def __init__(self, recording, scale=1.0, median=0.0, q1=0.01, q2=0.99, seed=0): 11 | BasePreprocessorRecordingExtractor.__init__(self, recording) 12 | 13 | random_data = self._get_random_data_for_scaling(seed=seed).ravel() 14 | loc_q1, pre_median, loc_q2 = np.quantile(random_data, q=[q1, 0.5, q2]) 15 | pre_scale = abs(loc_q2 - loc_q1) 16 | 17 | self._scalar = scale / pre_scale 18 | self._offset = median - pre_median * self._scalar 19 | self.has_unscaled = False 20 | self._kwargs = {'recording': recording.make_serialized_dict(), 'scale': scale, 'median': median, 21 | 'q1': q1, 'q2': q2, 'seed': seed} 22 | 23 | def _get_random_data_for_scaling(self, num_chunks=50, chunk_size=500, seed=0): 24 | N = self._recording.get_num_frames() 25 | random_ints = np.random.RandomState(seed=seed).randint(0, N - chunk_size, size=num_chunks) 26 | chunk_list = [] 27 | for ff in np.sort(random_ints): 28 | chunk = self._recording.get_traces(start_frame=ff, 29 | end_frame=ff + chunk_size) 30 | chunk_list.append(chunk) 31 | return np.concatenate(chunk_list, axis=1) 32 | 33 | @check_get_traces_args 34 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 35 | assert return_scaled, "'normalize_by_quantile' only supports return_scaled=True" 36 | 37 | traces = self._recording.get_traces(channel_ids=channel_ids, 38 | start_frame=start_frame, 39 | end_frame=end_frame, 40 | return_scaled=return_scaled) 41 | return traces * self._scalar + self._offset 42 | 43 | 44 | def normalize_by_quantile(recording, scale=1.0, median=0.0, q1=0.01, q2=0.99, seed=0): 45 | ''' 46 | Rescale the traces from the given recording extractor with a scalar 47 | and offset. First, the median and quantiles of the distribution are estimated. 48 | Then the distribution is rescaled and offset so that the scale is given by the 49 | distance between the quantiles (1st and 99th by default) is set to `scale`, 50 | and the median is set to the given median. 51 | 52 | Parameters 53 | ---------- 54 | recording: RecordingExtractor 55 | The recording extractor to be transformed 56 | scalar: float 57 | Scale for the output distribution 58 | median: float 59 | Median for the output distribution 60 | q1: float (default 0.01) 61 | Lower quantile used for measuring the scale 62 | q1: float (default 0.99) 63 | Upper quantile used for measuring the 64 | seed: int 65 | Random seed for reproducibility 66 | Returns 67 | ------- 68 | rescaled_traces: NormalizeByQuantileRecording 69 | The rescaled traces recording extractor object 70 | ''' 71 | return NormalizeByQuantileRecording( 72 | recording=recording, 73 | scale=scale, 74 | median=median, 75 | q1=q1, 76 | q2=q2, 77 | seed=seed 78 | ) 79 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/notch_filter.py: -------------------------------------------------------------------------------- 1 | from .filterrecording import FilterRecording 2 | import spikeextractors as se 3 | import numpy as np 4 | import scipy.signal as ss 5 | 6 | 7 | class NotchFilterRecording(FilterRecording): 8 | preprocessor_name = 'NotchFilter' 9 | 10 | def __init__(self, recording, freq=3000, q=30, chunk_size=30000, cache_chunks=False): 11 | self._freq = freq 12 | self._q = q 13 | fn = 0.5 * float(recording.get_sampling_frequency()) 14 | self._b, self._a = ss.iirnotch(self._freq / fn, self._q) 15 | 16 | if not np.all(np.abs(np.roots(self._a)) < 1): 17 | raise ValueError('Filter is not stable') 18 | FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks) 19 | self._kwargs = {'recording': recording.make_serialized_dict(), 'freq': freq, 20 | 'q': q, 'chunk_size': chunk_size, 'cache_chunks': cache_chunks} 21 | 22 | def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled): 23 | padding = 3000 24 | i1 = start_frame - padding 25 | i2 = end_frame + padding 26 | padded_chunk = self._read_chunk(i1, i2, channel_ids, return_scaled) 27 | filtered_padded_chunk = self._do_filter(padded_chunk) 28 | return filtered_padded_chunk[:, start_frame - i1:end_frame - i1] 29 | 30 | def _do_filter(self, chunk): 31 | chunk_filtered = ss.filtfilt(self._b, self._a, chunk, axis=1) 32 | 33 | return chunk_filtered 34 | 35 | 36 | def notch_filter(recording, freq=3000, q=30, chunk_size=30000, cache_chunks=False): 37 | ''' 38 | Performs a notch filter on the recording extractor traces using scipy iirnotch function. 39 | 40 | Parameters 41 | ---------- 42 | recording: RecordingExtractor 43 | The recording extractor to be notch-filtered. 44 | freq: int or float 45 | The target frequency of the notch filter. 46 | q: int 47 | The quality factor of the notch filter. 48 | chunk_size: int 49 | The chunk size to be used for the filtering. 50 | cache_chunks: bool (default False). 51 | If True then each chunk is cached in memory (in a dict) 52 | Returns 53 | ------- 54 | filter_recording: NotchFilterRecording 55 | The notch-filtered recording extractor object 56 | ''' 57 | 58 | notch_recording = NotchFilterRecording( 59 | recording=recording, 60 | freq=freq, 61 | q=q, 62 | chunk_size=chunk_size, 63 | cache_chunks=cache_chunks, 64 | ) 65 | return notch_recording 66 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/preprocessinglist.py: -------------------------------------------------------------------------------- 1 | from .highpass_filter import highpass_filter, HighpassFilterRecording 2 | from .bandpass_filter import bandpass_filter, BandpassFilterRecording 3 | from .notch_filter import notch_filter, NotchFilterRecording 4 | from .whiten import whiten, WhitenRecording 5 | from .common_reference import common_reference, CommonReferenceRecording 6 | from .resample import resample, ResampleRecording 7 | from .rectify import rectify, RectifyRecording 8 | from .remove_artifacts import remove_artifacts, RemoveArtifactsRecording 9 | from .transform import transform, TransformRecording 10 | from .remove_bad_channels import remove_bad_channels, RemoveBadChannelsRecording 11 | from .normalize_by_quantile import normalize_by_quantile, NormalizeByQuantileRecording 12 | from .clip import clip, ClipRecording 13 | from .blank_saturation import blank_saturation, BlankSaturationRecording 14 | from .center import center, CenterRecording 15 | from .mask import mask, MaskRecording 16 | 17 | preprocessers_full_list = [ 18 | HighpassFilterRecording, 19 | BandpassFilterRecording, 20 | NotchFilterRecording, 21 | WhitenRecording, 22 | CommonReferenceRecording, 23 | ResampleRecording, 24 | RectifyRecording, 25 | RemoveArtifactsRecording, 26 | RemoveBadChannelsRecording, 27 | TransformRecording, 28 | NormalizeByQuantileRecording, 29 | ClipRecording, 30 | BlankSaturationRecording, 31 | CenterRecording, 32 | MaskRecording 33 | ] 34 | 35 | installed_preprocessers_list = [pp for pp in preprocessers_full_list if pp.installed] 36 | preprocesser_dict = {pp_class.preprocessor_name: pp_class for pp_class in preprocessers_full_list} 37 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/rectify.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | 6 | 7 | class RectifyRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'Rectify' 9 | 10 | def __init__(self, recording): 11 | BasePreprocessorRecordingExtractor.__init__(self, recording) 12 | self._kwargs = {'recording': recording.make_serialized_dict()} 13 | 14 | @check_get_traces_args 15 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 16 | return np.abs(self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, 17 | return_scaled=return_scaled)) 18 | 19 | 20 | def rectify(recording): 21 | ''' 22 | Rectifies the recording extractor traces. It is useful, in combination with 'resample', to compute multi-unit 23 | activity (MUA). 24 | 25 | Parameters 26 | ---------- 27 | recording: RecordingExtractor 28 | The recording extractor object to be rectified 29 | 30 | Returns 31 | ------- 32 | rectified_recording: RectifyRecording 33 | The rectified recording extractor object 34 | 35 | ''' 36 | return RectifyRecording( 37 | recording=recording 38 | ) 39 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/remove_artifacts.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | from scipy.interpolate import interp1d 6 | 7 | 8 | class RemoveArtifactsRecording(BasePreprocessorRecordingExtractor): 9 | preprocessor_name = 'RemoveArtifacts' 10 | 11 | def __init__(self, recording, triggers, ms_before=0.5, ms_after=3.0, mode='zeros', fit_sample_spacing=1.): 12 | self._triggers = np.array(triggers) 13 | self._ms_before = ms_before 14 | self._ms_after = ms_after 15 | self._mode = mode 16 | self._fit_sample_spacing = fit_sample_spacing 17 | BasePreprocessorRecordingExtractor.__init__(self, recording) 18 | self._kwargs = {'recording': recording.make_serialized_dict(), 'triggers': triggers, 19 | 'ms_before': ms_before, 'ms_after': ms_after, 'mode': mode, 20 | 'fit_sample_spacing': fit_sample_spacing} 21 | 22 | @check_get_traces_args 23 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 24 | traces = self._recording.get_traces(channel_ids=channel_ids, 25 | start_frame=start_frame, 26 | end_frame=end_frame, 27 | return_scaled=return_scaled) 28 | triggers = self._triggers[(self._triggers > start_frame) & (self._triggers < end_frame)] - start_frame 29 | 30 | pad = [int(self._ms_before * self.get_sampling_frequency() / 1000), 31 | int(self._ms_after * self.get_sampling_frequency() / 1000)] 32 | 33 | traces = traces.copy() 34 | if self._mode == 'zeros': 35 | for trig in triggers: 36 | if trig - pad[0] > 0 and trig + pad[1] < end_frame - start_frame: 37 | traces[:, trig - pad[0]:trig + pad[1]] = 0 38 | elif trig - pad[0] <= 0 and trig + pad[1] >= end_frame - start_frame: 39 | traces = 0 40 | elif trig - pad[0] <= 0: 41 | traces[:, :trig + pad[1]] = 0 42 | elif trig + pad[1] >= end_frame - start_frame: 43 | traces[:, trig - pad[0]:] = 0 44 | else: 45 | sample_freq = self._recording.get_sampling_frequency() 46 | 47 | # generate indices for evenly spaced fit points before and after gap 48 | fit_sample_range = int(((sample_freq / 1000) * self._fit_sample_spacing * 2) + 1) 49 | fit_sample_interval = int(self._fit_sample_spacing * (sample_freq / 1000)) 50 | 51 | fit_samples = np.array(range(0, fit_sample_range, fit_sample_interval)) 52 | rev_fit_samples = fit_sample_range - fit_samples 53 | triggers = np.array(triggers).astype(int) 54 | for trig in triggers: 55 | pre_data_end_idx = trig - pad[0] - 1 56 | post_data_start_idx = trig + pad[1] + 1 57 | 58 | # Generate fit points from the sample points determined 59 | pre_idx = pre_data_end_idx - rev_fit_samples + 1 60 | post_idx = post_data_start_idx + fit_samples 61 | 62 | # Get indices of the gap to fill 63 | gap_idx = np.array(range(pre_data_end_idx + 1, post_data_start_idx + 0)) 64 | 65 | # Make sure we are not going out of bounds 66 | gap_idx = gap_idx[gap_idx >= 0] 67 | gap_idx = gap_idx[gap_idx < len(traces[0])] 68 | 69 | # correct for out of bounds indices on both sides: 70 | if np.max(post_idx) >= len(traces[0]): 71 | post_idx = post_idx[post_idx < len(traces[0])] 72 | 73 | if np.min(pre_idx) < 0: 74 | pre_idx = pre_idx[pre_idx >= 0] 75 | 76 | # fit x values 77 | all_idx = np.hstack((pre_idx, post_idx)) 78 | 79 | # fit y values 80 | interp_traces = traces[:, all_idx] 81 | 82 | # Get the median value from 5 samples around each fit point 83 | # for robustness to noise / small fluctuations 84 | pre_vals = np.empty((0, len(traces)), 'int32') 85 | for idx in iter(pre_idx): 86 | if idx == pre_idx[-1]: 87 | idxs = np.array(range(idx - 3, idx + 1)) 88 | else: 89 | idxs = np.array(range(idx - 2, idx + 3)) 90 | 91 | if np.min(idx) < 0: 92 | idx = idx[idx >= 0] 93 | 94 | median_vals = np.median(traces[:, idxs], axis=1) 95 | pre_vals = np.vstack((pre_vals, median_vals)) 96 | 97 | post_vals = np.empty((0, len(traces)), 'int32') 98 | for idx in iter(post_idx): 99 | if idx == post_idx[0]: 100 | idxs = np.array(range(idx, idx + 4)) 101 | else: 102 | idxs = np.array(range(idx - 2, idx + 3)) 103 | 104 | if np.max(idx) >= len(traces[0]): 105 | idx = idx[idx < len(traces[0])] 106 | 107 | median_vals = np.median(traces[:, idxs], axis=1) 108 | post_vals = np.vstack((post_vals, median_vals)) 109 | 110 | interp_traces = np.vstack((pre_vals, post_vals)).T 111 | 112 | if self._mode == 'cubic' and len(all_idx) >= 5: 113 | # Enough fit points present on either side to do cubic spline fit: 114 | interp_function = interp1d(all_idx, interp_traces, self._mode, 115 | bounds_error=False, 116 | fill_value='extrapolate') 117 | traces[:, gap_idx] = interp_function(gap_idx) 118 | elif self._mode == 'linear' and len(all_idx) >= 2: 119 | # Enough fit points present for a linear fit 120 | interp_function = interp1d(all_idx, interp_traces, self._mode, bounds_error=False, 121 | fill_value='extrapolate') 122 | traces[:, gap_idx] = interp_function(gap_idx) 123 | elif len(pre_idx) > len(post_idx): 124 | # not enough fit points, fill with nearest neighbour on side with the most data points 125 | traces[:, gap_idx] = np.repeat(traces[:, pre_idx[-1]] * np.ones((1, 1)), len(gap_idx), 0).T 126 | elif len(post_idx) > len(pre_idx): 127 | # not enough fit points, fill with nearest neighbour on side with the most data points 128 | traces[:, gap_idx] = np.repeat(traces[:, post_idx[0]] * np.ones((1, 1)), len(gap_idx), 0).T 129 | elif len(all_idx) > 0: 130 | # not enough fit points, both sides tied for most data points, fill with last pre value 131 | traces[:, gap_idx] = np.repeat(traces[:, pre_idx[-1]] * np.ones((1, 1)), len(gap_idx), 0).T 132 | else: 133 | # No data to interpolate from on either side of gap; 134 | # Fill with zeros 135 | traces[:, gap_idx] = 0 136 | 137 | return traces 138 | 139 | 140 | def remove_artifacts(recording, triggers, ms_before=0.5, ms_after=3, mode='zeros', fit_sample_spacing=1.): 141 | ''' 142 | Removes stimulation artifacts from recording extractor traces. By default, 143 | artifact periods are zeroed-out (mode = 'zeros'). This is only recommended 144 | for traces that are centered around zero (e.g. through a prior highpass 145 | filter); if this is not the case, linear and cubic interpolation modes are 146 | also available, controlled by the 'mode' input argument. 147 | 148 | Parameters 149 | ---------- 150 | recording: RecordingExtractor 151 | The recording extractor to remove artifacts from 152 | triggers: list 153 | List of int with the stimulation trigger frames 154 | ms_before: float 155 | Time interval in ms to remove before the trigger events 156 | ms_after: float 157 | Time interval in ms to remove after the trigger events 158 | mode: str 159 | Determines what artifacts are replaced by. Can be one of the following: 160 | 161 | - 'zeros' (default): Artifacts are replaced by zeros. 162 | 163 | - 'linear': Replacement are obtained through Linear interpolation between 164 | the trace before and after the artifact. 165 | If the trace starts or ends with an artifact period, the gap is filled 166 | with the closest available value before or after the artifact. 167 | 168 | - 'cubic': Cubic spline interpolation between the trace before and after 169 | the artifact, referenced to evenly spaced fit points before and after 170 | the artifact. This is an option thatcan be helpful if there are 171 | significant LFP effects around the time of the artifact, but visual 172 | inspection of fit behaviour with your chosen settings is recommended. 173 | The spacing of fit points is controlled by 'fit_sample_spacing', with 174 | greater spacing between points leading to a fit that is less sensitive 175 | to high frequency fluctuations but at the cost of a less smooth 176 | continuation of the trace. 177 | If the trace starts or ends with an artifact, the gap is filled with 178 | the closest available value before or after the artifact. 179 | fit_sample_spacing: float 180 | Determines the spacing (in ms) of reference points for the cubic spline 181 | fit if mode = 'cubic'. Default = 1ms. Note: The actual fit samples are 182 | the median of the 5 data points around the time of each sample point to 183 | avoid excessive influence from hyper-local fluctuations. 184 | 185 | 186 | Returns 187 | ------- 188 | removed_recording: RemoveArtifactsRecording 189 | The recording extractor after artifact removal 190 | 191 | ''' 192 | return RemoveArtifactsRecording( 193 | recording=recording, triggers=triggers, ms_before=ms_before, 194 | ms_after=ms_after, mode=mode, fit_sample_spacing=fit_sample_spacing) 195 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/remove_bad_channels.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor, SubRecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | 6 | 7 | class RemoveBadChannelsRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'RemoveBadChannels' 9 | 10 | def __init__(self, recording, bad_channel_ids, bad_threshold, seconds, verbose): 11 | self._bad_channel_ids = bad_channel_ids 12 | self._bad_threshold = bad_threshold 13 | self._seconds = seconds 14 | self.verbose = verbose 15 | self._initialize_subrecording_extractor(recording) 16 | BasePreprocessorRecordingExtractor.__init__(self, self._subrecording) 17 | self._kwargs = {'recording': recording.make_serialized_dict(), 'bad_channel_ids': bad_channel_ids, 18 | 'bad_threshold': bad_threshold, 'seconds': seconds, 'verbose': verbose} 19 | 20 | @check_get_traces_args 21 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 22 | traces = self._subrecording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, 23 | return_scaled=return_scaled) 24 | return traces 25 | 26 | def _initialize_subrecording_extractor(self, recording): 27 | if isinstance(self._bad_channel_ids, (list, np.ndarray)): 28 | active_channels = [] 29 | for chan in recording.get_channel_ids(): 30 | if chan not in self._bad_channel_ids: 31 | active_channels.append(chan) 32 | self._subrecording = SubRecordingExtractor(recording, channel_ids=active_channels) 33 | elif self._bad_channel_ids is None: 34 | start_frame = recording.get_num_frames() // 2 35 | end_frame = int(start_frame + self._seconds * recording.get_sampling_frequency()) 36 | if end_frame > recording.get_num_frames(): 37 | end_frame = recording.get_num_frames() 38 | traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame) 39 | stds = np.std(traces, axis=1) 40 | bad_channel_ids = [ch for ch, std in enumerate(stds) if std > self._bad_threshold * np.median(stds)] 41 | if self.verbose: 42 | print('Automatically removing channels:', bad_channel_ids) 43 | active_channels = [] 44 | for chan in recording.get_channel_ids(): 45 | if chan not in bad_channel_ids: 46 | active_channels.append(chan) 47 | self._subrecording = SubRecordingExtractor(recording, channel_ids=active_channels) 48 | else: 49 | self._subrecording = recording 50 | self.active_channels = self._subrecording.get_channel_ids() 51 | 52 | 53 | def remove_bad_channels(recording, bad_channel_ids=None, bad_threshold=2, seconds=10, verbose=False): 54 | ''' 55 | Remove bad channels from the recording extractor. 56 | 57 | Parameters 58 | ---------- 59 | recording: RecordingExtractor 60 | The recording extractor object 61 | bad_channel_ids: list 62 | List of bad channel ids (int). If None, automatic removal will be done based on standard deviation. 63 | bad_threshold: float 64 | If automatic is used, the threshold for the standard deviation over which channels are removed 65 | seconds: float 66 | If automatic is used, the number of seconds used to compute standard deviations 67 | verbose: bool 68 | If True, output is verbose 69 | 70 | Returns 71 | ------- 72 | remove_bad_channels_recording: RemoveBadChannelsRecording 73 | The recording extractor without bad channels 74 | 75 | ''' 76 | return RemoveBadChannelsRecording(recording=recording, bad_channel_ids=bad_channel_ids, 77 | bad_threshold=bad_threshold, seconds=seconds, verbose=verbose) 78 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/resample.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | from warnings import warn 6 | 7 | try: 8 | from scipy import special, signal 9 | 10 | HAVE_RR = True 11 | except ImportError: 12 | HAVE_RR = False 13 | 14 | 15 | class ResampleRecording(BasePreprocessorRecordingExtractor): 16 | preprocessor_name = 'Resample' 17 | installed = HAVE_RR # check at class level if installed or not 18 | installation_mesg = "To use the ResampleRecording, install scipy: \n\n pip install scipy\n\n" # err 19 | 20 | def __init__(self, recording, resample_rate): 21 | assert HAVE_RR, "To use the ResampleRecording, install scipy: \n\n pip install scipy\n\n" 22 | self._resample_rate = resample_rate 23 | BasePreprocessorRecordingExtractor.__init__(self, recording, copy_times=False) 24 | self._dtype = recording.get_dtype() 25 | 26 | if recording._times is not None: 27 | # resample timestamps uniformly 28 | warn("Timestamps will be resampled uniformly. Non-uniform timestamps will be lost due to resampling.") 29 | resampled_times = np.linspace(recording._times[0], recording._times[-1], self.get_num_frames()) 30 | self.set_times(resampled_times) 31 | 32 | self._kwargs = {'recording': recording.make_serialized_dict(), 'resample_rate': resample_rate} 33 | 34 | def get_sampling_frequency(self): 35 | return self._resample_rate 36 | 37 | def get_num_frames(self): 38 | return int(self._recording.get_num_frames() / self._recording.get_sampling_frequency() * self._resample_rate) 39 | 40 | # avoid filtering one sample 41 | def get_dtype(self, return_scaled=True): 42 | return self._dtype 43 | 44 | # need to override frame_to_time and time_to_frame because self._recording might not have "times" 45 | def frame_to_time(self, frames): 46 | if self._times is not None: 47 | return np.round(frames / self.get_sampling_frequency(), 6) 48 | else: 49 | return self._recording.time_to_frame(frames) 50 | 51 | def time_to_frame(self, times): 52 | if self._times is not None: 53 | return np.round(times * self.get_sampling_frequency()).astype('int64') 54 | else: 55 | return self._recording.time_to_frame(times) 56 | 57 | 58 | @check_get_traces_args 59 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 60 | start_frame_not_sampled = int(start_frame / self.get_sampling_frequency() * 61 | self._recording.get_sampling_frequency()) 62 | start_frame_sampled = start_frame 63 | end_frame_not_sampled = int(end_frame / self.get_sampling_frequency() * 64 | self._recording.get_sampling_frequency()) 65 | end_frame_sampled = end_frame 66 | traces = self._recording.get_traces(start_frame=start_frame_not_sampled, 67 | end_frame=end_frame_not_sampled, 68 | channel_ids=channel_ids, 69 | return_scaled=return_scaled) 70 | traces_resampled = signal.resample(traces, int(end_frame_sampled - start_frame_sampled), axis=1) 71 | 72 | return traces_resampled.astype(self._dtype) 73 | 74 | 75 | def resample(recording, resample_rate): 76 | ''' 77 | Resamples the recording extractor traces. If the resampling rate is multiple of the sampling rate, the faster 78 | scipy decimate function is used. 79 | 80 | Parameters 81 | ---------- 82 | recording: RecordingExtractor 83 | The recording extractor to be resampled 84 | resample_rate: int or float 85 | The resampling frequency 86 | 87 | Returns 88 | ------- 89 | resampled_recording: ResampleRecording 90 | The resample recording extractor 91 | 92 | ''' 93 | return ResampleRecording( 94 | recording=recording, 95 | resample_rate=resample_rate 96 | ) 97 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/transform.py: -------------------------------------------------------------------------------- 1 | from spikeextractors import RecordingExtractor 2 | from spikeextractors.extraction_tools import check_get_traces_args 3 | from .basepreprocessorrecording import BasePreprocessorRecordingExtractor 4 | import numpy as np 5 | 6 | 7 | class TransformRecording(BasePreprocessorRecordingExtractor): 8 | preprocessor_name = 'Transform' 9 | 10 | def __init__(self, recording, scalar=1., offset=0., dtype=None): 11 | if not isinstance(recording, RecordingExtractor): 12 | raise ValueError("'recording' must be a RecordingExtractor") 13 | self._scalar = scalar 14 | self._offset = offset 15 | if dtype is None: 16 | self._dtype = recording.get_dtype() 17 | else: 18 | self._dtype = dtype 19 | BasePreprocessorRecordingExtractor.__init__(self, recording) 20 | self.has_unscaled = False 21 | 22 | self._kwargs = {'recording': recording.make_serialized_dict(), 'scalar': scalar, 'offset': offset, 23 | 'dtype': dtype} 24 | 25 | @check_get_traces_args 26 | def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): 27 | assert return_scaled, "'transform' only supports return_scaled=True" 28 | 29 | traces = self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, 30 | return_scaled=return_scaled) 31 | if isinstance(self._scalar, (int, float, np.integer)): 32 | traces = traces*self._scalar 33 | else: 34 | if len(self._scalar) == len(channel_ids): 35 | scalar = np.array(self._scalar) 36 | else: 37 | channel_idxs = np.array([self._recording.get_channel_ids().index(ch) for ch in channel_ids]) 38 | scalar = np.array(self._scalar)[channel_idxs] 39 | traces = traces * scalar[:, np.newaxis] 40 | if isinstance(self._offset, (int, float, np.integer)): 41 | traces = traces + self._offset 42 | else: 43 | if len(self._offset) == len(channel_ids): 44 | offset = np.array(self._offset) 45 | else: 46 | channel_idxs = np.array([self._recording.get_channel_ids().index(ch) for ch in channel_ids]) 47 | offset = np.array(self._offset)[channel_idxs] 48 | traces = traces + offset[:, np.newaxis] 49 | return traces.astype(self._dtype) 50 | 51 | 52 | def transform(recording, scalar=1, offset=0): 53 | ''' 54 | Transforms the traces from the given recording extractor with a scalar 55 | and offset. New traces = traces*scalar + offset. 56 | 57 | Parameters 58 | ---------- 59 | recording: RecordingExtractor 60 | The recording extractor to be transformed 61 | scalar: float or array 62 | Scalar for the traces of the recording extractor or array with scalars for each channel 63 | offset: float or array 64 | Offset for the traces of the recording extractor or array with offsets for each channel 65 | Returns 66 | ------- 67 | transform_traces: TransformTracesRecording 68 | The transformed traces recording extractor object 69 | ''' 70 | return TransformRecording( 71 | recording=recording, scalar=scalar, offset=offset 72 | ) 73 | -------------------------------------------------------------------------------- /spiketoolkit/preprocessing/whiten.py: -------------------------------------------------------------------------------- 1 | from .filterrecording import FilterRecording 2 | import numpy as np 3 | 4 | 5 | class WhitenRecording(FilterRecording): 6 | preprocessor_name = 'Whiten' 7 | 8 | def __init__(self, recording, chunk_size=30000, cache_chunks=False, seed=0): 9 | FilterRecording.__init__(self, recording=recording, chunk_size=chunk_size, cache_chunks=cache_chunks) 10 | self._whitening_matrix = self._compute_whitening_matrix(seed=seed) 11 | self.has_unscaled = False 12 | self._kwargs = {'recording': recording.make_serialized_dict(), 'chunk_size': chunk_size, 13 | 'cache_chunks': cache_chunks, 'seed': seed} 14 | 15 | def _get_random_data_for_whitening(self, num_chunks=50, chunk_size=500, seed=0): 16 | N = self._recording.get_num_frames() 17 | random_ints = np.random.RandomState(seed=seed).randint(0, N - chunk_size, size=num_chunks) 18 | chunk_list = [] 19 | for ff in random_ints: 20 | chunk = self._recording.get_traces(start_frame=ff, 21 | end_frame=ff + chunk_size) 22 | chunk_list.append(chunk) 23 | return np.concatenate(chunk_list, axis=1) 24 | 25 | def _compute_whitening_matrix(self, seed): 26 | data = self._get_random_data_for_whitening(seed=seed) 27 | 28 | # center the data 29 | data = data - np.mean(data, axis=1, keepdims=True) 30 | 31 | # Original by Jeremy 32 | AAt = data @ np.transpose(data) 33 | AAt = AAt / data.shape[1] 34 | U, S, Ut = np.linalg.svd(AAt, full_matrices=True) 35 | W = (U @ np.diag(1 / np.sqrt(S))) @ Ut 36 | 37 | return W 38 | 39 | def filter_chunk(self, start_frame, end_frame, channel_ids, return_scaled): 40 | assert return_scaled, "'whiten' only supports return_scaled=True" 41 | 42 | chan_idxs = np.array([self.get_channel_ids().index(chan) for chan in channel_ids]) 43 | chunk = self._recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled) 44 | chunk = chunk - np.mean(chunk, axis=1, keepdims=True) 45 | chunk2 = self._whitening_matrix @ chunk 46 | return chunk2[chan_idxs] 47 | 48 | 49 | def whiten(recording, chunk_size=30000, cache_chunks=False, seed=0): 50 | ''' 51 | Whitens the recording extractor traces. 52 | 53 | Parameters 54 | ---------- 55 | recording: RecordingExtractor 56 | The recording extractor to be whitened. 57 | chunk_size: int 58 | The chunk size to be used for the filtering. 59 | cache_chunks: bool 60 | If True, filtered traces are computed and cached all at once (default False). 61 | seed: int 62 | Random seed for reproducibility 63 | Returns 64 | ------- 65 | whitened_recording: WhitenRecording 66 | The whitened recording extractor 67 | 68 | ''' 69 | return WhitenRecording( 70 | recording=recording, 71 | chunk_size=chunk_size, 72 | cache_chunks=cache_chunks, 73 | seed=seed 74 | ) 75 | -------------------------------------------------------------------------------- /spiketoolkit/sortingcomponents/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import detect_spikes -------------------------------------------------------------------------------- /spiketoolkit/sortingcomponents/detection.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import spikeextractors as se 3 | from ..postprocessing.postprocessing_tools import divide_recording_into_time_chunks 4 | import itertools 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | def detect_spikes(recording, channel_ids=None, detect_threshold=5, detect_sign=-1, 10 | n_shifts=2, n_snippets_for_threshold=10, snippet_size_sec=1, 11 | start_frame=None, end_frame=None, n_jobs=1, joblib_backend='loky', 12 | chunk_size=None, chunk_mb=500, verbose=False): 13 | ''' 14 | Detects spikes per channel. Spikes are detected as threshold crossings and the threshold is in terms of the median 15 | average deviation (MAD). The MAD is computed by taking 'n_snippets_for_threshold' snippets of the recordings 16 | of 'snippet_size_sec' seconds uniformly distributed between 'start_frame' and 'end_frame'. 17 | 18 | Parameters 19 | ---------- 20 | recording: RecordingExtractor 21 | The recording extractor object 22 | channel_ids: list or None 23 | List of channels to perform detection. If None all channels are used 24 | detect_threshold: float 25 | Threshold in median absolute deviations (MAD) to detect peaks 26 | n_shifts: int 27 | Number of shifts to find peak. E.g. if n_shift is 2, a peak is detected (if detect_sign is 'negative') if 28 | a sample is below the threshold, the two samples before are higher than the sample, and the two samples after 29 | the sample are higher than the sample. 30 | n_snippets_for_threshold: int 31 | Number of snippets to use to compute channel-wise thresholds 32 | snippet_size_sec: float 33 | Length of each snippet in seconds 34 | detect_sign: int 35 | Sign of the detection: -1 (negative), 1 (positive), 0 (both) 36 | start_frame: int 37 | Start frame for detection 38 | end_frame: int 39 | End frame end frame for detection 40 | n_jobs: int 41 | Number of jobs for parallelization. Default is None (no parallelization) 42 | joblib_backend: str 43 | The backend for joblib. Default is 'loky' 44 | chunk_size: int 45 | Size of chunks in number of samples. If None, it is automatically calculated 46 | chunk_mb: int 47 | Size of chunks in Mb (default 500 Mb) 48 | verbose: bool 49 | If True output is verbose 50 | 51 | Returns 52 | ------- 53 | sorting_detected: SortingExtractor 54 | The sorting extractor object with the detected spikes. Unit ids are the same as channel ids and units have the 55 | 'channel' property to specify which channel they correspond to. The sorting extractor also has the `spike_rate` 56 | and `spike_amplitude` properties. 57 | ''' 58 | if start_frame is None: 59 | start_frame = 0 60 | if end_frame is None: 61 | end_frame = recording.get_num_frames() 62 | 63 | if channel_ids is None: 64 | channel_ids = recording.get_channel_ids() 65 | else: 66 | assert np.all([ch in recording.get_channel_ids() for ch in channel_ids]), "Not all 'channel_ids' are in the" \ 67 | "recording." 68 | if n_jobs is None: 69 | n_jobs = 1 70 | if n_jobs == 0: 71 | n_jobs = 1 72 | 73 | if start_frame != 0 or end_frame != recording.get_num_frames(): 74 | recording_sub = se.SubRecordingExtractor(recording, start_frame=start_frame, end_frame=end_frame) 75 | else: 76 | recording_sub = recording 77 | 78 | num_frames = recording_sub.get_num_frames() 79 | 80 | # set chunk size 81 | if chunk_size is not None: 82 | chunk_size = int(chunk_size) 83 | elif chunk_mb is not None: 84 | n_bytes = np.dtype(recording.get_dtype()).itemsize 85 | max_size = int(chunk_mb * 1e6) # set Mb per chunk 86 | chunk_size = max_size // (recording.get_num_channels() * n_bytes) 87 | 88 | if n_jobs > 1: 89 | chunk_size /= n_jobs 90 | 91 | # chunk_size = num_bytes_per_chunk / num_bytes_per_frame 92 | chunks = divide_recording_into_time_chunks( 93 | num_frames=num_frames, 94 | chunk_size=chunk_size, 95 | padding_size=0 96 | ) 97 | n_chunk = len(chunks) 98 | 99 | if verbose: 100 | print(f"Number of chunks: {len(chunks)} - Number of jobs: {n_jobs}") 101 | 102 | if verbose and n_jobs == 1: 103 | chunk_iter = tqdm(range(n_chunk), ascii=True, desc="Detecting spikes in chunks") 104 | else: 105 | chunk_iter = range(n_chunk) 106 | 107 | if not recording_sub.check_if_dumpable(): 108 | if n_jobs > 1: 109 | n_jobs = 1 110 | print("RecordingExtractor is not dumpable and can't be processed in parallel") 111 | rec_arg = recording_sub 112 | else: 113 | if n_jobs > 1: 114 | rec_arg = recording_sub.dump_to_dict() 115 | else: 116 | rec_arg = recording_sub 117 | 118 | all_channel_times = [[] for ii in range(len(channel_ids))] 119 | all_channel_amps = [[] for ii in range(len(channel_ids))] 120 | 121 | snippet_len = int(snippet_size_sec * recording.get_sampling_frequency()) 122 | reference_frames = np.linspace(snippet_len+1, recording.get_num_frames() - snippet_len, 123 | n_snippets_for_threshold) 124 | snippets = recording.get_snippets(reference_frames=reference_frames, snippet_len=snippet_len) 125 | traces_mad = np.concatenate(snippets, 1) 126 | thresholds = detect_threshold * np.median(np.abs(traces_mad) / 0.6745, 1)[:, None] 127 | 128 | if n_jobs > 1: 129 | output = Parallel(n_jobs=n_jobs, backend=joblib_backend)(delayed(_detect_and_align_peaks_chunk) 130 | (ii, rec_arg, chunks, channel_ids, thresholds, 131 | detect_sign, 132 | n_shifts, verbose) 133 | for ii in chunk_iter) 134 | for ii, (times_ii, amps_ii) in enumerate(output): 135 | for i, ch in enumerate(channel_ids): 136 | times = times_ii[i] 137 | amps = amps_ii[i] 138 | all_channel_amps[i].append(amps) 139 | all_channel_times[i].append(times) 140 | else: 141 | for ii in chunk_iter: 142 | times_ii, amps_ii = _detect_and_align_peaks_chunk(ii, rec_arg, chunks, channel_ids, thresholds, 143 | detect_sign, n_shifts, False) 144 | 145 | for i, ch in enumerate(channel_ids): 146 | times = times_ii[i] 147 | amps = amps_ii[i] 148 | all_channel_amps[i].append(amps) 149 | all_channel_times[i].append(times) 150 | 151 | if len(chunks) > 1: 152 | times_list = [] 153 | amp_list = [] 154 | for i_ch in range(len(channel_ids)): 155 | times_concat = np.concatenate([all_channel_times[i_ch][ch] for ch in range(len(chunks))], 156 | axis=0) 157 | times_list.append(times_concat) 158 | amps_concat = np.concatenate([all_channel_amps[i_ch][ch] for ch in range(len(chunks))], 159 | axis=0) 160 | amp_list.append(amps_concat) 161 | else: 162 | times_list = [times[0] for times in all_channel_times] 163 | amp_list = [amps[0] for amps in all_channel_amps] 164 | 165 | labels_list = [[ch] * len(times) for (ch, times) in zip(channel_ids, times_list)] 166 | 167 | # create sorting extractor 168 | sorting = se.NumpySortingExtractor() 169 | labels_flat = np.array(list(itertools.chain(*labels_list))) 170 | times_flat = np.array(list(itertools.chain(*times_list))) 171 | sorting.set_times_labels(times=times_flat, labels=labels_flat) 172 | sorting.set_sampling_frequency(recording.get_sampling_frequency()) 173 | 174 | duration = (end_frame - start_frame) / recording.get_sampling_frequency() 175 | 176 | for i_u, u in enumerate(sorting.get_unit_ids()): 177 | sorting.set_unit_property(u, 'channel', u) 178 | amps = amp_list[i_u] 179 | if len(amps) > 0: 180 | sorting.set_unit_property(u, 'spike_amplitude', np.median(amp_list[i_u])) 181 | else: 182 | sorting.set_unit_property(u, 'spike_amplitude', 0) 183 | sorting.set_unit_property(u, 'spike_rate', len(sorting.get_unit_spike_train(u)) / duration) 184 | 185 | return sorting 186 | 187 | 188 | def _detect_and_align_peaks_chunk(ii, rec_arg, chunks, channel_ids, thresholds, detect_sign, n_shifts, 189 | verbose): 190 | chunk = chunks[ii] 191 | 192 | if verbose: 193 | print(f"Chunk {ii + 1}: detecting spikes") 194 | if isinstance(rec_arg, dict): 195 | recording = se.load_extractor_from_dict(rec_arg) 196 | else: 197 | recording = rec_arg 198 | 199 | traces = recording.get_traces(start_frame=chunk['istart'], 200 | end_frame=chunk['iend']) 201 | 202 | if detect_sign == -1: 203 | traces = -traces 204 | elif detect_sign == 0: 205 | traces = np.abs(traces) 206 | 207 | sig_center = traces[:, n_shifts:-n_shifts] 208 | peak_mask = sig_center > thresholds 209 | for i in range(n_shifts): 210 | peak_mask &= sig_center > traces[:, i:i + sig_center.shape[1]] 211 | peak_mask &= sig_center >= traces[:, n_shifts + i + 1:n_shifts + i + 1 + sig_center.shape[1]] 212 | 213 | # find peaks 214 | peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask) 215 | # correct for time shift 216 | peak_sample_ind += n_shifts 217 | 218 | sp_times = [] 219 | sp_amplitudes = [] 220 | 221 | for ch in range(len(channel_ids)): 222 | peak_times = peak_sample_ind[np.where(peak_chan_ind == ch)] 223 | sp_times.append(peak_sample_ind[np.where(peak_chan_ind == ch)] + chunk['istart']) 224 | sp_amplitudes.append(traces[ch, peak_times]) 225 | 226 | return sp_times, sp_amplitudes 227 | -------------------------------------------------------------------------------- /spiketoolkit/tests/.gitignore: -------------------------------------------------------------------------------- 1 | phy/ 2 | phy_group/ -------------------------------------------------------------------------------- /spiketoolkit/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeInterface/spiketoolkit/d90954388400b55e53cc9cc79dbb8a63a5e39c42/spiketoolkit/tests/__init__.py -------------------------------------------------------------------------------- /spiketoolkit/tests/test_curation.py: -------------------------------------------------------------------------------- 1 | import spikeextractors as se 2 | import numpy as np 3 | import shutil 4 | from spikeextractors.testing import check_dumping 5 | 6 | from spiketoolkit.curation import ( 7 | threshold_snrs, 8 | threshold_silhouette_scores, 9 | threshold_d_primes, 10 | threshold_firing_rates, 11 | threshold_isi_violations, 12 | threshold_num_spikes, 13 | threshold_presence_ratios, 14 | threshold_l_ratios, 15 | threshold_amplitude_cutoffs, 16 | threshold_isolation_distances, 17 | threshold_noise_overlaps, 18 | threshold_nn_metrics, 19 | threshold_drift_metrics, 20 | get_curation_params 21 | ) 22 | 23 | from spiketoolkit.validation import ( 24 | compute_num_spikes, 25 | compute_firing_rates, 26 | compute_presence_ratios, 27 | compute_isi_violations, 28 | compute_amplitude_cutoffs, 29 | compute_snrs, 30 | compute_drift_metrics, 31 | compute_silhouette_scores, 32 | compute_isolation_distances, 33 | compute_noise_overlaps, 34 | compute_l_ratios, 35 | compute_d_primes, 36 | compute_nn_metrics, 37 | ) 38 | 39 | 40 | def test_thresh_num_spikes(): 41 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 42 | seed=0) 43 | s_threshold = 25 44 | 45 | sort_ns = threshold_num_spikes(sort, s_threshold, 'less') 46 | new_ns = compute_num_spikes(sort_ns, sort.get_sampling_frequency()) 47 | 48 | assert np.all(new_ns >= s_threshold) 49 | check_dumping(sort_ns) 50 | shutil.rmtree('test') 51 | 52 | 53 | def test_thresh_isi_violations(): 54 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 55 | seed=0) 56 | s_threshold = 0.01 57 | 58 | sort_isi = threshold_isi_violations(sort, s_threshold, 'greater', rec.get_num_frames()) 59 | new_isi = compute_isi_violations(sort_isi, rec.get_num_frames(), sort.get_sampling_frequency()) 60 | 61 | assert np.all(new_isi <= s_threshold) 62 | check_dumping(sort_isi) 63 | shutil.rmtree('test') 64 | 65 | 66 | def test_thresh_presence_ratios(): 67 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 68 | seed=0) 69 | s_threshold = 0.18 70 | 71 | sort_pr = threshold_presence_ratios(sort, s_threshold, 'less', rec.get_num_frames()) 72 | new_pr = compute_presence_ratios(sort_pr, rec.get_num_frames(), sort.get_sampling_frequency()) 73 | 74 | assert np.all(new_pr >= s_threshold) 75 | check_dumping(sort_pr) 76 | shutil.rmtree('test') 77 | 78 | 79 | def test_thresh_amplitude_cutoffs(): 80 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 81 | seed=0) 82 | 83 | amplitude_cutoff_thresh = 0 84 | 85 | sort_amplitude_cutoff = threshold_amplitude_cutoffs(sort, rec, amplitude_cutoff_thresh, "less", 86 | apply_filter=False, seed=0) 87 | new_amplitude_cutoff = compute_amplitude_cutoffs(sort_amplitude_cutoff, rec, apply_filter=False, seed=0) 88 | 89 | assert np.all(new_amplitude_cutoff >= amplitude_cutoff_thresh) 90 | check_dumping(sort_amplitude_cutoff) 91 | shutil.rmtree('test') 92 | 93 | 94 | def test_thresh_frs(): 95 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 96 | seed=0) 97 | fr_thresh = 2 98 | 99 | sort_fr = threshold_firing_rates(sort, fr_thresh, 'less', rec.get_num_frames()) 100 | new_fr = compute_firing_rates(sort_fr, rec.get_num_frames()) 101 | 102 | assert np.all(new_fr >= fr_thresh) 103 | check_dumping(sort_fr) 104 | shutil.rmtree('test') 105 | 106 | 107 | def test_thresh_threshold_drift_metrics(): 108 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 109 | seed=0) 110 | s_threshold = 1 111 | 112 | sort_max = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="max_drift", 113 | apply_filter=False, seed=0) 114 | sort_cum = threshold_drift_metrics(sort, rec, s_threshold, 'greater', metric_name="cumulative_drift", 115 | apply_filter=False, seed=0) 116 | new_max_drift, _ = compute_drift_metrics(sort_max, rec, apply_filter=False, seed=0) 117 | _, new_cum_drift = compute_drift_metrics(sort_cum, rec, apply_filter=False, seed=0) 118 | 119 | assert np.all(new_max_drift <= s_threshold) 120 | assert np.all(new_cum_drift <= s_threshold) 121 | check_dumping(sort_max) 122 | check_dumping(sort_cum) 123 | shutil.rmtree('test') 124 | 125 | 126 | def test_thresh_snrs(): 127 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 128 | seed=0) 129 | 130 | snr_thresh = 4 131 | 132 | sort_snr = threshold_snrs(sort, rec, snr_thresh, 'less', apply_filter=False, seed=0) 133 | new_snr = compute_snrs(sort_snr, rec, apply_filter=False, seed=0) 134 | 135 | assert np.all(new_snr >= snr_thresh) 136 | check_dumping(sort_snr) 137 | shutil.rmtree('test') 138 | 139 | 140 | def test_thresh_noise_overlaps(): 141 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 142 | seed=0) 143 | 144 | noise_thresh = 0.3 145 | 146 | noise_overlaps = compute_noise_overlaps(sort, rec, apply_filter=False, seed=0) 147 | sort_noise = threshold_noise_overlaps(sort, rec, noise_thresh, 'less', apply_filter=False, seed=0) 148 | 149 | original_ids = sort.get_unit_ids() 150 | new_noise = [] 151 | for unit in sort_noise.get_unit_ids(): 152 | new_noise.append(noise_overlaps[original_ids.index(unit)]) 153 | new_noise = np.array(new_noise) 154 | assert np.all(new_noise >= noise_thresh) 155 | check_dumping(sort_noise) 156 | shutil.rmtree('test') 157 | 158 | 159 | # PCA-based 160 | def test_thresh_isolation_distances(): 161 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 162 | seed=0) 163 | s_threshold = 200 164 | 165 | iso = compute_isolation_distances(sort, rec, apply_filter=False, seed=0) 166 | sort_iso = threshold_isolation_distances(sort, rec, s_threshold, 'less', apply_filter=False, seed=0) 167 | 168 | original_ids = sort.get_unit_ids() 169 | new_iso = [] 170 | for unit in sort_iso.get_unit_ids(): 171 | new_iso.append(iso[original_ids.index(unit)]) 172 | new_iso = np.array(new_iso) 173 | assert np.all(new_iso >= s_threshold) 174 | check_dumping(sort_iso) 175 | shutil.rmtree('test') 176 | 177 | 178 | def test_thresh_silhouettes(): 179 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 180 | seed=0) 181 | silhouette_thresh = .5 182 | 183 | silhouette = compute_silhouette_scores(sort, rec, apply_filter=False, seed=0) 184 | sort_silhouette = threshold_silhouette_scores(sort, rec, silhouette_thresh, "less", apply_filter=False, seed=0) 185 | 186 | original_ids = sort.get_unit_ids() 187 | new_silhouette = [] 188 | for unit in sort_silhouette.get_unit_ids(): 189 | new_silhouette.append(silhouette[original_ids.index(unit)]) 190 | new_silhouette = np.array(new_silhouette) 191 | assert np.all(new_silhouette >= silhouette_thresh) 192 | check_dumping(sort_silhouette) 193 | shutil.rmtree('test') 194 | 195 | 196 | def test_thresh_nn_metrics(): 197 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 198 | seed=0) 199 | s_threshold_hit = 0.9 200 | s_threshold_miss = 0.002 201 | 202 | nn_hit, nn_miss = compute_nn_metrics(sort, rec, apply_filter=False, seed=0) 203 | sort_hit = threshold_nn_metrics(sort, rec, s_threshold_hit, 'less', metric_name="nn_hit_rate", 204 | apply_filter=False, seed=0) 205 | sort_miss = threshold_nn_metrics(sort, rec, s_threshold_miss, 'greater', metric_name="nn_miss_rate", 206 | apply_filter=False, seed=0) 207 | 208 | original_ids = sort.get_unit_ids() 209 | new_nn_hit = [] 210 | for unit in sort_hit.get_unit_ids(): 211 | new_nn_hit.append(nn_hit[original_ids.index(unit)]) 212 | new_nn_miss = [] 213 | for unit in sort_miss.get_unit_ids(): 214 | new_nn_miss.append(nn_miss[original_ids.index(unit)]) 215 | new_nn_hit = np.array(new_nn_hit) 216 | new_nn_miss = np.array(new_nn_miss) 217 | assert np.all(new_nn_hit >= s_threshold_hit) 218 | assert np.all(new_nn_miss <= s_threshold_miss) 219 | check_dumping(sort_hit) 220 | check_dumping(sort_miss) 221 | shutil.rmtree('test') 222 | 223 | 224 | def test_thresh_d_primes(): 225 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 226 | seed=0) 227 | d_primes_thresh = .5 228 | 229 | d_primes = compute_d_primes(sort, rec, apply_filter=False, seed=0) 230 | sort_d_primes = threshold_d_primes(sort, rec, d_primes_thresh, "less", apply_filter=False, seed=0) 231 | 232 | original_ids = sort.get_unit_ids() 233 | new_d_primes = [] 234 | for unit in sort_d_primes.get_unit_ids(): 235 | new_d_primes.append(d_primes[original_ids.index(unit)]) 236 | new_d_primes = np.array(new_d_primes) 237 | assert np.all(new_d_primes >= d_primes_thresh) 238 | check_dumping(sort_d_primes) 239 | shutil.rmtree('test') 240 | 241 | 242 | def test_thresh_l_ratios(): 243 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, K=10, 244 | seed=0) 245 | l_ratios_thresh = 0 246 | 247 | l_ratios = compute_l_ratios(sort, rec, apply_filter=False, seed=0) 248 | sort_l_ratios = threshold_l_ratios(sort, rec, l_ratios_thresh, "less", apply_filter=False, seed=0) 249 | 250 | original_ids = sort.get_unit_ids() 251 | new_l_ratios = [] 252 | for unit in sort_l_ratios.get_unit_ids(): 253 | new_l_ratios.append(l_ratios[original_ids.index(unit)]) 254 | new_l_ratios = np.array(new_l_ratios) 255 | assert np.all(new_l_ratios >= l_ratios_thresh) 256 | check_dumping(sort_l_ratios) 257 | shutil.rmtree('test') 258 | 259 | 260 | def test_curation_params(): 261 | print(get_curation_params()) 262 | 263 | 264 | if __name__ == "__main__": 265 | # test_thresh_num_spikes() 266 | # test_thresh_presence_ratios() 267 | # test_thresh_frs() 268 | # test_thresh_isi_violations() 269 | # 270 | # test_thresh_snrs() 271 | # test_thresh_amplitude_cutoffs() 272 | test_thresh_noise_overlaps() 273 | # test_thresh_silhouettes() 274 | # test_thresh_isolation_distances() 275 | # test_thresh_l_ratios() 276 | # test_thresh_threshold_drift_metrics() 277 | # test_thresh_nn_metrics() 278 | -------------------------------------------------------------------------------- /spiketoolkit/tests/test_curation_extractor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikeextractors as se 3 | import spiketoolkit as st 4 | from spikeextractors.testing import check_dumping 5 | 6 | 7 | def test_curation_sorting_extractor(): 8 | rec, sort = se.example_datasets.toy_example(dump_folder='test', dumpable=True, duration=10, num_channels=4, 9 | K=3, seed=0) 10 | 11 | # Dummy features for testing merging and splitting of features 12 | sort.set_unit_spike_features(1, 'f_int', range(0 + 1, len(sort.get_unit_spike_train(1)) + 1)) 13 | sort.set_unit_spike_features(2, 'f_int', range(0, len(sort.get_unit_spike_train(2)))) 14 | sort.set_unit_spike_features(2, 'bad_features', np.repeat(1, len(sort.get_unit_spike_train(2)))) 15 | sort.set_unit_spike_features(3, 'f_int', range(0, len(sort.get_unit_spike_train(3)))) 16 | 17 | CSX = st.curation.CurationSortingExtractor(parent_sorting=sort) 18 | merged_unit_id = CSX.merge_units(unit_ids=[1, 2]) 19 | assert np.allclose(merged_unit_id, 4) 20 | original_spike_train = np.concatenate((sort.get_unit_spike_train(1), sort.get_unit_spike_train(2))) 21 | indices_sort = np.argsort(original_spike_train) 22 | original_spike_train = original_spike_train[indices_sort] 23 | original_features = np.concatenate( 24 | (sort.get_unit_spike_features(1, 'f_int'), sort.get_unit_spike_features(2, 'f_int'))) 25 | original_features = original_features[indices_sort] 26 | assert np.allclose(CSX.get_unit_spike_train(4), original_spike_train) 27 | assert np.allclose(CSX.get_unit_spike_features(4, 'f_int'), original_features) 28 | assert CSX.get_unit_spike_feature_names(4) == ['f_int'] 29 | assert np.allclose(CSX.get_sampling_frequency(), sort.get_sampling_frequency()) 30 | 31 | unit_ids_split = CSX.split_unit(unit_id=3, indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 32 | assert np.allclose(unit_ids_split[0], 5) 33 | assert np.allclose(unit_ids_split[1], 6) 34 | original_spike_train = sort.get_unit_spike_train(3) 35 | original_features = sort.get_unit_spike_features(3, 'f_int') 36 | split_spike_train_1 = CSX.get_unit_spike_train(5) 37 | split_spike_train_2 = CSX.get_unit_spike_train(6) 38 | split_features_1 = CSX.get_unit_spike_features(5, 'f_int') 39 | split_features_2 = CSX.get_unit_spike_features(6, 'f_int') 40 | assert np.allclose(original_spike_train[:10], split_spike_train_1) 41 | assert np.allclose(original_spike_train[10:], split_spike_train_2) 42 | assert np.allclose(original_features[:10], split_features_1) 43 | assert np.allclose(original_features[10:], split_features_2) 44 | 45 | check_dumping(CSX) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_curation_sorting_extractor() 50 | -------------------------------------------------------------------------------- /spiketoolkit/tests/test_sortingcomponents.py: -------------------------------------------------------------------------------- 1 | import spikeextractors as se 2 | import spiketoolkit as st 3 | import numpy as np 4 | import shutil 5 | 6 | 7 | def test_detection(): 8 | folder = 'test' 9 | rec, sort = se.example_datasets.toy_example(num_channels=4, duration=20, seed=0, dumpable=True, dump_folder=folder) 10 | 11 | # negative 12 | sort_d_n = st.sortingcomponents.detect_spikes(rec) 13 | sort_dc_n = st.sortingcomponents.detect_spikes(rec, n_jobs=1, chunk_mb=10) 14 | sort_dp_n = st.sortingcomponents.detect_spikes(rec, n_jobs=2, chunk_mb=10) 15 | 16 | assert 'channel' in sort_d_n.get_shared_unit_property_names() 17 | assert 'channel' in sort_dc_n.get_shared_unit_property_names() 18 | assert 'channel' in sort_dp_n.get_shared_unit_property_names() 19 | assert 'spike_rate' in sort_d_n.get_shared_unit_property_names() 20 | assert 'spike_rate' in sort_dc_n.get_shared_unit_property_names() 21 | assert 'spike_rate' in sort_dp_n.get_shared_unit_property_names() 22 | assert 'spike_amplitude' in sort_d_n.get_shared_unit_property_names() 23 | assert 'spike_amplitude' in sort_dc_n.get_shared_unit_property_names() 24 | assert 'spike_amplitude' in sort_dp_n.get_shared_unit_property_names() 25 | 26 | for u in sort_d_n.get_unit_ids(): 27 | assert np.array_equal(sort_d_n.get_unit_spike_train(u), sort_dp_n.get_unit_spike_train(u)) 28 | assert np.array_equal(sort_d_n.get_unit_spike_train(u), sort_dc_n.get_unit_spike_train(u)) 29 | 30 | # positive 31 | sort_d_p = st.sortingcomponents.detect_spikes(rec, detect_sign=1) 32 | sort_dc_p = st.sortingcomponents.detect_spikes(rec, detect_sign=1, n_jobs=1, chunk_mb=10) 33 | sort_dp_p = st.sortingcomponents.detect_spikes(rec, detect_sign=1, n_jobs=2, chunk_mb=10) 34 | 35 | assert 'channel' in sort_d_p.get_shared_unit_property_names() 36 | assert 'channel' in sort_dc_p.get_shared_unit_property_names() 37 | assert 'channel' in sort_dp_p.get_shared_unit_property_names() 38 | assert 'spike_rate' in sort_d_p.get_shared_unit_property_names() 39 | assert 'spike_rate' in sort_dc_p.get_shared_unit_property_names() 40 | assert 'spike_rate' in sort_dp_p.get_shared_unit_property_names() 41 | assert 'spike_amplitude' in sort_d_p.get_shared_unit_property_names() 42 | assert 'spike_amplitude' in sort_dc_p.get_shared_unit_property_names() 43 | assert 'spike_amplitude' in sort_dp_p.get_shared_unit_property_names() 44 | 45 | for u in sort_d_p.get_unit_ids(): 46 | assert np.array_equal(sort_d_p.get_unit_spike_train(u), sort_dp_p.get_unit_spike_train(u)) 47 | assert np.array_equal(sort_d_p.get_unit_spike_train(u), sort_dc_p.get_unit_spike_train(u)) 48 | 49 | # both 50 | sort_d_b = st.sortingcomponents.detect_spikes(rec, detect_sign=0) 51 | sort_dc_b = st.sortingcomponents.detect_spikes(rec, detect_sign=0, n_jobs=2, chunk_mb=10) 52 | sort_dp_b = st.sortingcomponents.detect_spikes(rec, detect_sign=0, n_jobs=2, chunk_mb=10) 53 | 54 | assert 'channel' in sort_d_b.get_shared_unit_property_names() 55 | assert 'channel' in sort_dc_b.get_shared_unit_property_names() 56 | assert 'channel' in sort_dp_b.get_shared_unit_property_names() 57 | assert 'spike_rate' in sort_d_b.get_shared_unit_property_names() 58 | assert 'spike_rate' in sort_dc_b.get_shared_unit_property_names() 59 | assert 'spike_rate' in sort_dp_b.get_shared_unit_property_names() 60 | assert 'spike_amplitude' in sort_d_b.get_shared_unit_property_names() 61 | assert 'spike_amplitude' in sort_dc_b.get_shared_unit_property_names() 62 | assert 'spike_amplitude' in sort_dp_b.get_shared_unit_property_names() 63 | 64 | for u in sort_d_b.get_unit_ids(): 65 | assert np.array_equal(sort_d_b.get_unit_spike_train(u), sort_dp_b.get_unit_spike_train(u)) 66 | assert np.array_equal(sort_d_b.get_unit_spike_train(u), sort_dc_b.get_unit_spike_train(u)) 67 | 68 | shutil.rmtree(folder) 69 | 70 | 71 | if __name__ == '__main__': 72 | test_detection() 73 | -------------------------------------------------------------------------------- /spiketoolkit/tests/test_validation.py: -------------------------------------------------------------------------------- 1 | import spikeextractors as se 2 | import numpy as np 3 | from spiketoolkit.validation import compute_isolation_distances, compute_isi_violations, compute_snrs, \ 4 | compute_amplitude_cutoffs, compute_d_primes, compute_drift_metrics, compute_firing_rates, compute_l_ratios, \ 5 | compute_quality_metrics, compute_nn_metrics, compute_num_spikes, compute_presence_ratios, \ 6 | compute_silhouette_scores, compute_noise_overlaps, get_validation_params 7 | 8 | 9 | def test_functions(): 10 | rec, sort = se.example_datasets.toy_example(duration=10, num_channels=4, seed=0) 11 | 12 | firing_rates = compute_firing_rates(sort, rec.get_num_frames(), seed=0) 13 | num_spikes = compute_num_spikes(sort, seed=0) 14 | isi = compute_isi_violations(sort, rec.get_num_frames(), seed=0) 15 | presence = compute_presence_ratios(sort, rec.get_num_frames(), seed=0) 16 | amp_cutoff = compute_amplitude_cutoffs(sort, rec, seed=0) 17 | max_drift, cum_drift = compute_drift_metrics(sort, rec, seed=0, memmap=False) 18 | silh = compute_silhouette_scores(sort, rec, seed=0) 19 | iso = compute_isolation_distances(sort, rec, seed=0) 20 | l_ratio = compute_l_ratios(sort, rec, seed=0) 21 | dprime = compute_d_primes(sort, rec, seed=0) 22 | noise_overlaps = compute_noise_overlaps(sort, rec, seed=0) 23 | nn_hit, nn_miss = compute_nn_metrics(sort, rec, seed=0) 24 | snr = compute_snrs(sort, rec, seed=0) 25 | metrics = compute_quality_metrics(sort, rec, return_dict=True, seed=0) 26 | 27 | assert np.allclose(metrics['firing_rate'], firing_rates) 28 | assert np.allclose(metrics['num_spikes'], num_spikes) 29 | assert np.allclose(metrics['isi_violation'], isi) 30 | assert np.allclose(metrics['amplitude_cutoff'], amp_cutoff) 31 | assert np.allclose(metrics['presence_ratio'], presence) 32 | assert np.allclose(metrics['silhouette_score'], silh) 33 | assert np.allclose(metrics['isolation_distance'], iso) 34 | assert np.allclose(metrics['l_ratio'], l_ratio) 35 | assert np.allclose(metrics['d_prime'], dprime) 36 | assert np.allclose(metrics['snr'], snr) 37 | assert np.allclose(metrics['max_drift'], max_drift) 38 | assert np.allclose(metrics['cumulative_drift'], cum_drift) 39 | assert np.allclose(metrics['noise_overlap'], noise_overlaps) 40 | assert np.allclose(metrics['nn_hit_rate'], nn_hit) 41 | assert np.allclose(metrics['nn_miss_rate'], nn_miss) 42 | 43 | 44 | def test_validation_params(): 45 | print(get_validation_params()) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_functions() 50 | -------------------------------------------------------------------------------- /spiketoolkit/tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal as ss 3 | import spikeextractors as se 4 | from pathlib import Path 5 | 6 | def check_signal_power_signal1_below_signal2(signals1, signals2, freq_range, fs): 7 | ''' 8 | Check that spectrum power of signal1 is below the one of signal 2 in the range freq_range 9 | ''' 10 | f1, pow1 = ss.welch(signals1, fs, nfft=1024) 11 | f2, pow2 = ss.welch(signals2, fs, nfft=1024) 12 | 13 | below = True 14 | 15 | for (p1, p2) in zip(pow1, pow2): 16 | 17 | r1_idxs = np.where((f1 > freq_range[0]) & (f1 <= freq_range[1])) 18 | r2_idxs = np.where((f2 > freq_range[0]) & (f2 <= freq_range[1])) 19 | 20 | sump1 = np.sum(p1[r1_idxs]) 21 | sump2 = np.sum(p2[r2_idxs]) 22 | 23 | if sump1 >= sump2: 24 | below = False 25 | break 26 | 27 | return below 28 | 29 | 30 | def create_wf(min_val=-100, max_val=50, n_samples=100): 31 | ''' 32 | Creates stereotyped waveform 33 | ''' 34 | wf = np.zeros(n_samples) 35 | inter = n_samples // 4 36 | wf[:inter] = np.linspace(0, min_val, inter) 37 | wf[inter:3 * inter] = np.linspace(min_val, max_val, 2 * inter) 38 | wf[3 * inter:] = np.linspace(max_val, 0, n_samples - 3 * inter) 39 | 40 | return wf 41 | 42 | 43 | def generate_template_with_random_amps(n_ch, wf): 44 | ''' 45 | Creates stereotyped templates from waveform 46 | ''' 47 | amps = [] 48 | i = 1 49 | found = False 50 | while len(amps) < n_ch - 1 and i < 1000: 51 | a = np.random.rand() 52 | i = i + 1 53 | if a < 0.2 or a > 0.5: 54 | continue 55 | if sum(amps) + a < 0.9: 56 | amps.append(a) 57 | if len(amps) == n_ch - 1: 58 | amps.append(1 - sum(amps)) 59 | found = True 60 | template = np.zeros((n_ch, len(wf))) 61 | for i, a in enumerate(amps): 62 | template[i] = a * wf 63 | else: 64 | template = [] 65 | 66 | return template, amps, found 67 | 68 | 69 | def create_signal_with_known_waveforms(n_channels=4, n_waveforms=2, n_wf_samples=100, duration=5, fs=30000): 70 | ''' 71 | Creates stereotyped recording, sorting, with waveforms, templates, and max_chans 72 | ''' 73 | a_min = [-200, -50] 74 | a_max = [10, 50] 75 | wfs = [] 76 | 77 | # gen waveforms 78 | for w in range(n_waveforms): 79 | amp_min = np.random.randint(a_min[0], a_min[1]) 80 | amp_max = np.random.randint(a_max[0], a_max[1]) 81 | 82 | wf = create_wf(amp_min, amp_max, n_wf_samples) 83 | wfs.append(wf) 84 | 85 | # gen templates 86 | templates = [] 87 | max_chans = [] 88 | for wf in wfs: 89 | found = False 90 | while not found: 91 | template, amps, found = generate_template_with_random_amps(n_channels, wf) 92 | templates.append(template) 93 | max_chans.append(np.argmax(amps)) 94 | 95 | templates = np.array(templates) 96 | n_samples = int(fs * duration) 97 | 98 | # gen spiketrains 99 | interval = 10 * n_wf_samples 100 | times = np.arange(interval, duration * fs - interval, interval).astype(int) 101 | labels = np.zeros(len(times)).astype(int) 102 | for i, wf in enumerate(wfs): 103 | labels[i::len(wfs)] = i 104 | 105 | timeseries = np.zeros((n_channels, n_samples)) 106 | waveforms = [] 107 | amplitudes = [] 108 | for i, tem in enumerate(templates): 109 | idxs = np.where(labels == i) 110 | wav = [] 111 | amps = [] 112 | for t in times[idxs]: 113 | rand_val = np.random.randn() * 0.01 + 1 114 | timeseries[:, t - n_wf_samples // 2:t + n_wf_samples // 2] = rand_val * tem 115 | wav.append(rand_val * tem) 116 | amps.append(np.min(rand_val * tem)) 117 | wav = np.array(wav) 118 | amps = np.array(amps) 119 | waveforms.append(wav) 120 | amplitudes.append(amps) 121 | 122 | rec = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=fs) 123 | sort = se.NumpySortingExtractor() 124 | sort.set_times_labels(times=times, labels=labels) 125 | sort.set_sampling_frequency(fs) 126 | 127 | return rec, sort, waveforms, templates, max_chans, amplitudes 128 | 129 | 130 | def create_fake_waveforms_with_known_pc(): 131 | # HINT: start from Guassians in PC space and stereotyped waveforms and build dataset. 132 | pass 133 | 134 | 135 | def create_dumpable_extractors_from_existing(folder, RX, SX): 136 | folder = Path(folder) 137 | 138 | if 'location' not in RX.get_shared_channel_property_names(): 139 | RX.set_channel_locations(np.random.randn(RX.get_num_channels(), 2)) 140 | se.MdaRecordingExtractor.write_recording(RX, folder) 141 | RX_mda = se.MdaRecordingExtractor(folder) 142 | se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz') 143 | SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz') 144 | 145 | return RX_mda, SX_npz 146 | -------------------------------------------------------------------------------- /spiketoolkit/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_closest_channels(recording, channel_ids, num_channels=None): 5 | """Get closest channels + distances 6 | 7 | Parameters 8 | ---------- 9 | recording: RecordingExtractor 10 | The recording extractor to be re-referenced 11 | channel_ids: list or int 12 | list of channels id to compute there near neighborhood 13 | num_channels: int, optional 14 | Maximum number of neighborhood channel to return 15 | 16 | Returns 17 | ------- 18 | : array (2d) 19 | closest channel ids in ascending order for each channel id given in input 20 | : array (2d) 21 | distance in ascending order for each channel id given in input 22 | """ 23 | closest_channels_id = [] 24 | dist = [] 25 | 26 | if num_channels: 27 | num_channels = min(num_channels + 1, len(recording.get_channel_locations())) 28 | else: 29 | num_channels = len(recording.get_channel_locations()) 30 | 31 | if not isinstance(channel_ids, list): 32 | channel_ids = list(channel_ids) 33 | 34 | for n, id in enumerate(channel_ids): 35 | locs = recording.get_channel_locations() 36 | distances = [np.linalg.norm(l - locs[recording.get_channel_ids().index(id)]) for l in locs] 37 | closest_channels_id.append(np.argsort(distances)[1:num_channels]) 38 | dist.append(np.sort(distances)[1:num_channels]) 39 | 40 | return np.array(closest_channels_id), np.array(dist) 41 | -------------------------------------------------------------------------------- /spiketoolkit/validation/__init__.py: -------------------------------------------------------------------------------- 1 | from .validation_list import * 2 | from .curation_list import ( 3 | curation_full_list, 4 | installed_curation_list, 5 | curation_dict, 6 | ) -------------------------------------------------------------------------------- /spiketoolkit/validation/curation_list.py: -------------------------------------------------------------------------------- 1 | from .quality_metric_classes.metric_data import MetricData 2 | from .quality_metric_classes.num_spikes import NumSpikes 3 | from .quality_metric_classes.amplitude_cutoff import AmplitudeCutoff 4 | from .quality_metric_classes.silhouette_score import SilhouetteScore 5 | from .quality_metric_classes.d_prime import DPrime 6 | from .quality_metric_classes.l_ratio import LRatio 7 | from .quality_metric_classes.isolation_distance import IsolationDistance 8 | from .quality_metric_classes.firing_rate import FiringRate 9 | from .quality_metric_classes.presence_ratio import PresenceRatio 10 | from .quality_metric_classes.isi_violation import ISIViolation 11 | from .quality_metric_classes.snr import SNR 12 | from .quality_metric_classes.nearest_neighbor import NearestNeighbor 13 | from .quality_metric_classes.drift_metric import DriftMetric 14 | 15 | curation_full_list = [ 16 | NumSpikes, 17 | FiringRate, 18 | PresenceRatio, 19 | ISIViolation, 20 | SNR, 21 | AmplitudeCutoff, 22 | DriftMetric, 23 | SilhouetteScore, 24 | DPrime, 25 | LRatio, 26 | IsolationDistance, 27 | NearestNeighbor, 28 | ] 29 | 30 | installed_curation_list = [c for c in curation_full_list if c.installed] 31 | curation_dict = {c_class.curator_name: c_class for c_class in curation_full_list} 32 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/amplitude_cutoff.py: -------------------------------------------------------------------------------- 1 | from .quality_metric import QualityMetric 2 | import numpy as np 3 | import spikemetrics.metrics as metrics 4 | from .utils.thresholdcurator import ThresholdCurator 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class AmplitudeCutoff(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | curator_name = "ThresholdAmplitudeCutoffs" 13 | params = OrderedDict([]) 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="amplitude_cutoff") 17 | if not metric_data.has_amplitudes(): 18 | raise ValueError("MetricData object must have amplitudes") 19 | 20 | def compute_metric(self, **kwargs): 21 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 22 | save_property_or_features = params_dict['save_property_or_features'] 23 | amplitude_cutoffs_all = metrics.calculate_amplitude_cutoff( 24 | self._metric_data._spike_clusters_amps, 25 | self._metric_data._amplitudes, 26 | self._metric_data._total_units, 27 | spike_cluster_subset=self._metric_data._unit_indices, 28 | verbose=self._metric_data.verbose, 29 | ) 30 | amplitude_cutoffs_list = [] 31 | for i in self._metric_data._unit_indices: 32 | amplitude_cutoffs_list.append(amplitude_cutoffs_all[i]) 33 | amplitude_cutoffs = np.asarray(amplitude_cutoffs_list) 34 | if save_property_or_features: 35 | self.save_property_or_features(self._metric_data._sorting, amplitude_cutoffs, self._metric_name) 36 | return amplitude_cutoffs 37 | 38 | def threshold_metric(self, threshold, threshold_sign, **kwargs): 39 | amplitude_cutoffs = self.compute_metric(**kwargs) 40 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, 41 | metric=amplitude_cutoffs) 42 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 43 | return threshold_curator 44 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/d_prime.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class DPrime(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('num_channels_to_compare', 13), ('max_spikes_per_cluster', 500)]) 13 | curator_name = "ThresholdDPrimes" 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="d_prime") 17 | 18 | if not metric_data.has_pca_scores(): 19 | raise ValueError("MetricData object must have pca scores") 20 | 21 | def compute_metric(self, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | seed = params_dict['seed'] 24 | save_property_or_features = params_dict['save_property_or_features'] 25 | d_primes_all = metrics.calculate_pc_metrics( 26 | spike_clusters=self._metric_data._spike_clusters_pca, 27 | total_units=self._metric_data._total_units, 28 | pc_features=self._metric_data._pc_features, 29 | pc_feature_ind=self._metric_data._pc_feature_ind, 30 | num_channels_to_compare=num_channels_to_compare, 31 | max_spikes_for_cluster=max_spikes_per_cluster, 32 | channel_locations=self._metric_data._channel_locations, 33 | spikes_for_nn=None, 34 | n_neighbors=None, 35 | metric_names=["d_prime"], 36 | seed=seed, 37 | spike_cluster_subset=self._metric_data._unit_indices, 38 | verbose=self._metric_data.verbose, 39 | )[2] 40 | d_primes_list = [] 41 | for i in self._metric_data._unit_indices: 42 | d_primes_list.append(d_primes_all[i]) 43 | d_primes = np.asarray(d_primes_list) 44 | if save_property_or_features: 45 | self.save_property_or_features(self._metric_data._sorting, d_primes, self._metric_name) 46 | return d_primes 47 | 48 | def threshold_metric(self, threshold, threshold_sign, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 49 | d_primes = \ 50 | self.compute_metric(num_channels_to_compare, max_spikes_per_cluster, **kwargs) 51 | threshold_curator = ThresholdCurator( 52 | sorting=self._metric_data._sorting, metric=d_primes 53 | ) 54 | threshold_curator.threshold_sorting( 55 | threshold=threshold, threshold_sign=threshold_sign 56 | ) 57 | return threshold_curator 58 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/drift_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class DriftMetric(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('drift_metrics_interval_s', 51), ('drift_metrics_min_spikes_per_interval', 10)]) 13 | curator_name = "ThresholdDriftMetrics" 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="drift_metric") 17 | 18 | if not metric_data.has_pca_scores(): 19 | raise ValueError("MetricData object must have pca scores") 20 | 21 | def compute_metric(self, drift_metrics_interval_s, drift_metrics_min_spikes_per_interval, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | max_drifts_all, cumulative_drifts_all = metrics.calculate_drift_metrics( 25 | self._metric_data._spike_times_pca, 26 | self._metric_data._spike_clusters_pca, 27 | self._metric_data._total_units, 28 | self._metric_data._pc_features, 29 | self._metric_data._pc_feature_ind, 30 | drift_metrics_interval_s, 31 | drift_metrics_min_spikes_per_interval, 32 | channel_locations=self._metric_data._recording.get_channel_locations(), 33 | spike_cluster_subset=self._metric_data._unit_indices, 34 | verbose=self._metric_data.verbose, 35 | ) 36 | max_drifts_list = [] 37 | cumulative_drifts_list = [] 38 | for i in self._metric_data._unit_indices: 39 | max_drifts_list.append(max_drifts_all[i]) 40 | cumulative_drifts_list.append(cumulative_drifts_all[i]) 41 | max_drifts = np.asarray(max_drifts_list) 42 | cumulative_drifts = np.asarray(cumulative_drifts_list) 43 | if save_property_or_features: 44 | self.save_property_or_features(self._metric_data._sorting, max_drifts, metric_name="max_drift") 45 | self.save_property_or_features(self._metric_data._sorting, cumulative_drifts, metric_name="cumulative_drift") 46 | return [max_drifts, cumulative_drifts] 47 | 48 | def threshold_metric(self, threshold, threshold_sign, metric_name, drift_metrics_interval_s, 49 | drift_metrics_min_spikes_per_interval, **kwargs): 50 | max_drifts, cumulative_drifts = \ 51 | self.compute_metric(drift_metrics_interval_s, drift_metrics_min_spikes_per_interval, 52 | **kwargs) 53 | if metric_name == "max_drift": 54 | metric = max_drifts 55 | elif metric_name == "cumulative_drift": 56 | metric = cumulative_drifts 57 | else: 58 | raise ValueError("Invalid metric named entered") 59 | 60 | threshold_curator = ThresholdCurator( 61 | sorting=self._metric_data._sorting, metric=metric 62 | ) 63 | threshold_curator.threshold_sorting( 64 | threshold=threshold, threshold_sign=threshold_sign 65 | ) 66 | return threshold_curator 67 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/firing_rate.py: -------------------------------------------------------------------------------- 1 | from .quality_metric import QualityMetric 2 | import numpy as np 3 | import spikemetrics.metrics as metrics 4 | from .utils.thresholdcurator import ThresholdCurator 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class FiringRate(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict() 13 | curator_name = "ThresholdFiringRates" 14 | 15 | def __init__( 16 | self, 17 | metric_data, 18 | ): 19 | QualityMetric.__init__(self, metric_data, metric_name="firing_rate") 20 | 21 | def compute_metric(self, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | firing_rate_all = metrics.calculate_firing_rates( 25 | self._metric_data._spike_times, 26 | self._metric_data._spike_clusters, 27 | self._metric_data._total_units, 28 | duration=self._metric_data._duration_in_frames/self._metric_data._sampling_frequency, 29 | spike_cluster_subset=self._metric_data._unit_indices, 30 | verbose=self._metric_data.verbose, 31 | ) 32 | firing_rate_list = [] 33 | for i in self._metric_data._unit_indices: 34 | firing_rate_list.append(firing_rate_all[i]) 35 | firing_rate = np.asarray(firing_rate_list) 36 | if save_property_or_features: 37 | self.save_property_or_features(self._metric_data._sorting, firing_rate, self._metric_name) 38 | return firing_rate 39 | 40 | def threshold_metric(self, threshold, threshold_sign, **kwargs): 41 | firing_rate = self.compute_metric(**kwargs) 42 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=firing_rate) 43 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 44 | return threshold_curator 45 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/isi_violation.py: -------------------------------------------------------------------------------- 1 | from .quality_metric import QualityMetric 2 | import numpy as np 3 | import spikemetrics.metrics as metrics 4 | from .utils.thresholdcurator import ThresholdCurator 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class ISIViolation(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('isi_threshold', 0.0015), ('min_isi', None)]) 13 | curator_name = "ThresholdISIViolations" 14 | 15 | def __init__( 16 | self, 17 | metric_data, 18 | ): 19 | QualityMetric.__init__(self, metric_data, metric_name="isi_violation") 20 | 21 | def compute_metric(self, isi_threshold, min_isi, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | if min_isi is None: 25 | min_isi = 1 / (self._metric_data._sampling_frequency) * 0.5 26 | isi_violation_all = metrics.calculate_isi_violations( 27 | self._metric_data._spike_times, 28 | self._metric_data._spike_clusters, 29 | self._metric_data._total_units, 30 | isi_threshold=isi_threshold, 31 | min_isi=min_isi, 32 | duration=self._metric_data._duration_in_frames/self._metric_data._sampling_frequency, 33 | spike_cluster_subset=self._metric_data._unit_indices, 34 | verbose=self._metric_data.verbose, 35 | ) 36 | isi_violation_list = [] 37 | for i in self._metric_data._unit_indices: 38 | isi_violation_list.append(isi_violation_all[i]) 39 | isi_violations = np.asarray(isi_violation_list) 40 | if save_property_or_features: 41 | self.save_property_or_features(self._metric_data._sorting, isi_violations, self._metric_name) 42 | return isi_violations 43 | 44 | def threshold_metric(self, threshold, threshold_sign, isi_threshold, min_isi, **kwargs): 45 | isi_violations = self.compute_metric(isi_threshold, min_isi, **kwargs) 46 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=isi_violations) 47 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 48 | return threshold_curator 49 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/isolation_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class IsolationDistance(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('num_channels_to_compare', 13), ('max_spikes_per_cluster', 500)]) 13 | curator_name = "ThresholdIsolationDistances" 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="isolation_distance") 17 | 18 | if not metric_data.has_pca_scores(): 19 | raise ValueError("MetricData object must have pca scores") 20 | 21 | def compute_metric(self, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | seed = params_dict['seed'] 25 | isolation_distances_all = metrics.calculate_pc_metrics( 26 | spike_clusters=self._metric_data._spike_clusters_pca, 27 | total_units=self._metric_data._total_units, 28 | pc_features=self._metric_data._pc_features, 29 | pc_feature_ind=self._metric_data._pc_feature_ind, 30 | num_channels_to_compare=num_channels_to_compare, 31 | max_spikes_for_cluster=max_spikes_per_cluster, 32 | channel_locations=self._metric_data._channel_locations, 33 | spikes_for_nn=None, 34 | n_neighbors=None, 35 | metric_names=["isolation_distance"], 36 | seed=seed, 37 | spike_cluster_subset=self._metric_data._unit_indices, 38 | verbose=self._metric_data.verbose, 39 | )[0] 40 | isolation_distances_list = [] 41 | for i in self._metric_data._unit_indices: 42 | isolation_distances_list.append(isolation_distances_all[i]) 43 | isolation_distances = np.asarray(isolation_distances_list) 44 | if save_property_or_features: 45 | self.save_property_or_features(self._metric_data._sorting, isolation_distances, self._metric_name) 46 | return isolation_distances 47 | 48 | def threshold_metric(self, threshold, threshold_sign, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 49 | isolation_distances = \ 50 | self.compute_metric(num_channels_to_compare, max_spikes_per_cluster, **kwargs) 51 | threshold_curator = ThresholdCurator( 52 | sorting=self._metric_data._sorting, metric=isolation_distances 53 | ) 54 | threshold_curator.threshold_sorting( 55 | threshold=threshold, threshold_sign=threshold_sign 56 | ) 57 | return threshold_curator 58 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/l_ratio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class LRatio(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('num_channels_to_compare', 13), ('max_spikes_per_cluster', 500)]) 13 | curator_name = "ThresholdLRatios" 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="l_ratio") 17 | 18 | if not metric_data.has_pca_scores(): 19 | raise ValueError("MetricData object must have pca scores") 20 | 21 | def compute_metric(self, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | seed = params_dict['seed'] 25 | l_ratios_all = metrics.calculate_pc_metrics( 26 | spike_clusters=self._metric_data._spike_clusters_pca, 27 | total_units=self._metric_data._total_units, 28 | pc_features=self._metric_data._pc_features, 29 | pc_feature_ind=self._metric_data._pc_feature_ind, 30 | num_channels_to_compare=num_channels_to_compare, 31 | max_spikes_for_cluster=max_spikes_per_cluster, 32 | channel_locations=self._metric_data._channel_locations, 33 | spikes_for_nn=None, 34 | n_neighbors=None, 35 | metric_names=["l_ratio"], 36 | seed=seed, 37 | spike_cluster_subset=self._metric_data._unit_indices, 38 | verbose=self._metric_data.verbose, 39 | )[1] 40 | l_ratios_list = [] 41 | for i in self._metric_data._unit_indices: 42 | l_ratios_list.append(l_ratios_all[i]) 43 | l_ratios = np.asarray(l_ratios_list) 44 | if save_property_or_features: 45 | self.save_property_or_features(self._metric_data._sorting, l_ratios, self._metric_name) 46 | return l_ratios 47 | 48 | def threshold_metric(self, threshold, threshold_sign, num_channels_to_compare, max_spikes_per_cluster, **kwargs): 49 | l_ratios = \ 50 | self.compute_metric(num_channels_to_compare, max_spikes_per_cluster, **kwargs) 51 | threshold_curator = ThresholdCurator( 52 | sorting=self._metric_data._sorting, metric=l_ratios 53 | ) 54 | threshold_curator.threshold_sorting( 55 | threshold=threshold, threshold_sign=threshold_sign 56 | ) 57 | return threshold_curator 58 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/metric_data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict, defaultdict 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from copy import deepcopy 7 | from spikeextractors import RecordingExtractor, SortingExtractor 8 | 9 | import spikemetrics.metrics as metrics 10 | from spiketoolkit.preprocessing.bandpass_filter import bandpass_filter 11 | from spikemetrics.utils import Epoch, printProgressBar 12 | 13 | from .utils.validation_tools import get_amplitude_metric_data, get_pca_metric_data, get_spike_times_metrics_data 14 | 15 | 16 | # Baseclass for each quality metric 17 | class MetricData: 18 | def __init__( 19 | self, 20 | sorting, 21 | recording, 22 | duration_in_frames, 23 | sampling_frequency, 24 | apply_filter, 25 | freq_min, 26 | freq_max, 27 | unit_ids, 28 | verbose, 29 | raise_if_empty=True 30 | ): 31 | """ 32 | Computes and stores inital data along with the unit ids to be used for computing metrics. 33 | 34 | Parameters 35 | ---------- 36 | sorting: SortingExtractor 37 | The sorting extractor to be evaluated. 38 | recording: RecordingExtractor 39 | The recording extractor to be stored. If None, the recording extractor can be added later. 40 | duration_in_frames: int 41 | Length of recording (in frames). If None, will look to recording extractor for num frames. 42 | sampling_frequency: 43 | The sampling frequency of the result. If None, will check to see if sampling frequency is in sorting extractor. 44 | apply_filter: bool 45 | If True, recording is bandpass-filtered. 46 | freq_min: float 47 | High-pass frequency for optional filter (default 300 Hz). 48 | freq_max: float 49 | Low-pass frequency for optional filter (default 6000 Hz). 50 | unit_ids: list 51 | List of unit ids to compute metric for. If not specified, all units are used 52 | verbose: bool 53 | If True, progress bar is displayed 54 | raise_if_empty: bool 55 | If True, an Exception is thrown if some spike trains are empty 56 | """ 57 | if sampling_frequency is None and sorting.get_sampling_frequency() is None and recording is None: 58 | raise ValueError("Please pass in a sampling frequency (your SortingExtractor does not have one specified " 59 | "and no RecordingExtractor given).") 60 | elif sampling_frequency is None and sorting.get_sampling_frequency() is not None: 61 | self._sampling_frequency = sorting.get_sampling_frequency() 62 | elif sampling_frequency is None and recording is not None: 63 | self._sampling_frequency = recording.get_sampling_frequency() 64 | else: 65 | self._sampling_frequency = sampling_frequency 66 | 67 | if recording is None: 68 | channel_locations = np.array([0,0]) 69 | else: 70 | channel_locations = recording.get_channel_locations() 71 | 72 | # checks to see if any units have no spikes (will break metric calculation) 73 | if raise_if_empty: 74 | for unit_id in sorting.get_unit_ids(): 75 | if len(sorting.get_unit_spike_train(unit_id)) == 0: 76 | raise ValueError("Spike trains must have none zero length. " 77 | "Please remove all zero length spike trains") 78 | 79 | if unit_ids is None: 80 | unit_ids = sorting.get_unit_ids() 81 | else: 82 | unit_ids_keep = [] 83 | for unit in unit_ids: 84 | if unit in sorting.get_unit_ids(): 85 | unit_ids_keep.append(unit) 86 | else: 87 | print(f'Unit {unit} is invalid') 88 | unit_ids = unit_ids_keep 89 | 90 | if len(unit_ids) == 0: 91 | raise ValueError("No units found.") 92 | 93 | spike_times, spike_clusters = get_spike_times_metrics_data(sorting, self._sampling_frequency) 94 | assert isinstance( 95 | sorting, SortingExtractor 96 | ), "'sorting' must be a SortingExtractor object" 97 | self._sorting = sorting 98 | self._unit_ids = unit_ids 99 | self._spike_times = spike_times 100 | self._spike_clusters = spike_clusters 101 | self._total_units = len(sorting.get_unit_ids()) 102 | self._unit_indices = _get_unit_indices(self._sorting, unit_ids) 103 | self._channel_locations = channel_locations 104 | # To compute this data, need to call all metric data 105 | self._amplitudes = None 106 | self._pc_features = None 107 | self._pc_feature_ind = None 108 | self._spike_clusters_pca = None 109 | self._spike_clusters_amps = None 110 | self._spike_times_pca = None 111 | self._spike_times_amps = None 112 | self.verbose = verbose 113 | 114 | if recording is not None: 115 | assert isinstance( 116 | recording, RecordingExtractor 117 | ), "'recording' must be a RecordingExtractor object" 118 | self.set_recording( 119 | recording, 120 | apply_filter=apply_filter, 121 | freq_min=freq_min, 122 | freq_max=freq_max, 123 | ) 124 | self._duration_in_frames = recording.get_num_frames() 125 | else: 126 | self._recording = None 127 | self._duration_in_frames = duration_in_frames 128 | 129 | def set_recording(self, recording, apply_filter, freq_min, freq_max): 130 | """ 131 | Sets given recording extractor 132 | 133 | Parameters 134 | ---------- 135 | recording: RecordingExtractor 136 | The recording extractor to be stored. 137 | apply_filter: bool 138 | If True, recording is bandpass-filtered. 139 | freq_min: float 140 | High-pass frequency for optional filter (default 300 Hz). 141 | freq_max: float 142 | Low-pass frequency for optional filter (default 6000 Hz). 143 | """ 144 | if apply_filter and not recording.is_filtered: 145 | recording_filter = bandpass_filter( 146 | recording=recording, 147 | freq_min=freq_min, 148 | freq_max=freq_max, 149 | ) 150 | else: 151 | recording_filter = recording 152 | self._recording = recording_filter 153 | 154 | def is_filtered(self): 155 | return self._recording.is_filtered 156 | 157 | def has_recording(self): 158 | return self._recording is not None 159 | 160 | def has_amplitudes(self): 161 | return self._amplitudes is not None 162 | 163 | def has_pca_scores(self): 164 | return self._pc_features is not None 165 | 166 | def compute_amplitudes(self, **kwargs): 167 | """ 168 | Computes and stores amplitudes for the amplitude cutoff metric 169 | 170 | Parameters 171 | ---------- 172 | method: str 173 | If 'absolute' (default), amplitudes are absolute amplitudes in uV are returned. 174 | If 'relative', amplitudes are returned as ratios between waveform amplitudes and template amplitudes. 175 | peak: str 176 | If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or both ('both' - default) 177 | frames_before: int 178 | Frames before peak to compute amplitude 179 | frames_after: int 180 | Frames after peak to compute amplitude 181 | max_spikes_per_unit: int 182 | The maximum number of spikes to use to compute amplitudes. 183 | save_property_or_features: bool 184 | If true, it will save amplitudes in the sorting extractor. 185 | recompute_info: bool 186 | If True, will always re-extract waveforms. 187 | seed: int 188 | Random seed for reproducibility 189 | """ 190 | spike_times, spike_times_amp, spike_clusters, spike_clusters_amp, amplitudes = get_amplitude_metric_data( 191 | self._recording, self._sorting, **kwargs) 192 | self._amplitudes = amplitudes 193 | self._spike_clusters_amps = spike_clusters 194 | self._spike_times_amps = spike_times_amp 195 | self._spike_clusters_amps = spike_clusters_amp 196 | 197 | def compute_pca_scores(self, **kwargs): 198 | """ 199 | Computes and stores pca for the metrics computation 200 | 201 | Parameters 202 | ---------- 203 | n_comp: int 204 | n_compFeatures in template-gui format 205 | ms_before: float 206 | Time period in ms to cut waveforms before the spike events 207 | ms_after: float 208 | Time period in ms to cut waveforms after the spike events 209 | dtype: dtype 210 | The numpy dtype of the waveforms 211 | max_spikes_per_unit: int 212 | The maximum number of spikes to extract per unit. 213 | recompute_info: bool 214 | If True, will always re-extract waveforms. 215 | max_spikes_for_pca: int 216 | The maximum number of spikes per unit to use to compute PCA. 217 | save_property_or_features: bool 218 | If true, it will save amplitudes in the sorting extractor. 219 | seed: int 220 | Random seed for reproducibility 221 | """ 222 | 223 | spike_times, spike_times_pca, spike_clusters, spike_clusters_pca, pc_features, \ 224 | pc_feature_ind = get_pca_metric_data(self._recording, self._sorting, **kwargs) 225 | self._pc_features = pc_features 226 | self._spike_clusters_pca = spike_clusters 227 | self._spike_times_pca = spike_times_pca 228 | self._spike_clusters_pca = spike_clusters_pca 229 | self._pc_feature_ind = pc_feature_ind 230 | 231 | def set_amplitudes(self, amplitudes): 232 | self._amplitudes = amplitudes 233 | 234 | def set_pc_features(self, pc_features): 235 | self._pc_features = pc_features 236 | 237 | def set_pc_feature_ind(self, pc_feature_ind): 238 | self._pc_feature_ind = pc_feature_ind 239 | 240 | def get_unit_ids(self): 241 | return self._unit_ids 242 | 243 | def _get_unit_indices(sorting, unit_ids): 244 | unit_indices = [] 245 | sorting_unit_ids = np.asarray(sorting.get_unit_ids()) 246 | for unit_id in unit_ids: 247 | (index,) = np.where(sorting_unit_ids == unit_id) 248 | if len(index) != 0: 249 | unit_indices.append(index[0]) 250 | return unit_indices 251 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | # 9 | class NearestNeighbor(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict( 13 | [('num_channels_to_compare', 13), ('max_spikes_per_cluster', 500), ('max_spikes_for_nn', 10000), 14 | ('n_neighbors', 4)]) 15 | curator_name = "ThresholdNearestNeighbors" 16 | 17 | def __init__(self, metric_data): 18 | QualityMetric.__init__(self, metric_data, metric_name="nearest_neighbor") 19 | 20 | if not metric_data.has_pca_scores(): 21 | raise ValueError("MetricData object must have pca scores") 22 | 23 | def compute_metric(self, num_channels_to_compare, max_spikes_per_cluster, max_spikes_for_nn, 24 | n_neighbors, **kwargs): 25 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 26 | save_property_or_features = params_dict['save_property_or_features'] 27 | seed = params_dict['seed'] 28 | total_spikes = self._metric_data._spike_clusters_pca.shape[0] 29 | spikes_for_nn = np.min([total_spikes, max_spikes_for_nn]) 30 | nn_hit_rates_all, nn_miss_rates_all = metrics.calculate_pc_metrics( 31 | spike_clusters=self._metric_data._spike_clusters_pca, 32 | total_units=self._metric_data._total_units, 33 | pc_features=self._metric_data._pc_features, 34 | pc_feature_ind=self._metric_data._pc_feature_ind, 35 | num_channels_to_compare=num_channels_to_compare, 36 | max_spikes_for_cluster=max_spikes_per_cluster, 37 | spikes_for_nn=spikes_for_nn, 38 | n_neighbors=n_neighbors, 39 | channel_locations=self._metric_data._channel_locations, 40 | metric_names=["nearest_neighbor"], 41 | seed=seed, 42 | spike_cluster_subset=self._metric_data._unit_indices, 43 | verbose=self._metric_data.verbose, 44 | )[3:5] 45 | nn_hit_rates_list = [] 46 | nn_miss_rates_list = [] 47 | for i in self._metric_data._unit_indices: 48 | nn_hit_rates_list.append(nn_hit_rates_all[i]) 49 | nn_miss_rates_list.append(nn_miss_rates_all[i]) 50 | nn_hit_rates = np.asarray(nn_hit_rates_list) 51 | nn_miss_rates = np.asarray(nn_miss_rates_list) 52 | if save_property_or_features: 53 | self.save_property_or_features(self._metric_data._sorting, nn_hit_rates, metric_name="nn_hit_rate") 54 | self.save_property_or_features(self._metric_data._sorting, nn_miss_rates, metric_name="nn_miss_rate") 55 | return [nn_hit_rates_list, nn_miss_rates] 56 | 57 | def threshold_metric(self, threshold, threshold_sign, metric_name, num_channels_to_compare, max_spikes_per_cluster, 58 | max_spikes_for_nn, n_neighbors, **kwargs): 59 | nn_hit_rates, nn_miss_rates = \ 60 | self.compute_metric(num_channels_to_compare, max_spikes_per_cluster, max_spikes_for_nn, 61 | n_neighbors, **kwargs) 62 | if metric_name == "nn_hit_rate": 63 | metric = nn_hit_rates 64 | elif metric_name == "nn_miss_rate": 65 | metric = nn_miss_rates 66 | else: 67 | raise ValueError("Invalid metric named entered") 68 | 69 | threshold_curator = ThresholdCurator( 70 | sorting=self._metric_data._sorting, metric=metric 71 | ) 72 | threshold_curator.threshold_sorting( 73 | threshold=threshold, threshold_sign=threshold_sign 74 | ) 75 | return threshold_curator 76 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/noise_overlap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import copy 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | import spiketoolkit as st 6 | import spikemetrics.metrics as metrics 7 | from spikemetrics.utils import printProgressBar 8 | from spikemetrics.metrics import find_neighboring_channels 9 | from collections import OrderedDict 10 | from sklearn.neighbors import NearestNeighbors 11 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 12 | 13 | 14 | class NoiseOverlap(QualityMetric): 15 | installed = True # check at class level if installed or not 16 | installation_mesg = "" # err 17 | params = OrderedDict([('num_channels_to_compare', 13), 18 | ('max_spikes_per_unit_for_noise_overlap', 1000), 19 | ('num_features', 10), 20 | ('num_knn', 6)]) 21 | curator_name = "ThresholdNoiseOverlaps" 22 | 23 | def __init__(self, metric_data): 24 | QualityMetric.__init__(self, metric_data, metric_name="noise_overlap") 25 | 26 | if not metric_data.has_recording(): 27 | raise ValueError("MetricData object must have a recording") 28 | 29 | def compute_metric(self, num_channels_to_compare, max_spikes_per_unit_for_noise_overlap, 30 | num_features, num_knn, **kwargs): 31 | 32 | # Make sure max_spikes_per_unit_for_noise_overlap is not None 33 | assert max_spikes_per_unit_for_noise_overlap is not None, "'max_spikes_per_unit_for_noise_overlap' must be an integer." 34 | 35 | # update keyword arg in case it's already specified to something 36 | kwargs['max_spikes_per_unit'] = max_spikes_per_unit_for_noise_overlap 37 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 38 | save_property_or_features = params_dict['save_property_or_features'] 39 | seed = params_dict['seed'] 40 | 41 | # set random seed 42 | if seed is not None: 43 | np.random.seed(seed) 44 | 45 | # first, get waveform snippets of every unit (at most n spikes) 46 | # waveforms = List (units,) of np.array (n_spikes, n_channels, n_timepoints) 47 | waveforms = st.postprocessing.get_unit_waveforms( 48 | self._metric_data._recording, 49 | self._metric_data._sorting, 50 | unit_ids=self._metric_data._unit_ids, 51 | **kwargs) 52 | 53 | n_waveforms_per_unit = np.array([len(wf) for wf in waveforms]) 54 | n_spikes_per_unit = np.array([len(self._metric_data._sorting.get_unit_spike_train(u)) for u in self._metric_data._unit_ids]) 55 | 56 | if np.all(n_waveforms_per_unit < max_spikes_per_unit_for_noise_overlap): 57 | # in this case it means that waveforms have been computed on 58 | # less spikes than max_spikes_per_unit_for_noise_overlap --> recompute 59 | kwargs['recompute_info'] = True 60 | waveforms = st.postprocessing.get_unit_waveforms( 61 | self._metric_data._recording, 62 | self._metric_data._sorting, 63 | unit_ids = self._metric_data._unit_ids, 64 | # max_spikes_per_unit = max_spikes_per_unit_for_noise_overlap, 65 | **kwargs) 66 | elif np.all(n_waveforms_per_unit >= max_spikes_per_unit_for_noise_overlap): 67 | # waveforms computed on more spikes than needed --> sample 68 | for i_w, wfs in enumerate(waveforms): 69 | if len(wfs) > max_spikes_per_unit_for_noise_overlap: 70 | selecte_idxs = np.random.permutation(len(wfs))[:max_spikes_per_unit_for_noise_overlap] 71 | waveforms[i_w] = wfs[selecte_idxs] 72 | 73 | # get channel idx and locations 74 | channel_idx = np.arange(self._metric_data._recording.get_num_channels()) 75 | channel_locations = self._metric_data._channel_locations 76 | 77 | if num_channels_to_compare > len(channel_idx): 78 | num_channels_to_compare = len(channel_idx) 79 | 80 | # get noise snippets 81 | min_time = min([self._metric_data._sorting.get_unit_spike_train(unit_id=unit)[0] 82 | for unit in self._metric_data._sorting.get_unit_ids()]) 83 | max_time = max([self._metric_data._sorting.get_unit_spike_train(unit_id=unit)[-1] 84 | for unit in self._metric_data._sorting.get_unit_ids()]) 85 | max_spikes = np.max([len(self._metric_data._sorting.get_unit_spike_train(u)) for u in self._metric_data._unit_ids]) 86 | if max_spikes < max_spikes_per_unit_for_noise_overlap: 87 | max_spikes_per_unit_for_noise_overlap = max_spikes 88 | times_control = np.random.choice(np.arange(min_time, max_time), 89 | size=max_spikes_per_unit_for_noise_overlap, replace=False) 90 | clip_size = waveforms[0].shape[-1] 91 | # np.array, (n_spikes, n_channels, n_timepoints) 92 | clips_control_max = np.stack(self._metric_data._recording.get_snippets(snippet_len=clip_size, 93 | reference_frames=times_control)) 94 | 95 | noise_overlaps = [] 96 | for i_u, unit in enumerate(self._metric_data._unit_ids): 97 | # show progress bar 98 | if self._metric_data.verbose: 99 | printProgressBar(i_u + 1, len(self._metric_data._unit_ids)) 100 | 101 | # get spike and noise snippets 102 | # np.array, (n_spikes, n_channels, n_timepoints) 103 | clips = waveforms[i_u] 104 | clips_control = clips_control_max 105 | 106 | # make noise snippets size equal to number of spikes 107 | if len(clips) < max_spikes_per_unit_for_noise_overlap: 108 | selected_idxs = np.random.choice(np.arange(max_spikes_per_unit_for_noise_overlap), 109 | size=len(clips), replace=False) 110 | clips_control = clips_control[selected_idxs] 111 | else: 112 | selected_idxs = np.random.choice(np.arange(len(clips)), 113 | size=max_spikes_per_unit_for_noise_overlap, 114 | replace=False) 115 | clips = clips[selected_idxs] 116 | 117 | num_clips = len(clips) 118 | 119 | # compute weight for correcting noise snippets 120 | template = np.median(clips, axis=0) 121 | chmax, tmax = np.unravel_index(np.argmax(np.abs(template)), template.shape) 122 | max_val = template[chmax, tmax] 123 | weighted_clips_control = np.zeros(clips_control.shape) 124 | weights = np.zeros(num_clips) 125 | for j in range(num_clips): 126 | clip0 = clips_control[j, :, :] 127 | val0 = clip0[chmax, tmax] 128 | weight0 = val0 * max_val 129 | weights[j] = weight0 130 | weighted_clips_control[j, :, :] = clip0 * weight0 131 | 132 | noise_template = np.sum(weighted_clips_control, axis=0) 133 | noise_template = noise_template / np.sum(np.abs(noise_template)) * np.sum(np.abs(template)) 134 | 135 | # subtract it out 136 | for j in range(num_clips): 137 | clips[j, :, :] = _subtract_clip_component(clips[j, :, :], noise_template) 138 | clips_control[j, :, :] = _subtract_clip_component(clips_control[j, :, :], noise_template) 139 | 140 | # use only subsets of channels that are closest to peak channel 141 | channels_to_use = find_neighboring_channels(chmax, channel_idx, 142 | num_channels_to_compare, channel_locations) 143 | channels_to_use = np.sort(channels_to_use) 144 | clips = clips[:,channels_to_use,:] 145 | clips_control = clips_control[:,channels_to_use,:] 146 | 147 | all_clips = np.concatenate([clips, clips_control], axis=0) 148 | num_channels_wfs = all_clips.shape[1] 149 | num_samples_wfs = all_clips.shape[2] 150 | all_features = _compute_pca_features(all_clips.reshape((num_clips * 2, 151 | num_channels_wfs * num_samples_wfs)), num_features) 152 | num_all_clips=len(all_clips) 153 | distances, indices = NearestNeighbors(n_neighbors=min(num_knn + 1, num_all_clips - 1), algorithm='auto').fit( 154 | all_features.T).kneighbors() 155 | 156 | group_id = np.zeros((num_clips * 2)) 157 | group_id[0:num_clips] = 1 158 | group_id[num_clips:] = 2 159 | num_match = 0 160 | total = 0 161 | for j in range(num_clips * 2): 162 | for k in range(1, min(num_knn + 1, num_all_clips - 1)): 163 | ind = indices[j][k] 164 | if group_id[j] == group_id[ind]: 165 | num_match = num_match + 1 166 | total = total + 1 167 | pct_match = num_match / total 168 | noise_overlap = 1 - pct_match 169 | noise_overlaps.append(noise_overlap) 170 | noise_overlaps = np.asarray(noise_overlaps) 171 | if save_property_or_features: 172 | self.save_property_or_features(self._metric_data._sorting, noise_overlaps, self._metric_name) 173 | return noise_overlaps 174 | 175 | def threshold_metric(self, threshold, threshold_sign, num_channels_to_compare, 176 | max_spikes_per_unit_for_noise_overlap, 177 | num_features, num_knn, **kwargs): 178 | noise_overlaps = self.compute_metric(num_channels_to_compare, 179 | max_spikes_per_unit_for_noise_overlap, 180 | num_features, num_knn, **kwargs) 181 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=noise_overlaps) 182 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 183 | return threshold_curator 184 | 185 | 186 | def _compute_pca_features(X, num_components): 187 | u, s, vt = np.linalg.svd(X) 188 | return u[:, :num_components].T 189 | 190 | 191 | def _subtract_clip_component(clip1, component): 192 | V1 = clip1.flatten() 193 | V2 = component.flatten() 194 | V1 = V1 - np.mean(V1) 195 | V2 = V2 - np.mean(V2) 196 | V1 = V1 - V2 * np.dot(V1, V2) / np.dot(V2, V2) 197 | return V1.reshape(clip1.shape) 198 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/num_spikes.py: -------------------------------------------------------------------------------- 1 | from .quality_metric import QualityMetric 2 | import numpy as np 3 | import spikemetrics.metrics as metrics 4 | from .utils.thresholdcurator import ThresholdCurator 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class NumSpikes(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict() 13 | curator_name = "ThresholdNumSpikes" 14 | 15 | def __init__( 16 | self, 17 | metric_data, 18 | ): 19 | QualityMetric.__init__(self, metric_data, metric_name="num_spikes") 20 | 21 | def compute_metric(self, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | num_spikes_all = metrics.calculate_num_spikes( 25 | self._metric_data._spike_times, 26 | self._metric_data._spike_clusters, 27 | self._metric_data._total_units, 28 | spike_cluster_subset=self._metric_data._unit_indices, 29 | verbose=self._metric_data.verbose, 30 | ) 31 | num_spikes_list = [] 32 | for i in self._metric_data._unit_indices: 33 | num_spikes_list.append(num_spikes_all[i]) 34 | num_spikes = np.asarray(num_spikes_list).astype('int') 35 | if save_property_or_features: 36 | self.save_property_or_features(self._metric_data._sorting, num_spikes, self._metric_name) 37 | return num_spikes 38 | 39 | def threshold_metric(self, threshold, threshold_sign, **kwargs): 40 | num_spikes = self.compute_metric(**kwargs) 41 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=num_spikes) 42 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 43 | return threshold_curator 44 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/parameter_dictionaries.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from spiketoolkit.postprocessing.utils import get_amplitudes_params, get_waveforms_params, \ 3 | get_common_params, get_pca_params 4 | import numpy as np 5 | 6 | recording_params_dict = OrderedDict([('apply_filter', True), ('freq_min', 300.0), ('freq_max', 6000.0)]) 7 | 8 | def get_recording_params(): 9 | return recording_params_dict.copy() 10 | 11 | def get_validation_params(): 12 | ''' 13 | Returns all available keyword argument params 14 | 15 | Returns 16 | ------- 17 | all_params: dict 18 | Dictionary with all available keyword arguments for validation and curation functions. 19 | ''' 20 | all_params = {} 21 | all_params.update(get_recording_params()) 22 | all_params.update(get_waveforms_params()) 23 | all_params.update(get_amplitudes_params()) 24 | all_params.update(get_pca_params()) 25 | all_params.update(get_common_params()) 26 | 27 | return all_params 28 | 29 | 30 | def update_all_param_dicts_with_kwargs(kwargs): 31 | all_params = get_validation_params() 32 | 33 | if np.any([k in all_params.keys() for k in kwargs.keys()]): 34 | for k in kwargs.keys(): 35 | if k in all_params.keys(): 36 | all_params[k] = kwargs[k] 37 | 38 | return all_params 39 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/presence_ratio.py: -------------------------------------------------------------------------------- 1 | from .quality_metric import QualityMetric 2 | import numpy as np 3 | import spikemetrics.metrics as metrics 4 | from .utils.thresholdcurator import ThresholdCurator 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class PresenceRatio(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict() 13 | curator_name = "ThresholdPresenceRatios" 14 | 15 | def __init__( 16 | self, 17 | metric_data, 18 | ): 19 | QualityMetric.__init__(self, metric_data, metric_name="presence_ratio") 20 | 21 | def compute_metric(self, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | presence_ratios_all = metrics.calculate_presence_ratio( 25 | self._metric_data._spike_times, 26 | self._metric_data._spike_clusters, 27 | self._metric_data._total_units, 28 | duration=self._metric_data._duration_in_frames/self._metric_data._sampling_frequency, 29 | spike_cluster_subset=self._metric_data._unit_indices, 30 | verbose=self._metric_data.verbose, 31 | ) 32 | presence_ratios_list = [] 33 | for i in self._metric_data._unit_indices: 34 | presence_ratios_list.append(presence_ratios_all[i]) 35 | presence_ratios = np.asarray(presence_ratios_list) 36 | if save_property_or_features: 37 | self.save_property_or_features(self._metric_data._sorting, presence_ratios, self._metric_name) 38 | return presence_ratios 39 | 40 | def threshold_metric(self, threshold, threshold_sign, **kwargs): 41 | presence_ratios = self.compute_metric(**kwargs) 42 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=presence_ratios) 43 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 44 | return threshold_curator 45 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/quality_metric.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | # Baseclass for each quality metric 5 | 6 | class QualityMetric(ABC): 7 | def __init__( 8 | self, 9 | metric_data, 10 | metric_name 11 | ): 12 | ''' 13 | Parameters 14 | ---------- 15 | metric_data: MetricData 16 | An object for storing and computing preprocessed data 17 | ''' 18 | self._metric_data = metric_data 19 | self._metric_name = metric_name 20 | 21 | # implemented by quality metric subclasses 22 | @abstractmethod 23 | def compute_metric(self, **kwargs): 24 | pass 25 | 26 | @abstractmethod 27 | def threshold_metric(self, threshold, threshold_sign, **kwargs): 28 | ''' 29 | Parameters 30 | ---------- 31 | threshold: int or float 32 | The threshold for the given metric. 33 | threshold_sign: str 34 | If 'less', will threshold any metric less than the given threshold. 35 | If 'less_or_equal', will threshold any metric less than or equal to the given threshold. 36 | If 'greater', will threshold any metric greater than the given threshold. 37 | If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold. 38 | Returns 39 | ------- 40 | tc: ThresholdCurator 41 | The thresholded sorting extractor. 42 | ''' 43 | pass 44 | 45 | def save_property_or_features(self, sorting, metric, metric_name): 46 | for i_u, u in enumerate(self._metric_data._unit_ids): 47 | sorting.set_unit_property(u, metric_name, metric[i_u]) 48 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/silhouette_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import spikemetrics.metrics as metrics 3 | from .utils.thresholdcurator import ThresholdCurator 4 | from .quality_metric import QualityMetric 5 | from collections import OrderedDict 6 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 7 | 8 | 9 | class SilhouetteScore(QualityMetric): 10 | installed = True # check at class level if installed or not 11 | installation_mesg = "" # err 12 | params = OrderedDict([('max_spikes_for_silhouette', 10000)]) 13 | curator_name = "ThresholdSilhouetteScores" 14 | 15 | def __init__(self, metric_data): 16 | QualityMetric.__init__(self, metric_data, metric_name="silhouette_score") 17 | 18 | if not metric_data.has_pca_scores(): 19 | raise ValueError("MetricData object must have pca scores") 20 | 21 | def compute_metric(self, max_spikes_for_silhouette, **kwargs): 22 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 23 | save_property_or_features = params_dict['save_property_or_features'] 24 | seed = params_dict['seed'] 25 | total_spikes = self._metric_data._spike_clusters_pca.shape[0] 26 | spikes_for_silhouette = np.min([total_spikes, max_spikes_for_silhouette]) 27 | silhouette_scores_all = metrics.calculate_silhouette_score( 28 | self._metric_data._spike_clusters_pca, 29 | self._metric_data._total_units, 30 | self._metric_data._pc_features, 31 | self._metric_data._pc_feature_ind, 32 | spikes_for_silhouette, 33 | seed=seed, 34 | spike_cluster_subset=self._metric_data._unit_indices, 35 | verbose=self._metric_data.verbose, 36 | ) 37 | silhouette_scores_list = [] 38 | for index in self._metric_data._unit_indices: 39 | silhouette_scores_list.append(silhouette_scores_all[index]) 40 | silhouette_scores = np.asarray(silhouette_scores_list) 41 | if save_property_or_features: 42 | self.save_property_or_features(self._metric_data._sorting, silhouette_scores, self._metric_name) 43 | return silhouette_scores 44 | 45 | def threshold_metric(self, threshold, threshold_sign, max_spikes_for_silhouette, **kwargs): 46 | silhouette_scores = \ 47 | self.compute_metric(max_spikes_for_silhouette, **kwargs) 48 | threshold_curator = ThresholdCurator( 49 | sorting=self._metric_data._sorting, metric=silhouette_scores 50 | ) 51 | threshold_curator.threshold_sorting( 52 | threshold=threshold, threshold_sign=threshold_sign 53 | ) 54 | return threshold_curator 55 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/snr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .utils.thresholdcurator import ThresholdCurator 3 | from .quality_metric import QualityMetric 4 | import spiketoolkit as st 5 | from spikemetrics.utils import printProgressBar 6 | from collections import OrderedDict 7 | from .parameter_dictionaries import update_all_param_dicts_with_kwargs 8 | 9 | 10 | class SNR(QualityMetric): 11 | installed = True # check at class level if installed or not 12 | installation_mesg = "" # err 13 | params = OrderedDict([('snr_mode', "mad"), ('snr_noise_duration', 10.0), ('max_spikes_per_unit_for_snr', 1000), 14 | ('template_mode', "median"), ('max_channel_peak', "both")]) 15 | curator_name = "ThresholdSNRs" 16 | 17 | def __init__(self, metric_data): 18 | QualityMetric.__init__(self, metric_data, metric_name="snr") 19 | 20 | if not metric_data.has_recording(): 21 | raise ValueError("MetricData object must have a recording") 22 | 23 | def compute_metric(self, snr_mode, snr_noise_duration, max_spikes_per_unit_for_snr, 24 | template_mode, max_channel_peak, **kwargs): 25 | params_dict = update_all_param_dicts_with_kwargs(kwargs) 26 | save_property_or_features = params_dict['save_property_or_features'] 27 | seed = params_dict['seed'] 28 | channel_noise_levels = _compute_channel_noise_levels( 29 | recording=self._metric_data._recording, 30 | mode=snr_mode, 31 | noise_duration=snr_noise_duration, 32 | seed=seed, 33 | ) 34 | templates = st.postprocessing.get_unit_templates( 35 | self._metric_data._recording, 36 | self._metric_data._sorting, 37 | unit_ids=self._metric_data._unit_ids, 38 | max_spikes_per_unit=max_spikes_per_unit_for_snr, 39 | mode=template_mode, **kwargs 40 | ) 41 | max_channels = st.postprocessing.get_unit_max_channels( 42 | self._metric_data._recording, 43 | self._metric_data._sorting, 44 | unit_ids=self._metric_data._unit_ids, 45 | max_spikes_per_unit=max_spikes_per_unit_for_snr, 46 | peak=max_channel_peak, 47 | mode=template_mode, **kwargs 48 | ) 49 | snr_list = [] 50 | for i, unit_id in enumerate(self._metric_data._unit_ids): 51 | if self._metric_data.verbose: 52 | printProgressBar(i + 1, len(self._metric_data._unit_ids)) 53 | max_channel_idx = self._metric_data._recording.get_channel_ids().index(max_channels[i]) 54 | snr = _compute_template_SNR(templates[i], channel_noise_levels, max_channel_idx) 55 | snr_list.append(snr) 56 | snrs = np.asarray(snr_list) 57 | if save_property_or_features: 58 | self.save_property_or_features(self._metric_data._sorting, snrs, self._metric_name) 59 | return snrs 60 | 61 | def threshold_metric(self, threshold, threshold_sign, snr_mode, snr_noise_duration, max_spikes_per_unit_for_snr, 62 | template_mode, max_channel_peak, **kwargs): 63 | snrs = self.compute_metric(snr_mode, snr_noise_duration, max_spikes_per_unit_for_snr, 64 | template_mode, max_channel_peak, **kwargs) 65 | threshold_curator = ThresholdCurator(sorting=self._metric_data._sorting, metric=snrs) 66 | threshold_curator.threshold_sorting(threshold=threshold, threshold_sign=threshold_sign) 67 | return threshold_curator 68 | 69 | 70 | def _compute_template_SNR(template, channel_noise_levels, max_channel_idx): 71 | """ 72 | Computes SNR on the channel with largest amplitude 73 | 74 | Parameters 75 | ---------- 76 | template: np.array 77 | Template (n_elec, n_timepoints) 78 | channel_noise_levels: list 79 | Noise levels for the different channels 80 | max_channel_idx: int 81 | Index of channel with largest templaye 82 | 83 | Returns 84 | ------- 85 | snr: float 86 | Signal-to-noise ratio for the template 87 | """ 88 | max_template_channel = np.unravel_index(np.argmax(np.abs(template)), template.shape)[0] 89 | snr = np.max(np.abs(template[max_template_channel])) / channel_noise_levels[max_channel_idx] 90 | return snr 91 | 92 | 93 | def _compute_channel_noise_levels(recording, mode, noise_duration, seed): 94 | """ 95 | Computes noise level channel-wise 96 | 97 | Parameters 98 | ---------- 99 | recording: RecordingExtractor 100 | The recording ectractor object 101 | mode: str 102 | 'std' or 'mad' (default 'mad') 103 | noise_duration: float 104 | Number of seconds to compute SNR from 105 | 106 | Returns 107 | ------- 108 | moise_levels: list 109 | Noise levels for each channel 110 | """ 111 | M = recording.get_num_channels() 112 | n_frames = int(noise_duration * recording.get_sampling_frequency()) 113 | 114 | if n_frames >= recording.get_num_frames(): 115 | start_frame = 0 116 | end_frame = recording.get_num_frames() 117 | else: 118 | start_frame = np.random.RandomState(seed=seed).randint(0, recording.get_num_frames() - n_frames) 119 | end_frame = start_frame + n_frames 120 | 121 | X = recording.get_traces(start_frame=start_frame, end_frame=end_frame) 122 | if mode == "std": 123 | noise_levels = np.std(X, 1) 124 | elif mode == "mad": 125 | noise_levels = np.median(np.abs(X - np.median(X, 1, keepdims=True)) / 0.6745, 1) 126 | else: 127 | raise Exception("'mode' can be 'std' or 'mad'") 128 | return noise_levels 129 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeInterface/spiketoolkit/d90954388400b55e53cc9cc79dbb8a63a5e39c42/spiketoolkit/validation/quality_metric_classes/utils/__init__.py -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/utils/thresholdcurator.py: -------------------------------------------------------------------------------- 1 | from .curationsortingextractor import CurationSortingExtractor 2 | 3 | 4 | class ThresholdCurator(CurationSortingExtractor): 5 | def __init__(self, sorting, metric, threshold=None, threshold_sign=None): 6 | ''' 7 | Parent class for all threshold-based curators. 8 | 9 | Parameters 10 | ---------- 11 | sorting: SortingExtractor 12 | The sorting result to be evaluated. 13 | metrics: np.array 14 | The metric to be thresholded. 15 | ''' 16 | CurationSortingExtractor.__init__(self, parent_sorting=sorting) 17 | self._metric = metric 18 | self._threshold = threshold 19 | self._threshold_sign = threshold_sign 20 | # bypass dumping mechanism of CurationSortingExtractor 21 | del self._kwargs 22 | self._kwargs = {'sorting': sorting.make_serialized_dict(), 'metric': metric, 23 | 'threshold': threshold, 'threshold_sign': threshold_sign} 24 | if threshold is not None and threshold_sign is not None: 25 | self.threshold_sorting(threshold, threshold_sign) 26 | 27 | def threshold_sorting(self, threshold, threshold_sign): 28 | ''' 29 | Parameters 30 | ---------- 31 | threshold: 32 | The threshold for the given metric. 33 | threshold_sign: str 34 | If 'less', will threshold any metric less than the given threshold. 35 | If 'less_or_equal', will threshold any metric less than or equal to the given threshold. 36 | If 'greater', will threshold any metric greater than the given threshold. 37 | If 'greater_or_equal', will threshold any metric greater than or equal to the given threshold. 38 | ''' 39 | units_to_be_excluded = [] 40 | for i, unit_id in enumerate(self._parent_sorting.get_unit_ids()): 41 | if threshold_sign == 'less': 42 | if self._metric[i] < threshold: 43 | units_to_be_excluded.append(unit_id) 44 | elif threshold_sign == 'less_or_equal': 45 | if self._metric[i] <= threshold: 46 | units_to_be_excluded.append(unit_id) 47 | elif threshold_sign == 'greater': 48 | if self._metric[i] > threshold: 49 | units_to_be_excluded.append(unit_id) 50 | elif threshold_sign == 'greater_or_equal': 51 | if self._metric[i] >= threshold: 52 | units_to_be_excluded.append(unit_id) 53 | else: 54 | raise ValueError('Not a correct threshold sign.') 55 | self.exclude_units(units_to_be_excluded) 56 | # bypass dumping mechanism of CurationSortingExtractor 57 | if 'curation_steps' in self._kwargs.keys(): 58 | del self._kwargs['curation_steps'] 59 | self._kwargs['threshold'] = threshold 60 | self._kwargs['threshold_sign'] = threshold_sign 61 | -------------------------------------------------------------------------------- /spiketoolkit/validation/quality_metric_classes/utils/validation_tools.py: -------------------------------------------------------------------------------- 1 | from spiketoolkit.postprocessing.postprocessing_tools import _get_quality_metric_data, _get_pca_metric_data, \ 2 | _get_spike_times_clusters, _get_amp_metric_data 3 | import spikeextractors as se 4 | import numpy as np 5 | 6 | 7 | def get_spike_times_metrics_data(sorting, sampling_frequency): 8 | ''' 9 | Computes and returns the spike times in seconds and also returns 10 | along with cluster_ids needed for quality metrics 11 | 12 | Parameters 13 | ---------- 14 | sorting: SortingExtractor 15 | The sorting extractor 16 | sampling_frequency: float 17 | The sampling frequency of the recording 18 | 19 | Returns 20 | ------- 21 | spike_times: numpy.ndarray (num_spikes x 0) 22 | Spike times in seconds 23 | spike_clusters: numpy.ndarray (num_spikes x 0) 24 | Cluster IDs for each spike time 25 | ''' 26 | if not isinstance(sorting, se.SortingExtractor): 27 | raise AttributeError() 28 | if len(sorting.get_unit_ids()) == 0: 29 | raise Exception("No units in the sorting result, can't compute any metric information.") 30 | 31 | # spike times.npy and spike clusters.npy 32 | spike_times, spike_clusters = _get_spike_times_clusters(sorting) 33 | 34 | spike_times = np.squeeze((spike_times / sampling_frequency)) 35 | spike_clusters = np.squeeze(spike_clusters.astype(int)) 36 | 37 | return spike_times, spike_clusters 38 | 39 | 40 | def get_pca_metric_data(recording, sorting, **kwargs): 41 | ''' 42 | Computes and returns all data needed to compute all the quality metrics from SpikeMetrics 43 | 44 | Parameters 45 | ---------- 46 | recording: RecordingExtractor 47 | The recording extractor 48 | sorting: SortingExtractor 49 | The sorting extractor 50 | n_comp: int 51 | n_compFeatures in template-gui format 52 | recompute_info: bool 53 | If True, will always re-extract waveforms 54 | save_property_or_features: bool 55 | If True, save all features and properties in the sorting extractor 56 | verbose: bool 57 | If True output is verbose 58 | **wf_args: Keyword arguments 59 | Keyword arguments for waveforms. A dictionary with default values can be retrieved with: 60 | st.postprocessing.get_waveforms_params(): 61 | grouping_property: str 62 | Property to group channels. E.g. if the recording extractor has the 'group' property and 63 | 'grouping_property' is 'group', then waveforms are computed group-wise. 64 | ms_before: float 65 | Time period in ms to cut waveforms before the spike events 66 | ms_after: float 67 | Time period in ms to cut waveforms after the spike events 68 | dtype: dtype 69 | The numpy dtype of the waveforms 70 | max_spikes_per_unit: int 71 | The maximum number of spikes to extract per unit. 72 | compute_property_from_recording: bool 73 | If True and 'grouping_property' is given, the property of each unit is assigned as the corresponding 74 | property of the recording extractor channel on which the average waveform is the largest 75 | seed: int 76 | Random seed for extracting random waveforms 77 | n_jobs: int 78 | Number of parallel jobs (default 1) 79 | memmap: bool 80 | If True, waveforms are saved as memmap object (recommended for long recordings with many channels) 81 | max_channels_per_waveforms: int or None 82 | Maximum channels per waveforms to return. If None, all channels are returned 83 | 84 | Returns 85 | ------- 86 | spike_times: numpy.ndarray (num_spikes x 0) 87 | Spike times in seconds 88 | spike_clusters: numpy.ndarray (num_spikes x 0) 89 | Cluster IDs for each spike time 90 | pc_features: numpy.ndarray (num_spikes x num_pcs x num_channels) 91 | Pre-computed PCs for blocks of channels around each spike 92 | pc_feature_ind: numpy.ndarray (num_units x num_channels) 93 | Channel indices of PCs for each unit 94 | ''' 95 | if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor): 96 | raise AttributeError() 97 | if len(sorting.get_unit_ids()) == 0: 98 | raise Exception("No units in the sorting result, can't compute any metric information.") 99 | 100 | spike_times, spike_times_pca, spike_clusters, \ 101 | spike_clusters_pca, pc_features, pc_feature_ind = _get_pca_metric_data(recording, sorting, **kwargs) 102 | 103 | return np.squeeze(recording.frame_to_time(spike_times)), np.squeeze(recording.frame_to_time(spike_times_pca)),\ 104 | np.squeeze(spike_clusters), np.squeeze(spike_clusters_pca), pc_features, pc_feature_ind 105 | 106 | 107 | def get_amplitude_metric_data(recording, sorting, **kwargs): 108 | ''' 109 | Computes and returns all data needed to compute all the quality metrics from SpikeMetrics 110 | 111 | Parameters 112 | ---------- 113 | recording: RecordingExtractor 114 | The recording extractor 115 | sorting: SortingExtractor 116 | The sorting extractor 117 | **kwargs: Keyword arguments 118 | Keyword arguments for amplitudes. A dictionary with default values can be retrieved with: 119 | st.postprocessing.get_amplitude_params(): 120 | method: str 121 | If 'absolute' (default), amplitudes are absolute amplitudes in uV are returned. 122 | If 'relative', amplitudes are returned as ratios between waveform amplitudes and template amplitudes. 123 | peak: str 124 | If maximum channel has to be found among negative peaks ('neg'), positive ('pos') or 125 | both ('both' - default) 126 | frames_before: int 127 | Frames before peak to compute amplitude 128 | frames_after: float 129 | Frames after peak to compute amplitude 130 | max_spikes_per_unit: int 131 | The maximum number of amplitudes to extract for each unit(default is np.inf). If less than np.inf, 132 | the amplitudes will be returned from a random permutation of the spikes. 133 | recompute_info: bool 134 | If True, will always re-extract waveforms 135 | save_property_or_features: bool 136 | If True, save all features and properties in the sorting extractor 137 | seed: int 138 | Random seed for reproducibility 139 | memmap: bool 140 | If True, amplitudes are saved as memmap object (recommended for long recordings with many channels) 141 | 142 | Returns 143 | ------- 144 | spike_times: numpy.ndarray (num_spikes x 0) 145 | Spike times in seconds 146 | spike_clusters: numpy.ndarray (num_spikes x 0) 147 | Cluster IDs for each spike time 148 | amplitudes: numpy.ndarray (num_spikes x 0) 149 | Amplitude value for each spike time 150 | ''' 151 | if not isinstance(recording, se.RecordingExtractor) or not isinstance(sorting, se.SortingExtractor): 152 | raise AttributeError() 153 | if len(sorting.get_unit_ids()) == 0: 154 | raise Exception("No units in the sorting result, can't compute any metric information.") 155 | 156 | spike_times, spike_times_amp, spike_clusters, \ 157 | spike_clusters_amp, amplitudes = _get_amp_metric_data(recording, sorting, **kwargs) 158 | 159 | return np.squeeze(recording.frame_to_time(spike_times)), np.squeeze(recording.frame_to_time(spike_times_amp)),\ 160 | np.squeeze(spike_clusters), np.squeeze(spike_clusters_amp), np.squeeze(amplitudes) 161 | -------------------------------------------------------------------------------- /spiketoolkit/validation/validation_list.py: -------------------------------------------------------------------------------- 1 | from .quality_metrics import ( 2 | compute_num_spikes, 3 | compute_firing_rates, 4 | compute_presence_ratios, 5 | compute_isi_violations, 6 | compute_amplitude_cutoffs, 7 | compute_snrs, 8 | compute_drift_metrics, 9 | compute_silhouette_scores, 10 | compute_isolation_distances, 11 | compute_l_ratios, 12 | compute_d_primes, 13 | compute_nn_metrics, 14 | compute_noise_overlaps, 15 | compute_quality_metrics, 16 | get_quality_metrics_list 17 | ) 18 | 19 | from .quality_metric_classes.utils.validation_tools import get_pca_metric_data, \ 20 | get_amplitude_metric_data, get_spike_times_metrics_data 21 | 22 | from .quality_metric_classes.parameter_dictionaries import get_validation_params 23 | -------------------------------------------------------------------------------- /spiketoolkit/version.py: -------------------------------------------------------------------------------- 1 | version = '0.7.7' 2 | --------------------------------------------------------------------------------