├── .gitignore ├── .travis.yml ├── LICENSE.md ├── MANIFEST.in ├── Makefile ├── README.md ├── environment.yml ├── phylib ├── __init__.py ├── conftest.py ├── electrode │ ├── __init__.py │ ├── mea.py │ ├── probes │ │ └── 1x32_buzsaki.prb │ └── tests │ │ ├── __init__.py │ │ └── test_mea.py ├── io │ ├── __init__.py │ ├── alf.py │ ├── array.py │ ├── datasets.py │ ├── merge.py │ ├── mock.py │ ├── model.py │ ├── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_alf.py │ │ ├── test_array.py │ │ ├── test_datasets.py │ │ ├── test_merge.py │ │ ├── test_mock.py │ │ ├── test_model.py │ │ └── test_traces.py │ └── traces.py ├── stats │ ├── __init__.py │ ├── ccg.py │ ├── clusters.py │ └── tests │ │ ├── __init__.py │ │ ├── test_ccg.py │ │ └── test_clusters.py └── utils │ ├── __init__.py │ ├── _misc.py │ ├── _types.py │ ├── event.py │ ├── geometry.py │ ├── testing.py │ └── tests │ ├── __init__.py │ ├── test_event.py │ ├── test_geometry.py │ ├── test_misc.py │ ├── test_testing.py │ └── test_types.py ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmpdir 2 | .cache 3 | .*fuse* 4 | *.orig 5 | *.log* 6 | .eggs 7 | .profile 8 | .idea 9 | __pycache__ 10 | .pytest_cache 11 | _old 12 | *.py[cod] 13 | *~ 14 | *# 15 | .#* 16 | *! 17 | .coverage* 18 | *credentials 19 | tags 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Packages 25 | *.egg 26 | *.egg-info 27 | build 28 | dist 29 | eggs 30 | parts 31 | bin 32 | var 33 | sdist 34 | develop-eggs 35 | lib 36 | lib64 37 | 38 | # Sphinx 39 | docs/_build 40 | 41 | .vscode 42 | *.code-workspace 43 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: xenial 3 | sudo: false 4 | python: 5 | - "3.7" 6 | before_install: 7 | - pip install codecov 8 | install: 9 | - wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 10 | - bash miniconda.sh -b -p $HOME/miniconda 11 | - export PATH="$HOME/miniconda/bin:$PATH" 12 | - hash -r 13 | - conda config --set always_yes yes --set changeps1 no 14 | - conda update -q conda 15 | - conda info -a 16 | # Create the environment. 17 | - conda install pyyaml 18 | - conda env create python=3.7 19 | - source activate phylib 20 | - conda install pyqt 21 | # Dev requirements 22 | - pip install -r requirements.txt 23 | - pip install -r requirements-dev.txt 24 | - pip install -e . 25 | script: 26 | - make test 27 | after_success: 28 | - codecov 29 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Cortex-lab 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of phylib nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | recursive-include tests * 5 | recursive-include phylib/electrode/probes *.prb 6 | recursive-exclude * __pycache__ 7 | recursive-exclude * *.py[co] 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean-build: 2 | rm -fr build/ 3 | rm -fr dist/ 4 | rm -fr *.egg-info 5 | 6 | clean-pyc: 7 | find . -name '*.pyc' -exec rm -f {} + 8 | find . -name '*.pyo' -exec rm -f {} + 9 | find . -name '*~' -exec rm -f {} + 10 | find . -name '__pycache__' -exec rm -fr {} + 11 | 12 | clean: clean-build clean-pyc 13 | 14 | lint: 15 | flake8 phylib 16 | 17 | test: lint 18 | py.test --cov-report term-missing --cov=phylib phylib 19 | 20 | coverage: 21 | coverage --html 22 | 23 | apidoc: 24 | python tools/api.py 25 | 26 | build: 27 | python setup.py sdist --formats=zip 28 | 29 | upload: 30 | python setup.py sdist --formats=zip upload 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # phylib 2 | [![Build Status](https://img.shields.io/travis/cortex-lab/phylib.svg)](https://travis-ci.org/cortex-lab/phylib) 3 | [![codecov.io](https://img.shields.io/codecov/c/github/cortex-lab/phylib.svg)](http://codecov.io/github/cortex-lab/phylib?branch=master) 4 | 5 | Electrophysiological data analysis library used by [phy](https://github.com/kwikteam/phy/), a spike sorting visualization software, and [ibllib](https://github.com/int-brain-lab/ibllib/). 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: phylib 2 | dependencies: 3 | - python=3.7 4 | - pip 5 | - numpy 6 | - scipy 7 | - dask 8 | - pip: 9 | - requests 10 | - tqdm 11 | - joblib 12 | -------------------------------------------------------------------------------- /phylib/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Utilities for large-scale ephys data analysis.""" 5 | 6 | 7 | #------------------------------------------------------------------------------ 8 | # Imports 9 | #------------------------------------------------------------------------------ 10 | 11 | import atexit 12 | import logging 13 | import os.path as op 14 | 15 | from .utils._misc import _git_version 16 | from .utils.event import connect, unconnect, emit 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Global variables and functions 21 | #------------------------------------------------------------------------------ 22 | 23 | __author__ = 'Cyrille Rossant' 24 | __email__ = 'cyrille.rossant at gmail.com' 25 | __version__ = '2.6.0' 26 | __version_git__ = __version__ + _git_version() 27 | 28 | 29 | # Set a null handler on the root logger 30 | logger = logging.getLogger('phylib') 31 | logger.setLevel(logging.DEBUG) 32 | logger.addHandler(logging.NullHandler()) 33 | 34 | 35 | _logger_fmt = '%(asctime)s.%(msecs)03d [%(levelname)s] %(caller)s %(message)s' 36 | _logger_date_fmt = '%H:%M:%S' 37 | 38 | 39 | class _Formatter(logging.Formatter): 40 | color_codes = {'L': '94', 'D': '90', 'I': '0', 'W': '33', 'E': '31'} 41 | 42 | def format(self, record): 43 | # Only keep the first character in the level name. 44 | record.levelname = record.levelname[0] 45 | filename = op.splitext(op.basename(record.pathname))[0] 46 | record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(20) 47 | message = super(_Formatter, self).format(record) 48 | color_code = self.color_codes.get(record.levelname, '90') 49 | message = '\33[%sm%s\33[0m' % (color_code, message) 50 | return message 51 | 52 | 53 | def add_default_handler(level='INFO', logger=logger): 54 | handler = logging.StreamHandler() 55 | handler.setLevel(level) 56 | 57 | formatter = _Formatter(fmt=_logger_fmt, datefmt=_logger_date_fmt) 58 | handler.setFormatter(formatter) 59 | 60 | logger.addHandler(handler) 61 | 62 | 63 | def _add_log_file(filename): # pragma: no cover 64 | """Create a log file with DEBUG level.""" 65 | handler = logging.FileHandler(str(filename)) 66 | handler.setLevel(logging.DEBUG) 67 | formatter = _Formatter(fmt=_logger_fmt, datefmt=_logger_date_fmt) 68 | handler.setFormatter(formatter) 69 | logging.getLogger('phy').addHandler(handler) 70 | 71 | 72 | @atexit.register 73 | def on_exit(): # pragma: no cover 74 | # Close the logging handlers. 75 | for handler in logger.handlers: 76 | handler.close() 77 | logger.removeHandler(handler) 78 | 79 | 80 | def test(): # pragma: no cover 81 | """Run the full testing suite of phylib.""" 82 | import pytest 83 | pytest.main() 84 | -------------------------------------------------------------------------------- /phylib/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """py.test utilities.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | import os 11 | from pathlib import Path 12 | import tempfile 13 | import shutil 14 | import warnings 15 | 16 | import numpy as np 17 | from pytest import fixture 18 | 19 | from phylib import add_default_handler 20 | 21 | 22 | #------------------------------------------------------------------------------ 23 | # Common fixtures 24 | #------------------------------------------------------------------------------ 25 | 26 | logger = logging.getLogger('phylib') 27 | logger.setLevel(10) 28 | add_default_handler(5, logger=logger) 29 | 30 | # Fix the random seed in the tests. 31 | np.random.seed(2015) 32 | 33 | warnings.filterwarnings("ignore", message="numpy.dtype size changed") 34 | warnings.filterwarnings("ignore", message="numpy.ufunc size changed") 35 | 36 | 37 | @fixture 38 | def tempdir(): 39 | curdir = os.getcwd() 40 | tempdir = tempfile.mkdtemp() 41 | os.chdir(tempdir) 42 | yield Path(tempdir) 43 | os.chdir(curdir) 44 | shutil.rmtree(tempdir) 45 | -------------------------------------------------------------------------------- /phylib/electrode/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Electrodes.""" 5 | 6 | from .mea import MEA, load_probe 7 | -------------------------------------------------------------------------------- /phylib/electrode/mea.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Multi-electrode arrays.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import itertools 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | 15 | from phylib.utils._types import _as_array 16 | from phylib.utils._misc import read_python 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # PRB file utilities 21 | #------------------------------------------------------------------------------ 22 | 23 | def _edges_to_adjacency_list(edges): 24 | """Convert a list of edges into an adjacency list.""" 25 | adj = {} 26 | for i, j in edges: 27 | if i in adj: # pragma: no cover 28 | ni = adj[i] 29 | else: 30 | ni = adj[i] = set() 31 | if j in adj: 32 | nj = adj[j] 33 | else: 34 | nj = adj[j] = set() 35 | ni.add(j) 36 | nj.add(i) 37 | return adj 38 | 39 | 40 | def _adjacency_subset(adjacency, subset): 41 | return {c: [v for v in vals if v in subset] 42 | for (c, vals) in adjacency.items() if c in subset} 43 | 44 | 45 | def _remap_adjacency(adjacency, mapping): 46 | remapped = {} 47 | for key, vals in adjacency.items(): 48 | remapped[mapping[key]] = [mapping[i] for i in vals] 49 | return remapped 50 | 51 | 52 | def _probe_positions(probe, group): 53 | """Return the positions of a probe channel group.""" 54 | positions = probe['channel_groups'][group]['geometry'] 55 | channels = _probe_channels(probe, group) 56 | return np.array([positions[channel] for channel in channels]) 57 | 58 | 59 | def _probe_channels(probe, group): 60 | """Return the list of channels in a channel group. 61 | 62 | The order is kept. 63 | 64 | """ 65 | return probe['channel_groups'][group]['channels'] 66 | 67 | 68 | def _probe_adjacency_list(probe): 69 | """Return an adjacency list of a whole probe.""" 70 | cgs = probe['channel_groups'].values() 71 | graphs = [cg['graph'] for cg in cgs] 72 | edges = list(itertools.chain(*graphs)) 73 | adjacency_list = _edges_to_adjacency_list(edges) 74 | return adjacency_list 75 | 76 | 77 | def _channels_per_group(probe): 78 | groups = probe['channel_groups'].keys() 79 | return {group: probe['channel_groups'][group]['channels'] 80 | for group in groups} 81 | 82 | 83 | def load_probe(name_or_path): 84 | """Load one of the built-in probes.""" 85 | path = Path(name_or_path) 86 | # The argument can be either a path to a PRB file or the name of a built-in probe.. 87 | if not path.exists(): 88 | path = Path(__file__).parent / ('probes/%s.prb' % name_or_path) 89 | if not path.exists(): 90 | raise IOError("The probe `{}` cannot be found.".format(name_or_path)) 91 | return MEA(probe=read_python(path)) 92 | 93 | 94 | def list_probes(): 95 | """Return the list of built-in probes.""" 96 | path = Path(__file__).parent / 'probes' 97 | return [fn.stem for fn in path.glob('*.prb')] 98 | 99 | 100 | #------------------------------------------------------------------------------ 101 | # MEA class 102 | #------------------------------------------------------------------------------ 103 | 104 | class MEA(object): 105 | """A Multi-Electrode Array. 106 | 107 | There are two modes: 108 | 109 | * No probe specified: one single channel group, positions and adjacency 110 | list specified directly. 111 | * Probe specified: one can change the current channel_group. 112 | 113 | """ 114 | 115 | def __init__(self, 116 | channels=None, 117 | positions=None, 118 | adjacency=None, 119 | probe=None, 120 | ): 121 | self._probe = probe 122 | self._channels = channels 123 | self._check_positions(positions) 124 | self._positions = positions 125 | # This is a mapping {channel: list of neighbors}. 126 | if adjacency is None and probe is not None: 127 | adjacency = _probe_adjacency_list(probe) 128 | self.channels_per_group = _channels_per_group(probe) 129 | self._adjacency = adjacency 130 | if probe: 131 | # Select the first channel group. 132 | cg = sorted(self._probe['channel_groups'].keys())[0] 133 | self.change_channel_group(cg) 134 | 135 | def _check_positions(self, positions): 136 | if positions is None: 137 | return 138 | positions = _as_array(positions) 139 | if positions.shape[0] != self.n_channels: 140 | raise ValueError("'positions' " 141 | "(shape {0:s})".format(str(positions.shape)) + 142 | " and 'n_channels' " 143 | "({0:d})".format(self.n_channels) + 144 | " do not match.") 145 | 146 | @property 147 | def positions(self): 148 | """Channel positions in the current channel group.""" 149 | return self._positions 150 | 151 | @property 152 | def channels(self): 153 | """Channel ids in the current channel group.""" 154 | return self._channels 155 | 156 | @property 157 | def n_channels(self): 158 | """Number of channels in the current channel group.""" 159 | return len(self._channels) if self._channels is not None else 0 160 | 161 | @property 162 | def adjacency(self): 163 | """Adjacency graph in the current channel group.""" 164 | return self._adjacency 165 | 166 | def change_channel_group(self, group): 167 | """Change the current channel group.""" 168 | assert self._probe is not None 169 | self._channels = _probe_channels(self._probe, group) 170 | self._positions = _probe_positions(self._probe, group) 171 | -------------------------------------------------------------------------------- /phylib/electrode/probes/1x32_buzsaki.prb: -------------------------------------------------------------------------------- 1 | channel_groups = { 2 | # Shank index. 3 | 0: 4 | { 5 | # List of channels to keep for spike detection. 6 | 'channels': list(range(32)), 7 | 8 | # Adjacency graph. Dead channels will be automatically discarded 9 | # by considering the corresponding subgraph. 10 | 'graph': [ 11 | (0, 1), (0, 2), 12 | (1, 2), (1, 3), 13 | (2, 3), (2, 4), 14 | (3, 4), (3, 5), 15 | (4, 5), (4, 6), 16 | (5, 6), (5, 7), 17 | (6, 7), (6, 8), 18 | (7, 8), (7, 9), 19 | (8, 9), (8, 10), 20 | (9, 10), (9, 11), 21 | (10, 11), (10, 12), 22 | (11, 12), (11, 13), 23 | (12, 13), (12, 14), 24 | (13, 14), (13, 15), 25 | (14, 15), (14, 16), 26 | (15, 16), (15, 17), 27 | (16, 17), (16, 18), 28 | (17, 18), (17, 19), 29 | (18, 19), (18, 20), 30 | (19, 20), (19, 21), 31 | (20, 21), (20, 22), 32 | (21, 22), (21, 23), 33 | (22, 23), (22, 24), 34 | (23, 24), (23, 25), 35 | (24, 25), (24, 26), 36 | (25, 26), (25, 27), 37 | (26, 27), (26, 28), 38 | (27, 28), (27, 29), 39 | (28, 29), (28, 30), 40 | (29, 30), (29, 31), 41 | (30, 31), 42 | ], 43 | 44 | # 2D positions of the channels, only for visualization purposes 45 | # in KlustaViewa. The unit doesn't matter. 46 | 'geometry': { 47 | 31: (0, 0), 48 | 30: (5, 10), 49 | 29: (-6, 20), 50 | 28: (7, 30), 51 | 27: (-8, 40), 52 | 26: (9, 50), 53 | 25: (-10, 60), 54 | 24: (11, 70), 55 | 23: (-12, 80), 56 | 22: (13, 90), 57 | 21: (-14, 100), 58 | 20: (15, 110), 59 | 19: (-16, 120), 60 | 18: (17, 130), 61 | 17: (-18, 140), 62 | 16: (19, 150), 63 | 15: (-20, 160), 64 | 14: (21, 170), 65 | 13: (-22, 180), 66 | 12: (23, 190), 67 | 11: (-24, 200), 68 | 10: (25, 210), 69 | 9: (-26, 220), 70 | 8: (27, 230), 71 | 7: (-28, 240), 72 | 6: (29, 250), 73 | 5: (-30, 260), 74 | 4: (31, 270), 75 | 3: (-32, 280), 76 | 2: (33, 290), 77 | 1: (-34, 300), 78 | 0: (35, 310), 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /phylib/electrode/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cortex-lab/phylib/72316a4ccb0abed93464ed4bfcd36a3809d1dbdf/phylib/electrode/tests/__init__.py -------------------------------------------------------------------------------- /phylib/electrode/tests/test_mea.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test MEA.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | from pathlib import Path 10 | 11 | from pytest import raises 12 | import numpy as np 13 | from numpy.testing import assert_array_equal as ae 14 | 15 | from ..mea import (_probe_channels, _remap_adjacency, _adjacency_subset, 16 | _probe_positions, _probe_adjacency_list, 17 | MEA, load_probe, list_probes 18 | ) 19 | 20 | 21 | #------------------------------------------------------------------------------ 22 | # Tests 23 | #------------------------------------------------------------------------------ 24 | 25 | def test_remap(): 26 | adjacency = {1: [2, 3, 7], 3: [5, 11]} 27 | mapping = {1: 3, 2: 20, 3: 30, 5: 50, 7: 70, 11: 1} 28 | remapped = _remap_adjacency(adjacency, mapping) 29 | assert sorted(remapped.keys()) == [3, 30] 30 | assert remapped[3] == [20, 30, 70] 31 | assert remapped[30] == [50, 1] 32 | 33 | 34 | def test_adjacency_subset(): 35 | adjacency = {1: [2, 3, 7], 3: [5, 11], 5: [1, 2, 11]} 36 | subset = [1, 5, 32] 37 | adjsub = _adjacency_subset(adjacency, subset) 38 | assert sorted(adjsub.keys()) == [1, 5] 39 | assert adjsub[1] == [] 40 | assert adjsub[5] == [1] 41 | 42 | 43 | def test_probe(): 44 | probe = {'channel_groups': { 45 | 0: {'channels': [0, 3, 1], 46 | 'graph': [[0, 3], [1, 0]], 47 | 'geometry': {0: (10, 10), 1: (10, 20), 3: (20, 30)}, 48 | }, 49 | 1: {'channels': [7], 50 | 'graph': [], 51 | }, 52 | }} 53 | adjacency = {0: set([1, 3]), 54 | 1: set([0]), 55 | 3: set([0]), 56 | } 57 | assert _probe_channels(probe, 0) == [0, 3, 1] 58 | ae(_probe_positions(probe, 0), [(10, 10), (20, 30), (10, 20)]) 59 | assert _probe_adjacency_list(probe) == adjacency 60 | 61 | mea = MEA(probe=probe) 62 | 63 | assert mea.adjacency == adjacency 64 | assert mea.channels_per_group == {0: [0, 3, 1], 1: [7]} 65 | assert mea.channels == [0, 3, 1] 66 | assert mea.n_channels == 3 67 | ae(mea.positions, [(10, 10), (20, 30), (10, 20)]) 68 | 69 | 70 | def test_mea(): 71 | 72 | n_channels = 10 73 | channels = np.arange(n_channels) 74 | positions = np.random.randn(n_channels, 2) 75 | 76 | mea = MEA(channels, positions=positions) 77 | ae(mea.positions, positions) 78 | assert mea.adjacency is None 79 | 80 | mea = MEA(channels, positions=positions) 81 | assert mea.n_channels == n_channels 82 | 83 | mea = MEA(channels, positions=positions) 84 | assert mea.n_channels == n_channels 85 | 86 | with raises(ValueError): 87 | MEA(channels=np.arange(n_channels + 1), positions=positions) 88 | 89 | with raises(ValueError): 90 | MEA(channels=channels, positions=positions[:-1, :]) 91 | 92 | 93 | def test_library(tempdir): 94 | assert '1x32_buzsaki' in list_probes() 95 | 96 | probe = load_probe('1x32_buzsaki') 97 | assert probe 98 | assert probe.channels == list(range(32)) 99 | 100 | path = Path(tempdir) / 'test.prb' 101 | with raises(IOError): 102 | load_probe(path) 103 | 104 | with open(path, 'w') as f: 105 | f.write('') 106 | with raises(KeyError): 107 | load_probe(path) 108 | -------------------------------------------------------------------------------- /phylib/io/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Input/output.""" 5 | 6 | from .array import SpikeSelector 7 | from .traces import ( 8 | get_ephys_reader, get_spike_waveforms, NpyWriter, 9 | extract_waveforms, iter_waveforms, export_waveforms) 10 | -------------------------------------------------------------------------------- /phylib/io/alf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ALF dataset generation.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import logging 11 | from pathlib import Path 12 | import shutil 13 | import ast 14 | import uuid 15 | 16 | from tqdm import tqdm 17 | import numpy as np 18 | 19 | from phylib.utils._misc import _read_tsv_simple, ensure_dir_exists 20 | from phylib.io.array import _spikes_per_cluster, _unique 21 | from phylib.io.model import load_model 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | #------------------------------------------------------------------------------ 27 | # File utils 28 | #------------------------------------------------------------------------------ 29 | 30 | NSAMPLE_WAVEFORMS = 500 # number of waveforrms sampled out of the raw data 31 | 32 | _FILE_RENAMES = [ # file_in, file_out, squeeze (bool to squeeze vector from matlab in npy) 33 | ('params.py', 'params.py', None), 34 | ('cluster_KSLabel.tsv', 'cluster_KSLabel.tsv', None), 35 | ('spike_clusters.npy', 'spikes.clusters.npy', True), 36 | ('spike_templates.npy', 'spikes.templates.npy', True), 37 | ('channel_positions.npy', 'channels.localCoordinates.npy', False), 38 | ('channel_probe.npy', 'channels.probes.npy', True), 39 | ('channel_labels.npy', 'channels.labels.npy', True), 40 | ('cluster_probes.npy', 'clusters.probes.npy', True), 41 | ('cluster_shanks.npy', 'clusters.shanks.npy', True), 42 | ('whitening_mat.npy', '_kilosort_whitening.matrix.npy', False), 43 | ('_phy_spikes_subset.channels.npy', '_phy_spikes_subset.channels.npy', False), 44 | ('_phy_spikes_subset.spikes.npy', '_phy_spikes_subset.spikes.npy', False), 45 | ('_phy_spikes_subset.waveforms.npy', '_phy_spikes_subset.waveforms.npy', False), 46 | ('drift_depths.um.npy', 'drift_depths.um.npy', False), 47 | ('drift.times.npy', 'drift.times.npy', False), 48 | ('drift.um.npy', 'drift.um.npy', False), 49 | # ('cluster_group.tsv', 'ks2/clusters.phyAnnotation.tsv', False), # todo check indexing, add2QC 50 | ] 51 | 52 | FILE_DELETES = [ 53 | 'temp_wh.dat', # potentially large file that will clog the servers 54 | ] 55 | 56 | 57 | def _read_npy_header(filename): 58 | d = {} 59 | with open(filename, 'rb') as fid: 60 | d['magic_string'] = fid.read(6) 61 | d['version'] = fid.read(2) 62 | d['len'] = int.from_bytes(fid.read(2), byteorder='little') 63 | d = {**d, **ast.literal_eval(fid.read(d['len']).decode())} 64 | return d 65 | 66 | 67 | def _create_if_possible(path, new_path, force=False): 68 | """Prepare the copy/move/symlink of a file, by making sure the source exists 69 | while the destination does not.""" 70 | if not Path(path).exists(): # pragma: no cover 71 | logger.warning("Path %s does not exist, skipping.", path) 72 | return False 73 | if Path(new_path).exists() and not force: # pragma: no cover 74 | logger.warning("Path %s already exists, skipping.", new_path) 75 | return False 76 | ensure_dir_exists(new_path.parent) 77 | return True 78 | 79 | 80 | def _copy_if_possible(path, new_path, force=False): 81 | if not _create_if_possible(path, new_path, force=force): 82 | return False 83 | logger.debug("Copying %s to %s.", path, new_path) 84 | shutil.copy(path, new_path) 85 | return True 86 | 87 | 88 | def _load(path): 89 | path = str(path) 90 | if path.endswith('.npy'): 91 | return np.load(path) 92 | elif path.endswith(('.csv', '.tsv')): 93 | return _read_tsv_simple(path)[1] # the function returns a tuple (field, data) 94 | elif path.endswith('.bin'): 95 | # TODO: configurable dtype 96 | return np.fromfile(path, np.int16) 97 | 98 | 99 | #------------------------------------------------------------------------------ 100 | # Ephys ALF creator 101 | #------------------------------------------------------------------------------ 102 | 103 | class EphysAlfCreator(object): 104 | """Class for converting a dataset in KS/phy format into ALF.""" 105 | 106 | def __init__(self, model): 107 | self.model = model 108 | self.dir_path = Path(model.dir_path) 109 | self.spc = _spikes_per_cluster(model.spike_clusters) 110 | self.cluster_ids = _unique(self.model.spike_clusters) 111 | 112 | def convert(self, out_path, force=False, label='', ampfactor=1): 113 | """Convert from KS/phy format to ALF.""" 114 | logger.info("Converting dataset to ALF.") 115 | self.out_path = Path(out_path) 116 | self.label = label 117 | self.ampfactor = ampfactor 118 | if self.out_path.resolve() == self.dir_path.resolve(): 119 | raise IOError("The source and target directories cannot be the same.") 120 | if not self.out_path.exists(): 121 | self.out_path.mkdir() 122 | 123 | with tqdm(desc="Converting to ALF", total=135) as bar: 124 | bar.update(10) 125 | self.make_cluster_objects() 126 | bar.update(10) 127 | self.make_channel_objects() 128 | bar.update(5) 129 | self.make_template_and_spikes_objects() 130 | bar.update(30) 131 | self.model.save_spikes_subset_waveforms( 132 | NSAMPLE_WAVEFORMS, sample2unit=self.ampfactor) 133 | bar.update(50) 134 | self.make_depths() 135 | bar.update(20) 136 | self.rm_files() 137 | bar.update(10) 138 | self.copy_files(force=force) 139 | bar.update(10) 140 | self.rename_with_label() 141 | bar.update(10) 142 | self.compress_spikes_dtypes() 143 | 144 | # Return the TemplateModel of the converted ALF dataset if the params.py file exists. 145 | params_path = self.out_path / 'params.py' 146 | if params_path.exists(): 147 | return load_model(params_path) 148 | 149 | def copy_files(self, force=False): 150 | for fn0, fn1, squeeze in _FILE_RENAMES: 151 | f0 = self.dir_path / fn0 152 | f1 = self.out_path / fn1 153 | _copy_if_possible(f0, f1, force=force) 154 | if f0.exists() and squeeze and f0.suffix == '.npy': 155 | h = _read_npy_header(f0) 156 | # ks2 outputs vectors as multidimensional arrays. If there is no distinction 157 | # for Matlab, there is one in Numpy 158 | if len(h['shape']) == 2 and h['shape'][-1] == 1: # pragma: no cover 159 | d = np.load(f0) 160 | np.save(f1, d.squeeze()) 161 | continue 162 | 163 | def rm_files(self): 164 | for fn0 in FILE_DELETES: 165 | fn = self.dir_path.joinpath(fn0) 166 | if fn.exists(): # pragma: no cover 167 | fn.unlink() 168 | 169 | # File creation 170 | # ------------------------------------------------------------------------- 171 | 172 | def _save_npy(self, filename, arr, dtype=None): 173 | """Save an array into a .npy file.""" 174 | dtype = arr.dtype if dtype is None else dtype 175 | np.save(self.out_path / filename, arr.astype(dtype)) 176 | 177 | def make_cluster_objects(self): 178 | """Create clusters.channels, clusters.waveformsDuration and clusters.amps""" 179 | peak_channel_path = self.dir_path / 'clusters.channels.npy' 180 | if not peak_channel_path.exists(): 181 | # self._save_npy(peak_channel_path.name, self.model.templates_channels) 182 | self._save_npy(peak_channel_path.name, self.model.clusters_channels) 183 | 184 | waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy' 185 | if not waveform_duration_path.exists(): 186 | # self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations) 187 | waveform_duration = self.model.clusters_waveforms_durations 188 | waveform_duration[self.model.nan_idx] = np.nan 189 | self._save_npy(waveform_duration_path.name, waveform_duration) 190 | 191 | # group by average over cluster number 192 | # camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan 193 | camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan 194 | camps[self.cluster_ids] = self.model.clusters_amplitudes 195 | amps_path = self.dir_path / 'clusters.amps.npy' 196 | self._save_npy(amps_path.name, camps * self.ampfactor) 197 | 198 | # clusters uuids 199 | uuid_list = ['uuids'] 200 | uuid_list.extend([str(uuid.uuid4()) for _ in range(camps.size)]) 201 | with open(self.out_path / 'clusters.uuids.csv', 'w+') as fid: 202 | fid.write('\n'.join(uuid_list)) 203 | 204 | def make_channel_objects(self): 205 | """If there is no rawInd file, create it""" 206 | rawInd_path = self.dir_path / 'channels.rawInd.npy' 207 | rawInd = np.zeros_like(self.model.channel_probes).astype(int) 208 | channel_offset = 0 209 | for probe in np.unique(self.model.channel_probes): 210 | ind = self.model.channel_probes == probe 211 | rawInd[ind] = self.model.channel_mapping[ind] - channel_offset 212 | channel_offset += np.max(self.model.channel_mapping[ind]) 213 | self._save_npy(rawInd_path.name, rawInd) 214 | 215 | def make_depths(self): 216 | """Make spikes.depths.npy, clusters.depths.npy.""" 217 | channel_positions = self.model.channel_positions 218 | assert channel_positions.ndim == 2 219 | 220 | spike_clusters = self.model.spike_clusters 221 | assert spike_clusters.ndim == 1 222 | 223 | cluster_channels = np.load(self.out_path / 'clusters.channels.npy') 224 | assert cluster_channels.ndim == 1 225 | n_clusters = cluster_channels.shape[0] 226 | 227 | clusters_depths = channel_positions[cluster_channels, 1] 228 | clusters_depths[self.model.nan_idx] = np.nan 229 | assert clusters_depths.shape == (n_clusters,) 230 | 231 | if self.model.sparse_features is None: 232 | spikes_depths = clusters_depths[spike_clusters] 233 | else: 234 | spikes_depths = self.model.get_depths() 235 | self._save_npy('spikes.depths.npy', spikes_depths, np.float32) 236 | self._save_npy('clusters.depths.npy', clusters_depths) 237 | 238 | def make_template_and_spikes_objects(self): 239 | """Creates the template waveforms sparse object 240 | Without manual curation, it also corresponds to clusters waveforms objects. 241 | """ 242 | # "We cannot just rename/copy spike_times.npy because it is in unit of samples, 243 | # and not seconds 244 | self._save_npy('spikes.times.npy', self.model.spike_times) 245 | self._save_npy('spikes.samples.npy', self.model.spike_samples) 246 | spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor, 247 | use='templates') 248 | self._save_npy('spikes.amps.npy', spike_amps, np.float32) 249 | self._save_npy('templates.amps.npy', template_amps) 250 | 251 | if self.model.sparse_templates.cols: 252 | raise NotImplementedError("Sparse template export to ALF not implemented yet") 253 | else: 254 | n_templates, n_wavsamps, nchall = templates_v.shape 255 | # for some datasets, 32 may be too much 256 | ncw = min(self.model.n_closest_channels, nchall) 257 | assert n_templates == self.model.n_templates 258 | templates = np.zeros((n_templates, n_wavsamps, ncw), dtype=np.float32) 259 | templates_inds = np.zeros((n_templates, ncw), dtype=np.int32) 260 | # for each template, find the nearest channels to keep (one the same probe...) 261 | for t in np.arange(n_templates): 262 | current_probe = self.model.channel_probes[self.model.templates_channels[t]] 263 | channel_distance = np.sum(np.abs( 264 | self.model.channel_positions - 265 | self.model.channel_positions[self.model.templates_channels[t]]), axis=1) 266 | channel_distance[self.model.channel_probes != current_probe] += np.inf 267 | templates_inds[t, :] = np.argsort(channel_distance)[:ncw] 268 | templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]] 269 | np.save(self.out_path.joinpath('templates.waveforms'), templates) 270 | np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds) 271 | 272 | _, clusters_v, cluster_amps = self.model.get_amplitudes_true(self.ampfactor, 273 | use='clusters') 274 | n_clusters, n_wavsamps, nchall = clusters_v.shape 275 | # for some datasets, 32 may be too much 276 | ncw = min(self.model.n_closest_channels, nchall) 277 | assert n_clusters == self.model.n_clusters 278 | templates = np.zeros((n_clusters, n_wavsamps, ncw), dtype=np.float32) 279 | templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32) 280 | # for each template, find the nearest channels to keep (one the same probe...) 281 | for t in np.arange(n_clusters): 282 | channels = self.model.clusters_channels 283 | 284 | current_probe = self.model.channel_probes[channels[t]] 285 | channel_distance = np.sum(np.abs( 286 | self.model.channel_positions - 287 | self.model.channel_positions[channels[t]]), axis=1) 288 | channel_distance[self.model.channel_probes != current_probe] += np.inf 289 | templates_inds[t, :] = np.argsort(channel_distance)[:ncw] 290 | templates[t, ...] = clusters_v[t, :][:, templates_inds[t, :]] 291 | np.save(self.out_path.joinpath('clusters.waveforms'), templates) 292 | np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) 293 | np.save(self.out_path.joinpath('clusters.amps'), cluster_amps) 294 | 295 | def rename_with_label(self): 296 | """add the label as an ALF part name before the extension if any label provided""" 297 | if not self.label: 298 | return 299 | glob_patterns = ['channels.*', 'clusters.*', 'spikes.*', 'templates.*'] 300 | for pattern in glob_patterns: 301 | for f in self.out_path.glob(pattern): 302 | f.rename(f.with_suffix(f'.{self.label}{f.suffix}')) 303 | 304 | def compress_spikes_dtypes(self): 305 | """Convert clusters and templates to int16.""" 306 | for attribute in ['templates', 'clusters']: 307 | fn = next(self.out_path.glob(f'spikes.{attribute}.*npy')) 308 | np.save(fn, np.load(fn).astype(np.uint16)) 309 | -------------------------------------------------------------------------------- /phylib/io/array.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Utility functions for NumPy arrays.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | from math import floor, ceil 11 | from operator import itemgetter 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | 16 | from phylib.utils import _as_scalar, _as_scalars 17 | from phylib.utils._types import _as_array 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | #------------------------------------------------------------------------------ 23 | # Utility functions 24 | #------------------------------------------------------------------------------ 25 | 26 | def _clip(x, a, b): 27 | return max(a, min(b, x)) 28 | 29 | 30 | def _range_from_slice(myslice, start=None, stop=None, step=None, length=None): 31 | """Convert a slice to an array of integers.""" 32 | assert isinstance(myslice, slice) 33 | # Find 'step'. 34 | step = myslice.step if myslice.step is not None else step 35 | if step is None: 36 | step = 1 37 | # Find 'start'. 38 | start = myslice.start if myslice.start is not None else start 39 | if start is None: 40 | start = 0 41 | # Find 'stop' as a function of length if 'stop' is unspecified. 42 | stop = myslice.stop if myslice.stop is not None else stop 43 | if length is not None: 44 | stop_inferred = floor(start + step * length) 45 | if stop is not None and stop < stop_inferred: 46 | raise ValueError("'stop' ({stop}) and ".format(stop=stop) + 47 | "'length' ({length}) ".format(length=length) + 48 | "are not compatible.") 49 | stop = stop_inferred 50 | if stop is None and length is None: 51 | raise ValueError("'stop' and 'length' cannot be both unspecified.") 52 | myrange = np.arange(start, stop, step) 53 | # Check the length if it was specified. 54 | if length is not None: 55 | assert len(myrange) == length 56 | return myrange 57 | 58 | 59 | def _unique(x): 60 | """Faster version of np.unique(). 61 | 62 | This version is restricted to 1D arrays of non-negative integers. 63 | 64 | It is only faster if len(x) >> len(unique(x)). 65 | 66 | """ 67 | if x is None or len(x) == 0: 68 | return np.array([], dtype=np.int64) 69 | # WARNING: only keep positive values. 70 | # cluster=-1 means "unclustered". 71 | x = _as_array(x) 72 | x = x[x >= 0] 73 | bc = np.bincount(x) 74 | return np.nonzero(bc)[0] 75 | 76 | 77 | def _normalize(arr, keep_ratio=False): 78 | """Normalize an array into [0, 1].""" 79 | (x_min, y_min), (x_max, y_max) = arr.min(axis=0), arr.max(axis=0) 80 | 81 | if keep_ratio: 82 | a = 1. / max(x_max - x_min, y_max - y_min) 83 | ax = ay = a 84 | bx = .5 - .5 * a * (x_max + x_min) 85 | by = .5 - .5 * a * (y_max + y_min) 86 | else: 87 | ax = 1. / (x_max - x_min) 88 | ay = 1. / (y_max - y_min) 89 | bx = -x_min / (x_max - x_min) 90 | by = -y_min / (y_max - y_min) 91 | 92 | arr_n = arr.copy() 93 | arr_n[:, 0] *= ax 94 | arr_n[:, 0] += bx 95 | arr_n[:, 1] *= ay 96 | arr_n[:, 1] += by 97 | 98 | return arr_n 99 | 100 | 101 | def _index_of(arr, lookup): 102 | """Replace scalars in an array by their indices in a lookup table. 103 | 104 | Implicitely assume that: 105 | 106 | * All elements of arr and lookup are non-negative integers. 107 | * All elements or arr belong to lookup. 108 | 109 | This is not checked for performance reasons. 110 | 111 | """ 112 | # Equivalent of np.digitize(arr, lookup) - 1, but much faster. 113 | # TODO: assertions to disable in production for performance reasons. 114 | # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large 115 | # values 116 | lookup = np.asarray(lookup, dtype=np.int32) 117 | m = (lookup.max() if len(lookup) else 0) + 1 118 | tmp = np.zeros(m + 1, dtype=int) 119 | # Ensure that -1 values are kept. 120 | tmp[-1] = -1 121 | if len(lookup): 122 | tmp[lookup] = np.arange(len(lookup)) 123 | return tmp[arr] 124 | 125 | 126 | def _pad(arr, n, dir='right'): 127 | """Pad an array with zeros along the first axis. 128 | 129 | Parameters 130 | ---------- 131 | 132 | n : int 133 | Size of the returned array in the first axis. 134 | dir : str 135 | Direction of the padding. Must be one 'left' or 'right'. 136 | 137 | """ 138 | assert dir in ('left', 'right') 139 | if n < 0: 140 | raise ValueError("'n' must be positive: {0}.".format(n)) 141 | elif n == 0: 142 | return np.zeros((0,) + arr.shape[1:], dtype=arr.dtype) 143 | n_arr = arr.shape[0] 144 | shape = (n,) + arr.shape[1:] 145 | if n_arr == n: 146 | assert arr.shape == shape 147 | return arr 148 | elif n_arr < n: 149 | out = np.zeros(shape, dtype=arr.dtype) 150 | if dir == 'left': 151 | out[-n_arr:, ...] = arr 152 | elif dir == 'right': 153 | out[:n_arr, ...] = arr 154 | assert out.shape == shape 155 | return out 156 | else: 157 | if dir == 'left': 158 | out = arr[-n:, ...] 159 | elif dir == 'right': 160 | out = arr[:n, ...] 161 | assert out.shape == shape 162 | return out 163 | 164 | 165 | def _get_padded(data, start, end): 166 | """Return `data[start:end]` filling in with zeros outside array bounds 167 | 168 | Assumes that either `start<0` or `end>len(data)` but not both. 169 | 170 | """ 171 | if start < 0 and end > data.shape[0]: 172 | raise RuntimeError() 173 | if start < 0: 174 | start_zeros = np.zeros((-start, data.shape[1]), 175 | dtype=data.dtype) 176 | return np.vstack((start_zeros, data[:end])) 177 | elif end > data.shape[0]: 178 | end_zeros = np.zeros((end - data.shape[0], data.shape[1]), 179 | dtype=data.dtype) 180 | return np.vstack((data[start:], end_zeros)) 181 | else: 182 | return data[start:end] 183 | 184 | 185 | def _get_data_lim(arr, n_spikes=None): 186 | n = arr.shape[0] 187 | k = max(1, n // n_spikes) if n_spikes else 1 188 | arr = np.abs(arr[::k]) 189 | n = arr.shape[0] 190 | arr = arr.reshape((n, -1)) 191 | return arr.max() or 1. 192 | 193 | 194 | def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): 195 | """Return a list of pairs `(cluster, similarity)` sorted by decreasing 196 | similarity to a given cluster.""" 197 | l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) 198 | for candidate in _as_scalars(cluster_ids)] 199 | l = sorted(l, key=itemgetter(1), reverse=True) 200 | max_n = None or len(l) 201 | return l[:max_n] 202 | 203 | 204 | def _flatten(l): 205 | return [item for sublist in l for item in sublist] 206 | 207 | 208 | # ----------------------------------------------------------------------------- 209 | # I/O functions 210 | # ----------------------------------------------------------------------------- 211 | 212 | def read_array(path, mmap_mode=None): 213 | """Read a .npy array.""" 214 | path = Path(path) 215 | file_ext = path.suffix 216 | if file_ext == '.npy': 217 | return np.load(str(path), mmap_mode=mmap_mode) 218 | raise NotImplementedError("The file extension `{}` is not currently supported." % file_ext) 219 | 220 | 221 | def write_array(path, arr): 222 | """Write an array to a .npy file.""" 223 | path = Path(path) 224 | file_ext = path.suffix 225 | if file_ext == '.npy': 226 | return np.save(str(path), arr) 227 | raise NotImplementedError("The file extension `{}` is not currently supported." % file_ext) 228 | 229 | 230 | # ----------------------------------------------------------------------------- 231 | # Chunking functions 232 | # ----------------------------------------------------------------------------- 233 | 234 | def _excerpt_step(n_samples, n_excerpts=None, excerpt_size=None): 235 | """Compute the step of an excerpt set as a function of the number 236 | of excerpts or their sizes.""" 237 | assert n_excerpts >= 2 238 | step = max((n_samples - excerpt_size) // (n_excerpts - 1), 239 | excerpt_size) 240 | return step 241 | 242 | 243 | def chunk_bounds(n_samples, chunk_size, overlap=0): 244 | """Return chunk bounds. 245 | 246 | Chunks have the form: 247 | 248 | [ overlap/2 | chunk_size-overlap | overlap/2 ] 249 | s_start keep_start keep_end s_end 250 | 251 | Except for the first and last chunks which do not have a left/right 252 | overlap. 253 | 254 | This generator yields (s_start, s_end, keep_start, keep_end). 255 | 256 | """ 257 | s_start = 0 258 | s_end = chunk_size 259 | keep_start = s_start 260 | keep_end = s_end - overlap // 2 261 | yield s_start, s_end, keep_start, keep_end 262 | 263 | while s_end - overlap + chunk_size < n_samples: 264 | s_start = s_end - overlap 265 | s_end = s_start + chunk_size 266 | keep_start = keep_end 267 | keep_end = s_end - overlap // 2 268 | if s_start < s_end: 269 | yield s_start, s_end, keep_start, keep_end 270 | 271 | s_start = s_end - overlap 272 | s_end = n_samples 273 | keep_start = keep_end 274 | keep_end = s_end 275 | if s_start < s_end: 276 | yield s_start, s_end, keep_start, keep_end 277 | 278 | 279 | def excerpts(n_samples, n_excerpts=None, excerpt_size=None): 280 | """Yield (start, end) where start is included and end is excluded.""" 281 | assert n_excerpts >= 2 282 | step = _excerpt_step(n_samples, n_excerpts=n_excerpts, excerpt_size=excerpt_size) 283 | for i in range(n_excerpts): 284 | start = i * step 285 | if start >= n_samples: 286 | break 287 | end = min(start + excerpt_size, n_samples) 288 | yield start, end 289 | 290 | 291 | def data_chunk(data, chunk, with_overlap=False): 292 | """Get a data chunk.""" 293 | assert isinstance(chunk, tuple) 294 | if len(chunk) == 2: 295 | i, j = chunk 296 | elif len(chunk) == 4: 297 | if with_overlap: 298 | i, j = chunk[:2] 299 | else: 300 | i, j = chunk[2:] 301 | else: 302 | raise ValueError("'chunk' should have 2 or 4 elements, not {0:d}".format(len(chunk))) 303 | return data[i:j, ...] 304 | 305 | 306 | def get_excerpts(data, n_excerpts=None, excerpt_size=None): 307 | """Return excerpts of a data array.""" 308 | assert n_excerpts is not None 309 | assert excerpt_size is not None 310 | if len(data) < n_excerpts * excerpt_size: 311 | return data 312 | elif n_excerpts == 0: 313 | return data[:0] 314 | elif n_excerpts == 1: 315 | return data[:excerpt_size] 316 | out = np.concatenate([ 317 | data_chunk(data, chunk) 318 | for chunk in excerpts(len(data), n_excerpts=n_excerpts, excerpt_size=excerpt_size)]) 319 | assert len(out) <= n_excerpts * excerpt_size 320 | return out 321 | 322 | 323 | # ----------------------------------------------------------------------------- 324 | # Spike clusters utility functions 325 | # ----------------------------------------------------------------------------- 326 | 327 | def _spikes_in_clusters(spike_clusters, clusters): 328 | """Return the ids of all spikes belonging to the specified clusters.""" 329 | if len(spike_clusters) == 0 or len(clusters) == 0: 330 | return np.array([], dtype=int) 331 | return np.nonzero(np.isin(spike_clusters, clusters))[0] 332 | 333 | 334 | def _spikes_per_cluster(spike_clusters, spike_ids=None): 335 | """Return a dictionary {cluster: list_of_spikes}.""" 336 | if spike_clusters is None or not len(spike_clusters): 337 | return {} 338 | if spike_ids is None: 339 | spike_ids = np.arange(len(spike_clusters)).astype(np.int64) 340 | # NOTE: this sort method is stable, so spike ids are increasing 341 | # among any cluster. Therefore we don't have to sort again down here, 342 | # when creating the spikes_in_clusters dictionary. 343 | rel_spikes = np.argsort(spike_clusters, kind='mergesort') 344 | abs_spikes = spike_ids[rel_spikes] 345 | spike_clusters = spike_clusters[rel_spikes] 346 | 347 | diff = np.empty_like(spike_clusters) 348 | diff[0] = 1 349 | diff[1:] = np.diff(spike_clusters) 350 | 351 | idx = np.nonzero(diff > 0)[0] 352 | clusters = spike_clusters[idx] 353 | 354 | # NOTE: we don't have to sort abs_spikes[...] here because the argsort 355 | # using 'mergesort' above is stable. 356 | spikes_in_clusters = { 357 | clusters[i]: abs_spikes[idx[i]:idx[i + 1]] for i in range(len(clusters) - 1)} 358 | spikes_in_clusters[clusters[-1]] = abs_spikes[idx[-1]:] 359 | 360 | return spikes_in_clusters 361 | 362 | 363 | def _flatten_per_cluster(per_cluster): 364 | """Convert a dictionary {cluster: spikes} to a spikes array.""" 365 | return np.unique(np.concatenate(list(per_cluster.values()))).astype(np.int64) 366 | 367 | 368 | def grouped_mean(arr, spike_clusters): 369 | """Compute the mean of a spike-dependent quantity for every cluster. 370 | 371 | The two arguments should be 1D array with `n_spikes` elements. 372 | 373 | The output is a 1D array with `n_clusters` elements. The clusters are 374 | sorted in increasing order. 375 | 376 | """ 377 | arr = np.asarray(arr) 378 | spike_clusters = np.asarray(spike_clusters) 379 | assert arr.shape[0] == len(spike_clusters) 380 | cluster_ids = _unique(spike_clusters) 381 | spike_clusters_rel = _index_of(spike_clusters, cluster_ids) 382 | spike_counts = np.bincount(spike_clusters_rel) 383 | assert len(spike_counts) == len(cluster_ids) 384 | t = np.zeros((len(cluster_ids),) + arr.shape[1:]) 385 | # Compute the sum with possible repetitions. 386 | np.add.at(t, spike_clusters_rel, arr) 387 | return t / spike_counts.reshape((-1,) + (1,) * (arr.ndim - 1)) 388 | 389 | 390 | # ----------------------------------------------------------------------------- 391 | # Spike selection 392 | # ----------------------------------------------------------------------------- 393 | 394 | def _times_in_chunks(times, chunks_kept): 395 | """Return the indices of the times that belong to a list of kept chunks.""" 396 | ind = np.searchsorted(chunks_kept, times, side='right') 397 | return ind % 2 == 1 398 | 399 | 400 | class SpikeSelector(object): 401 | """Select a given number of spikes per cluster among a subset of the chunks.""" 402 | def __init__( 403 | self, get_spikes_per_cluster=None, spike_times=None, 404 | chunk_bounds=None, n_chunks_kept=None): 405 | self.get_spikes_per_cluster = get_spikes_per_cluster 406 | self.spike_times = spike_times 407 | self.chunks_kept = [] 408 | n_chunks = len(chunk_bounds) - 1 409 | 410 | for i in range(0, n_chunks, max(1, int(ceil(n_chunks / n_chunks_kept)))): 411 | self.chunks_kept.extend(chunk_bounds[i:i + 2]) 412 | self.chunks_kept = np.array(self.chunks_kept) 413 | 414 | def __call__(self, n_spk_clu, cluster_ids, subset_chunks=False, subset_spikes=None): 415 | """Select about n_spk_clu random spikes from each of the requested clusters, only 416 | in the kept chunks.""" 417 | if not len(cluster_ids): 418 | return np.array([], dtype=np.int64) 419 | # Start with all spikes from each cluster. 420 | selection = {} 421 | for cluster in cluster_ids: 422 | # Get all spikes from that cluster. 423 | spike_ids = self.get_spikes_per_cluster(cluster) 424 | # Get the spike times. 425 | t = self.spike_times[spike_ids] 426 | # Keep the spikes belonging to the chunks. 427 | if subset_chunks: 428 | spike_ids = spike_ids[_times_in_chunks(t, self.chunks_kept)] 429 | # Keep spikes from a given subset. 430 | if subset_spikes is not None: 431 | spike_ids = np.intersect1d(spike_ids, subset_spikes) 432 | # Make a subselection if needed. 433 | if n_spk_clu is not None and n_spk_clu > 0 and len(spike_ids) > n_spk_clu: 434 | spike_ids = np.random.choice(spike_ids, n_spk_clu, replace=False) 435 | selection[cluster] = spike_ids 436 | # Return the concatenation of all spikes. 437 | return _flatten_per_cluster(selection) 438 | -------------------------------------------------------------------------------- /phylib/io/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Utility functions for test datasets.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import hashlib 10 | import logging 11 | from pathlib import Path 12 | 13 | from phylib.utils._misc import ensure_dir_exists, phy_config_dir 14 | from phylib.utils.event import ProgressReporter 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | #------------------------------------------------------------------------------ 21 | # Utility functions 22 | #------------------------------------------------------------------------------ 23 | 24 | def _remote_file_size(path): 25 | import requests 26 | try: # pragma: no cover 27 | response = requests.head(path) 28 | return int(response.headers.get('content-length', 0)) 29 | except Exception: 30 | # Unable to get the file size: no progress report. 31 | pass 32 | return 0 33 | 34 | 35 | def _save_stream(r, path): 36 | size = _remote_file_size(r.url) 37 | pr = ProgressReporter() 38 | pr.value_max = size or 1 39 | pr.set_progress_message('Downloading `' + str(path) + '`: {progress:.1f}%.') 40 | pr.set_complete_message('Download complete.') 41 | downloaded = 0 42 | with open(path, 'wb') as f: 43 | for i, chunk in enumerate(r.iter_content(chunk_size=1024)): 44 | if chunk: 45 | f.write(chunk) 46 | f.flush() 47 | downloaded += len(chunk) 48 | if i % 100 == 0: 49 | pr.value = downloaded 50 | pr.set_complete() 51 | 52 | 53 | def _download(url, stream=None): 54 | from requests import get 55 | r = get(url, stream=stream) 56 | if r.status_code != 200: # pragma: no cover 57 | logger.debug("Error while downloading %s.", url) 58 | r.raise_for_status() 59 | return r 60 | 61 | 62 | def download_text_file(url): 63 | """Download a text file.""" 64 | return _download(url).text 65 | 66 | 67 | def _md5(path, blocksize=2 ** 20): 68 | """Compute the checksum of a file.""" 69 | m = hashlib.md5() 70 | with open(path, 'rb') as f: 71 | while True: 72 | buf = f.read(blocksize) 73 | if not buf: 74 | break 75 | m.update(buf) 76 | return m.hexdigest() 77 | 78 | 79 | def _check_md5(path, checksum): 80 | return (_md5(path) == checksum) if checksum else None 81 | 82 | 83 | def _check_md5_of_url(output_path, url): 84 | try: 85 | checksum = download_text_file(url + '.md5').split(' ')[0] 86 | except Exception: 87 | checksum = None 88 | finally: 89 | if checksum: 90 | return _check_md5(output_path, checksum) 91 | 92 | 93 | def download_file(url, output_path): 94 | """Download a binary file from an URL. 95 | 96 | The checksum will be downloaded from `URL + .md5`. If this download 97 | succeeds, the file's MD5 will be compared to the expected checksum. 98 | 99 | Parameters 100 | ---------- 101 | 102 | url : str 103 | The file's URL. 104 | output_path : str 105 | The path where the file is to be saved. 106 | 107 | """ 108 | path = Path(output_path) 109 | if path.exists(): 110 | checked = _check_md5_of_url(output_path, url) 111 | if checked is False: 112 | logger.debug( 113 | "The file `%s` already exists but is invalid: redownloading.", output_path) 114 | elif checked is True: 115 | logger.debug("The file `%s` already exists: skipping.", output_path) 116 | return output_path 117 | r = _download(url, stream=True) 118 | _save_stream(r, output_path) 119 | if _check_md5_of_url(output_path, url) is False: 120 | logger.debug("The checksum doesn't match: retrying the download.") 121 | r = _download(url, stream=True) 122 | _save_stream(r, output_path) 123 | if _check_md5_of_url(output_path, url) is False: 124 | raise RuntimeError("The checksum of the downloaded file " 125 | "doesn't match the provided checksum.") 126 | return 127 | 128 | 129 | _BASE_URL = 'https://raw.githubusercontent.com/kwikteam/phy-data/master/' 130 | 131 | 132 | def download_test_file(name, config_dir=None, force=False): 133 | """Download a test file.""" 134 | config_dir = Path(config_dir or phy_config_dir()) 135 | path = config_dir / 'test_data' / name 136 | # Ensure the directory exists. 137 | ensure_dir_exists(path.parent) 138 | if not force and path.exists(): 139 | return path 140 | url = _BASE_URL + name 141 | download_file(url, output_path=path) 142 | return path 143 | -------------------------------------------------------------------------------- /phylib/io/merge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Probe merging.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import logging 11 | from pathlib import Path 12 | 13 | from tqdm import tqdm 14 | import numpy as np 15 | from scipy.linalg import block_diag 16 | 17 | from phylib.utils._misc import ( 18 | _read_tsv_simple, _write_tsv_simple, write_tsv, read_python, write_python) 19 | from phylib.io.model import load_model 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | #------------------------------------------------------------------------------ 25 | # Merge utils 26 | #------------------------------------------------------------------------------ 27 | 28 | def _concat(arrs, axis=0, dtype=None): 29 | dtype = dtype or arrs[0].dtype 30 | return np.concatenate(arrs).astype(dtype) 31 | 32 | 33 | def _load_multiple_spike_times(*spike_times_l): 34 | """Load multiple spike_times arrays and merge them into a single one.""" 35 | # We concatenate all spike times arrays. 36 | spike_times_concat = _concat(spike_times_l) 37 | # We sort by increasing time. 38 | spike_order = np.argsort(spike_times_concat, kind='stable') 39 | spike_times_ordered = spike_times_concat[spike_order] 40 | assert np.all(np.diff(spike_times_ordered) >= 0) 41 | # We return the ordered spike times, and the reordering array. 42 | return spike_times_ordered, spike_order 43 | 44 | 45 | def _load_multiple_spike_arrays(*spike_array_l, spike_order=None): 46 | """Load multiple spike-dependent arrays and concatenate them along the first dimension. 47 | Keep the spike time ordering. 48 | """ 49 | assert spike_order is not None 50 | spike_array_concat = _concat(spike_array_l, axis=0) 51 | assert spike_array_concat.shape[0] == spike_order.shape[0] 52 | return spike_array_concat[spike_order] 53 | 54 | 55 | def _load_multiple_files(fn, subdirs): 56 | """Load the same filename in the different subdirectories.""" 57 | # Warning: squeeze may fail in degenerate cases. 58 | return [np.load(str(subdir / fn)).squeeze() for subdir in subdirs] 59 | 60 | 61 | #------------------------------------------------------------------------------ 62 | # Main Merger class 63 | #------------------------------------------------------------------------------ 64 | 65 | class Merger(object): 66 | """Merge spike-sorted data from different probes and output datasets to disk. 67 | 68 | Constructor 69 | ----------- 70 | 71 | subdirs : list 72 | List of paths to the probe directories. 73 | out_dir : str or Path 74 | Output directory for the merged spike-sorted data. 75 | probe_info : list 76 | For each probe, a dictionary with the following fields: 77 | * `label`: a string with the probe label (defaults to the probe folder name) 78 | * any other field, such as `model`, `serial_number`, etc. 79 | 80 | All fields will be saved in `probes.description.tsv`. 81 | 82 | """ 83 | def __init__(self, subdirs, out_dir, probe_info=None): 84 | assert subdirs 85 | self.subdirs = [Path(subdir) for subdir in subdirs] 86 | 87 | self.out_dir = Path(out_dir) 88 | self.out_dir.mkdir(parents=True, exist_ok=True) 89 | 90 | # Default probe info if not provided: the label is the probe folder name. 91 | self.probe_info = probe_info or [{'label': subdir.parts[-1]} for subdir in self.subdirs] 92 | assert len(self.probe_info) == len(self.subdirs) 93 | 94 | def _save(self, name, arr): 95 | """Save a npy array in the output directory.""" 96 | logger.debug("Saving %s %s %s.", name, arr.dtype, arr.shape) 97 | np.save(self.out_dir / name, arr) 98 | 99 | def write_params(self): 100 | """Write a params.py for the merged dataset.""" 101 | params_l = [read_python(subdir / 'params.py') for subdir in self.subdirs] 102 | n_channels_dat = sum(params['n_channels_dat'] for params in params_l) 103 | 104 | params_merged = params_l[0] 105 | params_merged['dat_path'] = [] 106 | params_merged['n_channels_dat'] = n_channels_dat 107 | 108 | write_python(self.out_dir / 'params.py', params_merged) 109 | 110 | def write_probe_desc(self): 111 | """Write the probe description in a TSV file.""" 112 | write_tsv(self.out_dir / 'probes.description.tsv', self.probe_info, first_field='label') 113 | 114 | def write_spike_times(self): 115 | """Write the merged spike times, and register self.spike_order with the reordering 116 | of the spikes.""" 117 | spike_times_l = _load_multiple_files('spike_times.npy', self.subdirs) 118 | spike_times, self.spike_order = _load_multiple_spike_times(*spike_times_l) 119 | self._save('spike_times.npy', spike_times) 120 | 121 | def write_spike_data(self): 122 | """Write spike-dependent data.""" 123 | spike_data = [ 124 | 'amplitudes.npy', 125 | 'spike_templates.npy', 126 | # 'pc_features.npy', 127 | # 'template_features.npy', 128 | ] 129 | for fn in spike_data: 130 | arrays = _load_multiple_files(fn, self.subdirs) 131 | concat = _load_multiple_spike_arrays(*arrays, spike_order=self.spike_order) 132 | self._save(fn, concat) 133 | 134 | def write_spike_clusters(self): 135 | """Write the merged spike clusters, and register self.cluster_offsets. 136 | Write the merged spike templates, and register self.template_offsets. 137 | """ 138 | spike_clusters_l = _load_multiple_files('spike_clusters.npy', self.subdirs) 139 | spike_templates_l = _load_multiple_files('spike_templates.npy', self.subdirs) 140 | self.cluster_offsets = [] 141 | self.template_offsets = [] 142 | cluster_probes_l = [] 143 | coffset = 0 144 | toffset = 0 145 | for i, (subdir, sc, st) in enumerate( 146 | zip(self.subdirs, spike_clusters_l, spike_templates_l)): 147 | n_clu = np.max(sc) + 1 148 | n_tmp = np.max(st) + 1 149 | sc += coffset 150 | st += toffset 151 | self.cluster_offsets.append(coffset) 152 | self.template_offsets.append(toffset) 153 | cluster_probes_l.append(i * np.ones(n_clu, dtype=np.int32)) 154 | coffset += n_clu 155 | toffset += n_tmp 156 | spike_clusters = _load_multiple_spike_arrays( 157 | *spike_clusters_l, spike_order=self.spike_order) 158 | spike_templates = _load_multiple_spike_arrays( 159 | *spike_templates_l, spike_order=self.spike_order) 160 | cluster_probes = _concat(cluster_probes_l) 161 | assert np.max(spike_clusters) + 1 == cluster_probes.size 162 | self._save('spike_clusters.npy', spike_clusters) 163 | self._save('spike_templates.npy', spike_templates) 164 | self._save('cluster_probes.npy', cluster_probes) 165 | 166 | def write_cluster_data(self): 167 | """We load all cluster metadata from TSV files, renumber the clusters, 168 | merge the dictionaries, and save in a new merged TSV file. """ 169 | 170 | cluster_data = [ 171 | 'cluster_Amplitude.tsv', 172 | 'cluster_ContamPct.tsv', 173 | 'cluster_KSLabel.tsv' 174 | ] 175 | 176 | for fn in cluster_data: 177 | metadata = {} 178 | for subdir, offset in zip(self.subdirs, self.cluster_offsets): 179 | try: 180 | field_name, metadata_loc = _read_tsv_simple(subdir / fn) 181 | except ValueError: 182 | # Skipping non-existing file. 183 | continue 184 | for k, v in metadata_loc.items(): 185 | metadata[k + offset] = v 186 | if metadata: 187 | _write_tsv_simple(self.out_dir / fn, field_name, metadata) 188 | 189 | def write_channel_data(self): 190 | """Write channel-dependent data, and register self.channel_offsets.""" 191 | self.channel_offsets = [] 192 | channel_probes = [] 193 | channel_maps_l = _load_multiple_files('channel_map.npy', self.subdirs) 194 | # TODO if needed: channel_shanks.npy 195 | offset = 0 196 | for ind, array in enumerate(channel_maps_l): 197 | array += offset 198 | self.channel_offsets.append(offset) 199 | offset = array.max() 200 | channel_probes.append(array * 0 + ind) 201 | channel_maps = _concat(channel_maps_l, axis=0) 202 | channel_probes = _concat(channel_probes, axis=0) 203 | self._save('channel_map.npy', channel_maps) 204 | self._save('channel_probe.npy', channel_probes) 205 | 206 | def write_channel_positions(self): 207 | """Write the channel positions.""" 208 | channel_positions_l = _load_multiple_files('channel_positions.npy', self.subdirs) 209 | x_offset = 0. 210 | for array in channel_positions_l: 211 | array[:, 0] += x_offset 212 | x_offset = 2. * array[:, 0].max() - array[:, 0].min() 213 | channel_positions = _concat(channel_positions_l, axis=0) 214 | self._save('channel_positions.npy', channel_positions) 215 | 216 | def write_templates(self): 217 | """Write the templates (only dense format for now).""" 218 | # TODO: write the templates array in sparse format. 219 | 220 | path = self.out_dir / 'templates.npy' 221 | 222 | templates_l = _load_multiple_files('templates.npy', self.subdirs) 223 | 224 | # Determine the templates array shape. 225 | n_templates = sum(tmp.shape[0] for tmp in templates_l) 226 | n_samples = templates_l[0].shape[1] # assuming all have the same number of samples 227 | assert np.all(np.array([templates_i.shape[1] for templates_i in templates_l]) == n_samples) 228 | 229 | n_channels = sum(tmp.shape[2] for tmp in templates_l) 230 | shape = (n_templates, n_samples, n_channels) 231 | 232 | np.save(path, np.empty(shape, dtype=templates_l[0].dtype)) 233 | offset = 0 234 | with open(path, 'r+b') as fid: 235 | fid.seek(8) 236 | offset = int.from_bytes(fid.read(2), byteorder='little') 237 | fid.seek(offset, 1) 238 | for i in range(len(self.subdirs)): 239 | j0 = templates_l[i - 1].shape[2] if i > 0 else 0 240 | j1 = j0 + templates_l[i].shape[2] 241 | for it in np.arange(templates_l[i].shape[0]): 242 | one_template = np.zeros((n_samples, n_channels), dtype=templates_l[0].dtype) 243 | one_template[:, j0:j1] = templates_l[i][it, :] 244 | fid.write(one_template.tobytes()) 245 | 246 | def write_template_data(self): 247 | template_data = [ 248 | # 'templates_ind.npy', # HACK: do not copy this array (which is trivial with 0 1 2 3.. 249 | # on each row), 250 | # the templates.npy file is really dense in KS2 and should stay this way 251 | 'pc_feature_ind.npy', 252 | 'template_feature_ind.npy', 253 | ] 254 | 255 | for fn in template_data: 256 | arrays = _load_multiple_files(fn, self.subdirs) 257 | # For ind arrays, we need to take into account the channel offset. 258 | for array, offset in zip(arrays, self.channel_offsets): 259 | array += offset 260 | concat = _concat(arrays, axis=0).astype(np.uint32) 261 | self._save(fn, concat) 262 | 263 | def write_misc(self): 264 | """Write misc merged data. 265 | 266 | Similar templates: we make a block diagonal matrix from the n_templates * n_templates 267 | matrices, assuming no similarity between templates from different probes. 268 | 269 | Whitening matrix: same thing, except that the matrices are n_channels * n_channels. 270 | 271 | """ 272 | diag_data = [ 273 | 'similar_templates.npy', 274 | 'whitening_mat.npy', 275 | 'whitening_mat_inv.npy', 276 | ] 277 | for fn in diag_data: 278 | try: 279 | concat = block_diag(*_load_multiple_files(fn, self.subdirs)) 280 | except FileNotFoundError: 281 | logger.debug("File %s not found, skipping.", fn) 282 | continue 283 | self._save(fn, concat) 284 | 285 | def merge(self): 286 | """Merge the probes data and return a TemplateModel instance of the merged data.""" 287 | 288 | with tqdm(desc="Merging", total=100) as bar: 289 | self.write_params() 290 | self.write_probe_desc() 291 | bar.update(10) 292 | self.write_spike_times() 293 | bar.update(10) 294 | self.write_spike_data() 295 | bar.update(10) 296 | self.write_spike_clusters() 297 | bar.update(10) 298 | self.write_cluster_data() 299 | bar.update(10) 300 | self.write_channel_data() 301 | bar.update(10) 302 | self.write_channel_positions() 303 | bar.update(10) 304 | self.write_templates() 305 | bar.update(10) 306 | self.write_template_data() 307 | bar.update(10) 308 | self.write_misc() 309 | bar.update(10) 310 | 311 | return load_model(self.out_dir / 'params.py') 312 | -------------------------------------------------------------------------------- /phylib/io/mock.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Mock datasets.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | import numpy.random as nr 11 | 12 | 13 | #------------------------------------------------------------------------------ 14 | # Artificial data 15 | #------------------------------------------------------------------------------ 16 | 17 | def artificial_waveforms(n_spikes=None, n_samples=None, n_channels=None): 18 | return .25 * nr.normal(size=(n_spikes, n_samples, n_channels)) 19 | 20 | 21 | def artificial_features(*args): 22 | return .25 * nr.normal(size=args) 23 | 24 | 25 | def artificial_masks(n_spikes=None, n_channels=None): 26 | masks = nr.uniform(size=(n_spikes, n_channels)) 27 | masks[masks < .25] = 0 28 | return masks 29 | 30 | 31 | def artificial_traces(n_samples, n_channels): 32 | return .25 * nr.normal(size=(n_samples, n_channels)) 33 | 34 | 35 | def artificial_spike_clusters(n_spikes, n_clusters, low=0): 36 | return nr.randint(size=n_spikes, low=low, high=max(1, n_clusters)) 37 | 38 | 39 | def artificial_spike_samples(n_spikes, max_isi=50): 40 | return np.cumsum(nr.randint(low=0, high=max_isi, size=n_spikes)) 41 | 42 | 43 | def artificial_correlograms(n_clusters, n_samples): 44 | return nr.uniform(size=(n_clusters, n_clusters, n_samples)) 45 | -------------------------------------------------------------------------------- /phylib/io/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cortex-lab/phylib/72316a4ccb0abed93464ed4bfcd36a3809d1dbdf/phylib/io/tests/__init__.py -------------------------------------------------------------------------------- /phylib/io/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test fixtures.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | import shutil 11 | 12 | import numpy as np 13 | from pytest import fixture 14 | 15 | from phylib.utils._misc import write_text, write_tsv 16 | from ..model import load_model, write_array 17 | from phylib.io.datasets import download_test_file 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | #------------------------------------------------------------------------------ 23 | # Fixtures 24 | #------------------------------------------------------------------------------ 25 | 26 | _FILES = [ 27 | 'template/params.py', 28 | 'template/sim_binary.dat', 29 | 'template/spike_times.npy', 30 | 'template/spike_templates.npy', 31 | 'template/spike_clusters.npy', 32 | 'template/amplitudes.npy', 33 | 34 | 'template/cluster_group.tsv', 35 | 36 | 'template/channel_map.npy', 37 | 'template/channel_positions.npy', 38 | 'template/channel_shanks.npy', 39 | 40 | 'template/similar_templates.npy', 41 | 'template/whitening_mat.npy', 42 | 43 | 'template/templates.npy', 44 | 'template/template_ind.npy', 45 | 46 | 'template/pc_features.npy', 47 | 'template/pc_feature_ind.npy', 48 | 'template/pc_feature_spike_ids.npy', 49 | 50 | 'template/template_features.npy', 51 | 'template/template_feature_ind.npy', 52 | 'template/template_feature_spike_ids.npy', 53 | ] 54 | 55 | 56 | def _remove(path): 57 | if path.exists(): 58 | path.unlink() 59 | logger.debug("Removed %s.", path) 60 | 61 | 62 | def _make_dataset(tempdir, param='dense', has_spike_attributes=True): 63 | np.random.seed(0) 64 | 65 | # Download the dataset. 66 | paths = list(map(download_test_file, _FILES)) 67 | # Copy the dataset to a temporary directory. 68 | for path in paths: 69 | to_path = tempdir / path.name 70 | # Skip sparse arrays if is_sparse is False. 71 | if param == 'sparse' and ('_ind.' in str(to_path) or 'spike_ids.' in str(to_path)): 72 | continue 73 | logger.debug("Copying file to %s.", to_path) 74 | shutil.copy(path, to_path) 75 | 76 | # Some changes to files if 'misc' fixture parameter. 77 | if param == 'misc': 78 | # Remove spike_clusters and recreate it from spike_templates. 79 | _remove(tempdir / 'spike_clusters.npy') 80 | # Replace spike_times.npy, in samples, by spikes.times.npy, in seconds. 81 | if (tempdir / 'spike_times.npy').exists(): 82 | st = np.load(tempdir / 'spike_times.npy').squeeze() 83 | st_r = st + np.random.randint(low=-20000, high=+20000, size=st.size) 84 | assert st_r.shape == st.shape 85 | # Reordered spikes. 86 | np.save(tempdir / 'spike_times_reordered.npy', st_r) 87 | np.save(tempdir / 'spikes.times.npy', st / 25000.) # sample rate 88 | _remove(tempdir / 'spike_times.npy') 89 | # Buggy TSV file should not cause a crash. 90 | write_text(tempdir / 'error.tsv', '') 91 | # Remove some non-necessary files. 92 | _remove(tempdir / 'template_features.npy') 93 | _remove(tempdir / 'pc_features.npy') 94 | _remove(tempdir / 'channel_probes.npy') 95 | _remove(tempdir / 'channel_shanks.npy') 96 | _remove(tempdir / 'amplitudes.npy') 97 | _remove(tempdir / 'whitening_mat.npy') 98 | _remove(tempdir / 'whitening_mat_inv.npy') 99 | _remove(tempdir / 'sim_binary.dat') 100 | 101 | if param == 'merged': 102 | # remove this file to make templates dense 103 | _remove(tempdir / 'template_ind.npy') 104 | clus = np.load(tempdir / 'spike_clusters.npy') 105 | max_clus = np.max(clus) 106 | # merge cluster 0 and 1 107 | clus[np.bitwise_or(clus == 0, clus == 1)] = max_clus + 1 108 | # split cluster 9 into two clusters 109 | idx = np.where(clus == 9)[0] 110 | clus[idx[0:3]] = max_clus + 2 111 | clus[idx[3:]] = max_clus + 3 112 | np.save(tempdir / 'spike_clusters.npy', clus) 113 | 114 | # Spike attributes. 115 | if has_spike_attributes: 116 | write_array(tempdir / 'spike_fail.npy', np.full(10, np.nan)) # wrong number of spikes 117 | write_array(tempdir / 'spike_works.npy', np.random.rand(314)) 118 | write_array(tempdir / 'spike_randn.npy', np.random.randn(314, 2)) 119 | 120 | # TSV file with cluster data. 121 | write_tsv( 122 | tempdir / 'cluster_Amplitude.tsv', [{'cluster_id': 1, 'Amplitude': 123.4}], 123 | first_field='cluster_id') 124 | 125 | write_tsv( 126 | tempdir / 'cluster_metrics.tsv', [ 127 | {'cluster_id': 2, 'met1': 123.4, 'met2': 'hello world 1'}, 128 | {'cluster_id': 3, 'met1': 5.678}, 129 | {'cluster_id': 5, 'met2': 'hello world 2'}, 130 | ]) 131 | 132 | template_path = tempdir / paths[0].name 133 | return template_path 134 | 135 | 136 | @fixture(scope='function', params=('dense', 'sparse', 'misc', 'merged')) 137 | def template_path_full(tempdir, request): 138 | return _make_dataset(tempdir, request.param) 139 | 140 | 141 | @fixture(scope='function') 142 | def template_path(tempdir, request): 143 | return _make_dataset(tempdir, param='dense', has_spike_attributes=False) 144 | 145 | 146 | @fixture 147 | def template_model_full(template_path_full): 148 | model = load_model(template_path_full) 149 | yield model 150 | model.close() 151 | 152 | 153 | @fixture 154 | def template_model(template_path): 155 | model = load_model(template_path) 156 | yield model 157 | model.close() 158 | -------------------------------------------------------------------------------- /phylib/io/tests/test_alf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test ALF dataset generation.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import os 11 | from pathlib import Path 12 | import shutil 13 | from pytest import fixture, raises 14 | 15 | import numpy as np 16 | import numpy.random as nr 17 | 18 | from phylib.utils._misc import _write_tsv_simple 19 | from ..alf import _FILE_RENAMES, _load, EphysAlfCreator 20 | from ..model import TemplateModel 21 | 22 | 23 | #------------------------------------------------------------------------------ 24 | # Fixture 25 | #------------------------------------------------------------------------------ 26 | 27 | class Dataset(object): 28 | def __init__(self, tempdir): 29 | np.random.seed(42) 30 | self.tmp_dir = tempdir 31 | p = Path(self.tmp_dir) 32 | self.ns = 100 33 | self.nsamp = 25 34 | self.ncmax = 42 35 | self.nc = 10 36 | self.nt = 5 37 | self.ncd = 1000 38 | np.save(p / 'spike_times.npy', .01 * np.cumsum(nr.exponential(size=self.ns))) 39 | np.save(p / 'spike_clusters.npy', nr.randint(low=1, high=self.nt, size=self.ns)) 40 | shutil.copy(p / 'spike_clusters.npy', p / 'spike_templates.npy') 41 | np.save(p / 'amplitudes.npy', nr.uniform(low=0.5, high=1.5, size=self.ns)) 42 | np.save(p / 'channel_positions.npy', np.c_[np.arange(self.nc), np.zeros(self.nc)]) 43 | np.save(p / 'templates.npy', np.random.normal(size=(self.nt, 50, self.nc))) 44 | np.save(p / 'similar_templates.npy', np.tile(np.arange(self.nt), (self.nt, 1))) 45 | np.save(p / 'channel_map.npy', np.c_[np.arange(self.nc)]) 46 | np.save(p / 'channel_probe.npy', np.zeros(self.nc)) 47 | np.save(p / 'whitening_mat.npy', np.eye(self.nc, self.nc)) 48 | np.save(p / '_phy_spikes_subset.channels.npy', np.zeros([self.ns, self.ncmax])) 49 | np.save(p / '_phy_spikes_subset.spikes.npy', np.zeros([self.ns])) 50 | np.save(p / '_phy_spikes_subset.waveforms.npy', np.zeros( 51 | [self.ns, self.nsamp, self.ncmax]) 52 | ) 53 | 54 | _write_tsv_simple(p / 'cluster_group.tsv', 'group', {2: 'good', 3: 'mua', 5: 'noise'}) 55 | _write_tsv_simple(p / 'cluster_Amplitude.tsv', field_name='Amplitude', 56 | data={str(n): np.random.rand() * 120 for n in np.arange(self.nt)}) 57 | with open(p / 'probes.description.txt', 'w+') as fid: 58 | fid.writelines(['label\n']) 59 | 60 | # Raw data 61 | self.dat_path = p / 'rawdata.npy' 62 | np.save(self.dat_path, np.random.normal(size=(self.ncd, self.nc))) 63 | 64 | # LFP data. 65 | lfdata = (100 * np.random.normal(size=(1000, self.nc))).astype(np.int16) 66 | with (p / 'mydata.lf.bin').open('wb') as f: 67 | lfdata.tofile(f) 68 | 69 | self.files = os.listdir(self.tmp_dir) 70 | 71 | def _load(self, fn): 72 | p = Path(self.tmp_dir) 73 | return _load(p / fn) 74 | 75 | 76 | @fixture 77 | def dataset(tempdir): 78 | return Dataset(tempdir) 79 | 80 | 81 | def test_ephys_1(dataset): 82 | assert dataset._load('spike_times.npy').shape == (dataset.ns,) 83 | assert dataset._load('spike_clusters.npy').shape == (dataset.ns,) 84 | assert dataset._load('amplitudes.npy').shape == (dataset.ns,) 85 | assert dataset._load('channel_positions.npy').shape == (dataset.nc, 2) 86 | assert dataset._load('templates.npy').shape == (dataset.nt, 50, dataset.nc) 87 | assert dataset._load('channel_map.npy').shape == (dataset.nc, 1) 88 | assert dataset._load('channel_probe.npy').shape == (dataset.nc,) 89 | assert len(dataset._load('cluster_group.tsv')) == 3 90 | assert dataset._load('rawdata.npy').shape == (1000, dataset.nc) 91 | assert dataset._load('mydata.lf.bin').shape == (1000 * dataset.nc,) 92 | assert dataset._load('whitening_mat.npy').shape == (dataset.nc, dataset.nc) 93 | assert dataset._load('_phy_spikes_subset.channels.npy').shape == (dataset.ns, dataset.ncmax) 94 | assert dataset._load('_phy_spikes_subset.spikes.npy').shape == (dataset.ns,) 95 | assert dataset._load('_phy_spikes_subset.waveforms.npy').shape == ( 96 | (dataset.ns, dataset.nsamp, dataset.ncmax) 97 | ) 98 | 99 | 100 | def test_spike_depths(dataset): 101 | path = Path(dataset.tmp_dir) 102 | out_path = path / 'alf' 103 | 104 | mtemp = TemplateModel( 105 | dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) 106 | 107 | # create some sparse PC features 108 | n_subch = int(np.round(mtemp.n_channels / 2) - 1) 109 | pc_features = np.zeros((mtemp.n_spikes, n_subch, 3)) 110 | close_channels = np.meshgrid(np.ones(mtemp.n_templates), np.arange(n_subch))[1] 111 | chind = mtemp.templates_channels 112 | chind = mtemp.templates_channels + close_channels * ((chind < 5) * 2 - 1) 113 | pc_features_ind = chind.transpose() 114 | # all PCs max between first and second channel 115 | pc_features[:, 0, 0] = 1 116 | pc_features[:, 1, 0] = 0.5 117 | print(mtemp.templates_channels) 118 | print(mtemp.templates_waveforms_durations) 119 | # add some depth information 120 | mtemp.channel_positions[:, 1] = mtemp.channel_positions[:, 0] + 10 121 | np.save(path / 'pc_features.npy', np.swapaxes(pc_features, 2, 1)) 122 | np.save(path / 'pc_feature_ind.npy', pc_features_ind) 123 | np.save(path / 'channel_positions.npy', mtemp.channel_positions) 124 | 125 | model = TemplateModel( 126 | dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) 127 | 128 | c = EphysAlfCreator(model) 129 | shutil.rmtree(out_path, ignore_errors=True) 130 | c.convert(out_path) 131 | sd = np.load(next(out_path.glob('spikes.depths.npy'))) 132 | sd_ = model.channel_positions[model.templates_channels[model.spike_templates], 1] 133 | assert np.all(np.abs(sd - sd_) <= 0.5) 134 | 135 | 136 | def test_creator(dataset): 137 | _FILE_CREATES = ( 138 | 'spikes.times*.npy', 139 | 'spikes.depths*.npy', 140 | 'spikes.samples*.npy', 141 | 'clusters.uuids*.csv', 142 | 'clusters.amps*.npy', 143 | 'clusters.channels*.npy', 144 | 'clusters.depths*.npy', 145 | 'clusters.peakToTrough*.npy', 146 | 'clusters.waveforms*.npy', 147 | 'clusters.waveformsChannels*.npy', 148 | 'channels.localCoordinates*.npy', 149 | 'channels.rawInd*.npy', 150 | 'templates.waveforms*.npy', 151 | 'templates.waveformsChannels*.npy', 152 | ) 153 | path = Path(dataset.tmp_dir) 154 | out_path = path / 'alf' 155 | 156 | model = TemplateModel( 157 | dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) 158 | 159 | c = EphysAlfCreator(model) 160 | with raises(IOError): 161 | c.convert(dataset.tmp_dir) 162 | 163 | def check_conversion_output(): 164 | # Check all renames. 165 | for old, new, _ in _FILE_RENAMES: 166 | if (path / old).exists(): 167 | pattern = f'{Path(new).stem}*{Path(new).suffix}' 168 | assert next(out_path.glob(pattern)).exists() 169 | 170 | new_files = [] 171 | for new in _FILE_CREATES: 172 | f = next(out_path.glob(new)) 173 | new_files.append(f) 174 | assert f.exists() 175 | 176 | # makes sure the output dimensions match (especially clusters which should be 4) 177 | cl_shape = [] 178 | for f in new_files: 179 | if f.name.startswith('clusters.') and f.name.endswith('.npy'): 180 | cl_shape.append(np.load(f).shape[0]) 181 | elif f.name.startswith('clusters.') and f.name.endswith('.csv'): 182 | with open(f) as fid: 183 | cl_shape.append(len(fid.readlines()) - 1) 184 | sp_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('spikes.')] 185 | ch_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('channels.')] 186 | 187 | assert len(set(cl_shape)) == 1 188 | assert len(set(sp_shape)) == 1 189 | assert len(set(ch_shape)) == 1 190 | 191 | dur = np.load(next(out_path.glob('clusters.peakToTrough*.npy'))) 192 | assert np.all(dur == np.array([-9.5, 3., 13., -4.5, -2.5])) 193 | 194 | def read_after_write(): 195 | model = TemplateModel(dir_path=out_path, dat_path=dataset.dat_path, 196 | sample_rate=2000, n_channels_dat=dataset.nc) 197 | 198 | np.all(model.spike_templates == c.model.spike_templates) 199 | np.all(model.spike_times == c.model.spike_times) 200 | np.all(model.spike_samples == c.model.spike_samples) 201 | 202 | # test a straight export, make sure we can reload the data 203 | shutil.rmtree(out_path, ignore_errors=True) 204 | c.convert(out_path) 205 | check_conversion_output() 206 | read_after_write() 207 | 208 | # test with a label after the attribute name 209 | shutil.rmtree(out_path) 210 | c.convert(out_path, label='probe00') 211 | check_conversion_output() 212 | read_after_write() 213 | 214 | 215 | def test_merger(dataset): 216 | 217 | path = Path(dataset.tmp_dir) 218 | out_path = path / 'alf' 219 | 220 | model = TemplateModel( 221 | dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) 222 | 223 | c = EphysAlfCreator(model) 224 | c.convert(out_path) 225 | 226 | model.close() 227 | 228 | # path.joinpath('_phy_spikes_subset.channels.npy').unlink() 229 | # path.joinpath('_phy_spikes_subset.waveforms.npy').unlink() 230 | # path.joinpath('_phy_spikes_subset.spikes.npy').unlink() 231 | 232 | out_path_merge = path / 'alf_merge' 233 | spike_clusters = dataset._load('spike_clusters.npy') 234 | clu, n_clu = np.unique(spike_clusters, return_counts=True) 235 | 236 | # merge the first two clusters 237 | merge_clu = clu[0:2] 238 | spike_clusters[np.bitwise_or(spike_clusters == clu[0], 239 | spike_clusters == clu[1])] = np.max(clu) + 1 240 | # split the cluster with the most spikes 241 | split_clu = clu[-1] 242 | idx = np.where(spike_clusters == split_clu)[0] 243 | spike_clusters[idx[0:int(n_clu[-1] / 2)]] = np.max(clu) + 2 244 | spike_clusters[idx[int(n_clu[-1] / 2):]] = np.max(clu) + 3 245 | 246 | np.save(path / 'spike_clusters.npy', spike_clusters) 247 | 248 | model = TemplateModel( 249 | dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) 250 | 251 | c = EphysAlfCreator(model) 252 | c.convert(out_path_merge) 253 | 254 | # Test that the split are the same for the expected datasets 255 | clu_old = np.load(next(out_path.glob('clusters.peakToTrough.npy'))) 256 | clu_new = np.load(next(out_path_merge.glob('clusters.peakToTrough.npy'))) 257 | assert clu_old[split_clu] == clu_new[np.max(clu) + 2] 258 | assert clu_old[split_clu] == clu_new[np.max(clu) + 3] 259 | 260 | assert np.isnan([clu_new[split_clu]])[0] 261 | assert np.isnan([clu_new[merge_clu[0]]])[0] 262 | assert np.isnan([clu_new[merge_clu[1]]])[0] 263 | 264 | clu_old = np.load(next(out_path.glob('clusters.channels.npy'))) 265 | clu_new = np.load(next(out_path_merge.glob('clusters.channels.npy'))) 266 | assert clu_old[split_clu] == clu_new[np.max(clu) + 2] 267 | assert clu_old[split_clu] == clu_new[np.max(clu) + 3] 268 | assert clu_new[split_clu] == 0 269 | assert clu_new[merge_clu[0]] == 0 270 | assert clu_new[merge_clu[1]] == 0 271 | 272 | clu_old = np.load(next(out_path.glob('clusters.depths.npy'))) 273 | clu_new = np.load(next(out_path_merge.glob('clusters.depths.npy'))) 274 | assert clu_old[split_clu] == clu_new[np.max(clu) + 2] 275 | assert clu_old[split_clu] == clu_new[np.max(clu) + 3] 276 | assert np.isnan([clu_new[split_clu]])[0] 277 | assert np.isnan([clu_new[merge_clu[0]]])[0] 278 | assert np.isnan([clu_new[merge_clu[1]]])[0] 279 | 280 | clu_old = np.load(next(out_path.glob('clusters.waveformsChannels.npy'))) 281 | clu_new = np.load(next(out_path_merge.glob('clusters.waveformsChannels.npy'))) 282 | assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 2]) 283 | assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 3]) 284 | -------------------------------------------------------------------------------- /phylib/io/tests/test_array.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of array utility functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | from pytest import raises 13 | 14 | from ..array import ( 15 | _unique, _normalize, _index_of, _spikes_in_clusters, _spikes_per_cluster, 16 | _flatten_per_cluster, get_closest_clusters, _get_data_lim, _flatten, _clip, 17 | chunk_bounds, excerpts, data_chunk, grouped_mean, SpikeSelector, 18 | get_excerpts, _range_from_slice, _pad, _get_padded, 19 | read_array, write_array) 20 | from phylib.utils._types import _as_array 21 | from phylib.utils.testing import _assert_equal as ae 22 | from ..mock import artificial_spike_clusters, artificial_spike_samples 23 | 24 | 25 | #------------------------------------------------------------------------------ 26 | # Test utility functions 27 | #------------------------------------------------------------------------------ 28 | 29 | def test_clip(): 30 | assert _clip(-1, 0, 1) == 0 31 | 32 | 33 | def test_range_from_slice(): 34 | """Test '_range_from_slice'.""" 35 | 36 | class _SliceTest(object): 37 | """Utility class to make it more convenient to test slice objects.""" 38 | def __init__(self, **kwargs): 39 | self._kwargs = kwargs 40 | 41 | def __getitem__(self, item): 42 | if isinstance(item, slice): 43 | return _range_from_slice(item, **self._kwargs) 44 | 45 | with raises(ValueError): 46 | _SliceTest()[:] 47 | with raises(ValueError): 48 | _SliceTest()[1:] 49 | ae(_SliceTest()[:5], [0, 1, 2, 3, 4]) 50 | ae(_SliceTest()[1:5], [1, 2, 3, 4]) 51 | 52 | with raises(ValueError): 53 | _SliceTest()[::2] 54 | with raises(ValueError): 55 | _SliceTest()[1::2] 56 | ae(_SliceTest()[1:5:2], [1, 3]) 57 | 58 | with raises(ValueError): 59 | _SliceTest(start=0)[:] 60 | with raises(ValueError): 61 | _SliceTest(start=1)[:] 62 | with raises(ValueError): 63 | _SliceTest(step=2)[:] 64 | 65 | ae(_SliceTest(stop=5)[:], [0, 1, 2, 3, 4]) 66 | ae(_SliceTest(start=1, stop=5)[:], [1, 2, 3, 4]) 67 | ae(_SliceTest(stop=5)[1:], [1, 2, 3, 4]) 68 | ae(_SliceTest(start=1)[:5], [1, 2, 3, 4]) 69 | ae(_SliceTest(start=1, step=2)[:5], [1, 3]) 70 | ae(_SliceTest(start=1)[:5:2], [1, 3]) 71 | 72 | ae(_SliceTest(length=5)[:], [0, 1, 2, 3, 4]) 73 | with raises(ValueError): 74 | _SliceTest(length=5)[:3] 75 | ae(_SliceTest(length=5)[:10], [0, 1, 2, 3, 4]) 76 | ae(_SliceTest(length=5)[:5], [0, 1, 2, 3, 4]) 77 | ae(_SliceTest(start=1, length=5)[:], [1, 2, 3, 4, 5]) 78 | ae(_SliceTest(start=1, length=5)[:6], [1, 2, 3, 4, 5]) 79 | with raises(ValueError): 80 | _SliceTest(start=1, length=5)[:4] 81 | ae(_SliceTest(start=1, step=2, stop=5)[:], [1, 3]) 82 | ae(_SliceTest(start=1, stop=5)[::2], [1, 3]) 83 | ae(_SliceTest(stop=5)[1::2], [1, 3]) 84 | 85 | 86 | def test_pad(): 87 | arr = np.random.rand(10, 3) 88 | 89 | ae(_pad(arr, 0, 'right'), arr[:0, :]) 90 | ae(_pad(arr, 3, 'right'), arr[:3, :]) 91 | ae(_pad(arr, 9), arr[:9, :]) 92 | ae(_pad(arr, 10), arr) 93 | 94 | ae(_pad(arr, 12, 'right')[:10, :], arr) 95 | ae(_pad(arr, 12)[10:, :], np.zeros((2, 3))) 96 | 97 | ae(_pad(arr, 0, 'left'), arr[:0, :]) 98 | ae(_pad(arr, 3, 'left'), arr[7:, :]) 99 | ae(_pad(arr, 9, 'left'), arr[1:, :]) 100 | ae(_pad(arr, 10, 'left'), arr) 101 | 102 | ae(_pad(arr, 12, 'left')[2:, :], arr) 103 | ae(_pad(arr, 12, 'left')[:2, :], np.zeros((2, 3))) 104 | 105 | with raises(ValueError): 106 | _pad(arr, -1) 107 | 108 | 109 | def test_get_padded(): 110 | arr = np.array([1, 2, 3])[:, np.newaxis] 111 | 112 | with raises(RuntimeError): 113 | ae(_get_padded(arr, -2, 5).ravel(), [1, 2, 3, 0, 0]) 114 | ae(_get_padded(arr, 1, 2).ravel(), [2]) 115 | ae(_get_padded(arr, 0, 5).ravel(), [1, 2, 3, 0, 0]) 116 | ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) 117 | 118 | 119 | def test_get_data_lim(): 120 | arr = np.random.rand(10, 5) 121 | assert 0 < _get_data_lim(arr) < 1 122 | assert 0 < _get_data_lim(arr, 2) < 1 123 | 124 | 125 | def test_unique(): 126 | """Test _unique() function""" 127 | _unique([]) 128 | 129 | n_spikes = 300 130 | n_clusters = 3 131 | spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) 132 | ae(_unique(spike_clusters), np.arange(n_clusters)) 133 | 134 | 135 | def test_normalize(): 136 | """Test _normalize() function.""" 137 | 138 | n_channels = 10 139 | positions = 1 + 2 * np.random.randn(n_channels, 2) 140 | 141 | # Keep ration is False. 142 | positions_n = _normalize(positions) 143 | 144 | x_min, y_min = positions_n.min(axis=0) 145 | x_max, y_max = positions_n.max(axis=0) 146 | 147 | np.allclose(x_min, 0.) 148 | np.allclose(x_max, 1.) 149 | np.allclose(y_min, 0.) 150 | np.allclose(y_max, 1.) 151 | 152 | # Keep ratio is True. 153 | positions_n = _normalize(positions, keep_ratio=True) 154 | 155 | x_min, y_min = positions_n.min(axis=0) 156 | x_max, y_max = positions_n.max(axis=0) 157 | 158 | np.allclose(min(x_min, y_min), 0.) 159 | np.allclose(max(x_max, y_max), 1.) 160 | np.allclose(x_min + x_max, 1) 161 | np.allclose(y_min + y_max, 1) 162 | 163 | 164 | def test_index_of(): 165 | """Test _index_of.""" 166 | arr = [36, 42, 42, 36, 36, 2, 42] 167 | lookup = _unique(arr) 168 | ae(_index_of(arr, lookup), [1, 2, 2, 1, 1, 0, 2]) 169 | 170 | 171 | def test_as_array(): 172 | ae(_as_array(3), [3]) 173 | ae(_as_array([3]), [3]) 174 | ae(_as_array(3.), [3.]) 175 | ae(_as_array([3.]), [3.]) 176 | 177 | with raises(ValueError): 178 | _as_array(map) 179 | 180 | 181 | def test_flatten(): 182 | assert _flatten([[0, 1], [2]]) == [0, 1, 2] 183 | 184 | 185 | def test_get_closest_clusters(): 186 | out = get_closest_clusters(1, [0, 1, 2], lambda c, d: (d - c)) 187 | assert [_ for _, __ in out] == [2, 1, 0] 188 | 189 | 190 | #------------------------------------------------------------------------------ 191 | # Test read/save 192 | #------------------------------------------------------------------------------ 193 | 194 | def test_read_write(tempdir): 195 | arr = np.arange(10).astype(np.float32) 196 | 197 | path = Path(tempdir) / 'test.npy' 198 | 199 | write_array(path, arr) 200 | ae(read_array(path), arr) 201 | ae(read_array(path, mmap_mode='r'), arr) 202 | 203 | 204 | #------------------------------------------------------------------------------ 205 | # Test chunking 206 | #------------------------------------------------------------------------------ 207 | 208 | def test_chunk_bounds(): 209 | chunks = chunk_bounds(200, 100, overlap=20) 210 | 211 | assert next(chunks) == (0, 100, 0, 90) 212 | assert next(chunks) == (80, 180, 90, 170) 213 | assert next(chunks) == (160, 200, 170, 200) 214 | 215 | 216 | def test_chunk(): 217 | data = np.random.randn(200, 4) 218 | chunks = chunk_bounds(data.shape[0], 100, overlap=20) 219 | 220 | with raises(ValueError): 221 | data_chunk(data, (0, 0, 0)) 222 | 223 | assert data_chunk(data, (0, 0)).shape == (0, 4) 224 | 225 | # Chunk 1. 226 | ch = next(chunks) 227 | d = data_chunk(data, ch) 228 | d_o = data_chunk(data, ch, with_overlap=True) 229 | 230 | ae(d_o, data[0:100]) 231 | ae(d, data[0:90]) 232 | 233 | # Chunk 2. 234 | ch = next(chunks) 235 | d = data_chunk(data, ch) 236 | d_o = data_chunk(data, ch, with_overlap=True) 237 | 238 | ae(d_o, data[80:180]) 239 | ae(d, data[90:170]) 240 | 241 | 242 | def test_excerpts_1(): 243 | bounds = [(start, end) for (start, end) in excerpts(100, 244 | n_excerpts=3, 245 | excerpt_size=10)] 246 | assert bounds == [(0, 10), (45, 55), (90, 100)] 247 | 248 | 249 | def test_excerpts_2(): 250 | bounds = [(start, end) for (start, end) in excerpts(10, 251 | n_excerpts=3, 252 | excerpt_size=10)] 253 | assert bounds == [(0, 10)] 254 | 255 | 256 | def test_get_excerpts(): 257 | data = np.random.rand(100, 2) 258 | subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) 259 | assert subdata.shape == (50, 2) 260 | ae(subdata[:5, :], data[:5, :]) 261 | ae(subdata[-5:, :], data[-10:-5, :]) 262 | 263 | data = np.random.rand(10, 2) 264 | subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) 265 | ae(subdata, data) 266 | 267 | data = np.random.rand(10, 2) 268 | subdata = get_excerpts(data, n_excerpts=1, excerpt_size=10) 269 | ae(subdata, data) 270 | 271 | assert len(get_excerpts(data, n_excerpts=0, excerpt_size=10)) == 0 272 | 273 | 274 | #------------------------------------------------------------------------------ 275 | # Test spike clusters functions 276 | #------------------------------------------------------------------------------ 277 | 278 | def test_spikes_in_clusters(): 279 | """Test _spikes_in_clusters().""" 280 | 281 | n_spikes = 100 282 | n_clusters = 5 283 | spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) 284 | 285 | ae(_spikes_in_clusters(spike_clusters, []), []) 286 | 287 | for i in range(n_clusters): 288 | assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, [i])] == i) 289 | 290 | clusters = [1, 2, 3] 291 | assert np.all(np.isin( 292 | spike_clusters[_spikes_in_clusters(spike_clusters, clusters)], clusters)) 293 | 294 | 295 | def test_spikes_per_cluster(): 296 | """Test _spikes_per_cluster().""" 297 | 298 | n_spikes = 100 299 | n_clusters = 3 300 | spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) 301 | 302 | assert not _spikes_per_cluster([]) 303 | 304 | spikes_per_cluster = _spikes_per_cluster(spike_clusters) 305 | assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) 306 | 307 | for i in range(n_clusters): 308 | ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) 309 | assert np.all(spike_clusters[spikes_per_cluster[i]] == i) 310 | 311 | 312 | def test_flatten_per_cluster(): 313 | spc = {2: [2, 7, 11], 3: [3, 5], 5: []} 314 | arr = _flatten_per_cluster(spc) 315 | ae(arr, [2, 3, 5, 7, 11]) 316 | 317 | 318 | def test_grouped_mean(): 319 | spike_clusters = np.array([2, 3, 2, 2, 5]) 320 | arr = [9, -3, 10, 11, -5] 321 | ae(grouped_mean(arr, spike_clusters), [10, -3, -5]) 322 | 323 | 324 | #------------------------------------------------------------------------------ 325 | # Test spike selection 326 | #------------------------------------------------------------------------------ 327 | 328 | def test_select_spikes_1(): 329 | spike_times = np.array([0., 1., 2., 3.3, 4.4]) 330 | spike_clusters = np.array([1, 2, 1, 2, 4]) 331 | chunk_bounds = [0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6] 332 | n_chunks_kept = 2 333 | cluster_ids = [1, 2, 4] 334 | spikes_ids_kept = [0, 1, 3] 335 | 336 | spc = _spikes_per_cluster(spike_clusters) 337 | ss = SpikeSelector( 338 | get_spikes_per_cluster=lambda cl: spc.get(cl, np.array([], dtype=np.int64)), 339 | spike_times=spike_times, chunk_bounds=chunk_bounds, n_chunks_kept=n_chunks_kept) 340 | ae(ss.chunks_kept, [0.0, 1.1, 3.3, 4.4]) 341 | 342 | ae(ss(3, [], subset_chunks=True), []) 343 | ae(ss(3, [0], subset_chunks=True), []) 344 | ae(ss(3, [1], subset_chunks=True), [0]) 345 | 346 | ae(ss(None, cluster_ids, subset_chunks=True), spikes_ids_kept) 347 | ae(ss(0, cluster_ids, subset_chunks=True), spikes_ids_kept) 348 | ae(ss(3, cluster_ids, subset_chunks=True), spikes_ids_kept) 349 | ae(ss(2, cluster_ids, subset_chunks=True), spikes_ids_kept) 350 | assert list(ss(1, cluster_ids, subset_chunks=True)) in [[0, 1], [0, 3]] 351 | 352 | ae(ss(2, cluster_ids, subset_spikes=[0, 1], subset_chunks=True), [0, 1]) 353 | ae(ss(2, cluster_ids, subset_chunks=False), np.arange(5)) 354 | 355 | 356 | def test_select_spikes_2(): 357 | n_spikes = 1000 358 | n_clusters = 10 359 | spike_times = artificial_spike_samples(n_spikes) 360 | spike_times = 10. * spike_times / spike_times.max() 361 | chunk_bounds = np.linspace(0.0, 10.0, 11) 362 | n_chunks_kept = 3 363 | chunks_kept = [0., 1., 4., 5., 8., 9.] 364 | spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) 365 | 366 | spc = _spikes_per_cluster(spike_clusters) 367 | ss = SpikeSelector( 368 | get_spikes_per_cluster=lambda cl: spc.get(cl, np.array([], dtype=np.int64)), 369 | spike_times=spike_times, chunk_bounds=chunk_bounds, n_chunks_kept=n_chunks_kept) 370 | ae(ss.chunks_kept, chunks_kept) 371 | 372 | def _check_chunks(sid): 373 | chunk_ids = np.searchsorted(chunk_bounds, spike_times[sid], 'right') - 1 374 | ae(np.unique(chunk_ids), [0, 4, 8]) 375 | 376 | # Select all spikes belonging to the kept chunks. 377 | sid = ss(n_spikes, np.arange(n_clusters), subset_chunks=True) 378 | _check_chunks(sid) 379 | 380 | # Select 10 spikes from each cluster. 381 | sid = ss(10, np.arange(n_clusters), subset_chunks=True) 382 | assert np.all(np.diff(sid) > 0) 383 | _check_chunks(sid) 384 | ae(np.bincount(spike_clusters[sid]), [10] * 10) 385 | -------------------------------------------------------------------------------- /phylib/io/tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of dataset utility functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | from pathlib import Path 11 | from itertools import product 12 | 13 | import numpy as np 14 | from numpy.testing import assert_array_equal as ae 15 | import responses 16 | from pytest import raises, fixture 17 | 18 | from ..datasets import (download_file, 19 | download_test_file, 20 | _check_md5_of_url, 21 | ) 22 | from phylib.utils.testing import captured_logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | #------------------------------------------------------------------------------ 28 | # Fixtures 29 | #------------------------------------------------------------------------------ 30 | 31 | # Test URL and data 32 | _URL = 'http://test/data' 33 | _DATA = np.linspace(0., 1., 100000).astype(np.float32) 34 | _CHECKSUM = '7d257d0ae7e3af8ca3574ccc3a4bf072' 35 | 36 | 37 | def _add_mock_response(url, body, file_type='binary'): 38 | content_type = ('application/octet-stream' 39 | if file_type == 'binary' else 'text/plain') 40 | responses.add(responses.GET, url, 41 | body=body, 42 | status=200, 43 | content_type=content_type, 44 | ) 45 | 46 | 47 | @fixture 48 | def mock_url(): 49 | _add_mock_response(_URL, _DATA.tobytes()) 50 | _add_mock_response(_URL + '.md5', _CHECKSUM + ' ' + Path(_URL).name) 51 | yield _URL 52 | responses.reset() 53 | 54 | 55 | @fixture(params=product((True, False), repeat=4)) 56 | def mock_urls(request): 57 | data = _DATA.tobytes() 58 | checksum = _CHECKSUM 59 | url_data = _URL 60 | url_checksum = _URL + '.md5' 61 | 62 | if not request.param[0]: 63 | # Data URL is corrupted. 64 | url_data = url_data[:-1] 65 | if not request.param[1]: 66 | # Data is corrupted. 67 | data = data[:-1] 68 | if not request.param[2]: 69 | # Checksum URL is corrupted. 70 | url_checksum = url_checksum[:-1] 71 | if not request.param[3]: 72 | # Checksum is corrupted. 73 | checksum = checksum[:-1] 74 | 75 | _add_mock_response(url_data, data) 76 | _add_mock_response(url_checksum, checksum) 77 | yield request.param, url_data, url_checksum 78 | responses.reset() 79 | 80 | 81 | def _dl(path): 82 | assert path 83 | download_file(_URL, path) 84 | with open(path, 'rb') as f: 85 | data = f.read() 86 | return data 87 | 88 | 89 | def _check(data): 90 | ae(np.frombuffer(data, np.float32), _DATA) 91 | 92 | 93 | #------------------------------------------------------------------------------ 94 | # Test utility functions 95 | #------------------------------------------------------------------------------ 96 | 97 | @responses.activate 98 | def test_check_md5_of_url(tempdir, mock_url): 99 | output_path = Path(tempdir) / 'data' 100 | download_file(_URL, output_path) 101 | assert _check_md5_of_url(output_path, _URL) 102 | 103 | 104 | #------------------------------------------------------------------------------ 105 | # Test download functions 106 | #------------------------------------------------------------------------------ 107 | 108 | @responses.activate 109 | def test_download_not_found(tempdir): 110 | path = Path(tempdir) / 'test' 111 | with raises(Exception): 112 | download_file(_URL + '_notfound', path) 113 | 114 | 115 | @responses.activate 116 | def test_download_already_exists_invalid(tempdir, mock_url): 117 | with captured_logging() as buf: 118 | path = Path(tempdir) / 'test' 119 | # Create empty file. 120 | open(path, 'a').close() 121 | _check(_dl(path)) 122 | assert 'redownload' in buf.getvalue() 123 | 124 | 125 | @responses.activate 126 | def test_download_already_exists_valid(tempdir, mock_url): 127 | with captured_logging() as buf: 128 | path = Path(tempdir) / 'test' 129 | # Create valid file. 130 | with open(path, 'ab') as f: 131 | f.write(_DATA.tobytes()) 132 | _check(_dl(path)) 133 | assert 'skip' in buf.getvalue() 134 | 135 | 136 | @responses.activate 137 | def test_download_file(tempdir, mock_urls): 138 | path = Path(tempdir) / 'test' 139 | param, url_data, url_checksum = mock_urls 140 | data_here, data_valid, checksum_here, checksum_valid = param 141 | 142 | assert_succeeds = (data_here and data_valid and 143 | ((checksum_here == checksum_valid) or 144 | (not(checksum_here) and checksum_valid))) 145 | 146 | download_succeeds = (assert_succeeds or (data_here and 147 | (not(data_valid) and not(checksum_here)))) 148 | 149 | if download_succeeds: 150 | data = _dl(path) 151 | else: 152 | with raises(Exception): 153 | data = _dl(path) 154 | 155 | if assert_succeeds: 156 | _check(data) 157 | 158 | 159 | def test_download_test_file(tempdir): 160 | name = 'test/test-4ch-1s.dat' 161 | path = download_test_file(name, config_dir=tempdir) 162 | assert path.exists() 163 | assert path.stat().st_size == 160000 164 | path = download_test_file(name, config_dir=tempdir) 165 | -------------------------------------------------------------------------------- /phylib/io/tests/test_merge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test probe merging.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import numpy as np 11 | 12 | from ..merge import Merger 13 | from phylib.io.alf import EphysAlfCreator 14 | from phylib.io.model import load_model 15 | from phylib.io.tests.conftest import _make_dataset 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | # Merging tests 20 | #------------------------------------------------------------------------------ 21 | 22 | def test_probe_merge_1(tempdir): 23 | out_dir = tempdir / 'merged' 24 | 25 | # Create two identical datasets. 26 | probe_names = ('probe_left', 'probe_right') 27 | for name in probe_names: 28 | (tempdir / name).mkdir(exist_ok=True, parents=True) 29 | _make_dataset(tempdir / name, param='dense', has_spike_attributes=False) 30 | 31 | subdirs = [tempdir / name for name in probe_names] 32 | 33 | # Merge them. 34 | m = Merger(subdirs, out_dir) 35 | single = load_model(tempdir / probe_names[0] / 'params.py') 36 | 37 | # Test the merged dataset. 38 | merged = m.merge() 39 | for name in ('n_spikes', 'n_channels', 'n_templates'): 40 | assert getattr(merged, name) == getattr(single, name) * 2 41 | assert merged.sample_rate == single.sample_rate 42 | 43 | 44 | def test_probe_merge_2(tempdir): 45 | out_dir = tempdir / 'merged' 46 | 47 | # Create two identical datasets. 48 | probe_names = ('probe_left', 'probe_right') 49 | for name in probe_names: 50 | (tempdir / name).mkdir(exist_ok=True, parents=True) 51 | _make_dataset(tempdir / name, param='dense', has_spike_attributes=False) 52 | subdirs = [tempdir / name for name in probe_names] 53 | 54 | # Add small shift in the spike times of the second probe. 55 | single = load_model(tempdir / probe_names[0] / 'params.py') 56 | st_path = tempdir / 'probe_right/spike_times.npy' 57 | np.save(st_path, single.spike_samples + 1) 58 | # make amplitudes unique and growing so they can serve as key and sorting indices 59 | single.amplitudes = np.linspace(5, 15, single.n_spikes) 60 | # single.spike_clusters[single.spike_clusters == 0] = 12 61 | for m, subdir in enumerate(subdirs): 62 | np.save(subdir / 'amplitudes.npy', single.amplitudes + 20 * m) 63 | np.save(subdir / 'spike_clusters.npy', single.spike_clusters) 64 | 65 | # Merge them. 66 | m = Merger(subdirs, out_dir) 67 | merged = m.merge() 68 | 69 | # Test the merged dataset. 70 | for name in ('n_spikes', 'n_channels', 'n_templates'): 71 | assert getattr(merged, name) == getattr(single, name) * 2 72 | assert merged.sample_rate == single.sample_rate 73 | 74 | # Check the spikes. 75 | single = load_model(tempdir / probe_names[0] / 'params.py') 76 | 77 | def test_merged_single(merged, merged_original_amps=None): 78 | if merged_original_amps is None: 79 | merged_original_amps = merged.amplitudes 80 | _, im1, i1 = np.intersect1d(merged_original_amps, single.amplitudes, return_indices=True) 81 | _, im2, i2 = np.intersect1d(merged_original_amps, single.amplitudes + 20, 82 | return_indices=True) 83 | # intersection spans the full vector 84 | assert i1.size + i2.size == merged.amplitudes.size 85 | # test spikes 86 | assert np.allclose(merged.spike_times[im1], single.spike_times[i1]) 87 | assert np.allclose(merged.spike_times[im2], single.spike_times[i2] + 4e-5) 88 | # test clusters 89 | assert np.allclose(merged.spike_clusters[im2], single.spike_clusters[i2] + 64) 90 | assert np.allclose(merged.spike_clusters[im1], single.spike_clusters[i1]) 91 | # test templates 92 | assert np.all(merged.spike_templates[im1] - single.spike_templates[i1] == 0) 93 | assert np.all(merged.spike_templates[im2] - single.spike_templates[i2] == 64) 94 | # test probes 95 | assert np.all(merged.channel_probes == np.r_[single.channel_probes, 96 | single.channel_probes + 1]) 97 | assert np.all(merged.templates_channels[merged.templates_probes == 0] < single.n_channels) 98 | assert np.all(merged.templates_channels[merged.templates_probes == 1] >= single.n_channels) 99 | spike_probes = merged.templates_probes[merged.spike_templates] 100 | 101 | assert np.all(merged_original_amps[spike_probes == 0] <= 15) 102 | assert np.all(merged_original_amps[spike_probes == 1] >= 20) 103 | 104 | # np.all(merged.sparse_templates.data[:64, :, 0:32] == single.sparse_templates.data) 105 | 106 | # Convert into ALF and load. 107 | alf = EphysAlfCreator(merged).convert(tempdir / 'alf') 108 | test_merged_single(merged) 109 | test_merged_single(alf, merged_original_amps=merged.amplitudes) 110 | 111 | # specific test channel ids only for ALF merge dataset: the raw indices are still individual 112 | # file indices, the merged channel mapping is in `channels._phy_ids.npy` 113 | chid = np.load(tempdir.joinpath('alf', 'channels.rawInd.npy')) 114 | assert np.all(chid == np.r_[single.channel_mapping, single.channel_mapping]) 115 | 116 | out_files = list(tempdir.joinpath('alf').glob('*.*')) 117 | cl_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('clusters.') and 118 | f.name.endswith('.npy')] 119 | sp_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('spikes.')] 120 | ch_shape = [np.load(f).shape[0] for f in out_files if f.name.startswith('channels.')] 121 | assert len(set(cl_shape)) == 1 122 | assert len(set(sp_shape)) == 1 123 | assert len(set(ch_shape)) == 1 124 | -------------------------------------------------------------------------------- /phylib/io/tests/test_mock.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of mock datasets.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal as ae 11 | 12 | from ..mock import (artificial_waveforms, 13 | artificial_traces, 14 | artificial_spike_clusters, 15 | artificial_features, 16 | artificial_masks, 17 | artificial_spike_samples, 18 | artificial_correlograms, 19 | ) 20 | 21 | 22 | #------------------------------------------------------------------------------ 23 | # Tests 24 | #------------------------------------------------------------------------------ 25 | 26 | def _test_artificial(n_spikes=None, n_clusters=None): 27 | n_samples_waveforms = 32 28 | n_samples_traces = 50 29 | n_channels = 35 30 | n_features = n_channels * 2 31 | 32 | # Waveforms. 33 | waveforms = artificial_waveforms(n_spikes=n_spikes, 34 | n_samples=n_samples_waveforms, 35 | n_channels=n_channels) 36 | assert waveforms.shape == (n_spikes, n_samples_waveforms, n_channels) 37 | 38 | # Traces. 39 | traces = artificial_traces(n_samples=n_samples_traces, 40 | n_channels=n_channels) 41 | assert traces.shape == (n_samples_traces, n_channels) 42 | 43 | # Spike clusters. 44 | spike_clusters = artificial_spike_clusters(n_spikes=n_spikes, 45 | n_clusters=n_clusters) 46 | assert spike_clusters.shape == (n_spikes,) 47 | if n_clusters >= 1: 48 | assert spike_clusters.min() in (0, 1) 49 | assert spike_clusters.max() in (n_clusters - 1, n_clusters - 2) 50 | ae(np.unique(spike_clusters), np.arange(n_clusters)) 51 | 52 | # Features. 53 | features = artificial_features(n_spikes, n_features) 54 | assert features.shape == (n_spikes, n_features) 55 | 56 | # Masks. 57 | masks = artificial_masks(n_spikes, n_channels) 58 | assert masks.shape == (n_spikes, n_channels) 59 | 60 | # Spikes. 61 | spikes = artificial_spike_samples(n_spikes) 62 | assert spikes.shape == (n_spikes,) 63 | 64 | # CCG. 65 | ccg = artificial_correlograms(n_clusters, 10) 66 | assert ccg.shape == (n_clusters, n_clusters, 10) 67 | 68 | 69 | def test_artificial(): 70 | _test_artificial(n_spikes=100, n_clusters=10) 71 | _test_artificial(n_spikes=0, n_clusters=0) 72 | -------------------------------------------------------------------------------- /phylib/io/tests/test_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Testing the Template model.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | 11 | import numpy as np 12 | from numpy.testing import assert_equal as ae 13 | from pytest import raises 14 | 15 | # from phylib.utils import Bunch 16 | from phylib.utils.testing import captured_output 17 | from ..model import from_sparse, load_model 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | #------------------------------------------------------------------------------ 23 | # Tests 24 | #------------------------------------------------------------------------------ 25 | 26 | def test_from_sparse(): 27 | data = np.array([[0, 1, 2], [3, 4, 5]]) 28 | cols = np.array([[20, 23, 21], [21, 19, 22]]) 29 | 30 | def _test(channel_ids, expected): 31 | expected = np.asarray(expected) 32 | dense = from_sparse(data, cols, np.array(channel_ids)) 33 | assert dense.shape == expected.shape 34 | ae(dense, expected) 35 | 36 | _test([0], np.zeros((2, 1))) 37 | _test([19], [[0], [4]]) 38 | _test([20], [[0], [0]]) 39 | _test([21], [[2], [3]]) 40 | 41 | _test([19, 21], [[0, 2], [4, 3]]) 42 | _test([21, 19], [[2, 0], [3, 4]]) 43 | 44 | with raises(NotImplementedError): 45 | _test([19, 19], [[0, 0], [4, 4]]) 46 | 47 | 48 | def test_model_1(template_model_full): 49 | with captured_output() as (stdout, stderr): 50 | template_model_full.describe() 51 | out = stdout.getvalue() 52 | assert 'sim_binary.dat' in out 53 | assert '64' in out 54 | 55 | 56 | def test_model_2(template_model_full): 57 | m = template_model_full 58 | tmp = m.get_template(3) 59 | channel_ids = tmp.channel_ids 60 | spike_ids = m.get_cluster_spikes(3) 61 | 62 | w = m.get_waveforms(spike_ids, channel_ids) 63 | assert w is None or w.shape == (len(spike_ids), tmp.template.shape[0], len(channel_ids)) 64 | 65 | f = m.get_features(spike_ids, channel_ids) 66 | assert f is None or f.shape == (len(spike_ids), len(channel_ids), 3) 67 | 68 | tf = m.get_template_features(spike_ids) 69 | assert tf is None or tf.shape == (len(spike_ids), m.n_templates) 70 | 71 | 72 | def test_model_3(template_model_full): 73 | m = template_model_full 74 | 75 | spike_ids = m.get_template_spikes(3) 76 | n_spikes = len(spike_ids) 77 | 78 | channel_ids = m.get_template_channels(3) 79 | n_channels = len(channel_ids) 80 | 81 | waveforms = m.get_template_spike_waveforms(3) 82 | if waveforms is not None: 83 | assert waveforms.ndim == 3 84 | assert waveforms.shape[0] == n_spikes 85 | assert waveforms.shape[2] == n_channels 86 | 87 | tw = m.get_template_waveforms(3) 88 | assert tw.ndim == 2 89 | assert tw.shape[1] == n_channels 90 | 91 | 92 | def test_model_4(template_model_full): 93 | m = template_model_full 94 | 95 | spike_ids = m.get_cluster_spikes(3) 96 | n_spikes = len(spike_ids) 97 | 98 | channel_ids = m.get_cluster_channels(3) 99 | n_channels = len(channel_ids) 100 | 101 | waveforms = m.get_cluster_spike_waveforms(3) 102 | if waveforms is not None: 103 | assert waveforms.ndim == 3 104 | assert waveforms.shape[0] == n_spikes 105 | assert waveforms.shape[2] == n_channels 106 | 107 | mean_waveforms = m.get_cluster_mean_waveforms(3) 108 | assert mean_waveforms.mean_waveforms.shape[1] == len(mean_waveforms.channel_ids) 109 | 110 | 111 | def test_model_depth(template_model): 112 | depths = template_model.get_depths() 113 | assert depths.shape == (template_model.n_spikes,) 114 | 115 | 116 | def test_model_merge(template_model_full): 117 | m = template_model_full 118 | 119 | # This is the case where we can do the merging 120 | if not np.all(m.spike_templates == m.spike_clusters) and m.sparse_clusters.cols is None: 121 | assert len(m.merge_map) > 0 122 | assert not np.array_equal(m.sparse_clusters.data, m.sparse_templates.data) 123 | assert m.sparse_clusters.data.shape[0] == m.n_clusters 124 | assert m.sparse_templates.data.shape[0] == m.n_templates 125 | 126 | else: 127 | assert len(m.merge_map) == 0 128 | assert np.array_equal(m.sparse_clusters.data, m.sparse_templates.data) 129 | assert np.array_equal(m.n_templates, m.n_clusters) 130 | 131 | 132 | def test_model_save(template_model_full): 133 | m = template_model_full 134 | m.save_metadata('test', {1: 1}) 135 | m.save_spike_clusters(m.spike_clusters) 136 | 137 | 138 | def test_model_spike_waveforms(template_path_full): 139 | model = load_model(template_path_full) 140 | 141 | if model.traces is not None: 142 | traces = model.traces[:] 143 | assert isinstance(traces, np.ndarray) 144 | 145 | waveforms = {} 146 | for tid in model.template_ids: 147 | spike_ids = model.get_template_spikes(tid) 148 | channel_ids = model.get_template_channels(tid) 149 | waveforms[tid] = model.get_waveforms(spike_ids, channel_ids) 150 | 151 | # Export the waveforms. 152 | model.save_spikes_subset_waveforms(1000, 16) 153 | # Fill spike_waveforms after saving them. 154 | model.spike_waveforms = model._load_spike_waveforms() 155 | 156 | # Check the waveforms loaded from the spike subset waveforms arrays. 157 | nsw = model.n_samples_waveforms // 2 158 | if model.spike_waveforms is None: 159 | return 160 | for tid in model.template_ids: 161 | spike_ids = model.get_template_spikes(tid) 162 | channel_ids = model.get_template_channels(tid) 163 | spike_ids = np.intersect1d(spike_ids, model.spike_waveforms.spike_ids) 164 | w = model.get_waveforms(spike_ids, channel_ids) 165 | 166 | # Check the 2 ways of getting the waveforms. 167 | ae(w, waveforms[tid]) 168 | 169 | if model.traces is not None: 170 | # Check each array with the ground truth, obtained from the raw data. 171 | for i, spike in enumerate(spike_ids): 172 | t = int(model.spike_samples[spike]) 173 | wt = traces[t - nsw:t + nsw, channel_ids] 174 | 175 | ae(waveforms[tid][i], wt) 176 | ae(w[i], wt) 177 | 178 | model.close() 179 | 180 | 181 | def test_model_metadata_1(template_model_full): 182 | m = template_model_full 183 | 184 | assert m.metadata.get('group', {}).get(4, None) == 'good' 185 | assert m.metadata.get('unknown', {}).get(4, None) is None 186 | 187 | assert m.metadata.get('quality', {}).get(6, None) is None 188 | m.save_metadata('quality', {6: 3}) 189 | m.metadata = m._load_metadata() 190 | assert m.metadata.get('quality', {}).get(6, None) == 3 191 | 192 | 193 | def test_model_metadata_2(template_model): 194 | m = template_model 195 | 196 | m.save_metadata('quality', {0: None, 1: 1}) 197 | m.metadata = m._load_metadata() 198 | assert m.metadata.get('quality', {}).get(0, None) is None 199 | assert m.metadata.get('quality', {}).get(1, None) == 1 200 | 201 | 202 | def test_model_metadata_3(template_model): 203 | m = template_model 204 | 205 | assert m.metadata.get('met1', {}).get(2, None) == 123.4 206 | assert m.metadata.get('met1', {}).get(3, None) == 5.678 207 | assert m.metadata.get('met1', {}).get(4, None) is None 208 | assert m.metadata.get('met1', {}).get(5, None) is None 209 | 210 | assert m.metadata.get('met2', {}).get(2, None) == 'hello world 1' 211 | assert m.metadata.get('met2', {}).get(3, None) is None 212 | assert m.metadata.get('met2', {}).get(4, None) is None 213 | assert m.metadata.get('met2', {}).get(5, None) == 'hello world 2' 214 | 215 | 216 | def test_model_spike_attributes(template_model_full): 217 | model = template_model_full 218 | assert set(model.spike_attributes.keys()) == set(('randn', 'works')) 219 | assert model.spike_attributes.works.shape == (model.n_spikes,) 220 | assert model.spike_attributes.randn.shape == (model.n_spikes, 2) 221 | -------------------------------------------------------------------------------- /phylib/io/tests/test_traces.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Testing the BaseEphysTraces class.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import logging 10 | 11 | import numpy as np 12 | from numpy.testing import assert_equal as ae 13 | from numpy.testing import assert_allclose as ac 14 | import mtscomp 15 | from pytest import raises, fixture, mark 16 | 17 | from phylib.utils import Bunch 18 | from ..traces import ( 19 | _get_subitems, _get_chunk_bounds, 20 | get_ephys_reader, BaseEphysReader, extract_waveforms, export_waveforms, RandomEphysReader, 21 | get_spike_waveforms) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | #------------------------------------------------------------------------------ 27 | # Test utils 28 | #------------------------------------------------------------------------------ 29 | 30 | def test_get_subitems(): 31 | bounds = [0, 2, 5] 32 | 33 | def _a(x, y): 34 | res = _get_subitems(bounds, x) 35 | res = [(chk, val.tolist() if isinstance(val, np.ndarray) else val) for chk, val in res] 36 | assert res == y 37 | 38 | _a(-1, [(1, 2)]) 39 | _a(0, [(0, 0)]) 40 | _a(2, [(1, 0)]) 41 | _a(4, [(1, 2)]) 42 | with raises(IndexError): 43 | _a(5, []) 44 | 45 | _a(slice(None, None, None), [(0, slice(0, 2, 1)), (1, slice(0, 3, 1))]) 46 | 47 | _a(slice(1, None, 1), [(0, slice(1, 2, 1)), (1, slice(0, 3, 1))]) 48 | 49 | _a(slice(2, None, 1), [(1, slice(0, 3, 1))]) 50 | _a(slice(3, None, 1), [(1, slice(1, 3, 1))]) 51 | _a(slice(5, None, 1), []) 52 | 53 | _a(slice(0, 4, 1), [(0, slice(0, 2, 1)), (1, slice(0, 2, 1))]) 54 | _a(slice(1, 2, 1), [(0, slice(1, 2, 1))]) 55 | _a(slice(1, -1, 1), [(0, slice(1, 2, 1)), (1, slice(0, 2, 1))]) 56 | _a(slice(-2, -1, 1), [(1, slice(1, 2, 1))]) 57 | 58 | _a([0], [(0, [0])]) 59 | _a([2], [(1, [0])]) 60 | _a([4], [(1, [2])]) 61 | with raises(IndexError): 62 | _a([5], []) 63 | 64 | _a([0, 1], [(0, [0, 1])]) 65 | _a([0, 2], [(0, [0]), (1, [0])]) 66 | _a([0, 3], [(0, [0]), (1, [1])]) 67 | with raises(IndexError): 68 | _a([0, 5], [(0, [0])]) 69 | _a([3, 4], [(1, [1, 2])]) 70 | 71 | _a(([3, 4], None), [(1, [1, 2])]) 72 | 73 | 74 | def test_get_chunk_bounds(): 75 | def _a(x, y, z): 76 | assert _get_chunk_bounds(x, y) == z 77 | 78 | _a([3], 2, [0, 2, 3]) 79 | _a([3], 3, [0, 3]) 80 | _a([3], 4, [0, 3]) 81 | 82 | _a([3, 2], 2, [0, 2, 3, 5]) 83 | _a([3, 2], 3, [0, 3, 5]) 84 | 85 | _a([3, 7, 5], 4, [0, 3, 7, 10, 14, 15]) 86 | _a([3, 7, 6], 4, [0, 3, 7, 10, 14, 16]) 87 | 88 | _a([3, 7, 5], 10, [0, 3, 10, 15]) 89 | 90 | 91 | #------------------------------------------------------------------------------ 92 | # Test ephys reader 93 | #------------------------------------------------------------------------------ 94 | 95 | @fixture 96 | def arr(): 97 | return np.random.randn(2000, 10) 98 | 99 | 100 | @fixture(params=[10000, 1000, 100]) 101 | def sample_rate(request): 102 | return request.param 103 | 104 | 105 | @fixture(params=['numpy', 'npy', 'flat', 'flat_concat', 'mtscomp', 'mtscomp_reader']) 106 | def traces(request, tempdir, arr, sample_rate): 107 | if request.param == 'numpy': 108 | return get_ephys_reader(arr, sample_rate=sample_rate) 109 | 110 | elif request.param == 'npy': 111 | path = tempdir / 'data.npy' 112 | np.save(path, arr) 113 | return get_ephys_reader(path, sample_rate=sample_rate) 114 | 115 | elif request.param == 'flat': 116 | path = tempdir / 'data.bin' 117 | with open(path, 'wb') as f: 118 | arr.tofile(f) 119 | return get_ephys_reader( 120 | path, sample_rate=sample_rate, dtype=arr.dtype, n_channels=arr.shape[1]) 121 | 122 | elif request.param == 'flat_concat': 123 | path0 = tempdir / 'data0.bin' 124 | with open(path0, 'wb') as f: 125 | arr[:arr.shape[0] // 2, :].tofile(f) 126 | path1 = tempdir / 'data1.bin' 127 | with open(path1, 'wb') as f: 128 | arr[arr.shape[0] // 2:, :].tofile(f) 129 | return get_ephys_reader( 130 | [path0, path1], sample_rate=sample_rate, dtype=arr.dtype, n_channels=arr.shape[1]) 131 | 132 | elif request.param in ('mtscomp', 'mtscomp_reader'): 133 | path = tempdir / 'data.bin' 134 | with open(path, 'wb') as f: 135 | arr.tofile(f) 136 | out = tempdir / 'data.cbin' 137 | outmeta = tempdir / 'data.ch' 138 | mtscomp.compress( 139 | path, out, outmeta, sample_rate=sample_rate, 140 | n_channels=arr.shape[1], dtype=arr.dtype, 141 | n_threads=1, check_after_compress=False, quiet=True) 142 | reader = mtscomp.decompress(out, outmeta, check_after_decompress=False, quiet=True) 143 | if request.param == 'mtscomp': 144 | return get_ephys_reader(reader) 145 | else: 146 | return get_ephys_reader(out) 147 | 148 | 149 | def test_ephys_reader_1(tempdir, arr, traces, sample_rate): 150 | assert isinstance(traces, BaseEphysReader) 151 | assert traces.dtype == arr.dtype 152 | assert traces.ndim == 2 153 | assert traces.shape == arr.shape 154 | assert traces.n_samples == arr.shape[0] 155 | assert traces.n_channels == arr.shape[1] 156 | assert traces.n_parts in (1, 2) 157 | assert traces.duration == arr.shape[0] / sample_rate 158 | assert len(traces.part_bounds) == traces.n_parts + 1 159 | assert len(traces.chunk_bounds) == traces.n_chunks + 1 160 | 161 | ac(traces[:], arr) 162 | 163 | def _a(f): 164 | ac(f(traces)[:], f(arr)) 165 | 166 | _a(lambda x: x[:, ::-1]) 167 | 168 | _a(lambda x: x + 1) 169 | _a(lambda x: 1 + x) 170 | 171 | _a(lambda x: x - 1) 172 | _a(lambda x: 1 - x) 173 | 174 | _a(lambda x: x * 2) 175 | _a(lambda x: 2 * x) 176 | 177 | _a(lambda x: x ** 2) 178 | _a(lambda x: 2 ** x) 179 | 180 | _a(lambda x: x / 2) 181 | _a(lambda x: 2 / x) 182 | 183 | _a(lambda x: x / 2.) 184 | _a(lambda x: 2. / x) 185 | 186 | _a(lambda x: x // 2) 187 | _a(lambda x: 2 // x) 188 | 189 | _a(lambda x: +x) 190 | _a(lambda x: -x) 191 | 192 | _a(lambda x: -x[:, [1, 3, 5]]) 193 | 194 | _a(lambda x: 1 + x * 2) 195 | _a(lambda x: 1 + (2 * x)) 196 | _a(lambda x: -x * 2) 197 | 198 | _a(lambda x: x[::1]) 199 | _a(lambda x: x[::1, :]) 200 | _a(lambda x: x[::1, 1:5]) 201 | _a(lambda x: x[::1, ::3]) 202 | 203 | 204 | def test_ephys_random(sample_rate): 205 | reader = RandomEphysReader(2000, 10, sample_rate=sample_rate) 206 | assert reader[:10].shape == (10, 10) 207 | assert reader[:].shape == (2000, 10) 208 | assert reader[0].shape == (1, 10) 209 | assert reader[10:20].shape == (10, 10) 210 | assert reader[[1, 3, 5]].shape == (3, 10) 211 | assert reader[[1, 3, 5], :].shape == (3, 10) 212 | assert reader[[1, 3, 5], ::2].shape == (3, 5) 213 | assert reader[[1, 3, 5], [0, 2, 4]].shape == (3, 3) 214 | assert reader[0:-1].shape == (1999, 10) 215 | assert reader[-10:-1].shape == (9, 10) 216 | 217 | 218 | def test_get_spike_waveforms(): 219 | ns, nsw, nc = 8, 5, 3 220 | 221 | w = np.random.rand(ns, nsw, nc) 222 | s = np.arange(1, 1 + 2 * ns, 2) 223 | c = np.tile(np.array([1, 2, 3]), (ns, 1)) 224 | 225 | assert w.shape == (ns, nsw, nc) 226 | assert s.shape == (ns,) 227 | assert c.shape == (ns, nc) 228 | 229 | sw = Bunch(waveforms=w, spike_ids=s, spike_channels=c) 230 | out = get_spike_waveforms([5, 1, 3], [2, 1], spike_waveforms=sw, n_samples_waveforms=nsw) 231 | 232 | expected = w[[2, 0, 1], ...][..., [1, 0]] 233 | ae(out, expected) 234 | 235 | 236 | @mark.parametrize('do_export', [False, True]) 237 | @mark.parametrize('do_cache', [False, True]) 238 | def test_waveform_extractor(tempdir, arr, traces, sample_rate, do_export, do_cache): 239 | data = arr 240 | 241 | nsw = 20 242 | channel_ids = [1, 3, 5] 243 | spike_samples = [5, 25, 100, 1000, 1995] 244 | spike_ids = np.arange(len(spike_samples)) 245 | spike_channels = np.array([channel_ids] * len(spike_samples)) 246 | 247 | # Export waveforms into a npy file. 248 | if do_export: 249 | export_waveforms( 250 | tempdir / 'waveforms.npy', traces, spike_samples, spike_channels, 251 | n_samples_waveforms=nsw, cache=do_cache) 252 | w = np.load(tempdir / 'waveforms.npy') 253 | # Extract waveforms directly. 254 | else: 255 | w = extract_waveforms(traces, spike_samples, channel_ids, n_samples_waveforms=nsw) 256 | 257 | assert w.dtype == data.dtype == traces.dtype 258 | 259 | spike_waveforms = Bunch( 260 | spike_ids=spike_ids, 261 | spike_channels=spike_channels, 262 | waveforms=w, 263 | ) 264 | 265 | ww = get_spike_waveforms( 266 | spike_ids, channel_ids, spike_waveforms=spike_waveforms, 267 | n_samples_waveforms=nsw) 268 | ae(w, ww) 269 | 270 | assert np.all(w[0, :5, :] == 0) 271 | ac(w[0, 5:, :], data[0:15, [1, 3, 5]]) 272 | 273 | ac(w[1, ...], data[15:35, [1, 3, 5]]) 274 | ac(w[2, ...], data[90:110, [1, 3, 5]]) 275 | ac(w[3, ...], data[990:1010, [1, 3, 5]]) 276 | 277 | assert np.all(w[4, -5:, :] == 0) 278 | ac(w[4, :-5, :], data[-15:, [1, 3, 5]]) 279 | -------------------------------------------------------------------------------- /phylib/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Statistics functions.""" 5 | 6 | from .ccg import correlograms, firing_rate 7 | -------------------------------------------------------------------------------- /phylib/stats/ccg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Cross-correlograms.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | 11 | from phylib.utils._types import _as_array 12 | from phylib.io.array import _index_of, _unique 13 | 14 | 15 | #------------------------------------------------------------------------------ 16 | # Cross-correlograms 17 | #------------------------------------------------------------------------------ 18 | 19 | def _increment(arr, indices): 20 | """Increment some indices in a 1D vector of non-negative integers. 21 | Repeated indices are taken into account.""" 22 | arr = _as_array(arr) 23 | indices = _as_array(indices) 24 | bbins = np.bincount(indices) 25 | arr[:len(bbins)] += bbins 26 | return arr 27 | 28 | 29 | def _diff_shifted(arr, steps=1): 30 | arr = _as_array(arr) 31 | return arr[steps:] - arr[:len(arr) - steps] 32 | 33 | 34 | def _create_correlograms_array(n_clusters, winsize_bins): 35 | return np.zeros((n_clusters, n_clusters, winsize_bins // 2 + 1), 36 | dtype=np.int32) 37 | 38 | 39 | def _symmetrize_correlograms(correlograms): 40 | """Return the symmetrized version of the CCG arrays.""" 41 | 42 | n_clusters, _, n_bins = correlograms.shape 43 | assert n_clusters == _ 44 | 45 | # We symmetrize c[i, j, 0]. 46 | # This is necessary because the algorithm in correlograms() 47 | # is sensitive to the order of identical spikes. 48 | correlograms[..., 0] = np.maximum(correlograms[..., 0], 49 | correlograms[..., 0].T) 50 | 51 | sym = correlograms[..., 1:][..., ::-1] 52 | sym = np.transpose(sym, (1, 0, 2)) 53 | 54 | return np.dstack((sym, correlograms)) 55 | 56 | 57 | def firing_rate(spike_clusters, cluster_ids=None, bin_size=None, duration=None): 58 | """Compute the average number of spikes per cluster per bin.""" 59 | 60 | # Take the cluster order into account. 61 | if cluster_ids is None: 62 | cluster_ids = _unique(spike_clusters) 63 | else: 64 | cluster_ids = _as_array(cluster_ids) 65 | 66 | # Like spike_clusters, but with 0..n_clusters-1 indices. 67 | spike_clusters_i = _index_of(spike_clusters, cluster_ids) 68 | 69 | assert bin_size > 0 70 | bc = np.bincount(spike_clusters_i) 71 | # Handle the case where the last cluster(s) are empty. 72 | if len(bc) < len(cluster_ids): 73 | n = len(cluster_ids) - len(bc) 74 | bc = np.concatenate((bc, np.zeros(n, dtype=bc.dtype))) 75 | assert bc.shape == (len(cluster_ids),) 76 | return bc * np.c_[bc] * (bin_size / (duration or 1.)) 77 | 78 | 79 | def correlograms( 80 | spike_times, spike_clusters, cluster_ids=None, sample_rate=1., 81 | bin_size=None, window_size=None, symmetrize=True): 82 | """Compute all pairwise cross-correlograms among the clusters appearing 83 | in `spike_clusters`. 84 | 85 | Parameters 86 | ---------- 87 | 88 | spike_times : array-like 89 | Spike times in seconds. 90 | spike_clusters : array-like 91 | Spike-cluster mapping. 92 | cluster_ids : array-like 93 | The list of *all* unique clusters, in any order. That order will be used 94 | in the output array. 95 | bin_size : float 96 | Size of the bin, in seconds. 97 | window_size : float 98 | Size of the window, in seconds. 99 | sample_rate : float 100 | Sampling rate. 101 | symmetrize : boolean (True) 102 | Whether the output matrix should be symmetrized or not. 103 | 104 | Returns 105 | ------- 106 | 107 | correlograms : array 108 | A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise CCGs. 109 | 110 | """ 111 | assert sample_rate > 0. 112 | assert np.all(np.diff(spike_times) >= 0), ("The spike times must be " 113 | "increasing.") 114 | 115 | # Get the spike samples. 116 | spike_times = np.asarray(spike_times, dtype=np.float64) 117 | spike_samples = (spike_times * sample_rate).astype(np.int64) 118 | 119 | spike_clusters = _as_array(spike_clusters) 120 | 121 | assert spike_samples.ndim == 1 122 | assert spike_samples.shape == spike_clusters.shape 123 | 124 | # Find `binsize`. 125 | bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds 126 | binsize = int(sample_rate * bin_size) # in samples 127 | assert binsize >= 1 128 | 129 | # Find `winsize_bins`. 130 | window_size = np.clip(window_size, 1e-5, 1e5) # in seconds 131 | winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 132 | 133 | assert winsize_bins >= 1 134 | assert winsize_bins % 2 == 1 135 | 136 | # Take the cluster order into account. 137 | if cluster_ids is None: 138 | clusters = _unique(spike_clusters) 139 | else: 140 | clusters = _as_array(cluster_ids) 141 | n_clusters = len(clusters) 142 | 143 | # Like spike_clusters, but with 0..n_clusters-1 indices. 144 | spike_clusters_i = _index_of(spike_clusters, clusters) 145 | 146 | # Shift between the two copies of the spike trains. 147 | shift = 1 148 | 149 | # At a given shift, the mask precises which spikes have matching spikes 150 | # within the correlogram time window. 151 | mask = np.ones_like(spike_samples, dtype=bool) 152 | 153 | correlograms = _create_correlograms_array(n_clusters, winsize_bins) 154 | 155 | # The loop continues as long as there is at least one spike with 156 | # a matching spike. 157 | while mask[:-shift].any(): 158 | # Number of time samples between spike i and spike i+shift. 159 | spike_diff = _diff_shifted(spike_samples, shift) 160 | 161 | # Binarize the delays between spike i and spike i+shift. 162 | spike_diff_b = spike_diff // binsize 163 | 164 | # Spikes with no matching spikes are masked. 165 | mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False 166 | 167 | # Cache the masked spike delays. 168 | m = mask[:-shift].copy() 169 | d = spike_diff_b[m] 170 | 171 | # # Update the masks given the clusters to update. 172 | # m0 = np.in1d(spike_clusters[:-shift], clusters) 173 | # m = m & m0 174 | # d = spike_diff_b[m] 175 | d = spike_diff_b[m] 176 | 177 | # Find the indices in the raveled correlograms array that need 178 | # to be incremented, taking into account the spike clusters. 179 | indices = np.ravel_multi_index( 180 | (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d), correlograms.shape) 181 | 182 | # Increment the matching spikes in the correlograms array. 183 | _increment(correlograms.ravel(), indices) 184 | 185 | shift += 1 186 | 187 | if symmetrize: 188 | return _symmetrize_correlograms(correlograms) 189 | else: 190 | return correlograms 191 | -------------------------------------------------------------------------------- /phylib/stats/clusters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Cluster statistics.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | 11 | 12 | #------------------------------------------------------------------------------ 13 | # Cluster statistics 14 | #------------------------------------------------------------------------------ 15 | 16 | def mean(x): 17 | """Return the mean of an array across the first dimension.""" 18 | return x.mean(axis=0) 19 | 20 | 21 | def get_unmasked_channels(mean_masks, min_mask=.25): 22 | """Return the unmasked channels (mean masks above a given threshold).""" 23 | return np.nonzero(mean_masks > min_mask)[0] 24 | 25 | 26 | def get_mean_probe_position(mean_masks, site_positions): 27 | """Return the mean position of clusters on the probe, depending on the masks.""" 28 | m = max(1, np.sum(mean_masks)) 29 | return np.sum(site_positions * mean_masks[:, np.newaxis], axis=0) / m 30 | 31 | 32 | def get_sorted_main_channels(mean_masks, unmasked_channels): 33 | """Weighted mean of the channels, weighted by the mean masks.""" 34 | main_channels = np.argsort(mean_masks)[::-1] 35 | main_channels = np.array([c for c in main_channels 36 | if c in unmasked_channels]) 37 | return main_channels 38 | 39 | 40 | #------------------------------------------------------------------------------ 41 | # Wizard measures 42 | #------------------------------------------------------------------------------ 43 | 44 | def get_waveform_amplitude(mean_masks, mean_waveforms): 45 | """Return the amplitude of the waveforms on all channels.""" 46 | 47 | assert mean_waveforms.ndim == 2 48 | n_samples, n_channels = mean_waveforms.shape 49 | 50 | assert mean_masks.ndim == 1 51 | assert mean_masks.shape == (n_channels,) 52 | 53 | mean_waveforms = mean_waveforms * mean_masks 54 | assert mean_waveforms.shape == (n_samples, n_channels) 55 | 56 | # Amplitudes. 57 | m, M = mean_waveforms.min(axis=0), mean_waveforms.max(axis=0) 58 | return M - m 59 | 60 | 61 | def get_mean_masked_features_distance( 62 | mean_features_0, mean_features_1, mean_masks_0, mean_masks_1, n_features_per_channel=None): 63 | """Compute the distance between the mean masked features.""" 64 | 65 | assert n_features_per_channel > 0 66 | 67 | mu_0 = mean_features_0.ravel() 68 | mu_1 = mean_features_1.ravel() 69 | 70 | omeg_0 = mean_masks_0 71 | omeg_1 = mean_masks_1 72 | 73 | omeg_0 = np.repeat(omeg_0, n_features_per_channel) 74 | omeg_1 = np.repeat(omeg_1, n_features_per_channel) 75 | 76 | d_0 = mu_0 * omeg_0 77 | d_1 = mu_1 * omeg_1 78 | 79 | return np.linalg.norm(d_0 - d_1) 80 | -------------------------------------------------------------------------------- /phylib/stats/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cortex-lab/phylib/72316a4ccb0abed93464ed4bfcd36a3809d1dbdf/phylib/stats/tests/__init__.py -------------------------------------------------------------------------------- /phylib/stats/tests/test_ccg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of CCG functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal as ae 11 | 12 | from ..ccg import (_increment, 13 | _diff_shifted, 14 | correlograms, 15 | firing_rate, 16 | ) 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Tests 21 | #------------------------------------------------------------------------------ 22 | 23 | def _random_data(max_cluster): 24 | sr = 20000 25 | nspikes = 10000 26 | spike_samples = np.cumsum(np.random.exponential(scale=.025, size=nspikes)) 27 | spike_samples = (spike_samples * sr).astype(np.uint64) 28 | spike_clusters = np.random.randint(0, max_cluster, nspikes) 29 | return spike_samples, spike_clusters 30 | 31 | 32 | def _ccg_params(): 33 | return .001, .05 34 | 35 | 36 | def test_utils(): 37 | # First, test _increment(). 38 | 39 | # Original array. 40 | arr = np.arange(10) 41 | # Indices of elements to increment. 42 | indices = [0, 2, 4, 2, 2, 2, 2, 2, 2] 43 | 44 | ae(_increment(arr, indices), [1, 1, 9, 3, 5, 5, 6, 7, 8, 9]) 45 | 46 | # Then, test _shitdiff. 47 | # Original array. 48 | arr = [2, 3, 5, 7, 11, 13, 17] 49 | # Shifted once. 50 | ds1 = [1, 2, 2, 4, 2, 4] 51 | # Shifted twice. 52 | ds2 = [3, 4, 6, 6, 6] 53 | 54 | ae(_diff_shifted(arr, 1), ds1) 55 | ae(_diff_shifted(arr, 2), ds2) 56 | 57 | 58 | def test_firing_rate_0(): 59 | spike_clusters = [0, 1, 0, 1] 60 | bin_size = 1 61 | 62 | fr = firing_rate(spike_clusters, bin_size=bin_size, duration=20) 63 | ae(fr, .2 * np.ones((2, 2))) 64 | 65 | 66 | def test_firing_rate_1(): 67 | spike_clusters = [0, 1, 0, 1] 68 | bin_size = 1 69 | 70 | fr = firing_rate(spike_clusters, cluster_ids=[0, 1, 2], bin_size=bin_size, duration=20) 71 | print(fr) 72 | ae(fr[:2, :2], .2 * np.ones((2, 2))) 73 | assert np.all(fr[2, :] == fr[:, 2]) 74 | assert np.all(fr[2, :] == 0) 75 | 76 | 77 | def test_firing_rate_2(): 78 | spike_clusters = np.tile(np.arange(10), 100) 79 | fr = firing_rate(spike_clusters, cluster_ids=np.arange(10), bin_size=.1, duration=1.) 80 | ae(fr, np.ones((10, 10)) * 1000) 81 | 82 | 83 | def test_ccg_0(): 84 | spike_samples = [0, 10, 10, 20] 85 | spike_clusters = [0, 1, 0, 1] 86 | bin_size = 1 87 | winsize_bins = 2 * 3 + 1 88 | 89 | c_expected = np.zeros((2, 2, 4)) 90 | 91 | # WARNING: correlograms() is sensitive to the order of identical spike 92 | # times. This needs to be taken into account when post-processing the 93 | # CCGs. 94 | c_expected[1, 0, 0] = 1 95 | c_expected[0, 1, 0] = 0 # This is a peculiarity of the algorithm. 96 | 97 | c = correlograms(spike_samples, spike_clusters, 98 | bin_size=bin_size, window_size=winsize_bins, 99 | cluster_ids=[0, 1], symmetrize=False) 100 | 101 | ae(c, c_expected) 102 | 103 | 104 | def test_ccg_1(): 105 | spike_samples = np.array([2, 3, 10, 12, 20, 24, 30, 40], dtype=np.uint64) 106 | spike_clusters = [0, 1, 0, 0, 2, 1, 0, 2] 107 | bin_size = 1 108 | winsize_bins = 2 * 3 + 1 109 | 110 | c_expected = np.zeros((3, 3, 4)) 111 | c_expected[0, 1, 1] = 1 112 | c_expected[0, 0, 2] = 1 113 | 114 | c = correlograms(spike_samples, spike_clusters, 115 | bin_size=bin_size, window_size=winsize_bins, 116 | symmetrize=False) 117 | 118 | ae(c, c_expected) 119 | 120 | 121 | def test_ccg_2(): 122 | max_cluster = 10 123 | spike_samples, spike_clusters = _random_data(max_cluster) 124 | bin_size, winsize_bins = _ccg_params() 125 | 126 | c = correlograms( 127 | spike_samples, spike_clusters, bin_size=bin_size, window_size=winsize_bins, 128 | sample_rate=20000, symmetrize=False) 129 | 130 | assert c.shape == (max_cluster, max_cluster, 26) 131 | 132 | 133 | def test_ccg_symmetry_time(): 134 | """Reverse time and check that the CCGs are just transposed.""" 135 | 136 | spike_samples, spike_clusters = _random_data(2) 137 | bin_size, winsize_bins = _ccg_params() 138 | 139 | c0 = correlograms(spike_samples, spike_clusters, 140 | bin_size=bin_size, window_size=winsize_bins, 141 | sample_rate=20000, symmetrize=False) 142 | 143 | spike_samples_1 = np.cumsum(np.r_[np.arange(1), 144 | np.diff(spike_samples)[::-1]]) 145 | spike_samples_1 = spike_samples_1.astype(np.uint64) 146 | spike_clusters_1 = spike_clusters[::-1] 147 | c1 = correlograms(spike_samples_1, spike_clusters_1, 148 | bin_size=bin_size, window_size=winsize_bins, 149 | sample_rate=20000, symmetrize=False) 150 | 151 | # The ACGs are identical. 152 | ae(c0[0, 0], c1[0, 0]) 153 | ae(c0[1, 1], c1[1, 1]) 154 | 155 | # The CCGs are just transposed. 156 | ae(c0[0, 1], c1[1, 0]) 157 | ae(c0[1, 0], c1[0, 1]) 158 | 159 | 160 | def test_ccg_symmetry_clusters(): 161 | """Exchange clusters and check that the CCGs are just transposed.""" 162 | 163 | spike_samples, spike_clusters = _random_data(2) 164 | bin_size, winsize_bins = _ccg_params() 165 | 166 | c0 = correlograms(spike_samples, spike_clusters, 167 | bin_size=bin_size, window_size=winsize_bins, 168 | sample_rate=20000, symmetrize=False) 169 | 170 | spike_clusters_1 = 1 - spike_clusters 171 | c1 = correlograms(spike_samples, spike_clusters_1, 172 | bin_size=bin_size, window_size=winsize_bins, 173 | sample_rate=20000, symmetrize=False) 174 | 175 | # The ACGs are identical. 176 | ae(c0[0, 0], c1[1, 1]) 177 | ae(c0[1, 1], c1[0, 0]) 178 | 179 | # The CCGs are just transposed. 180 | ae(c0[0, 1], c1[1, 0]) 181 | ae(c0[1, 0], c1[0, 1]) 182 | 183 | 184 | def test_symmetrize_correlograms(): 185 | spike_samples, spike_clusters = _random_data(3) 186 | bin_size, winsize_bins = _ccg_params() 187 | 188 | sym = correlograms(spike_samples, spike_clusters, 189 | bin_size=bin_size, window_size=winsize_bins, 190 | sample_rate=20000) 191 | assert sym.shape == (3, 3, 51) 192 | 193 | # The ACG are reversed. 194 | for i in range(3): 195 | ae(sym[i, i, :], sym[i, i, ::-1]) 196 | 197 | ae(sym[0, 1, :], sym[1, 0, ::-1]) 198 | -------------------------------------------------------------------------------- /phylib/stats/tests/test_clusters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of cluster statistics.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal as ae 11 | from numpy.testing import assert_allclose as ac 12 | from pytest import fixture 13 | 14 | from ..clusters import (mean, 15 | get_unmasked_channels, 16 | get_mean_probe_position, 17 | get_sorted_main_channels, 18 | get_mean_masked_features_distance, 19 | get_waveform_amplitude, 20 | ) 21 | from phylib.utils.geometry import staggered_positions 22 | from phylib.io.mock import artificial_features, artificial_masks, artificial_waveforms 23 | 24 | 25 | #------------------------------------------------------------------------------ 26 | # Fixtures 27 | #------------------------------------------------------------------------------ 28 | 29 | @fixture 30 | def n_channels(): 31 | yield 28 32 | 33 | 34 | @fixture 35 | def n_spikes(): 36 | yield 50 37 | 38 | 39 | @fixture 40 | def n_samples(): 41 | yield 40 42 | 43 | 44 | @fixture 45 | def n_features_per_channel(): 46 | yield 4 47 | 48 | 49 | @fixture 50 | def features(n_spikes, n_channels, n_features_per_channel): 51 | yield artificial_features(n_spikes, n_channels, n_features_per_channel) 52 | 53 | 54 | @fixture 55 | def masks(n_spikes, n_channels): 56 | yield artificial_masks(n_spikes, n_channels) 57 | 58 | 59 | @fixture 60 | def waveforms(n_spikes, n_samples, n_channels): 61 | yield artificial_waveforms(n_spikes, n_samples, n_channels) 62 | 63 | 64 | @fixture 65 | def site_positions(n_channels): 66 | yield staggered_positions(n_channels) 67 | 68 | 69 | #------------------------------------------------------------------------------ 70 | # Tests 71 | #------------------------------------------------------------------------------ 72 | 73 | def test_mean(features, n_channels, n_features_per_channel): 74 | mf = mean(features) 75 | assert mf.shape == (n_channels, n_features_per_channel) 76 | ae(mf, features.mean(axis=0)) 77 | 78 | 79 | def test_unmasked_channels(masks, n_channels): 80 | # Mask many values in the masks array. 81 | threshold = .05 82 | masks[:, 1::2] *= threshold 83 | # Compute the mean masks. 84 | mean_masks = mean(masks) 85 | # Find the unmasked channels. 86 | channels = get_unmasked_channels(mean_masks, threshold) 87 | # These are 0, 2, 4, etc. 88 | ae(channels, np.arange(0, n_channels, 2)) 89 | 90 | 91 | def test_mean_probe_position(masks, site_positions): 92 | masks[:, ::2] *= .05 93 | mean_masks = mean(masks) 94 | mean_pos = get_mean_probe_position(mean_masks, site_positions) 95 | assert mean_pos.shape == (2,) 96 | assert mean_pos[0] < 0 97 | assert mean_pos[1] > 0 98 | 99 | 100 | def test_sorted_main_channels(masks): 101 | masks *= .05 102 | masks[:, [5, 7]] *= 20 103 | mean_masks = mean(masks) 104 | channels = get_sorted_main_channels(mean_masks, 105 | get_unmasked_channels(mean_masks)) 106 | assert np.all(np.isin(channels, [5, 7])) 107 | 108 | 109 | def test_waveform_amplitude(masks, waveforms): 110 | waveforms *= .1 111 | masks *= .1 112 | 113 | waveforms[:, 10, :] *= 10 114 | masks[:, 10] *= 10 115 | 116 | mean_waveforms = mean(waveforms) 117 | mean_masks = mean(masks) 118 | 119 | amplitude = get_waveform_amplitude(mean_masks, mean_waveforms) 120 | assert np.all(amplitude >= 0) 121 | assert amplitude.shape == (mean_waveforms.shape[1],) 122 | 123 | 124 | def test_mean_masked_features_distance(features, 125 | n_channels, 126 | n_features_per_channel, 127 | ): 128 | 129 | # Shifted feature vectors. 130 | shift = 10. 131 | f0 = mean(features) 132 | f1 = mean(features) + shift 133 | 134 | # Only one channel is unmasked. 135 | m0 = m1 = np.zeros(n_channels) 136 | m0[n_channels // 2] = 1 137 | 138 | # Check the distance. 139 | d_expected = np.sqrt(n_features_per_channel) * shift 140 | d_computed = get_mean_masked_features_distance(f0, f1, m0, m1, 141 | n_features_per_channel) 142 | ac(d_expected, d_computed) 143 | -------------------------------------------------------------------------------- /phylib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Utilities.""" 5 | 6 | from ._misc import ( 7 | load_json, save_json, load_pickle, save_pickle, _fullname, read_python, 8 | read_text, write_text, read_tsv, write_tsv) 9 | from ._types import ( 10 | _is_array_like, _as_array, _as_tuple, _as_list, _as_scalar, _as_scalars, 11 | Bunch, _is_list, _bunchify) 12 | from .event import ProgressReporter, emit, connect, unconnect, silent, reset, set_silent 13 | -------------------------------------------------------------------------------- /phylib/utils/_misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Utility functions.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import base64 11 | import csv 12 | from importlib import import_module 13 | import json 14 | import logging 15 | import os 16 | from pathlib import Path 17 | import subprocess 18 | from textwrap import dedent 19 | 20 | import numpy as np 21 | 22 | from ._types import _is_integer 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | #------------------------------------------------------------------------------ 28 | # JSON utility functions 29 | #------------------------------------------------------------------------------ 30 | 31 | def _encode_qbytearray(arr): 32 | """Encode binary arrays with base64.""" 33 | b = arr.toBase64().data() 34 | data_b64 = base64.b64encode(b).decode('utf8') 35 | return data_b64 36 | 37 | 38 | def _decode_qbytearray(data_b64): 39 | """Decode binary arrays with base64.""" 40 | encoded = base64.b64decode(data_b64) 41 | try: 42 | from PyQt5.QtCore import QByteArray 43 | out = QByteArray.fromBase64(encoded) 44 | except ImportError: # pragma: no cover 45 | pass 46 | return out 47 | 48 | 49 | class _CustomEncoder(json.JSONEncoder): 50 | """JSON encoder that accepts NumPy arrays.""" 51 | def default(self, obj): 52 | if isinstance(obj, np.ndarray) and obj.ndim == 1 and obj.shape[0] <= 10: 53 | # Serialize small arrays in clear text (lists of numbers). 54 | return obj.tolist() 55 | elif isinstance(obj, np.ndarray): 56 | obj_contiguous = np.ascontiguousarray(obj) 57 | data_b64 = base64.b64encode(obj_contiguous.data).decode('utf8') 58 | return dict(__ndarray__=data_b64, dtype=str(obj.dtype), shape=obj.shape) 59 | elif obj.__class__.__name__ == 'QByteArray': 60 | return {'__qbytearray__': _encode_qbytearray(obj)} 61 | elif isinstance(obj, np.generic): 62 | return obj.item() 63 | return super(_CustomEncoder, self).default(obj) # pragma: no cover 64 | 65 | 66 | def _json_custom_hook(d): 67 | """Serialize NumPy arrays.""" 68 | if isinstance(d, dict) and '__ndarray__' in d: 69 | data = base64.b64decode(d['__ndarray__']) 70 | return np.frombuffer(data, d['dtype']).reshape(d['shape']) 71 | elif isinstance(d, dict) and '__qbytearray__' in d: 72 | return _decode_qbytearray(d['__qbytearray__']) 73 | return d 74 | 75 | 76 | def _intify_keys(d): 77 | """Make sure all integer strings in a dictionary are converted into integers.""" 78 | assert isinstance(d, dict) 79 | out = {} 80 | for k, v in d.items(): 81 | if isinstance(k, str) and k.isdigit(): 82 | k = int(k) 83 | out[k] = v 84 | return out 85 | 86 | 87 | def _stringify_keys(d): 88 | """Make sure all integers in a dictionary are converted into strings.""" 89 | assert isinstance(d, dict) 90 | out = {} 91 | for k, v in d.items(): 92 | if _is_integer(k): 93 | k = str(k) 94 | out[k] = v 95 | return out 96 | 97 | 98 | def _pretty_floats(obj, n=2): 99 | """Display floating point numbers properly.""" 100 | if isinstance(obj, (float, np.float64, np.float32)): 101 | return ('%.' + str(n) + 'f') % obj 102 | elif isinstance(obj, dict): 103 | return dict((k, _pretty_floats(v)) for k, v in obj.items()) 104 | elif isinstance(obj, (list, tuple)): 105 | return list(map(_pretty_floats, obj)) 106 | return obj 107 | 108 | 109 | def load_json(path): 110 | """Load a JSON file.""" 111 | path = Path(path) 112 | if not path.exists(): 113 | raise IOError("The JSON file `{}` doesn't exist.".format(path)) 114 | contents = path.read_text() 115 | if not contents: 116 | return {} 117 | out = json.loads(contents, object_hook=_json_custom_hook) 118 | return _intify_keys(out) 119 | 120 | 121 | def save_json(path, data): 122 | """Save a dictionary to a JSON file. 123 | 124 | Support NumPy arrays and QByteArray objects. NumPy arrays are saved as base64-encoded strings, 125 | except for 1D arrays with less than 10 elements, which are saved as a list for human 126 | readability. 127 | 128 | """ 129 | assert isinstance(data, dict) 130 | data = _stringify_keys(data) 131 | path = Path(path) 132 | ensure_dir_exists(path.parent) 133 | with path.open('w') as f: 134 | json.dump(data, f, cls=_CustomEncoder, indent=2, sort_keys=True) 135 | 136 | 137 | #------------------------------------------------------------------------------ 138 | # Other read/write functions 139 | #------------------------------------------------------------------------------ 140 | 141 | def load_pickle(path): 142 | """Load a pickle file using joblib.""" 143 | from joblib import load 144 | return load(path) 145 | 146 | 147 | def save_pickle(path, data): 148 | """Save data to a pickle file using joblib.""" 149 | from joblib import dump 150 | return dump(data, path) 151 | 152 | 153 | def read_python(path): 154 | """Read a Python file. 155 | 156 | Parameters 157 | ---------- 158 | 159 | path : str or Path 160 | 161 | Returns 162 | ------- 163 | 164 | metadata : dict 165 | A dictionary containing all variables defined in the Python file (with `exec()`). 166 | 167 | """ 168 | path = Path(path) 169 | if not path.exists(): # pragma: no cover 170 | raise IOError("Path %s does not exist.", path) 171 | contents = path.read_text() 172 | metadata = {} 173 | exec(contents, {}, metadata) 174 | metadata = {k.lower(): v for (k, v) in metadata.items()} 175 | return metadata 176 | 177 | 178 | def write_python(path, data): 179 | """Write a dictionary in a Python file. 180 | 181 | Parameters 182 | ---------- 183 | 184 | path : str or Path 185 | Path to the Python file to write. 186 | data : dict 187 | A key-value mapping to write as a Python file. 188 | 189 | Returns 190 | ------- 191 | 192 | """ 193 | with open(path, 'w') as f: 194 | for k, v in data.items(): 195 | if isinstance(v, str): 196 | v = '"%s"' % v 197 | f.write('%s = %s\n' % (k, str(v))) 198 | 199 | 200 | def read_text(path): 201 | """Read a text file.""" 202 | path = Path(path) 203 | return path.read_text() 204 | 205 | 206 | def write_text(path, contents): 207 | """Write a text file.""" 208 | contents = dedent(contents) 209 | path = Path(path) 210 | ensure_dir_exists(path.parent) 211 | path.write_text(contents) 212 | 213 | 214 | def _try_make_number(value): 215 | """Convert a string into an int or float if possible, otherwise do nothing.""" 216 | try: 217 | return int(value) 218 | except ValueError: 219 | try: 220 | return float(value) 221 | except ValueError: 222 | return value 223 | raise ValueError() 224 | 225 | 226 | def read_tsv(path): 227 | """Read a CSV/TSV file. 228 | 229 | Returns 230 | ------- 231 | 232 | data : list of dicts 233 | 234 | """ 235 | path = Path(path) 236 | data = [] 237 | if not path.exists(): 238 | logger.debug("%s does not exist, skipping.", path) 239 | return data 240 | # Find whether the delimiter is tab or comma. 241 | with path.open('r') as f: 242 | delimiter = '\t' if '\t' in f.readline() else ',' 243 | with path.open('r') as f: 244 | reader = csv.reader(f, delimiter=delimiter) 245 | # Skip the header. 246 | field_names = list(next(reader)) 247 | for row in reader: 248 | data.append({k: _try_make_number(v) for k, v in zip(field_names, row) if v != ''}) 249 | logger.log(5, "Read %s.", path) 250 | return data 251 | 252 | 253 | def write_tsv(path, data, first_field=None, exclude_fields=(), n_significant_figures=4): 254 | """Write a CSV/TSV file. 255 | 256 | Parameters 257 | ---------- 258 | 259 | data : list of dicts 260 | first_field : str 261 | The name of the field that should come first in the file. 262 | exclude_fields : list-like 263 | Fields present in the data that should not be saved in the file. 264 | n_significant_figures : int 265 | Number of significant figures used for floating-point numbers in the file. 266 | 267 | """ 268 | path = Path(path) 269 | ensure_dir_exists(path.parent) 270 | delimiter = '\t' if path.suffix == '.tsv' else ',' 271 | with path.open('w', newline='') as f: 272 | if not data: 273 | logger.info("Data was empty when writing %s.", path) 274 | return 275 | # Get the union of all keys from all rows. 276 | fields = set().union(*data) 277 | # Remove ignored fields. 278 | for field in exclude_fields: 279 | if field in fields: 280 | fields.remove(field) 281 | # Make sure the first field is the first one. 282 | if first_field in fields: 283 | fields.remove(first_field) 284 | fields = [first_field] + sorted(fields) 285 | else: 286 | fields = sorted(fields) 287 | writer = csv.writer(f, delimiter=delimiter) 288 | # Write the header. 289 | writer.writerow(fields) 290 | # Write all rows. 291 | writer.writerows( 292 | [[_pretty_floats(row.get(field, None), n_significant_figures) 293 | for field in fields] for row in data]) 294 | logger.debug("Wrote %s.", path) 295 | 296 | 297 | def _read_tsv_simple(path): 298 | """Read a CSV/TSV file with only two columns: cluster_id and . 299 | 300 | Return (field_name, dictionary {cluster_id: value}). 301 | 302 | """ 303 | path = Path(path) 304 | data = {} 305 | if not path.exists(): 306 | logger.debug("%s does not exist, skipping.", path) 307 | return data 308 | # Find whether the delimiter is tab or comma. 309 | with path.open('r') as f: 310 | delimiter = '\t' if '\t' in f.readline() else ',' 311 | with path.open('r') as f: 312 | reader = csv.reader(f, delimiter=delimiter) 313 | # Skip the header. 314 | _, field_name = next(reader) 315 | for row in reader: 316 | cluster_id, value = row 317 | cluster_id = int(cluster_id) 318 | data[cluster_id] = _try_make_number(value) 319 | logger.debug("Read %s.", path) 320 | return field_name, data 321 | 322 | 323 | def _write_tsv_simple(path, field_name, data): 324 | """Write a CSV/TSV file with two columns: cluster_id and . 325 | 326 | data is a dictionary {cluster_id: value}. 327 | 328 | """ 329 | path = Path(path) 330 | ensure_dir_exists(path.parent) 331 | delimiter = '\t' if path.suffix == '.tsv' else ',' 332 | with path.open('w', newline='') as f: 333 | writer = csv.writer(f, delimiter=delimiter) 334 | writer.writerow(['cluster_id', field_name]) 335 | writer.writerows([(cluster_id, data[cluster_id]) for cluster_id in sorted(data)]) 336 | logger.debug("Wrote %s.", path) 337 | 338 | 339 | #------------------------------------------------------------------------------ 340 | # Various Python utility functions 341 | #------------------------------------------------------------------------------ 342 | 343 | def _fullname(o): 344 | """Return the fully-qualified name of a function.""" 345 | return o.__module__ + "." + o.__name__ if o.__module__ else o.__name__ 346 | 347 | 348 | def _load_from_fullname(name): 349 | """Load a Python object from its fully qualified name.""" 350 | if not isinstance(name, str): 351 | return name 352 | parts = name.rsplit('.', 1) 353 | return getattr(import_module(parts[0]), parts[1], parts[1]) 354 | 355 | 356 | def _git_version(): 357 | """Return the git version.""" 358 | curdir = os.getcwd() 359 | os.chdir(str(Path(__file__).parent)) 360 | try: 361 | with open(os.devnull, 'w') as fnull: 362 | version = ('-git-' + subprocess.check_output( 363 | ['git', 'describe', '--abbrev=8', '--dirty', '--always', '--tags'], 364 | stderr=fnull).strip().decode('ascii')) 365 | return version 366 | except (OSError, subprocess.CalledProcessError): # pragma: no cover 367 | return "" 368 | finally: 369 | os.chdir(curdir) 370 | 371 | 372 | def phy_config_dir(): 373 | """Return the absolute path to the phy user directory. By default, `~/.phy/`.""" 374 | return Path.home() / '.phy' 375 | 376 | 377 | def ensure_dir_exists(path): 378 | """Ensure a directory exists, and create it otherwise.""" 379 | path = Path(path) 380 | if path.exists(): 381 | assert path.is_dir() 382 | else: 383 | path.mkdir(exist_ok=True, parents=True) 384 | assert path.exists() and path.is_dir() 385 | -------------------------------------------------------------------------------- /phylib/utils/_types.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Utility functions.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import numpy as np 11 | 12 | 13 | #------------------------------------------------------------------------------ 14 | # Various Python utility functions 15 | #------------------------------------------------------------------------------ 16 | 17 | _ACCEPTED_ARRAY_DTYPES = ( 18 | float, np.float32, np.float64, int, np.int8, np.int16, np.uint8, np.uint16, 19 | np.int32, np.int64, np.uint32, np.uint64, bool) 20 | 21 | 22 | class Bunch(dict): 23 | """A subclass of dictionary with an additional dot syntax.""" 24 | def __init__(self, *args, **kwargs): 25 | super(Bunch, self).__init__(*args, **kwargs) 26 | self.__dict__ = self 27 | 28 | def copy(self): 29 | """Return a new Bunch instance which is a copy of the current Bunch instance.""" 30 | return Bunch(super(Bunch, self).copy()) 31 | 32 | 33 | def _bunchify(b): 34 | """Ensure all dict elements are Bunch.""" 35 | assert isinstance(b, dict) 36 | b = Bunch(b) 37 | for k in b: 38 | if isinstance(b[k], dict): 39 | b[k] = Bunch(b[k]) 40 | return b 41 | 42 | 43 | def _is_list(obj): 44 | """Return whether an object is a list.""" 45 | return isinstance(obj, list) 46 | 47 | 48 | def _as_scalar(obj): 49 | """Return whether an object is a scalar number (integer or floating point number).""" 50 | if isinstance(obj, np.generic): 51 | return obj.item() 52 | assert isinstance(obj, (int, float)) 53 | return obj 54 | 55 | 56 | def _as_scalars(arr): 57 | """Make sure a list only contains scalar numbers.""" 58 | return [_as_scalar(x) for x in arr] 59 | 60 | 61 | def _is_integer(x): 62 | """Return whether an object is an integer.""" 63 | return isinstance(x, (int, np.generic)) 64 | 65 | 66 | def _is_float(x): 67 | """Return whether an object is a floating point number.""" 68 | return isinstance(x, (float, np.float32, np.float64)) 69 | 70 | 71 | def _as_list(obj): 72 | """Ensure an object is a list.""" 73 | if obj is None: 74 | return None 75 | elif isinstance(obj, str): 76 | return [obj] 77 | elif isinstance(obj, tuple): 78 | return list(obj) 79 | elif not hasattr(obj, '__len__'): 80 | return [obj] 81 | else: 82 | return obj 83 | 84 | 85 | def _is_array_like(arr): 86 | """Return whether an object is an array or a list.""" 87 | return isinstance(arr, (list, np.ndarray)) 88 | 89 | 90 | def _as_array(arr, dtype=None): 91 | """Convert an object to a numerical NumPy array. 92 | 93 | Avoid a copy if possible. 94 | 95 | """ 96 | if arr is None: 97 | return None 98 | if isinstance(arr, np.ndarray) and dtype is None: 99 | return arr 100 | if isinstance(arr, (int, float)): 101 | arr = [arr] 102 | out = np.asarray(arr) 103 | if dtype is not None: 104 | if out.dtype != dtype: 105 | out = out.astype(dtype) 106 | if out.dtype not in _ACCEPTED_ARRAY_DTYPES: 107 | raise ValueError("'arr' seems to have an invalid dtype: " 108 | "{0:s}".format(str(out.dtype))) 109 | return out 110 | 111 | 112 | def _as_tuple(item): 113 | """Ensure an item is a tuple.""" 114 | if item is None: 115 | return None 116 | elif not isinstance(item, tuple): 117 | return (item,) 118 | else: 119 | return item 120 | -------------------------------------------------------------------------------- /phylib/utils/event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | """Simple event system.""" 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | from contextlib import contextmanager 11 | import logging 12 | import string 13 | import re 14 | from functools import partial 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Event system 21 | #------------------------------------------------------------------------------ 22 | 23 | class EventEmitter(object): 24 | """Singleton class that emits events and accepts registered callbacks. 25 | 26 | Example 27 | ------- 28 | 29 | ```python 30 | class MyClass(EventEmitter): 31 | def f(self): 32 | self.emit('my_event', 1, key=2) 33 | 34 | o = MyClass() 35 | 36 | # The following function will be called when `o.f()` is called. 37 | @o.connect 38 | def on_my_event(arg, key=None): 39 | print(arg, key) 40 | 41 | ``` 42 | 43 | """ 44 | 45 | def __init__(self): 46 | self.reset() 47 | self.is_silent = False 48 | 49 | def set_silent(self, silent): 50 | """Set whether to silence the events.""" 51 | self.is_silent = silent 52 | 53 | def reset(self): 54 | """Remove all registered callbacks.""" 55 | self._callbacks = [] 56 | 57 | def _get_on_name(self, func): 58 | """Return `eventname` when the function name is `on_()`.""" 59 | r = re.match("^on_(.+)$", func.__name__) 60 | if r: 61 | event = r.group(1) 62 | else: 63 | raise ValueError("The function name should be " 64 | "`on_`().") 65 | return event 66 | 67 | @contextmanager 68 | def silent(self): 69 | """Prevent all callbacks to be called if events are raised 70 | in the context manager. 71 | """ 72 | self.is_silent = not(self.is_silent) 73 | yield 74 | self.is_silent = not(self.is_silent) 75 | 76 | def connect(self, func=None, event=None, sender=None, **kwargs): 77 | """Register a callback function to a given event. 78 | 79 | To register a callback function to the `spam` event, where `obj` is 80 | an instance of a class deriving from `EventEmitter`: 81 | 82 | ```python 83 | @obj.connect(sender=sender) 84 | def on_spam(sender, arg1, arg2): 85 | pass 86 | ``` 87 | 88 | This is called when `obj.emit('spam', sender, arg1, arg2)` is called. 89 | 90 | Several callback functions can be registered for a given event. 91 | 92 | The registration order is conserved and may matter in applications. 93 | 94 | """ 95 | if func is None: 96 | return partial(self.connect, event=event, sender=sender, **kwargs) 97 | 98 | # Get the event name from the function. 99 | if event is None: 100 | event = self._get_on_name(func) 101 | 102 | # We register the callback function. 103 | self._callbacks.append((event, sender, func, kwargs)) 104 | 105 | return func 106 | 107 | def unconnect(self, *items): 108 | """Unconnect specified callback functions or senders.""" 109 | self._callbacks = [ 110 | (event, sender, f, kwargs) 111 | for (event, sender, f, kwargs) in self._callbacks 112 | if f not in items and sender not in items and 113 | getattr(f, '__self__', None) not in items] 114 | 115 | def emit(self, event, sender, *args, **kwargs): 116 | """Call all callback functions registered with an event. 117 | 118 | Any positional and keyword arguments can be passed here, and they will 119 | be forwarded to the callback functions. 120 | 121 | Return the list of callback return results. 122 | 123 | """ 124 | if self.is_silent: 125 | return 126 | sender_name = sender.__class__.__name__ 127 | logger.log( 128 | 5, "Emit %s.%s(%s, %s)", sender_name, event, 129 | ', '.join(map(str, args)), ', '.join('%s=%s' % (k, v) for k, v in kwargs.items())) 130 | # Call the last callback if this is a single event. 131 | single = kwargs.pop('single', None) 132 | res = [] 133 | # Put `last=True` callbacks at the end. 134 | callbacks = [c for c in self._callbacks if not c[-1].get('last', None)] 135 | callbacks += [c for c in self._callbacks if c[-1].get('last', None)] 136 | for e, s, f, k in callbacks: 137 | if e == event and (s is None or s == sender): 138 | f_name = getattr(f, '__qualname__', getattr(f, '__name__', str(f))) 139 | s_name = s.__class__.__name__ 140 | logger.log(5, "Callback %s (%s).", f_name, s_name) 141 | res.append(f(sender, *args, **kwargs)) 142 | if single: 143 | return res[-1] 144 | return res 145 | 146 | 147 | #------------------------------------------------------------------------------ 148 | # Progress reporter 149 | #------------------------------------------------------------------------------ 150 | 151 | class PartialFormatter(string.Formatter): 152 | """Prevent KeyError when a format parameter is absent.""" 153 | def get_field(self, field_name, args, kwargs): 154 | try: 155 | return super(PartialFormatter, self).get_field(field_name, 156 | args, 157 | kwargs) 158 | except (KeyError, AttributeError): 159 | return None, field_name 160 | 161 | def format_field(self, value, spec): 162 | """Format a field.""" 163 | if value is None: 164 | return '?' 165 | try: 166 | return super(PartialFormatter, self).format_field(value, spec) 167 | except ValueError: 168 | return '?' 169 | 170 | 171 | def _default_on_progress(sender, message, value, value_max, end='\r', **kwargs): 172 | if value_max == 0: # pragma: no cover 173 | return 174 | if value <= value_max: 175 | progress = 100 * value / float(value_max) 176 | fmt = PartialFormatter() 177 | kwargs['value'] = value 178 | kwargs['value_max'] = value_max 179 | print(fmt.format(message, progress=progress, **kwargs), end=end) 180 | 181 | 182 | def _default_on_complete(message, end='\n', **kwargs): 183 | # Override the initializing message and clear the terminal 184 | # line. 185 | fmt = PartialFormatter() 186 | print(fmt.format(message + '\033[K', **kwargs), end=end) 187 | 188 | 189 | class ProgressReporter(object): 190 | """A class that reports progress done. 191 | 192 | Example 193 | ------- 194 | 195 | ```python 196 | pr = ProgressReporter() 197 | pr.set_progress_message("Progress: {progress}%...") 198 | pr.set_complete_message("Completed!") 199 | pr.value_max = 10 200 | 201 | for i in range(10): 202 | pr.value += 1 # or pr.increment() 203 | ``` 204 | 205 | You can also add custom keyword arguments in `pr.increment()`: these 206 | will be replaced in the message string. 207 | 208 | Emits 209 | ----- 210 | 211 | * `progress(value, value_max)` 212 | * `complete()` 213 | 214 | """ 215 | def __init__(self): 216 | super(ProgressReporter, self).__init__() 217 | self._value = 0 218 | self._value_max = 0 219 | self._has_completed = False 220 | 221 | def set_progress_message(self, message, line_break=False): 222 | """Set a progress message. 223 | 224 | The string needs to contain `{progress}`. 225 | 226 | """ 227 | 228 | end = '\r' if not line_break else None 229 | 230 | @connect(sender=self) 231 | def on_progress(sender, value, value_max, **kwargs): 232 | kwargs['end'] = None if value == value_max else end 233 | _default_on_progress(sender, message, value, value_max, **kwargs) 234 | 235 | def set_complete_message(self, message): 236 | """Set a complete message.""" 237 | 238 | @connect(sender=self) 239 | def on_complete(sender, **kwargs): 240 | _default_on_complete(message, **kwargs) 241 | 242 | def _set_value(self, value, **kwargs): 243 | if value < self._value_max: 244 | self._has_completed = False 245 | self._value = value 246 | emit('progress', self, self._value, self._value_max, **kwargs) 247 | if not self._has_completed and self._value >= self._value_max: 248 | emit('complete', self, **kwargs) 249 | self._has_completed = True 250 | 251 | def increment(self, **kwargs): 252 | """Equivalent to `self.value += 1`. 253 | 254 | Custom keywoard arguments can also be passed to be processed in the 255 | progress message format string. 256 | 257 | """ 258 | self._set_value(self._value + 1, **kwargs) 259 | 260 | def reset(self, value_max=None): 261 | """Reset the value to 0 and the value max to a given value.""" 262 | self._value = 0 263 | if value_max is not None: 264 | self._value_max = value_max 265 | 266 | @property 267 | def value(self): 268 | """Current value (integer).""" 269 | return self._value 270 | 271 | @value.setter 272 | def value(self, value): 273 | self._set_value(value) 274 | 275 | @property 276 | def value_max(self): 277 | """Maximum value (integer).""" 278 | return self._value_max 279 | 280 | @value_max.setter 281 | def value_max(self, value_max): 282 | if value_max > self._value_max: 283 | self._has_completed = False 284 | self._value_max = value_max 285 | 286 | def is_complete(self): 287 | """Return whether the task has completed.""" 288 | return self._value >= self._value_max 289 | 290 | def set_complete(self, **kwargs): 291 | """Set the task as complete.""" 292 | self._set_value(self.value_max, **kwargs) 293 | 294 | @property 295 | def progress(self): 296 | """Return the current progress as a float value in `[0, 1]`.""" 297 | return self._value / float(self._value_max) 298 | 299 | 300 | #------------------------------------------------------------------------------ 301 | # Global event system 302 | #------------------------------------------------------------------------------ 303 | 304 | _EVENT = EventEmitter() 305 | 306 | emit = _EVENT.emit 307 | connect = _EVENT.connect 308 | unconnect = _EVENT.unconnect 309 | silent = _EVENT.silent 310 | set_silent = _EVENT.set_silent 311 | reset = _EVENT.reset 312 | -------------------------------------------------------------------------------- /phylib/utils/geometry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Plotting utilities.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | from functools import partial 11 | import logging 12 | 13 | import numpy as np 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | # Common probe layouts 20 | #------------------------------------------------------------------------------ 21 | 22 | def linear_positions(n_channels): 23 | """Linear channel positions along the vertical axis.""" 24 | return np.c_[np.zeros(n_channels), 25 | np.linspace(0., 1., n_channels)] 26 | 27 | 28 | def staggered_positions(n_channels): 29 | """Generate channel positions for a staggered probe.""" 30 | i = np.arange(n_channels - 1) 31 | x, y = (-1) ** i * (5 + i), 10 * (i + 1) 32 | pos = np.flipud(np.r_[np.zeros((1, 2)), np.c_[x, y]]) 33 | return pos 34 | 35 | 36 | #------------------------------------------------------------------------------ 37 | # Box positioning 38 | #------------------------------------------------------------------------------ 39 | 40 | def range_transform(from_bounds, to_bounds, positions, do_offset=True): 41 | """Transform for a rectangle to another.""" 42 | from_bounds = np.asarray(from_bounds) 43 | to_bounds = np.asarray(to_bounds) 44 | positions = np.asarray(positions) 45 | 46 | assert from_bounds.ndim == to_bounds.ndim == positions.ndim == 2 47 | 48 | f0 = from_bounds[..., :2] 49 | f1 = from_bounds[..., 2:] 50 | t0 = to_bounds[..., :2] 51 | t1 = to_bounds[..., 2:] 52 | 53 | # Degenerate axes are extended maximally. 54 | for z0, z1 in ((f0, f1), (t0, t1)): 55 | for i in range(2): 56 | ind = np.abs(z0[:, i] - z1[:, i]) < 1e-8 57 | z0[ind, i] = -1 58 | z1[ind, i] = +1 59 | 60 | d = (f1 - f0) 61 | d[d == 0] = 1 62 | 63 | out = positions.copy() 64 | if do_offset: 65 | out -= f0.astype(out.dtype) 66 | out *= ((t1 - t0) / d).astype(out.dtype) 67 | if do_offset: 68 | out += t0.astype(out.dtype) 69 | return out 70 | 71 | 72 | def _boxes_overlap(x0, y0, x1, y1): 73 | """Return whether a set of boxes, defined by their 2D corners, overlap or not.""" 74 | assert x0.ndim == y0.ndim == y0.ndim == y1.ndim == 2 75 | n = len(x0) 76 | overlap_matrix = ((x0 < x1.T) & (x1 > x0.T) & (y0 < y1.T) & (y1 > y0.T)) 77 | overlap_matrix[np.arange(n), np.arange(n)] = False 78 | return np.any(overlap_matrix.ravel()) 79 | 80 | 81 | def _binary_search(f, xmin, xmax, eps=1e-9): 82 | """Return the largest x such f(x) is True.""" 83 | middle = (xmax + xmin) / 2. 84 | while xmax - xmin > eps: 85 | assert xmin < xmax 86 | middle = (xmax + xmin) / 2. 87 | if f(xmax): 88 | return xmax 89 | if not f(xmin): 90 | return xmin 91 | if f(middle): 92 | xmin = middle 93 | else: 94 | xmax = middle 95 | return middle 96 | 97 | 98 | def _find_box_size(x, y, ar=.5, margin=0): 99 | """Return the maximum (half) box size such that boxes centered around box positions 100 | do not overlap.""" 101 | if x.ndim == 1: 102 | x = x[:, np.newaxis] 103 | if y.ndim == 1: 104 | y = y[:, np.newaxis] 105 | logger.log(5, "Get box size for %d points.", len(x)) 106 | # Deal with degenerate x case. 107 | xmin, xmax = x.min(), x.max() 108 | if xmin == xmax: 109 | # If all positions are vertical, the width can be maximum. 110 | wmax = 1. 111 | else: 112 | wmax = xmax - xmin 113 | 114 | def f1(w, keep_aspect_ratio=True, h=None): 115 | """Return true if the configuration with the current box size 116 | is non-overlapping.""" 117 | # NOTE: w|h are the *half* width|height. 118 | if keep_aspect_ratio: 119 | h = w * ar # fixed aspect ratio 120 | return not _boxes_overlap(x - w, y - h, x + w, y + h) 121 | 122 | # Find the largest box size leading to non-overlapping boxes. 123 | w = _binary_search(f1, 0, wmax) 124 | w = w * (1 - margin) # margin 125 | # Clip the half-width. 126 | h = w * ar # aspect ratio 127 | 128 | # Extend the boxes horizontally as much as possible. 129 | w = _binary_search(partial(f1, keep_aspect_ratio=False, h=h), w, wmax) 130 | w = w * (1 - margin) # margin 131 | 132 | return w, h 133 | 134 | 135 | def get_non_overlapping_boxes(box_pos): 136 | """Normalize box positions and return a convenient half box size.""" 137 | box_pos = np.asarray(box_pos) 138 | assert box_pos.ndim == 2 139 | assert box_pos.shape[1] == 2 140 | # Renormalize box_pos. 141 | mx, my = box_pos.min(axis=0) 142 | Mx, My = box_pos.max(axis=0) 143 | box_pos = range_transform([[mx, my, Mx, My]], [[-1, -1, +1, +1]], box_pos) 144 | # Compute box size. 145 | x, y = box_pos.T 146 | w, h = _find_box_size(x, y, margin=.1) 147 | # Renormalize again so that the boxes fit inside the view. 148 | mx, my = np.min(box_pos - np.array([[w, h]]), axis=0) 149 | Mx, My = np.max(box_pos + np.array([[w, h]]), axis=0) 150 | b1 = [[mx, my, Mx, My]] 151 | b2 = [[-1, -1, 1, 1]] 152 | box_pos = range_transform(b1, b2, box_pos) 153 | w, h = range_transform(b1, b2, [[w, h]], do_offset=False).ravel() 154 | w *= .95 155 | h *= .9 156 | logger.log(5, "Found box size %s.", (w, h)) 157 | return box_pos, (w, h) 158 | 159 | 160 | def get_closest_box(pos, box_pos, box_size): 161 | """Return the box closest to a given point.""" 162 | # box_size is the half size 163 | # see https://gamedev.stackexchange.com/a/44496 164 | w, h = box_size 165 | x, y = pos 166 | px, py = box_pos.T 167 | dx = np.maximum(np.abs(px - x) - w, 0) 168 | dy = np.maximum(np.abs(py - y) - h, 0) 169 | d = dx * dx + dy * dy 170 | return np.argmin(d) 171 | 172 | 173 | #------------------------------------------------------------------------------ 174 | # Data bounds utilities 175 | #------------------------------------------------------------------------------ 176 | 177 | def _get_data_bounds(data_bounds, pos=None, length=None): 178 | """"Prepare data bounds, possibly using min/max of the data.""" 179 | if data_bounds is None or (isinstance(data_bounds, str) and data_bounds == 'auto'): 180 | if pos is not None and len(pos): 181 | m, M = pos.min(axis=0), pos.max(axis=0) 182 | data_bounds = [m[0], m[1], M[0], M[1]] 183 | else: 184 | data_bounds = [-1, -1, 1, 1] 185 | data_bounds = np.atleast_2d(data_bounds) 186 | 187 | ind_x = data_bounds[:, 0] == data_bounds[:, 2] 188 | ind_y = data_bounds[:, 1] == data_bounds[:, 3] 189 | if np.sum(ind_x): 190 | data_bounds[ind_x, 0] -= 1 191 | data_bounds[ind_x, 2] += 1 192 | if np.sum(ind_y): 193 | data_bounds[ind_y, 1] -= 1 194 | data_bounds[ind_y, 3] += 1 195 | 196 | # Extend the data_bounds if needed. 197 | if length is None: 198 | length = pos.shape[0] if pos is not None else 1 199 | if data_bounds.shape[0] == 1: 200 | data_bounds = np.tile(data_bounds, (length, 1)) 201 | 202 | # Check the shape of data_bounds. 203 | assert data_bounds.shape == (length, 4) 204 | 205 | assert data_bounds.ndim == 2 206 | assert data_bounds.shape[1] == 4 207 | assert np.all(data_bounds[:, 0] < data_bounds[:, 2]) 208 | assert np.all(data_bounds[:, 1] < data_bounds[:, 3]) 209 | 210 | return data_bounds 211 | -------------------------------------------------------------------------------- /phylib/utils/testing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Utility functions used for tests.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | from contextlib import contextmanager 10 | from io import StringIO 11 | import logging 12 | import os 13 | import sys 14 | 15 | from numpy.testing import assert_array_equal as ae 16 | from numpy.testing import assert_allclose as ac 17 | 18 | from ._types import _is_array_like 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | #------------------------------------------------------------------------------ 24 | # Utility functions 25 | #------------------------------------------------------------------------------ 26 | 27 | @contextmanager 28 | def captured_output(): 29 | """Context manager that captures all output to stdout and stderr.""" 30 | new_out, new_err = StringIO(), StringIO() 31 | old_out, old_err = sys.stdout, sys.stderr 32 | try: 33 | sys.stdout, sys.stderr = new_out, new_err 34 | yield sys.stdout, sys.stderr 35 | finally: 36 | sys.stdout, sys.stderr = old_out, old_err 37 | 38 | 39 | @contextmanager 40 | def captured_logging(name=None): 41 | """Context manager that captures all logging.""" 42 | logger = logging.getLogger(name) 43 | handlers = list(logger.handlers) 44 | for handler in logger.handlers: 45 | logger.removeHandler(handler) 46 | buffer = StringIO() 47 | handler = logging.StreamHandler(buffer) 48 | handler.setLevel(logging.DEBUG) 49 | logger.addHandler(handler) 50 | yield buffer 51 | buffer.flush() 52 | logger.removeHandler(handler) 53 | for handler in handlers: 54 | logger.addHandler(handler) 55 | handler.close() 56 | 57 | 58 | def _assert_equal(d_0, d_1): 59 | """Check that two objects are equal.""" 60 | # Compare arrays. 61 | if _is_array_like(d_0): 62 | try: 63 | ae(d_0, d_1) 64 | except AssertionError: 65 | ac(d_0, d_1) 66 | # Compare dicts recursively. 67 | elif isinstance(d_0, dict): 68 | assert set(d_0) == set(d_1) 69 | for k_0 in d_0: 70 | _assert_equal(d_0[k_0], d_1[k_0]) 71 | else: 72 | # General comparison. 73 | assert d_0 == d_1 74 | 75 | 76 | def _in_travis(): # pragma: no cover 77 | """Return whether we're in travis.""" 78 | return 'TRAVIS' in os.environ 79 | -------------------------------------------------------------------------------- /phylib/utils/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cortex-lab/phylib/72316a4ccb0abed93464ed4bfcd36a3809d1dbdf/phylib/utils/tests/__init__.py -------------------------------------------------------------------------------- /phylib/utils/tests/test_event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test event system.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | from pytest import raises 10 | 11 | from ..event import EventEmitter, ProgressReporter, connect 12 | 13 | 14 | #------------------------------------------------------------------------------ 15 | # Test event system 16 | #------------------------------------------------------------------------------ 17 | 18 | def test_event_system(): 19 | ev = EventEmitter() 20 | 21 | _list = [] 22 | 23 | with raises(ValueError): 24 | ev.connect(lambda x: x) 25 | 26 | @ev.connect 27 | def on_my_event(sender, arg, kwarg=None): 28 | _list.append((arg, kwarg)) 29 | 30 | ev.emit('my_event', ev, 'a') 31 | assert _list == [('a', None)] 32 | 33 | ev.emit('my_event', ev, 'b', 'c') 34 | assert _list == [('a', None), ('b', 'c')] 35 | 36 | ev.unconnect(on_my_event) 37 | 38 | ev.emit('my_event', ev, 'b', 'c') 39 | assert _list == [('a', None), ('b', 'c')] 40 | 41 | 42 | def test_event_silent(): 43 | ev = EventEmitter() 44 | 45 | _list = [] 46 | 47 | @ev.connect() 48 | def on_test(sender, x): 49 | _list.append(x) 50 | 51 | ev.emit('test', ev, 1) 52 | assert _list == [1] 53 | 54 | with ev.silent(): 55 | ev.emit('test', ev, 1) 56 | assert _list == [1] 57 | 58 | ev.set_silent(True) 59 | 60 | 61 | def test_event_single(): 62 | ev = EventEmitter() 63 | 64 | l = [] 65 | 66 | @ev.connect(event='test') 67 | def on_test_bou(sender): 68 | l.append(0) 69 | 70 | @ev.connect # noqa 71 | def on_test(sender): 72 | l.append(1) 73 | 74 | ev.emit('test', ev) 75 | assert l == [0, 1] 76 | 77 | ev.emit('test', ev, single=True) 78 | assert l == [0, 1, 0] 79 | 80 | 81 | #------------------------------------------------------------------------------ 82 | # Test progress reporter 83 | #------------------------------------------------------------------------------ 84 | 85 | def test_progress_reporter(): 86 | """Test the progress reporter.""" 87 | pr = ProgressReporter() 88 | 89 | _reported = [] 90 | _completed = [] 91 | 92 | @connect(sender=pr) 93 | def on_progress(sender, value, value_max): 94 | # value is the sum of the values, value_max the sum of the max values 95 | _reported.append((value, value_max)) 96 | 97 | @connect(sender=pr) 98 | def on_complete(sender): 99 | _completed.append(True) 100 | 101 | pr.value_max = 10 102 | pr.value = 0 103 | pr.value = 5 104 | assert pr.value == 5 105 | assert pr.progress == .5 106 | assert not pr.is_complete() 107 | pr.value = 10 108 | assert pr.is_complete() 109 | assert pr.progress == 1. 110 | assert _completed == [True] 111 | 112 | pr.value_max = 11 113 | assert not pr.is_complete() 114 | assert pr.progress < 1. 115 | pr.set_complete() 116 | assert pr.is_complete() 117 | assert pr.progress == 1. 118 | 119 | assert _reported == [(0, 10), (5, 10), (10, 10), (11, 11)] 120 | assert _completed == [True, True] 121 | 122 | pr.value = 10 123 | # Only trigger a complete event once. 124 | pr.value = pr.value_max 125 | pr.value = pr.value_max 126 | assert _completed == [True, True, True] 127 | 128 | 129 | def test_progress_message(): 130 | """Test messages with the progress reporter.""" 131 | pr = ProgressReporter() 132 | pr.reset(5) 133 | pr.set_progress_message("The progress is {progress}%. ({hello:d})") 134 | pr.set_complete_message("Finished {hello}.") 135 | 136 | pr.value_max = 10 137 | pr.value = 0 138 | print() 139 | pr.value = 5 140 | print() 141 | pr.increment() 142 | print() 143 | pr.increment(hello='hello') 144 | print() 145 | pr.increment(hello=3) 146 | print() 147 | pr.value = 10 148 | -------------------------------------------------------------------------------- /phylib/utils/tests/test_geometry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Test geometry utilities.""" 4 | 5 | 6 | #------------------------------------------------------------------------------ 7 | # Imports 8 | #------------------------------------------------------------------------------ 9 | 10 | import numpy as np 11 | from numpy.testing import assert_array_equal as ae 12 | from numpy.testing import assert_allclose as ac 13 | import pytest 14 | 15 | from ..geometry import ( 16 | linear_positions, 17 | staggered_positions, 18 | _get_data_bounds, 19 | _boxes_overlap, 20 | _binary_search, 21 | _find_box_size, 22 | get_non_overlapping_boxes, 23 | get_closest_box, 24 | ) 25 | 26 | 27 | #------------------------------------------------------------------------------ 28 | # Test utilities 29 | #------------------------------------------------------------------------------ 30 | 31 | def test_get_data_bounds(): 32 | ae(_get_data_bounds(None), [[-1., -1., 1., 1.]]) 33 | 34 | db0 = np.array([[0, 1, 4, 5], 35 | [0, 1, 4, 5], 36 | [0, 1, 4, 5]]) 37 | arr = np.arange(6).reshape((3, 2)) 38 | assert np.all(_get_data_bounds(None, arr) == [[0, 1, 4, 5]]) 39 | 40 | db = db0.copy() 41 | assert np.all(_get_data_bounds(db, arr) == [[0, 1, 4, 5]]) 42 | 43 | db = db0.copy() 44 | db[2, :] = [1, 1, 1, 1] 45 | assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 1, 4, 5]]) 46 | assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 0, 2, 2]) 47 | 48 | db = db0.copy() 49 | db[:2, :] = [1, 1, 1, 1] 50 | assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 0, 2, 2]]) 51 | assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 1, 4, 5]) 52 | 53 | 54 | def test_boxes_overlap(): 55 | 56 | def _get_args(boxes): 57 | x0, y0, x1, y1 = np.array(boxes).T 58 | x0 = x0[:, np.newaxis] 59 | x1 = x1[:, np.newaxis] 60 | y0 = y0[:, np.newaxis] 61 | y1 = y1[:, np.newaxis] 62 | return x0, y0, x1, y1 63 | 64 | boxes = [[-1, -1, 0, 0], [0.01, 0.01, 1, 1]] 65 | x0, y0, x1, y1 = _get_args(boxes) 66 | assert not _boxes_overlap(x0, y0, x1, y1) 67 | 68 | boxes = [[-1, -1, 0.1, 0.1], [0, 0, 1, 1]] 69 | x0, y0, x1, y1 = _get_args(boxes) 70 | assert _boxes_overlap(x0, y0, x1, y1) 71 | 72 | x = np.zeros((5, 1)) 73 | x0 = x - .1 74 | x1 = x + .1 75 | y = np.linspace(-1, 1, 5)[:, np.newaxis] 76 | y0 = y - .2 77 | y1 = y + .2 78 | assert not _boxes_overlap(x0, y0, x1, y1) 79 | 80 | 81 | def test_binary_search(): 82 | def f(x): 83 | return x < .4 84 | ac(_binary_search(f, 0, 1), .4) 85 | ac(_binary_search(f, 0, .3), .3) 86 | ac(_binary_search(f, .5, 1), .5) 87 | 88 | 89 | def test_find_box_size(): 90 | x = np.zeros(5) 91 | y = np.linspace(-1, 1, 5) 92 | w, h = _find_box_size(x, y, margin=0) 93 | ac(w, .5, atol=1e-8) 94 | ac(h, .25, atol=1e-8) 95 | 96 | 97 | @pytest.mark.parametrize('n_channels', [5, 500]) 98 | def test_get_non_overlapping_boxes_1(n_channels): 99 | x = np.zeros(n_channels) 100 | y = np.linspace(-1, 1, n_channels) 101 | box_pos, box_size = get_non_overlapping_boxes(np.c_[x, y]) 102 | ac(box_pos[:, 0], 0, atol=1e-8) 103 | ac(box_pos[:, 1], -box_pos[::-1, 1], atol=1e-8) 104 | 105 | assert box_size[0] >= .8 106 | 107 | s = np.array([box_size]) 108 | box_bounds = np.c_[box_pos - s, box_pos + s] 109 | assert box_bounds.min() >= -1 110 | assert box_bounds.max() <= +1 111 | 112 | 113 | def test_get_non_overlapping_boxes_2(): 114 | pos = staggered_positions(32) 115 | box_pos, box_size = get_non_overlapping_boxes(pos) 116 | assert box_size[0] >= .05 117 | 118 | s = np.array([box_size]) 119 | box_bounds = np.c_[box_pos - s, box_pos + s] 120 | assert box_bounds.min() >= -1 121 | assert box_bounds.max() <= +1 122 | 123 | 124 | def test_get_closest_box(): 125 | n = 10 126 | px = np.zeros(n) 127 | py = np.linspace(-1, 1, n) 128 | box_pos = np.c_[px, py] 129 | w, h = (1, .9 / n) 130 | expected = [] 131 | for x in (0, -1, 1, -2, +2): 132 | for i in range(n): 133 | expected.extend([ 134 | (x, py[i], i), 135 | (x, py[i] - h, i), 136 | (x, py[i] + h, i), 137 | (x, py[i] - 1.25 * h, max(0, min(i - 1, n - 1))), 138 | (x, py[i] + 1.25 * h, max(0, min(i + 1, n - 1))), 139 | ]) 140 | for x, y, i in expected: 141 | assert get_closest_box((x, y), box_pos, (w, h)) == i 142 | 143 | 144 | def test_positions(): 145 | probe = staggered_positions(31) 146 | assert probe.shape == (31, 2) 147 | ae(probe[-1], (0, 0)) 148 | 149 | probe = linear_positions(29) 150 | assert probe.shape == (29, 2) 151 | -------------------------------------------------------------------------------- /phylib/utils/tests/test_misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of misc utility functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | from numpy.testing import assert_array_equal as ae 11 | from pytest import raises, mark 12 | 13 | from .._misc import ( 14 | _git_version, load_json, save_json, load_pickle, save_pickle, read_python, write_python, 15 | read_text, write_text, _read_tsv_simple, _write_tsv_simple, read_tsv, write_tsv, 16 | _pretty_floats, _encode_qbytearray, _decode_qbytearray, _fullname, _load_from_fullname) 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Misc tests 21 | #------------------------------------------------------------------------------ 22 | 23 | def test_qbytearray(tempdir): 24 | try: 25 | from PyQt5.QtCore import QByteArray 26 | except ImportError: # pragma: no cover 27 | return 28 | arr = QByteArray() 29 | arr.append('1') 30 | arr.append('2') 31 | arr.append('3') 32 | 33 | encoded = _encode_qbytearray(arr) 34 | assert isinstance(encoded, str) 35 | decoded = _decode_qbytearray(encoded) 36 | assert arr == decoded 37 | 38 | # Test JSON serialization of QByteArray. 39 | d = {'arr': arr} 40 | path = tempdir / 'test' 41 | save_json(path, d) 42 | d_bis = load_json(path) 43 | assert d == d_bis 44 | 45 | 46 | def test_pretty_float(): 47 | assert _pretty_floats(0.123456) == '0.12' 48 | assert _pretty_floats([0.123456]) == ['0.12'] 49 | assert _pretty_floats({'a': 0.123456}) == {'a': '0.12'} 50 | 51 | 52 | def test_json_simple(tempdir): 53 | d = {'a': 1, 'b': 'bb', 3: '33', 'mock': {'mock': True}} 54 | 55 | path = tempdir / 'test_dir/test' 56 | save_json(path, d) 57 | d_bis = load_json(path) 58 | assert d == d_bis 59 | 60 | path.write_text('') 61 | assert load_json(path) == {} 62 | with raises(IOError): 63 | load_json('%s_bis' % path) 64 | 65 | 66 | @mark.parametrize('kind', ['json', 'pickle']) 67 | def test_json_numpy(tempdir, kind): 68 | arr = np.arange(20).reshape((2, -1)).astype(np.float32) 69 | d = {'a': arr, 'b': arr.ravel()[:10], 'c': arr[0, 0]} 70 | 71 | path = tempdir / 'test' 72 | f = save_json if kind == 'json' else save_pickle 73 | f(path, d) 74 | 75 | f = load_json if kind == 'json' else load_pickle 76 | d_bis = f(path) 77 | arr_bis = d_bis['a'] 78 | 79 | assert arr_bis.dtype == arr.dtype 80 | assert arr_bis.shape == arr.shape 81 | ae(arr_bis, arr) 82 | 83 | ae(d['b'], d_bis['b']) 84 | ae(d['c'], d_bis['c']) 85 | 86 | 87 | def test_read_python(tempdir): 88 | path = tempdir / 'mock.py' 89 | with open(path, 'w') as f: 90 | f.write("""a = {'b': 1}""") 91 | 92 | assert read_python(path) == {'a': {'b': 1}} 93 | 94 | 95 | def test_write_python(tempdir): 96 | data = {'a': 1, 'b': 'hello', 'c': [1, 2, 3]} 97 | path = tempdir / 'mock.py' 98 | 99 | write_python(path, data) 100 | assert read_python(path) == data 101 | 102 | 103 | def test_write_text(tempdir): 104 | for path in (tempdir / 'test_1', 105 | tempdir / 'test_dir/test_2.txt', 106 | ): 107 | write_text(path, 'hello world') 108 | assert read_text(path) == 'hello world' 109 | 110 | 111 | def test_write_tsv_simple(tempdir): 112 | path = tempdir / 'test.tsv' 113 | assert _read_tsv_simple(path) == {} 114 | 115 | # The read/write TSV functions conserve the types: int, float, or strings. 116 | data = {2: 20, 3: 30.5, 5: 'hello'} 117 | _write_tsv_simple(path, 'myfield', data) 118 | 119 | assert _read_tsv_simple(path) == ('myfield', data) 120 | 121 | 122 | def test_write_tsv(tempdir): 123 | path = tempdir / 'test.tsv' 124 | assert read_tsv(path) == [] 125 | write_tsv(path, []) 126 | 127 | data = [{'a': 1, 'b': 2}, {'a': 10}, {'b': 20, 'c': 30.5}] 128 | 129 | write_tsv(path, data) 130 | assert read_tsv(path) == data 131 | 132 | write_tsv(path, data, first_field='b', exclude_fields=('c', 'd')) 133 | assert read_text(path)[0] == 'b' 134 | del data[2]['c'] 135 | assert read_tsv(path) == data 136 | 137 | 138 | def test_git_version(): 139 | v = _git_version() 140 | assert v 141 | 142 | 143 | def _myfunction(x): 144 | return 145 | 146 | 147 | def test_fullname(): 148 | assert _fullname(_myfunction) == 'phylib.utils.tests.test_misc._myfunction' 149 | 150 | assert _load_from_fullname(_myfunction) == _myfunction 151 | assert _load_from_fullname(_fullname(_myfunction)) == _myfunction 152 | -------------------------------------------------------------------------------- /phylib/utils/tests/test_testing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of testing utility functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | from copy import deepcopy 10 | import logging 11 | 12 | import numpy as np 13 | 14 | from ..testing import captured_output, captured_logging, _assert_equal 15 | 16 | logger = logging.getLogger('phylib') 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Tests 21 | #------------------------------------------------------------------------------ 22 | 23 | def test_logging_1(): 24 | print() 25 | logger.setLevel(5) 26 | logger.log(5, "level 5") 27 | logger.log(10, "debug") 28 | logger.log(20, "info") 29 | logger.log(30, "warning") 30 | logger.log(40, "error") 31 | 32 | 33 | def test_captured_output(): 34 | with captured_output() as (out, err): 35 | print('Hello world!') 36 | assert out.getvalue().strip() == 'Hello world!' 37 | 38 | 39 | def test_captured_logging(): 40 | handlers = logger.handlers 41 | with captured_logging() as buf: 42 | logger.debug('Hello world!') 43 | assert 'Hello world!' in buf.getvalue() 44 | assert logger.handlers == handlers 45 | 46 | 47 | def test_assert_equal(): 48 | d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} 49 | d_bis = deepcopy(d) 50 | d_bis['a']['b'] = d_bis['a']['b'] + 1e-16 51 | _assert_equal(d, d_bis) 52 | -------------------------------------------------------------------------------- /phylib/utils/tests/test_types.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests of misc type utility functions.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import numpy as np 10 | from pytest import raises 11 | 12 | from .._types import (Bunch, _bunchify, _is_integer, _is_list, _is_float, 13 | _as_list, _is_array_like, _as_array, _as_tuple, 14 | _as_scalar, _as_scalars, 15 | ) 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | # Tests 20 | #------------------------------------------------------------------------------ 21 | 22 | def test_bunch(): 23 | obj = Bunch() 24 | obj['a'] = 1 25 | assert obj.a == 1 26 | obj.b = 2 27 | assert obj['b'] == 2 28 | assert obj.copy() == obj 29 | 30 | 31 | def test_bunchify(): 32 | d = {'a': {'b': 0}} 33 | b = _bunchify(d) 34 | assert isinstance(b, Bunch) 35 | assert isinstance(b['a'], Bunch) 36 | 37 | 38 | def test_number(): 39 | assert not _is_integer(None) 40 | assert not _is_integer(3.) 41 | assert _is_integer(3) 42 | assert _is_integer(np.arange(1)[0]) 43 | 44 | assert not _is_float(None) 45 | assert not _is_float(3) 46 | assert not _is_float(np.array([3])[0]) 47 | assert _is_float(3.) 48 | assert _is_float(np.array([3.])[0]) 49 | 50 | 51 | def test_list(): 52 | assert not _is_list(None) 53 | assert not _is_list(()) 54 | assert _is_list([]) 55 | 56 | assert _as_list(None) is None 57 | assert _as_list(3) == [3] 58 | assert _as_list([3]) == [3] 59 | assert _as_list((3,)) == [3] 60 | assert _as_list('3') == ['3'] 61 | assert np.all(_as_list(np.array([3])) == np.array([3])) 62 | 63 | 64 | def test_as_tuple(): 65 | assert _as_tuple(3) == (3,) 66 | assert _as_tuple((3,)) == (3,) 67 | assert _as_tuple(None) is None 68 | assert _as_tuple((None,)) == (None,) 69 | assert _as_tuple((3, 4)) == (3, 4) 70 | assert _as_tuple([3]) == ([3], ) 71 | assert _as_tuple([3, 4]) == ([3, 4], ) 72 | 73 | 74 | def test_as_scalar(): 75 | assert _as_scalar(1) == 1 76 | assert _as_scalar(np.ones(1)[0]) == 1. 77 | assert type(_as_scalar(np.ones(1)[0])) == float 78 | 79 | assert _as_scalars(np.arange(3)) == [0, 1, 2] 80 | 81 | 82 | def test_array(): 83 | def _check(arr): 84 | assert isinstance(arr, np.ndarray) 85 | assert np.all(arr == [3]) 86 | 87 | _check(_as_array(3)) 88 | _check(_as_array(3.)) 89 | _check(_as_array([3])) 90 | 91 | _check(_as_array(3, float)) 92 | _check(_as_array(3., float)) 93 | _check(_as_array([3], float)) 94 | _check(_as_array(np.array([3]))) 95 | with raises(ValueError): 96 | _check(_as_array(np.array([3]), dtype=object)) 97 | _check(_as_array(np.array([3]), float)) 98 | 99 | assert _as_array(None) is None 100 | assert not _is_array_like(None) 101 | assert not _is_array_like(3) 102 | assert _is_array_like([3]) 103 | assert _is_array_like(np.array([3])) 104 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | flake8 4 | coverage 5 | coveralls 6 | responses 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | dask 4 | requests 5 | tqdm 6 | toolz 7 | joblib 8 | mtscomp 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [wheel] 2 | universal = 1 3 | 4 | [tool:pytest] 5 | norecursedirs = 6 | filterwarnings = 7 | default 8 | ignore::DeprecationWarning:responses|cookies|socks|matplotlib 9 | ignore:numpy.ufunc 10 | 11 | [flake8] 12 | ignore=E265,E731,E741,W504,W605 13 | max-line-length=99 14 | 15 | [coverage:run] 16 | branch = False 17 | source = phylib 18 | omit = 19 | 20 | [coverage:report] 21 | exclude_lines = 22 | pragma: no cover 23 | raise 24 | except IOError: 25 | pass 26 | return$ 27 | continue$ 28 | omit = 29 | show_missing = True 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # flake8: noqa 3 | 4 | """Installation script.""" 5 | 6 | 7 | #------------------------------------------------------------------------------ 8 | # Imports 9 | #------------------------------------------------------------------------------ 10 | 11 | import os 12 | import os.path as op 13 | from pathlib import Path 14 | import re 15 | 16 | from setuptools import setup 17 | 18 | 19 | #------------------------------------------------------------------------------ 20 | # Setup 21 | #------------------------------------------------------------------------------ 22 | 23 | def _package_tree(pkgroot): 24 | path = op.dirname(__file__) 25 | subdirs = [op.relpath(i[0], path).replace(op.sep, '.') 26 | for i in os.walk(op.join(path, pkgroot)) 27 | if '__init__.py' in i[2]] 28 | return subdirs 29 | 30 | 31 | readme = (Path(__file__).parent / 'README.md').read_text() 32 | 33 | 34 | # Find version number from `__init__.py` without executing it. 35 | with (Path(__file__).parent / 'phylib/__init__.py').open('r') as f: 36 | version = re.search(r"__version__ = '([^']+)'", f.read()).group(1) 37 | 38 | with open('requirements.txt') as f: 39 | require = [x.strip() for x in f.readlines() if not x.startswith('git+')] 40 | 41 | setup( 42 | name='phylib', 43 | version=version, 44 | license="BSD", 45 | description='Ephys data analysis for thousands of channels', 46 | long_description=readme, 47 | long_description_content_type='text/markdown', 48 | author='Cyrille Rossant', 49 | author_email='cyrille.rossant@gmail.com', 50 | url='https://github.com/cortex-lab/phylib', 51 | packages=_package_tree('phylib'), 52 | package_dir={'phylib': 'phylib'}, 53 | package_data={ 54 | 'phylib': [ 55 | '*.vert', '*.frag', '*.glsl', '*.npy', '*.gz', '*.txt', 56 | '*.html', '*.css', '*.js', '*.prb'], 57 | }, 58 | include_package_data=True, 59 | keywords='phy,data analysis,electrophysiology,neuroscience', 60 | classifiers=[ 61 | 'Development Status :: 4 - Beta', 62 | 'Intended Audience :: Developers', 63 | 'License :: OSI Approved :: BSD License', 64 | 'Natural Language :: English', 65 | "Framework :: IPython", 66 | 'Programming Language :: Python :: 3', 67 | 'Programming Language :: Python :: 3.7', 68 | ], 69 | install_requires=require, 70 | ) 71 | --------------------------------------------------------------------------------