├── src └── braingeneers │ ├── py.typed │ ├── data │ ├── transforms │ │ └── __init__.py │ ├── mxw_h5_plugin │ │ ├── Linux │ │ │ └── libcompression.so │ │ ├── Windows │ │ │ └── compression.dll │ │ ├── Mac_arm64 │ │ │ └── libcompression.dylib │ │ └── Mac_x86_64 │ │ │ └── libcompression.dylib │ ├── datasets_fluidics.py │ ├── __init__.py │ ├── datasets_imaging.py │ └── datasets_neuron.py │ ├── ml │ ├── __init__.py │ └── ephys_dataloader.py │ ├── iot │ ├── __init__.py │ ├── example_device_main.py │ ├── example_device.py │ ├── shadows_dev_playground.py │ ├── authenticate.py │ ├── simple.py │ └── gui.py │ ├── analysis │ ├── __init__.py │ ├── visualize_maxwell.py │ └── analysis.py │ ├── utils │ ├── smart_open_braingeneers │ │ └── __init__.py │ ├── s3wrangler │ │ └── __init__.py │ ├── __init__.py │ ├── configure.py │ ├── numpy_s3_memmap.py │ ├── memoize_s3.py │ └── common_utils.py │ └── __init__.py ├── .gitattributes ├── .git_archival.txt ├── .github ├── dependabot.yml ├── matchers │ └── pylint.json ├── workflows │ ├── ci.yml │ ├── cd.yml │ └── publish.yml └── CONTRIBUTING.md ├── tests ├── test_package.py ├── test_s3wrangler.py ├── test_analysis.py ├── test_smart_open_braingeneers.py ├── test_numpy_s3_memmap.py ├── test_memoize_s3.py ├── test_data │ ├── maxwell-metadata.old.json │ └── maxwell-metadata.expected.json ├── test_common_utils.py └── test_messaging.py ├── docs └── source │ ├── index.md │ └── conf.py ├── .readthedocs.yaml ├── .devcontainer ├── devcontainer.json └── post_create.sh ├── Makefile ├── LICENSE ├── .pre-commit-config.yaml ├── noxfile.py ├── .gitignore └── pyproject.toml /src/braingeneers/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/braingeneers/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst 2 | -------------------------------------------------------------------------------- /src/braingeneers/ml/__init__.py: -------------------------------------------------------------------------------- 1 | import braingeneers 2 | -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: 7bf35e396af5f3bd860a84c5798d208c5ae27e26 2 | node-date: 2025-12-15T17:47:57-08:00 3 | describe-name: 0.3.1-9-g7bf35e3 4 | ref-names: HEAD -> master 5 | -------------------------------------------------------------------------------- /src/braingeneers/data/mxw_h5_plugin/Linux/libcompression.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braingeneers/braingeneerspy/HEAD/src/braingeneers/data/mxw_h5_plugin/Linux/libcompression.so -------------------------------------------------------------------------------- /src/braingeneers/data/mxw_h5_plugin/Windows/compression.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braingeneers/braingeneerspy/HEAD/src/braingeneers/data/mxw_h5_plugin/Windows/compression.dll -------------------------------------------------------------------------------- /src/braingeneers/iot/__init__.py: -------------------------------------------------------------------------------- 1 | import braingeneers 2 | from braingeneers.iot.messaging import * 3 | from braingeneers.iot.device import * 4 | from braingeneers.iot.simple import * 5 | -------------------------------------------------------------------------------- /src/braingeneers/data/mxw_h5_plugin/Mac_arm64/libcompression.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braingeneers/braingeneerspy/HEAD/src/braingeneers/data/mxw_h5_plugin/Mac_arm64/libcompression.dylib -------------------------------------------------------------------------------- /src/braingeneers/data/mxw_h5_plugin/Mac_x86_64/libcompression.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braingeneers/braingeneerspy/HEAD/src/braingeneers/data/mxw_h5_plugin/Mac_x86_64/libcompression.dylib -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /tests/test_package.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import braingeneers as m 4 | 5 | 6 | def test_version(): 7 | assert m.__version__ 8 | 9 | 10 | if __name__ == "__main__": 11 | unittest.main() 12 | -------------------------------------------------------------------------------- /src/braingeneers/data/datasets_fluidics.py: -------------------------------------------------------------------------------- 1 | def get_fluidics_data(uuid): 2 | full_path = '/public/groups/braingeneers/Fluidics/derived/' + uuid + '/parameters.txt' 3 | with open(full_path, 'r') as file: 4 | data = file.read().replace('\n', '') 5 | return data 6 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | # braingeneers 2 | 3 | ```{toctree} 4 | :maxdepth: 2 5 | :hidden: 6 | 7 | ``` 8 | 9 | ```{include} ../README.md 10 | :start-after: 11 | ``` 12 | 13 | ## Indices and tables 14 | 15 | - {ref}`genindex` 16 | - {ref}`modindex` 17 | - {ref}`search` 18 | -------------------------------------------------------------------------------- /tests/test_s3wrangler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from braingeneers.utils import s3wrangler 4 | 5 | 6 | class S3WranglerUnitTest(unittest.TestCase): 7 | def test_online_s3wrangler(self): 8 | dir_list = s3wrangler.list_directories("s3://braingeneers/") 9 | self.assertTrue("s3://braingeneers/ephys/" in dir_list) 10 | 11 | 12 | if __name__ == "__main__": 13 | unittest.main() 14 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | sphinx: 11 | configuration: docs/source/conf.py 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - dev 19 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | { 3 | "image":"ghcr.io/braingeneers/research:latest", 4 | 5 | "customizations": { 6 | "vscode": { 7 | "extensions": [ 8 | "ms-toolsai.jupyter", 9 | "ms-python.python", 10 | "ms-vsliveshare.vsliveshare" 11 | ] 12 | } 13 | }, 14 | "postCreateCommand": "sh .devcontainer/post_create.sh" 15 | } 16 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | # Test locally and then from PRP 3 | BRAINGENEERS_ARCHIVE_PATH="./tests" python3 -B -m pytest -s 4 | python3 -B -m pytest -s 5 | 6 | sync: 7 | # Sync test files up to PRP to test local path vs. URL datasets 8 | aws --profile prp --endpoint https://s3.nautilus.optiputer.net \ 9 | s3 sync --delete \ 10 | ./tests/derived/test-datasets/ \ 11 | s3://braingeneers/archive/derived/test-datasets/ \ 12 | --acl public-read 13 | -------------------------------------------------------------------------------- /.devcontainer/post_create.sh: -------------------------------------------------------------------------------- 1 | # For writing commands that will be executed after the container is created 2 | 3 | # Uninstalls the braingeneerspy package (pre-installed in the research Docker image) from the environment 4 | python3 -m pip uninstall braingeneerspy 5 | 6 | # Installs a Python package located in the current directory in editable mode and includes all optional extras specified in the [all] section of braingeneers. 7 | python3 -m pip install -e ".[all]" 8 | -------------------------------------------------------------------------------- /src/braingeneers/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from spikedata import (DCCResult, SpikeData, best_effort_sample, 2 | burst_detection, cumulative_moving_average, 3 | fano_factors, pearson, population_firing_rate, 4 | randomize_raster, spike_time_tiling) 5 | 6 | from braingeneers.analysis.analysis import (NeuronAttributes, filter, 7 | load_spike_data, read_phy_files) 8 | -------------------------------------------------------------------------------- /src/braingeneers/utils/smart_open_braingeneers/__init__.py: -------------------------------------------------------------------------------- 1 | import braingeneers 2 | import smart_open 3 | 4 | 5 | braingeneers.set_default_endpoint() 6 | 7 | 8 | # noinspection PyProtectedMember 9 | def open(*args, **kwargs): 10 | """ 11 | Simple hand off to smart_open, this explicit handoff is required because open can reference a 12 | new function if configure.set_default_endpoint is called in the future. 13 | """ 14 | return braingeneers.utils.configure._open(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /src/braingeneers/utils/s3wrangler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extends the awswrangler.s3 package for Braingeneers/PRP access. 3 | See documentation: https://aws-data-wrangler.readthedocs.io/en/2.4.0-docs/api.html#amazon-s3 4 | 5 | Usage examples: 6 | import braingeneers.utils.s3wrangler as wr 7 | uuids = wr.list_directories('s3://braingeneers/ephys/') 8 | print(uuids) 9 | """ 10 | import awswrangler 11 | from awswrangler import config 12 | from awswrangler.s3 import * 13 | import braingeneers 14 | import botocore 15 | 16 | awswrangler.config.botocore_config = botocore.config.Config( 17 | retries={"max_attempts": 10}, 18 | connect_timeout=20, 19 | max_pool_connections=50) 20 | 21 | braingeneers.set_default_endpoint() 22 | -------------------------------------------------------------------------------- /.github/matchers/pylint.json: -------------------------------------------------------------------------------- 1 | { 2 | "problemMatcher": [ 3 | { 4 | "severity": "warning", 5 | "pattern": [ 6 | { 7 | "regexp": "^([^:]+):(\\d+):(\\d+): ([A-DF-Z]\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", 8 | "file": 1, 9 | "line": 2, 10 | "column": 3, 11 | "code": 4, 12 | "message": 5 13 | } 14 | ], 15 | "owner": "pylint-warning" 16 | }, 17 | { 18 | "severity": "error", 19 | "pattern": [ 20 | { 21 | "regexp": "^([^:]+):(\\d+):(\\d+): (E\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", 22 | "file": 1, 23 | "line": 2, 24 | "column": 3, 25 | "code": 4, 26 | "message": 5 27 | } 28 | ], 29 | "owner": "pylint-error" 30 | } 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /src/braingeneers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | # Deprecated import braingeneers.utils.messaging is allowed for backwards compatibility. 5 | # This code should be removed in the future. This was added 27apr2022 by David Parks. 6 | def __getattr__(name): 7 | if name == 'messaging': 8 | warnings.warn( 9 | message='braingeneers.utils.messaging has been deprecated, please import braingeneers.iot.messaging.', 10 | category=DeprecationWarning, 11 | ) 12 | from braingeneers.iot import messaging 13 | return messaging 14 | 15 | elif name == 'NumpyS3Memmap': 16 | warnings.warn( 17 | message='braingeneers.utils.NumpyS3Memmap has been deprecated, ' 18 | 'please import braingeneers.utils.numpy_s3_memmap.NumpyS3Memmap.', 19 | category=DeprecationWarning, 20 | ) 21 | from braingeneers.utils.numpy_s3_memmap import NumpyS3Memmap 22 | return NumpyS3Memmap 23 | 24 | else: 25 | raise AttributeError(name) 26 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib.metadata 4 | 5 | project = "braingeneers" 6 | copyright = "2023, Braingeneers" 7 | author = "Braingeneers" 8 | version = release = importlib.metadata.version("braingeneers") 9 | 10 | extensions = [ 11 | "myst_parser", 12 | "sphinx.ext.autodoc", 13 | "sphinx.ext.intersphinx", 14 | "sphinx.ext.mathjax", 15 | "sphinx.ext.napoleon", 16 | "sphinx_autodoc_typehints", 17 | "sphinx_copybutton", 18 | ] 19 | 20 | source_suffix = [".rst", ".md"] 21 | exclude_patterns = [ 22 | "_build", 23 | "**.ipynb_checkpoints", 24 | "Thumbs.db", 25 | ".DS_Store", 26 | ".env", 27 | ".venv", 28 | ] 29 | 30 | html_theme = "furo" 31 | 32 | myst_enable_extensions = [ 33 | "colon_fence", 34 | ] 35 | 36 | intersphinx_mapping = { 37 | "python": ("https://docs.python.org/3", None), 38 | } 39 | 40 | nitpick_ignore = [ 41 | ("py:class", "_io.StringIO"), 42 | ("py:class", "_io.BytesIO"), 43 | ] 44 | 45 | always_document_param_types = True 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2023, Braingeneers 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | -------------------------------------------------------------------------------- /tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import braingeneers.analysis as ba 6 | from braingeneers import skip_unittest_if_offline 7 | 8 | 9 | class TestSpikeDataLoaders(unittest.TestCase): 10 | def assertAll(self, bools, msg=None): 11 | "Assert that two arrays are equal elementwise." 12 | self.assertTrue(np.all(bools), msg=msg) 13 | 14 | @skip_unittest_if_offline 15 | def testSpikeAttributes(self): 16 | uuid = "2023-04-17-e-causal_v1" 17 | sd = ba.load_spike_data(uuid) 18 | self.assertTrue(isinstance(sd, ba.SpikeData)) 19 | r = sd.raster(1) 20 | rr = sd.randomized(1).raster(1) 21 | self.assertAll(r.sum(1) == rr.sum(1)) 22 | self.assertAll(r.sum(0) == rr.sum(0)) 23 | 24 | @skip_unittest_if_offline 25 | def testSpikeAttributesDiffSorter(self): 26 | uuid = "2023-04-17-e-causal_v1" 27 | exp = "data_phy.zip" 28 | sorter = "kilosort3" 29 | sd = ba.load_spike_data(uuid, exp, sorter=sorter) 30 | self.assertTrue(isinstance(sd, ba.SpikeData)) 31 | -------------------------------------------------------------------------------- /src/braingeneers/iot/example_device_main.py: -------------------------------------------------------------------------------- 1 | from example_device import ExampleDevice 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | 6 | parser = argparse.ArgumentParser(description="Command line tool for the ExampleDevice utility") 7 | # Adding arguments with default values and making them optional 8 | parser.add_argument('--device_name', type=str, required=False, default="spam", help='Name of device (default: spam)') 9 | parser.add_argument('--eggs', type=int, required=False, default=0, help='Starting quantity of eggs (default: 0)') 10 | parser.add_argument('--ham', type=int, required=False, default=0, help='Starting quantity of ham (default: 0)') 11 | parser.add_argument('--spam', type=int, required=False, default=1, help='Starting quantity of spam (default: 1)') 12 | 13 | args = parser.parse_args() 14 | 15 | # Create a device object 16 | device = ExampleDevice(device_name=args.device_name, eggs=args.eggs, ham=args.ham, spam=args.spam) 17 | 18 | # Start the device activities, running in a loop 19 | # Control + C should gracefully stop execution 20 | device.start_mqtt() 21 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - master 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | env: 14 | FORCE_COLOR: 3 15 | 16 | jobs: 17 | checks: 18 | name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} 19 | runs-on: ${{ matrix.runs-on }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | python-version: ["3.11"] 24 | runs-on: [ubuntu-latest] 25 | 26 | steps: 27 | - name: Get Credentials 28 | env: 29 | AWS_CREDENTIALS: | 30 | ${{ secrets.AWS_CREDENTIALS }} 31 | shell: bash 32 | run: | 33 | mkdir ~/.aws 34 | echo "$AWS_CREDENTIALS" > ~/.aws/credentials 35 | wc ~/.aws/credentials 36 | 37 | - uses: actions/checkout@v6 38 | with: 39 | fetch-depth: 0 40 | 41 | - uses: actions/setup-python@v6 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | allow-prereleases: true 45 | 46 | - name: Install Package 47 | run: python -m pip install .[dev] 48 | 49 | - name: Run Tests 50 | run: python -m pytest -ra --cov --durations=5 51 | 52 | all-pass: 53 | name: All CI checks passed 54 | runs-on: ubuntu-latest 55 | needs: checks 56 | steps: 57 | - name: Print Confirmation 58 | run: echo "All matrix tests passed!" 59 | -------------------------------------------------------------------------------- /src/braingeneers/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import platform 4 | import warnings 5 | 6 | 7 | # At import time add the Maxwell custom H5 compression plugin environment variable if it's not already set. 8 | # This is necessary to enable all V2 maxwell H5 datafiles to be readable. 9 | # We have included the compiled plugin binaries in the braingeneerspy source under braingeneerspy/data/mxw_h5_plugin/*. 10 | if 'HDF5_PLUGIN_PATH' not in os.environ: 11 | system = platform.system() 12 | machine = platform.machine() 13 | 14 | plugin_arch_dir = \ 15 | 'Linux' if system == 'Linux' else \ 16 | 'Windows ' if system == 'Windows' else \ 17 | 'Mac_arm64' if system == 'Darwin' and machine == 'arm64' else \ 18 | 'Mac_x86_64' if system == 'Darwin' and machine == 'x86_64' else \ 19 | None 20 | 21 | if plugin_arch_dir is None: 22 | warnings.warn(f'System [{system}] and machine [{machine}] architecture is not supported ' 23 | f'by the Maxwell HDF5 compression plugin. The Maxwell data reader will not ' 24 | f'work for V2 HDF5 files on this system.') 25 | else: 26 | os.environ['HDF5_PLUGIN_PATH'] = os.path.join( 27 | pathlib.Path(__file__).parent.resolve(), # path to this __init__.py file 28 | 'mxw_h5_plugin', # sudirectory where maxwell plugins are stored (for all architectures) 29 | plugin_arch_dir # architecture specific sudirectory where the system-specific plugin is stored 30 | ) 31 | -------------------------------------------------------------------------------- /tests/test_smart_open_braingeneers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tempfile 3 | import unittest 4 | 5 | import braingeneers 6 | import braingeneers.utils.smart_open_braingeneers as smart_open 7 | 8 | 9 | class SmartOpenTestCase(unittest.TestCase): 10 | test_bucket = "braingeneersdev" 11 | test_file = "test_file.txt" 12 | 13 | def test_online_smart_open_read(self): 14 | """Tests that a simple file open and read operation succeeds""" 15 | braingeneers.set_default_endpoint() # sets the default PRP endpoint 16 | s3_url = f"s3://{self.test_bucket}/{self.test_file}" 17 | with smart_open.open(s3_url, "r") as f: 18 | txt = f.read() 19 | 20 | self.assertEqual(txt, "Don't panic\n") 21 | 22 | @unittest.skipIf(sys.platform.startswith("win"), "TODO: Test is broken on Windows.") 23 | def test_local_path_endpoint(self): 24 | with tempfile.TemporaryDirectory(prefix="smart_open_unittest_") as tmp_dirname: 25 | with tempfile.NamedTemporaryFile( 26 | dir=tmp_dirname, prefix="temp_unittest" 27 | ) as tmp_file: 28 | tmp_file_name = tmp_file.name 29 | tmp_file.write(b"unittest") 30 | tmp_file.flush() 31 | 32 | braingeneers.set_default_endpoint(f"{tmp_dirname}/") 33 | with smart_open.open(tmp_file_name, mode="rb") as tmp_file_smart_open: 34 | self.assertEqual(tmp_file_smart_open.read(), b"unittest") 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tests/test_numpy_s3_memmap.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from braingeneers.utils.configure import skip_unittest_if_offline 6 | from braingeneers.utils.numpy_s3_memmap import NumpyS3Memmap 7 | 8 | 9 | class TestNumpyS3Memmap(unittest.TestCase): 10 | @skip_unittest_if_offline 11 | def test_numpy32memmap_online(self): 12 | """Note: this is an online test requiring access to the PRP/S3 braingeneersdev bucket.""" 13 | x = NumpyS3Memmap("s3://braingeneersdev/dfparks/test/test.npy") 14 | 15 | # Online test data at s3://braingeneersdev/dfparks/test/test.npy 16 | # array([[1., 2., 3.], 17 | # [4., 5., 6.]], dtype=float32) 18 | 19 | e = np.arange(1, 7, dtype=np.float32).reshape(2, 3) 20 | 21 | self.assertTrue(np.all(x[0] == e[0])) 22 | self.assertTrue(np.all(x[:, 0:2] == e[:, 0:2])) 23 | self.assertTrue(np.all(x[:, [0, 1]] == e[:, [0, 1]])) 24 | 25 | @skip_unittest_if_offline 26 | def test_online_in_the_wild_file(self): 27 | """ 28 | This test assumes online access. 29 | Specifically this test case found a bug in numpy arrays for fortran order. 30 | """ 31 | x = NumpyS3Memmap( 32 | "s3://braingeneersdev/ephys/2020-07-06-e-MGK-76-2614-Drug/numpy/" 33 | "well_A1_chan_group_idx_1_time_000.npy" 34 | ) 35 | self.assertEqual(x.shape, (3750000, 4)) 36 | 37 | all_data = x[:] 38 | self.assertEqual(all_data.shape, (3750000, 4)) 39 | 40 | 41 | if __name__ == "__main__": 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | 5 | repos: 6 | - repo: https://github.com/psf/black 7 | rev: "23.7.0" 8 | hooks: 9 | - id: black-jupyter 10 | 11 | - repo: https://github.com/asottile/blacken-docs 12 | rev: "1.15.0" 13 | hooks: 14 | - id: blacken-docs 15 | additional_dependencies: [black==23.7.0] 16 | 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: "v4.4.0" 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-case-conflict 22 | - id: check-merge-conflict 23 | - id: check-symlinks 24 | - id: check-yaml 25 | - id: debug-statements 26 | - id: end-of-file-fixer 27 | - id: mixed-line-ending 28 | - id: name-tests-test 29 | args: ["--pytest-test-first"] 30 | - id: requirements-txt-fixer 31 | - id: trailing-whitespace 32 | 33 | - repo: https://github.com/pre-commit/pygrep-hooks 34 | rev: "v1.10.0" 35 | hooks: 36 | - id: rst-backticks 37 | - id: rst-directive-colons 38 | - id: rst-inline-touching-normal 39 | 40 | - repo: https://github.com/pre-commit/mirrors-prettier 41 | rev: "v3.0.0" 42 | hooks: 43 | - id: prettier 44 | types_or: [yaml, markdown, html, css, scss, javascript, json] 45 | args: [--prose-wrap=always] 46 | 47 | - repo: https://github.com/astral-sh/ruff-pre-commit 48 | rev: "v0.0.277" 49 | hooks: 50 | - id: ruff 51 | args: ["--fix", "--show-fixes"] 52 | 53 | - repo: https://github.com/pre-commit/mirrors-mypy 54 | rev: "v1.4.1" 55 | hooks: 56 | - id: mypy 57 | files: src|tests 58 | args: [] 59 | additional_dependencies: 60 | - pytest 61 | 62 | - repo: https://github.com/codespell-project/codespell 63 | rev: "v2.2.5" 64 | hooks: 65 | - id: codespell 66 | 67 | - repo: https://github.com/shellcheck-py/shellcheck-py 68 | rev: "v0.9.0.5" 69 | hooks: 70 | - id: shellcheck 71 | 72 | - repo: local 73 | hooks: 74 | - id: disallow-caps 75 | name: Disallow improper capitalization 76 | language: pygrep 77 | entry: PyBind|Numpy|Cmake|CCache|Github|PyTest 78 | exclude: .pre-commit-config.yaml 79 | -------------------------------------------------------------------------------- /src/braingeneers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023 Braingeneers. All rights reserved. 3 | 4 | braingeneers: braingeneerspy 5 | """ 6 | 7 | from __future__ import annotations 8 | import warnings 9 | 10 | from importlib.metadata import version as _pkg_version, PackageNotFoundError 11 | 12 | from . import utils 13 | from .utils.configure import ( 14 | set_default_endpoint, 15 | get_default_endpoint, 16 | skip_unittest_if_offline, 17 | ) 18 | 19 | try: # preferred: read version from installed package metadata 20 | VERSION = _pkg_version("braingeneers") 21 | except PackageNotFoundError: # e.g. running from a source tree without an installed dist 22 | VERSION = "0.0.0.0000" 23 | 24 | __version__ = VERSION 25 | 26 | __all__ = ("set_default_endpoint", "get_default_endpoint", "skip_unittest_if_offline", "utils") 27 | 28 | # Deprecated imports are allowed for backwards compatibility. 29 | # This code should be removed in the future. This was added 27apr2022 by David Parks. 30 | def __getattr__(name): 31 | if name == 'neuron': 32 | warnings.warn( 33 | message='braingeneers.neuron has been deprecated, please import braingeneers.analysis.neuron.', 34 | category=DeprecationWarning, 35 | ) 36 | from braingeneers.analysis import neuron 37 | return neuron 38 | 39 | if name == 'datasets_electrophysiology': 40 | warnings.warn( 41 | message='braingeneers.datasets_electrophysiology has been deprecated, ' 42 | 'please import braingeneers.data.datasets_electrophysiology.', 43 | category=DeprecationWarning, 44 | ) 45 | from braingeneers.data import datasets_electrophysiology 46 | return datasets_electrophysiology 47 | 48 | if name == 'datasets_fluidics': 49 | warnings.warn( 50 | message='braingeneers.datasets_fluidics has been deprecated, ' 51 | 'please import braingeneers.data.datasets_fluidics.', 52 | category=DeprecationWarning, 53 | ) 54 | from braingeneers.data import datasets_fluidics 55 | return datasets_fluidics 56 | 57 | if name == 'datasets_imaging': 58 | warnings.warn( 59 | message='braingeneers.datasets_imaging has been deprecated, ' 60 | 'please import braingeneers.data.datasets_imaging.', 61 | category=DeprecationWarning, 62 | ) 63 | from braingeneers.data import datasets_imaging 64 | return datasets_imaging 65 | 66 | if name == 'datasets_neuron': 67 | warnings.warn( 68 | message='braingeneers.datasets_neuron has been deprecated, ' 69 | 'please import braingeneers.data.datasets_neuron.', 70 | category=DeprecationWarning, 71 | ) 72 | from braingeneers.data import datasets_neuron 73 | return datasets_neuron 74 | 75 | else: 76 | raise AttributeError(name) 77 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed 2 | description of best practices for developing scientific packages. 3 | 4 | [spc-dev-intro]: https://scientific-python-cookie.readthedocs.io/guide/intro 5 | 6 | # Quick development 7 | 8 | The fastest way to start with development is to use nox. If you don't have nox, 9 | you can use `pipx run nox` to run it without installing, or `pipx install nox`. 10 | If you don't have pipx (pip for applications), then you can install with with 11 | `pip install pipx` (the only case were installing an application with regular 12 | pip is reasonable). If you use macOS, then pipx and nox are both in brew, use 13 | `brew install pipx nox`. 14 | 15 | To use, run `nox`. This will lint and test using every installed version of 16 | Python on your system, skipping ones that are not installed. You can also run 17 | specific jobs: 18 | 19 | ```console 20 | $ nox -s lint # Lint only 21 | $ nox -s tests # Python tests 22 | $ nox -s docs -- serve # Build and serve the docs 23 | $ nox -s build # Make an SDist and wheel 24 | ``` 25 | 26 | Nox handles everything for you, including setting up an temporary virtual 27 | environment for each run. 28 | 29 | # Setting up a development environment manually 30 | 31 | You can set up a development environment by running: 32 | 33 | ```bash 34 | python3 -m venv .venv 35 | source ./.venv/bin/activate 36 | pip install -v -e .[dev] 37 | ``` 38 | 39 | If you have the 40 | [Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you 41 | can instead do: 42 | 43 | ```bash 44 | py -m venv .venv 45 | py -m install -v -e .[dev] 46 | ``` 47 | 48 | # Post setup 49 | 50 | You should prepare pre-commit, which will help you by checking that commits pass 51 | required checks: 52 | 53 | ```bash 54 | pip install pre-commit # or brew install pre-commit on macOS 55 | pre-commit install # Will install a pre-commit hook into the git repo 56 | ``` 57 | 58 | You can also/alternatively run `pre-commit run` (changes only) or 59 | `pre-commit run --all-files` to check even without installing the hook. 60 | 61 | # Testing 62 | 63 | Use pytest to run the unit checks: 64 | 65 | ```bash 66 | pytest 67 | ``` 68 | 69 | # Coverage 70 | 71 | Use pytest-cov to generate coverage reports: 72 | 73 | ```bash 74 | pytest --cov=braingeneers 75 | ``` 76 | 77 | # Building docs 78 | 79 | You can build the docs using: 80 | 81 | ```bash 82 | nox -s docs 83 | ``` 84 | 85 | You can see a preview with: 86 | 87 | ```bash 88 | nox -s docs -- serve 89 | ``` 90 | 91 | # Pre-commit 92 | 93 | This project uses pre-commit for all style checking. While you can run it with 94 | nox, this is such an important tool that it deserves to be installed on its own. 95 | Install pre-commit and run: 96 | 97 | ```bash 98 | pre-commit run -a 99 | ``` 100 | 101 | to check all files. 102 | -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - master 8 | - main 9 | - development 10 | release: 11 | types: 12 | - published 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | env: 19 | FORCE_COLOR: 3 20 | 21 | jobs: 22 | dist: 23 | name: Distribution build 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/checkout@v6 28 | with: 29 | fetch-depth: 0 30 | 31 | - name: Build sdist and wheel 32 | run: pipx run build 33 | 34 | - uses: actions/upload-artifact@v6 35 | with: 36 | path: dist 37 | 38 | - name: Check products 39 | run: pipx run twine check dist/* 40 | 41 | test-built-dist: 42 | needs: [dist] 43 | name: Test built distribution 44 | runs-on: ubuntu-latest 45 | permissions: 46 | id-token: write 47 | steps: 48 | - uses: actions/setup-python@v6 49 | name: Install Python 50 | with: 51 | python-version: '3.10' 52 | - uses: actions/download-artifact@v7 53 | with: 54 | name: artifact 55 | path: dist 56 | - name: List contents of built dist 57 | run: | 58 | ls -ltrh 59 | ls -ltrh dist 60 | - name: Publish to Test PyPI 61 | uses: pypa/gh-action-pypi-publish@v1.13.0 62 | with: 63 | repository-url: https://test.pypi.org/legacy/ 64 | verbose: true 65 | skip-existing: true 66 | - name: Check pypi packages 67 | run: | 68 | sleep 3 69 | python -m pip install --upgrade pip 70 | 71 | echo "=== Testing wheel file ===" 72 | # Install wheel to get dependencies and check import 73 | python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre braingeneers 74 | python -c "import braingeneers; print(braingeneers.__version__)" 75 | echo "=== Done testing wheel file ===" 76 | 77 | echo "=== Testing source tar file ===" 78 | # Install tar gz and check import 79 | python -m pip uninstall --yes braingeneers 80 | python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade --pre --no-binary=:all: braingeneers 81 | python -c "import braingeneers; print(braingeneers.__version__)" 82 | echo "=== Done testing source tar file ===" 83 | 84 | 85 | publish: 86 | needs: [dist, test-built-dist] 87 | name: Publish to PyPI 88 | environment: pypi 89 | permissions: 90 | id-token: write 91 | runs-on: ubuntu-latest 92 | if: github.event_name == 'release' && github.event.action == 'published' 93 | 94 | steps: 95 | - uses: actions/download-artifact@v7 96 | with: 97 | name: artifact 98 | path: dist 99 | 100 | - uses: pypa/gh-action-pypi-publish@v1.13.0 101 | if: startsWith(github.ref, 'refs/tags') 102 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import shutil 5 | from pathlib import Path 6 | 7 | import nox 8 | 9 | DIR = Path(__file__).parent.resolve() 10 | 11 | nox.options.sessions = ["lint", "pylint", "tests"] 12 | 13 | 14 | @nox.session 15 | def lint(session: nox.Session) -> None: 16 | """ 17 | Run the linter. 18 | """ 19 | session.install("pre-commit") 20 | session.run("pre-commit", "run", "--all-files", *session.posargs) 21 | 22 | 23 | @nox.session 24 | def pylint(session: nox.Session) -> None: 25 | """ 26 | Run PyLint. 27 | """ 28 | # This needs to be installed into the package environment, and is slower 29 | # than a pre-commit check 30 | session.install(".", "pylint") 31 | session.run("pylint", "src", *session.posargs) 32 | 33 | 34 | @nox.session 35 | def tests(session: nox.Session) -> None: 36 | """ 37 | Run the unit and regular tests. Use --cov to activate coverage. 38 | """ 39 | session.install(".[test]") 40 | session.run("pytest", *session.posargs) 41 | 42 | 43 | @nox.session 44 | def docs(session: nox.Session) -> None: 45 | """ 46 | Build the docs. Pass "--serve" to serve. 47 | """ 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--serve", action="store_true", help="Serve after building") 51 | parser.add_argument( 52 | "-b", dest="builder", default="html", help="Build target (default: html)" 53 | ) 54 | args, posargs = parser.parse_known_args(session.posargs) 55 | 56 | if args.builder != "html" and args.serve: 57 | session.error("Must not specify non-HTML builder with --serve") 58 | 59 | session.install(".[docs]") 60 | session.chdir("docs") 61 | 62 | if args.builder == "linkcheck": 63 | session.run( 64 | "sphinx-build", "-b", "linkcheck", ".", "_build/linkcheck", *posargs 65 | ) 66 | return 67 | 68 | session.run( 69 | "sphinx-build", 70 | "-n", # nitpicky mode 71 | "-T", # full tracebacks 72 | "-W", # Warnings as errors 73 | "--keep-going", # See all errors 74 | "-b", 75 | args.builder, 76 | ".", 77 | f"_build/{args.builder}", 78 | *posargs, 79 | ) 80 | 81 | if args.serve: 82 | session.log("Launching docs at http://localhost:8000/ - use Ctrl-C to quit") 83 | session.run("python", "-m", "http.server", "8000", "-d", "_build/html") 84 | 85 | 86 | @nox.session 87 | def build_api_docs(session: nox.Session) -> None: 88 | """ 89 | Build (regenerate) API docs. 90 | """ 91 | 92 | session.install("sphinx") 93 | session.chdir("docs") 94 | session.run( 95 | "sphinx-apidoc", 96 | "-o", 97 | "api/", 98 | "--module-first", 99 | "--no-toc", 100 | "--force", 101 | "../src/braingeneers", 102 | ) 103 | 104 | 105 | @nox.session 106 | def build(session: nox.Session) -> None: 107 | """ 108 | Build an SDist and wheel. 109 | """ 110 | 111 | build_p = DIR.joinpath("build") 112 | if build_p.exists(): 113 | shutil.rmtree(build_p) 114 | 115 | session.install("build") 116 | session.run("python", "-m", "build") 117 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | tags: 7 | - '*.*.*' # Github tag format for braingeneerpy repo formatted as in this example, "0.3.1" 8 | release: 9 | types: [published] 10 | pull_request: 11 | branches: [master] 12 | 13 | jobs: 14 | publish: 15 | runs-on: ubuntu-latest 16 | 17 | # Extra safety: only run when base is master 18 | if: > 19 | (github.event_name == 'push' && github.ref == 'refs/heads/master') || 20 | (github.event_name == 'pull_request' && github.event.pull_request.base.ref == 'master') 21 | 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v6 25 | with: 26 | fetch-depth: 0 # IMPORTANT: we need full history for tags + commit count 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v6 30 | with: 31 | python-version: "3.11" 32 | 33 | - name: Install build tools 34 | run: | 35 | python -m pip install --upgrade pip 36 | python -m pip install build 37 | 38 | - name: Get latest tag (X.Y.Z) 39 | id: tag 40 | run: | 41 | set -e 42 | # Use latest tag matching X.Y.Z; adjust pattern if you use "vX.Y.Z" tags instead 43 | if git describe --tags --match '*.*.*' >/dev/null 2>&1; then 44 | TAG=$(git describe --tags --match '*.*.*' --abbrev=0) 45 | else 46 | # Fallback if no tags yet 47 | TAG="0.0.0" 48 | fi 49 | echo "tag=$TAG" >> "$GITHUB_OUTPUT" 50 | 51 | - name: Compute version from tag and commit count 52 | id: version 53 | run: | 54 | set -e 55 | TAG="${{ steps.tag.outputs.tag }}" 56 | 57 | if [ "$TAG" = "0.0.0" ]; then 58 | COUNT=$(git rev-list --count HEAD) 59 | else 60 | COUNT=$(git rev-list --count "$TAG"..HEAD) 61 | fi 62 | 63 | VERSION="${TAG}.${COUNT}" 64 | 65 | echo "Using version: $VERSION" 66 | echo "version=$VERSION" >> "$GITHUB_OUTPUT" 67 | echo "VERSION=$VERSION" >> "$GITHUB_ENV" 68 | 69 | - name: Patch version into pyproject.toml 70 | run: | 71 | python - << 'PY' 72 | from pathlib import Path 73 | import os, re 74 | 75 | version = os.environ["VERSION"] 76 | path = Path("pyproject.toml") 77 | text = path.read_text(encoding="utf-8") 78 | 79 | # Replace the first 'version = "..."' line in [project] 80 | new_text, n = re.subn( 81 | r'(?m)^version\s*=\s*".*"', 82 | f'version = "{version}"', 83 | text, 84 | count=1, 85 | ) 86 | if n == 0: 87 | raise SystemExit("Did not find a version = \"...\" line to replace in pyproject.toml") 88 | 89 | path.write_text(new_text, encoding="utf-8") 90 | PY 91 | 92 | - name: Build package 93 | run: | 94 | python -m build 95 | 96 | - name: Publish to PyPI 97 | if: github.event_name == 'push' 98 | uses: pypa/gh-action-pypi-publish@release/v1 99 | with: 100 | password: ${{ secrets.PYPI_API_TOKEN }} 101 | -------------------------------------------------------------------------------- /tests/test_memoize_s3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | from unittest import mock 4 | 5 | import pytest 6 | 7 | from braingeneers.utils.configure import skip_unittest_if_offline 8 | from braingeneers.utils.memoize_s3 import memoize 9 | 10 | 11 | # These have to all ignore UserWarnings because joblib generates them whenever 12 | # the store backend takes more than a few hundred ms, which S3 often does. 13 | @pytest.mark.filterwarnings("ignore::UserWarning") 14 | class TestMemoizeS3(unittest.TestCase): 15 | @skip_unittest_if_offline 16 | def test(self): 17 | # Run these checks in a context where S3_USER is set. 18 | unique_user = f"unittest-{id(self)}" 19 | with mock.patch.dict("os.environ", {"S3_USER": unique_user}): 20 | # Memoize a function that counts its calls. 21 | @memoize()(ignore=["y"]) 22 | def foo(x, y): 23 | nonlocal side_effect 24 | side_effect += 1 25 | return x 26 | 27 | self.assertEqual( 28 | foo.store_backend.location, 29 | f"s3://braingeneersdev/{unique_user}/cache/joblib", 30 | ) 31 | 32 | # Call it a few times and make sure it only runs once. 33 | foo.clear() 34 | side_effect = 0 35 | for i in range(3): 36 | self.assertEqual(foo("bar", i), "bar") 37 | self.assertEqual(side_effect, 1) 38 | 39 | # Force it to run again and make sure it happens. 40 | foo("baz", 1) 41 | self.assertEqual(side_effect, 2) 42 | 43 | # Clean up by reaching into the cache and clearing the directory 44 | # without recreating the cache. This is important to avoid 45 | # cluttering with fake user directories after tests are done. 46 | foo.store_backend.clear() 47 | 48 | @skip_unittest_if_offline 49 | def test_uri_validation(self): 50 | # Our backend only supports S3 URIs. 51 | with self.assertRaises(ValueError): 52 | 53 | @memoize("this has to start with s3://") 54 | def foo(x): 55 | return x 56 | 57 | @skip_unittest_if_offline 58 | def test_cant_mmap(self): 59 | # We have to fail if memory mapping is requested because it's 60 | # impossible on S3. 61 | with self.assertRaises(ValueError): 62 | 63 | @memoize(mmap_mode=True) 64 | def foo(x): 65 | return x 66 | 67 | @skip_unittest_if_offline 68 | def test_bucket_existence(self): 69 | # Bucket existence should be checked at creation, and the user should get a 70 | # warning that we're falling back to local storage. 71 | with self.assertWarns(UserWarning): 72 | 73 | @memoize("s3://i-sure-hope-this-crazy-bucket-doesnt-exist/") 74 | def foo(x): 75 | return x 76 | 77 | @skip_unittest_if_offline 78 | def test_default_location(self): 79 | # Make sure a default location is correctly set when S3_USER is not. 80 | with mock.patch.dict("os.environ", {"S3_USER": ""}): 81 | 82 | @memoize() 83 | def foo(x): 84 | return x 85 | 86 | self.assertEqual( 87 | foo.store_backend.location, "s3://braingeneersdev/common/cache/joblib" 88 | ) 89 | 90 | @skip_unittest_if_offline 91 | def test_custom_location(self): 92 | # Make sure custom locations get set correctly. 93 | @memoize("s3://braingeneersdev/unittest/cache") 94 | def foo(x): 95 | return x 96 | 97 | self.assertEqual( 98 | foo.store_backend.location, "s3://braingeneersdev/unittest/cache/joblib" 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | unittest.main() 104 | -------------------------------------------------------------------------------- /src/braingeneers/utils/configure.py: -------------------------------------------------------------------------------- 1 | """ Global package functions and helpers for Braingeneers specific configuration and package management. """ 2 | import distutils.util 3 | import functools 4 | import os 5 | 6 | 7 | """ 8 | Preconfigure remote filesystem access 9 | Default S3 endpoint/bucket_name, note this can be overridden with the environment variable ENDPOINT_URL 10 | or by calling set_default_endpoint(...). Either an S3 endpoint or local path can be used. 11 | """ 12 | DEFAULT_ENDPOINT = "https://s3.braingeneers.gi.ucsc.edu" # PRP/S3/CEPH default 13 | CURRENT_ENDPOINT = None 14 | _open = None # a reference to the current smart_open function that is configured this should not be used directly 15 | 16 | def get_default_endpoint() -> str: 17 | """ 18 | Returns the current default (S3) endpoint. By default this will point to the standard 19 | S3 location where data files are stored. Use set_default_endpoint(...) to change. 20 | 21 | :return: str: the current endpoint 22 | """ 23 | if CURRENT_ENDPOINT is None: 24 | return DEFAULT_ENDPOINT 25 | return CURRENT_ENDPOINT 26 | 27 | 28 | def set_default_endpoint(endpoint: str = None, verify_ssl_cert: bool = True) -> None: 29 | """ 30 | Sets the default S3 endpoint and (re)configures braingeneers.utils.smart_open and 31 | braingeneers.utils.s3wrangler to utilize the new endpoint. This endpoint may also be set 32 | to the local filesystem with a relative or absolute local path. 33 | 34 | Examples: 35 | PRP/S3/CEPH: "https://s3-west.nrp-nautilus.io" 36 | PRP/S3/SeaweedFS: "https://swfs-s3.nrp-nautilus.io" 37 | Local Filesystem: "/home/user/project/files/" (absolute) or "project/files/" (relative) 38 | 39 | :param endpoint: S3 or local-path endpoint as shown in examples above, if None will look for ENDPOINT 40 | environment variable, then default to DEFAULT_ENDPOINT if not found. 41 | :param verify_ssl_cert: advanced option, should be True (default) unless there's a specific reason to disable 42 | it. An example use case: when using a proxy server this must be disabled. 43 | """ 44 | # lazy loading of imports is necessary so that we don't import these classes with braingeneers root 45 | # these imports can cause SSL warning messages for some users, so it's especially important to avoid 46 | # importing them unless S3 access is needed. 47 | import boto3 48 | import smart_open 49 | import awswrangler 50 | 51 | global _open 52 | endpoint = endpoint if endpoint is not None else os.environ.get('ENDPOINT', DEFAULT_ENDPOINT) 53 | 54 | # smart_open 55 | if endpoint.startswith('http'): 56 | transport_params = { 57 | 'client': boto3.Session().client('s3', endpoint_url=endpoint, verify=verify_ssl_cert), 58 | } 59 | _open = functools.partial(smart_open.open, transport_params=transport_params) 60 | else: 61 | _open = smart_open.open 62 | 63 | # s3wrangler - only update s3wrangler if the endpoint is S3 based, s3wrangler doesn't support local 64 | if endpoint.startswith('http'): 65 | awswrangler.config.s3_endpoint_url = endpoint 66 | 67 | global CURRENT_ENDPOINT 68 | CURRENT_ENDPOINT = endpoint 69 | 70 | 71 | def skip_unittest_if_offline(f): 72 | """ 73 | Decorator for unit tests which check if environment variable ONLINE_TESTS is set to "false". 74 | 75 | Usage example: 76 | -------------- 77 | import unittest 78 | 79 | class MyUnitTests(unittest.TestCase): 80 | @braingeneers.skip_if_offline() 81 | def test_online_features(self): 82 | self.assertTrue(do_something()) 83 | """ 84 | def wrapper(self, *args, **kwargs): 85 | allow_online_tests = bool(distutils.util.strtobool(os.environ.get('ONLINE_TESTS', 'true'))) 86 | if not allow_online_tests: 87 | self.skipTest() 88 | else: 89 | f(self, *args, **kwargs) 90 | return wrapper 91 | -------------------------------------------------------------------------------- /src/braingeneers/data/datasets_imaging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request, json 3 | from urllib.error import HTTPError 4 | 5 | from skimage import io 6 | 7 | 8 | camera_ids = [11, 12, 13, 14, 15, 16, 21, 22, 23, 24, 25, 26, 31, 32, 33, 34, 35, 36, 41, 42, 43, 44, 45, 46] 9 | 10 | def get_timestamps(uuid): 11 | with urllib.request.urlopen("https://s3.nautilus.optiputer.net/braingeneers/archive/"+uuid+ "/images/manifest.json") as url: 12 | data = json.loads(url.read().decode()) 13 | return data['captures'] 14 | 15 | camera_ids = [11, 12, 13, 14, 15, 16, 21, 22, 23, 24, 25, 26, 31, 32, 33, 34, 35, 36, 41, 42, 43, 44, 45, 46] 16 | 17 | 18 | def import_json(uuid): 19 | with urllib.request.urlopen("https://s3.nautilus.optiputer.net/braingeneers/archive/"+uuid+"/images/manifest.json") as url: 20 | json_file = json.loads(url.read().decode()) 21 | return json_file 22 | 23 | 24 | def save_images(uuid, timestamps = None, cameras=None , focal_lengths=None): 25 | if os.path.isdir(uuid): 26 | #the directory is already made 27 | pass 28 | else: 29 | os.mkdir(uuid) 30 | 31 | 32 | if os.path.isdir(uuid+'/images'): 33 | #the directory is already made 34 | pass 35 | else: 36 | os.mkdir(uuid+'/images') 37 | 38 | 39 | 40 | json_file = import_json(uuid) 41 | 42 | if type(timestamps) == int: 43 | actual_timestamps = [json_file['captures'][timestamps]] 44 | 45 | elif type(timestamps) == list: 46 | actual_timestamps = [json_file['captures'][x] for x in timestamps] 47 | 48 | elif timestamps == None: 49 | actual_timestamps = json_file['captures'] 50 | 51 | ################################################################ 52 | 53 | if type(cameras) == int: 54 | actual_cameras = [cameras] 55 | 56 | elif type(cameras) == list: 57 | actual_cameras = cameras 58 | 59 | elif cameras == None: 60 | actual_cameras = camera_ids 61 | 62 | ################################################################ 63 | actual_focal_lengths=[] 64 | if type(focal_lengths) == int: 65 | actual_focal_lengths = [focal_lengths] 66 | 67 | elif type(focal_lengths) == list: 68 | actual_focal_lengths = focal_lengths 69 | 70 | elif focal_lengths == None: 71 | actual_focal_lengths = list(range(1, json_file['stack_size']+1)) 72 | 73 | 74 | for timestamp in actual_timestamps: 75 | if os.path.isdir(uuid+'/images/'+str(timestamp)): 76 | #the directory is already made 77 | pass 78 | else: 79 | os.mkdir(uuid+'/images/'+str(timestamp)) 80 | 81 | for camera in actual_cameras: 82 | if os.path.isdir(uuid+'/images/'+str(timestamp)+"/cameraC" + str(camera)): 83 | #the directory is already made 84 | pass 85 | else: 86 | os.mkdir(uuid+'/images/'+str(timestamp)+"/cameraC" + str(camera)) 87 | 88 | for focal_length in actual_focal_lengths: 89 | 90 | full_path = "https://s3.nautilus.optiputer.net/braingeneers/archive/"+uuid+'/images/'+str(timestamp)+"/cameraC"\ 91 | + str(camera) +"/"+str(focal_length)+".jpg" 92 | 93 | io.imsave(uuid+'/images/'+str(timestamp)+"/cameraC"+ str(camera) +"/"+str(focal_length)+".jpg", io.imread(full_path)) 94 | 95 | print('Downloading image to: '+ str(uuid)+'/images/'+str(timestamp)+"/cameraC"\ 96 | + str(camera) +"/"+str(focal_length)+".jpg") 97 | 98 | try: 99 | io.imshow(io.imread(full_path)) 100 | 101 | except HTTPError as err: 102 | if err.code == 403: 103 | print("URL address is wrong.") 104 | 105 | 106 | return None 107 | 108 | 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | 7 | # Distribution / packaging 8 | .Python 9 | build/ 10 | develop-eggs/ 11 | dist/ 12 | downloads/ 13 | eggs/ 14 | .eggs/ 15 | lib/ 16 | lib64/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | share/python-wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | cover/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | */build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # For a library or package, you might want to ignore these files since the code is 85 | # intended to run in multiple environments; otherwise, check them in: 86 | .python-version 87 | uv.lock 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # poetry 97 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 98 | # This is especially recommended for binary packages to ensure reproducibility, and is more 99 | # commonly ignored for libraries. 100 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 101 | #poetry.lock 102 | 103 | # pdm 104 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 105 | #pdm.lock 106 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 107 | # in version control. 108 | # https://pdm.fming.dev/#use-with-ide 109 | .pdm.toml 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | # PyCharm 155 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 156 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 157 | # and can be added to the global gitignore or merged into this file. For a more nuclear 158 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 159 | #.idea/ 160 | 161 | # Version file 162 | **/_version.py 163 | 164 | .ipynb_checkpoints 165 | .coverage 166 | __pycache__ 167 | *.egg-info 168 | *#* 169 | .idea 170 | tmp/ 171 | *.pyc 172 | **/.DS_Store 173 | dist/ 174 | **/.vscode/** 175 | 176 | # don't commit the service_account file 177 | src/braingeneers/iot/service_account/config.json 178 | 179 | # data-lifecycle tmp 180 | data-lifecycle/tmp -------------------------------------------------------------------------------- /src/braingeneers/iot/example_device.py: -------------------------------------------------------------------------------- 1 | from device import Device 2 | 3 | # Or use this instead, if calling this in a repo outside braingeneerspy: 4 | # from braingeneers.iot import Device 5 | 6 | class ExampleDevice(Device): 7 | """ Example Device Class 8 | Demonstrates how to use and inherit the Device class for new application 9 | """ 10 | def __init__(self, device_name, eggs = 0, ham = 0, spam = 1): 11 | """Initialize the ExampleDevice class 12 | Args: 13 | device_name (str): name of the device 14 | ham (str): starting quantity of ham 15 | eggs (int): starting quantity of eggs 16 | spam (int): starting quantity of spam 17 | 18 | :param device_specific_handlers: dictionary that maps the device's command keywords 19 | to a function call that handles an incomming command. 20 | """ 21 | self.eggs = eggs 22 | self.ham = ham 23 | self.spam = spam 24 | 25 | super().__init__(device_name=device_name, device_type = "Other", primed_default=True) 26 | 27 | self.device_specific_handlers.update({ 28 | "ADD": self.handle_add, # new command to add any amount of eggs or ham 29 | "LIST": self.handle_list # new command to list current amount of eggs and ham by message 30 | }) 31 | return 32 | 33 | @property 34 | def device_state(self): 35 | """ 36 | Return the device state as a dictionary. This is used by the parent Device class to update the device shadow. 37 | Child can add any additional device specific state to the dictionary i.e., "EGGS", "HAM", and "SPAM" 38 | """ 39 | return { **super().device_state, 40 | "EGGS": self.eggs, 41 | "HAM": self.ham, 42 | "SPAM": self.spam 43 | } 44 | 45 | def is_primed(self): 46 | """ 47 | Modify this function if your device requires a physical prerequsite. 48 | In Parent initialization, when primed_default=True no physical prerequsite is required. 49 | 50 | If a physical prerequsite is required, set primed_default=False and modify this function to check for a condition to be met to return True. 51 | 52 | For example, you may wait for a hardware button press confirming that a physical resource is attached (i.e., a consumable, like fresh media) before allowing 53 | the device to be used in an experiment. 54 | 55 | This function should not perform any blocking/loops because it is checked periodically by the parent loop in "IDLE" state! 56 | """ 57 | return self.primed_default 58 | 59 | 60 | def handle_add(self, topic, message): 61 | """ 62 | Function to handle the ADD command. This function is called by the parent Device class when an ADD command is received. 63 | Args: 64 | topic (str): topic of the received message 65 | message (dict): message received by the device 66 | ADD assumes that the message contains the keys "EGGS", "HAM", "SPAM", and adds the values to the device's state. 67 | """ 68 | try: 69 | self.eggs += message["EGGS"] 70 | self.ham += message["HAM"] 71 | self.spam += message["SPAM"] 72 | self.update_state(self.state) # to update eggs and ham change ASAP in shadow 73 | except: 74 | self.mb.publish_message(topic= self.generate_response_topic("ADD", "ERROR"), 75 | message= { "COMMAND": "ADD-ERROR", 76 | "ERROR": f"Missing argument for EGGS, HAM, or SPAM"}) 77 | return 78 | 79 | def handle_list(self, topic, message): 80 | """ 81 | Function to handle the LIST command. This function is called by the parent Device class when an LIST command is received. 82 | Args: 83 | topic (str): topic of the received message 84 | message (dict): message received by the device 85 | LIST responds with a message containing the current values for "EGGS", "HAM", and "SPAM". 86 | """ 87 | self.mb.publish_message(topic=self.generate_response_topic("LIST", "RESPONSE"), 88 | message= { "COMMAND": "LIST-RESPONSE", 89 | "EGGS": self.eggs, 90 | "HAM" : self.ham, 91 | "SPAM" : self.spam }) 92 | return -------------------------------------------------------------------------------- /src/braingeneers/iot/shadows_dev_playground.py: -------------------------------------------------------------------------------- 1 | ''' 2 | How do we want this to flow 3 | 4 | Create interaction thing 5 | - only happens once, when a new device is added to the system 6 | - contains important imaging data 7 | 8 | Create Experiment 9 | - happens when a new experiment is started 10 | - Creates a plate or assigns existing plate to experiment 11 | - assigns current experiment to interaction thing? 12 | 13 | Create Plate 14 | - Plate spawns wells, wells don't exist outside of plates 15 | - Plates have interaction things 16 | - Imaging metadata should be transferred to the plate from the interaction thing 17 | 18 | Starting an image run 19 | - Image uuid must be passed to plate object 20 | - should not be able to start image run if interaction thing has no plate 21 | 22 | 23 | ''' 24 | 25 | 26 | import shadows as sh 27 | import json 28 | 29 | from credentials import API_KEY 30 | 31 | # endpoint = "http://localhost:1337/api" 32 | # endpoint = "http://braingeneers.gi.ucsc.edu:1337/api" 33 | # ENDPOINT = "http://shadows-db.braingeneers.gi.ucsc.edu/api" 34 | 35 | # token = API_KEY 36 | # Create a shadow object 37 | # instance = sh.DatabaseInteractor(overwrite_endpoint = ENDPOINT, overwrite_api_key=token) 38 | instance = sh.DatabaseInteractor() 39 | # print(json.dumps(instance.get_device_state(13), indent=4)) 40 | # print(instance.list_experiments()) 41 | # print(instance.get_device_state_by_name("Evee")) 42 | # print(instance.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]=BioPlateScope")) 43 | print(instance.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]=BioPlateScope")) 44 | 45 | instance.empty_trash() 46 | 47 | thing1 = instance.create_interaction_thing("Other", "delete_test") 48 | # shadow = {"params": { 49 | # "interval": 1, 50 | # "stack_size": 15, 51 | # "stack_offset": 750, 52 | # "step_size": 50, 53 | # "camera_params": "-t 4000 -awb off -awbg 1,1 -o", 54 | # "light_mode": "Above" 55 | # }, 56 | # "connected": 1691797301180, 57 | # "timestamp": "2023-08-11T23:19:56", 58 | # "uuid": "2023-09-09-i-test", 59 | # "time": { 60 | # "elapsed": 97924, 61 | # "seconds": 97.924, 62 | # "minutes": 1.6320666666666668 63 | # }, 64 | # "num_cameras": 9, 65 | # "last-upload": "", 66 | # "experiment-state": "engaged", 67 | # "group-id": "I" 68 | # } 69 | # thing1.add_to_shadow(shadow) 70 | print(json.dumps(thing1.to_json(), indent=4)) 71 | 72 | thing1.move_to_trash() 73 | 74 | print(instance.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]=Other")) 75 | 76 | print(json.dumps(thing1.to_json(), indent=4)) 77 | 78 | # thing1.recover_from_trash() 79 | 80 | print(instance.list_objects_with_name_and_id("interaction-things", "?filters[type][$eq]=Other")) 81 | 82 | # plate = instance.create_plate("ephys-tester",1,1) 83 | # plate.add_thing(thing1) 84 | # plate.add_entry_to_ephys_params(uuid="2023-04-17-e-causal_v1", channels="362-71-31-338-231-315-309-234", timestamp="20230419_153528", data_length="-1") 85 | # thing2 = instance.create_interaction_thing("BioPlateScope", "Evee") 86 | # experiment1 = instance.create_experiment("Feed-Frequency-06-26-2022","Feed frequency experiment") 87 | # experiment2 = instance.create_experiment("Connectoids-06-26-2022","Feed frequency experiment") 88 | # uuids1 = {"uuids": 89 | # { 90 | # "2022-06-26-i-feed-frequency-4": "G", 91 | # "2022-06-28-i-feed-frequency-5": "G" 92 | # } 93 | # } 94 | # uuids2 = {"uuids": 95 | # { 96 | # "2022-06-28-i-connectoid" : "C", 97 | # "2022-06-28-i-connectoid-2":"C", 98 | # "2022-06-29-i-connectoids":"C", 99 | # "2022-06-29-i-connectoid-2":"C", 100 | # "2022-07-11-i-connectoid-3":"C" 101 | # } 102 | # } 103 | # plate1 = instance.create_plate("Fluidic-24-well-06-26-2022",4,6,uuids1) 104 | # plate2 = instance.create_plate("Connectoid-plate-06-26-2022",4,6,uuids2) 105 | # experiment1.add_plate(plate1) 106 | # experiment2.add_plate(plate2) 107 | # thing1.set_current_experiment(experiment1) 108 | # thing2.set_current_experiment(experiment2) 109 | # thing1.set_current_plate(plate1) 110 | # thing2.set_current_plate(plate2) 111 | 112 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | 6 | [project] 7 | version = "0.3.1.0000" # NOTE: This is a placeholder; CI overwrites it based on git tags + commit count. It's set only to the last manual push. 8 | name = "braingeneers" 9 | authors = [ 10 | { name = "UCSC Braingeneers", email = "braingeneers-admins-group@ucsc.edu" }, 11 | ] 12 | maintainers = [ 13 | { name = "David Parks", email = "dfparks@ucsc.edu" }, 14 | { name = "Kate Voitiuk", email = "kvoitiuk@ucsc.edu" }, 15 | ] 16 | description = "Braingeneers Python utilities" 17 | readme = "README.md" 18 | requires-python = ">=3.10" 19 | classifiers = [ 20 | "Development Status :: 3 - Alpha", 21 | "Intended Audience :: Science/Research", 22 | "Intended Audience :: Developers", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3 :: Only", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Topic :: Scientific/Engineering", 31 | "Typing :: Typed", 32 | ] 33 | dependencies = [ 34 | "awswrangler==3.*", 35 | "boto3==1.35.95", # until https://github.com/boto/boto3/issues/4398 is fixed 36 | "braingeneers-smart-open==2023.10.6", 37 | "deprecated", 38 | "h5py", 39 | "matplotlib", 40 | "nptyping", 41 | "numpy", 42 | "paho-mqtt>=2", 43 | "pandas", 44 | "redis", 45 | "requests", 46 | "schedule", 47 | "scipy", 48 | "spikedata", 49 | "tenacity", 50 | "typing_extensions>=4.6; python_version<'3.11'", 51 | "diskcache", 52 | "pytz", 53 | "tzlocal", 54 | ] 55 | 56 | [project.optional-dependencies] 57 | all = [ 58 | "braingeneers[ml]", 59 | "braingeneers[dev]", 60 | ] 61 | ml = [ 62 | "torch", 63 | "scikit-learn", 64 | ] 65 | dev = [ 66 | "pytest >=6", 67 | "pytest-cov >=3", 68 | "sphinx>=4.0", 69 | "myst_parser>=0.13", 70 | "sphinx_book_theme>=0.1.0", 71 | "sphinx_copybutton", 72 | "sphinx_autodoc_typehints", 73 | "furo", 74 | "joblib", 75 | ] 76 | 77 | [dependency-groups] 78 | dev = [ 79 | "braingeneers[all]", 80 | ] 81 | 82 | [project.urls] 83 | Homepage = "https://github.com/braingeneers/braingeneerspy" 84 | "Bug Tracker" = "https://github.com/braingeneers/braingeneerspy/issues" 85 | Discussions = "https://github.com/braingeneers/braingeneerspy/discussions" 86 | Changelog = "https://github.com/braingeneers/braingeneerspy/releases" 87 | 88 | [tool.hatch] 89 | metadata.allow-direct-references = true 90 | envs.default.dependencies = [ 91 | "pytest", 92 | "pytest-cov", 93 | ] 94 | 95 | [tool.hatch.build.targets.wheel] 96 | packages = ["src/braingeneers"] 97 | 98 | [tool.pytest.ini_options] 99 | minversion = "6.0" 100 | addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] 101 | xfail_strict = true 102 | filterwarnings = [ 103 | "error", 104 | "ignore::DeprecationWarning" 105 | ] 106 | log_cli_level = "INFO" 107 | testpaths = [ 108 | "tests", 109 | ] 110 | 111 | 112 | [tool.coverage] 113 | run.source = ["braingeneers"] 114 | port.exclude_lines = [ 115 | "pragma: no cover", 116 | "\\.\\.\\.", 117 | "if typing.TYPE_CHECKING:", 118 | ] 119 | 120 | [tool.mypy] 121 | files = ["src", "tests"] 122 | python_version = "3.10" 123 | warn_unused_configs = true 124 | strict = true 125 | show_error_codes = true 126 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] 127 | warn_unreachable = true 128 | disallow_untyped_defs = false 129 | disallow_incomplete_defs = false 130 | 131 | [[tool.mypy.overrides]] 132 | module = "braingeneers.*" 133 | disallow_untyped_defs = true 134 | disallow_incomplete_defs = true 135 | 136 | 137 | [tool.ruff] 138 | select = [ 139 | "E", "F", "W", # flake8 140 | "B", # flake8-bugbear 141 | "I", # isort 142 | "ARG", # flake8-unused-arguments 143 | "C4", # flake8-comprehensions 144 | "EM", # flake8-errmsg 145 | "ICN", # flake8-import-conventions 146 | "ISC", # flake8-implicit-str-concat 147 | "G", # flake8-logging-format 148 | "PGH", # pygrep-hooks 149 | "PIE", # flake8-pie 150 | "PL", # pylint 151 | "PT", # flake8-pytest-style 152 | "PTH", # flake8-use-pathlib 153 | "RET", # flake8-return 154 | "RUF", # Ruff-specific 155 | "SIM", # flake8-simplify 156 | "T20", # flake8-print 157 | "UP", # pyupgrade 158 | "YTT", # flake8-2020 159 | "EXE", # flake8-executable 160 | "NPY", # NumPy specific rules 161 | "PD", # pandas-vet 162 | ] 163 | extend-ignore = [ 164 | "PLR", # Design related pylint codes 165 | "E501", # Line too long 166 | ] 167 | typing-modules = ["braingeneers._compat.typing"] 168 | src = ["src"] 169 | unfixable = [ 170 | "T20", # Removes print statements 171 | "F841", # Removes unused variables 172 | ] 173 | exclude = [] 174 | flake8-unused-arguments.ignore-variadic-names = true 175 | isort.required-imports = ["from __future__ import annotations"] 176 | 177 | [tool.ruff.per-file-ignores] 178 | "tests/**" = ["T20"] 179 | "noxfile.py" = ["T20"] 180 | 181 | 182 | [tool.pylint] 183 | py-version = "3.10" 184 | ignore-paths = [] 185 | reports.output-format = "colorized" 186 | similarities.ignore-imports = "yes" 187 | messages_control.disable = [ 188 | "design", 189 | "fixme", 190 | "line-too-long", 191 | "missing-module-docstring", 192 | "wrong-import-position", 193 | ] 194 | 195 | [tool.uv.sources] 196 | braingeneers = { workspace = true } 197 | -------------------------------------------------------------------------------- /src/braingeneers/iot/authenticate.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import webbrowser 5 | import importlib.resources 6 | import configparser 7 | import datetime 8 | import requests 9 | import argparse 10 | 11 | 12 | def authenticate_and_get_token(): 13 | """ 14 | Directs users to a URL to authenticate and get a JWT token. 15 | Once the token has been obtained manually it will refresh automatically every month. 16 | By default, the token is valid for 4 months from issuance. 17 | Returns token data as a dict containing `access_token` and `expires_at` keys. 18 | """ 19 | PACKAGE_NAME = "braingeneers.iot" 20 | 21 | url = 'https://service-accounts.braingeneers.gi.ucsc.edu/generate_token' 22 | print(f'Please visit the following URL to generate your JWT token: {url}') 23 | webbrowser.open(url) 24 | 25 | token_json = input('Please paste the JSON token issued by the page and press Enter:\n') 26 | try: 27 | token_data = json.loads(token_json) 28 | except json.JSONDecodeError: 29 | raise ValueError('Invalid JSON. Please make sure you have copied the token correctly.') 30 | 31 | config_dir = os.path.join(importlib.resources.files(PACKAGE_NAME), 'service_account') 32 | os.makedirs(config_dir, exist_ok=True) 33 | config_file = os.path.join(config_dir, 'config.json') 34 | 35 | with open(config_file, 'w') as f: 36 | json.dump(token_data, f) 37 | 38 | print('Token has been saved successfully.') 39 | return token_data 40 | 41 | 42 | def update_config_file(file_path, section, key, new_value): 43 | with open(file_path, 'r') as file: 44 | lines = file.readlines() 45 | 46 | with open(file_path, 'w') as file: 47 | section_found = False 48 | for line in lines: 49 | if line.strip() == f'[{section}]': 50 | section_found = True 51 | if section_found and line.strip().startswith(key): 52 | line = f'{key} = {new_value}\n' 53 | section_found = False # Reset the flag 54 | file.write(line) 55 | 56 | 57 | def picroscope_authenticate_and_update_token(credentials_file): 58 | """ 59 | Authentication and update service-account token for legacy picroscope environment. This updates the AWS credentials file 60 | with the JWT token and updates it if it has <3 months before expiration. This function can be run as a cron job. 61 | """ 62 | # Check if the JWT token exists and if it exists in the credentials file if it's expired. 63 | # The credentials file section is [strapi] with `api_key` containing the jwt token, and `api_key_expires` containing 64 | # the expiration date in ISO format. 65 | config_file_path = os.path.expanduser(credentials_file) 66 | 67 | config = configparser.ConfigParser() 68 | with open(config_file_path, 'r') as f: 69 | config.read_string(f.read()) 70 | 71 | assert 'strapi' in config, \ 72 | 'Your AWS credentials file is missing a section [strapi], you may have the wrong version of the credentials file.' 73 | 74 | token_exists = 'api_key' in config['strapi'] 75 | expire_exists = 'api_key_expires' in config['strapi'] 76 | 77 | if expire_exists: 78 | expiration_str = config['strapi']['api_key_expires'] 79 | expiration_str = expiration_str.split(' ')[0] + ' ' + expiration_str.split(' ')[1] # Remove timezone 80 | expiration_date = datetime.datetime.fromisoformat(expiration_str) 81 | days_remaining = (expiration_date - datetime.datetime.now()).days 82 | print('Days remaining for token:', days_remaining) 83 | else: 84 | days_remaining = -1 85 | 86 | # check if api_key_expires exists, if not, it's expired, else check if it has <90 days remaining on it 87 | manual_refresh = not token_exists \ 88 | or not expire_exists \ 89 | or (datetime.datetime.fromisoformat(config['strapi']['api_key_expires']) - datetime.datetime.now()).days < 0 90 | auto_refresh = (token_exists and expire_exists) \ 91 | and (datetime.datetime.fromisoformat(config['strapi']['api_key_expires']) - datetime.datetime.now()).days < 90 92 | 93 | if manual_refresh or auto_refresh: 94 | token_data = authenticate_and_get_token() if manual_refresh else requests.get(url).json() 95 | update_config_file(config_file_path, 'strapi', 'api_key', token_data['access_token']) 96 | update_config_file(config_file_path, 'strapi', 'api_key_expires', token_data['expires_at']) 97 | print(f'JWT token has been updated in {config_file_path}') 98 | else: 99 | print('JWT token is still valid, no action taken.') 100 | 101 | 102 | def parse_args(): 103 | """ 104 | Two commands are available: 105 | 106 | # Authenticate and obtain a JWT service account token for braingeneerspy 107 | python -m braingeneers.iot.authenticate 108 | 109 | # Authenticate and obtain a JWT service account token for picroscope specific environment 110 | python -m braingeneers.iot.authenticate picroscope 111 | """ 112 | parser = argparse.ArgumentParser(description='JWT Service Account Token Management') 113 | parser.add_argument('config', nargs='?', choices=['picroscope'], help='Picroscope specific JWT token configuration.') 114 | parser.add_argument('--credentials', default='~/.aws/credentials', help='Path to the AWS credentials file, only used for picroscope authentication.') 115 | 116 | return parser.parse_args() 117 | 118 | 119 | def main(): 120 | args = parse_args() 121 | 122 | if args.config == 'picroscope': 123 | credentials_file = args.credentials 124 | picroscope_authenticate_and_update_token(credentials_file) 125 | else: 126 | authenticate_and_get_token() 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /src/braingeneers/utils/numpy_s3_memmap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ast 3 | import boto3 4 | import os 5 | import logging 6 | import urllib.parse 7 | import numpy as np 8 | from tenacity import * 9 | import braingeneers 10 | 11 | logging.basicConfig(stream=sys.stdout, level=logging.WARNING) 12 | logger = logging.getLogger(__name__) 13 | _s3client = None 14 | braingeneers.set_default_endpoint() 15 | 16 | 17 | # todo list: 18 | # 1) switch to using smart_open and support local or remote files 19 | 20 | 21 | class NumpyS3Memmap: 22 | """ 23 | Provides array slicing over a numpy file on S3 24 | Uses os.environ['ENDPOINT_URL'] and default boto3 credential lookup 25 | 26 | The class 27 | 28 | Example use: 29 | # Open a remote numpy file and query it's size 30 | > from apps.utils import NumpyS3Memmap 31 | > data = NumpyS3Memmap('s3://braingeneersdev/dfparks/test/test.npy') 32 | > data.shape 33 | (2, 3) 34 | 35 | # Read a slice using standard numpy slicing syntax 36 | > data[0, :] 37 | array([1., 2., 3.], dtype=float32) 38 | 39 | # If your slice is not contiguous multiple HTTP requests will be made but you'll be warned 40 | > data[:, [1,2]] 41 | WARNING:apps.utils:2 separate requests to S3 are being performed for this slice. 42 | array([[2., 3.], 43 | [5., 6.]], dtype=float32) 44 | 45 | # Read the full ndarray using [:], this does not flatten the data as it does in normal numpy indexing 46 | > data[:] 47 | array([[1., 2., 3.], 48 | [4., 5., 6.]], dtype=float32) 49 | 50 | Available properties: 51 | bucket S3 bucket name from URL 52 | key S3 key from URL 53 | dtype dtype of numpy file 54 | shape shape of numpy file 55 | fortran_order boolean, whether the array data is Fortran-contiguous or not 56 | """ 57 | 58 | def __init__(self, s3url, warning_s3_calls=1): 59 | """ 60 | :param s3url: An S3 url, example: s3://braingeneersdev/dfparks/test/test.npy 61 | :param warning_s3_calls: max number of calls to S3 before a warning is issued 62 | """ 63 | self.s3client = boto3.client( 64 | 's3', endpoint_url=os.environ.get('ENDPOINT_URL', 'https://s3.nautilus.optiputer.net') 65 | ) 66 | self.warning_s3_calls = warning_s3_calls 67 | 68 | # Parse S3 URL 69 | o = urllib.parse.urlparse(s3url) 70 | self.bucket = o.netloc 71 | self.key = o.path[1:] 72 | 73 | # Read numpy header, get shape, dtype, and order 74 | numpy_header = read_s3_bytes(self.bucket, self.key, 0, 128) # initial guess at size, likely correct. 75 | assert numpy_header[:6] == b'\x93NUMPY', 'File {} not in numpy format.'.format(self.key) 76 | self.header_size = np.frombuffer(numpy_header[8:10], dtype=np.int16)[0] + 10 77 | if self.header_size != 128: # re-read numpy header if we guessed wrong on the size 78 | numpy_header = read_s3_bytes(self.bucket, self.key, 0, self.header_size) 79 | header_dict = ast.literal_eval(numpy_header[10:].decode('utf-8')) # parse the header information 80 | self.dtype = np.dtype(header_dict['descr']) 81 | self.fortran_order = header_dict['fortran_order'] 82 | self.shape = header_dict['shape'] 83 | 84 | def __getitem__(self, item): 85 | # This is a naive indexing approach but it works. 86 | # Better options seem to involve reimplementing all numpy indexing options which is quite a rat hole to go down. 87 | dummy = np.arange(np.prod(self.shape), dtype=np.int64).reshape( 88 | self.shape, order='F' if self.fortran_order else 'C' 89 | ) 90 | ixs_tensor = dummy[item] 91 | shape = ixs_tensor.shape 92 | ixs = ixs_tensor.T.flatten() if self.fortran_order else ixs_tensor.flatten() 93 | splits = np.where(ixs[1:] - ixs[:-1] > 1)[0] + 1 94 | read_sets = np.split(ixs, splits) 95 | 96 | # Compute raw byte offsets, adjusting for data size and header length 97 | read_from_to = [ 98 | ( 99 | r[0] * self.dtype.itemsize + self.header_size, 100 | (r[-1] + 1) * self.dtype.itemsize + self.header_size 101 | ) 102 | for r in read_sets 103 | ] 104 | 105 | # Warn if too many requests are made to S3 106 | if len(read_from_to) > self.warning_s3_calls: 107 | logger.warning('{} separate requests to S3 are being performed for this slice.'.format(len(read_from_to))) 108 | 109 | # Read raw bytes and concatenate 110 | b = b''.join([ 111 | read_s3_bytes(self.bucket, self.key, offset_from, offset_to) 112 | for offset_from, offset_to in read_from_to 113 | ]) 114 | 115 | # Convert bytes to numpy dtype and reshape 116 | arr = np.frombuffer(b, dtype=self.dtype).reshape(shape, order='F' if self.fortran_order else 'C') 117 | 118 | return arr 119 | 120 | 121 | # @retry(wait=wait_exponential(multiplier=1/(2**5), max=30), after=after_log(logger, logging.WARNING)) 122 | def read_s3_bytes(bucket, key, bytes_from=None, bytes_to=None, s3client=None): 123 | """ 124 | Performs S3 request, bytes_to is exclusive, using python standard indexing. 125 | 126 | :param bucket: S3 bucket name, example: 'briangeneersdev' 127 | :param key: S3 key, example 'somepath/somefile.txt' 128 | :param bytes_from: starting read byte, or None for the full file. Must be 0 to read from 0 to part of a file. 129 | :param bytes_to: ending read byte + 1 (python standard indexing), 130 | example: (0, 10) reads the first 10 bytes, (10, 20) would read the next 10 bytes of a file. 131 | :param s3client: If s3client is not passed a single client will be instantiated lazily at the 132 | module level and re-used for all requests. 133 | :return: raw bytes as a byte array (b''). 134 | """ 135 | if s3client is None: 136 | global _s3client 137 | if _s3client is None: 138 | _s3client = boto3.client('s3', endpoint_url=braingeneers.get_default_endpoint()) 139 | s3client = _s3client 140 | 141 | rng = '' if bytes_from is None else 'bytes={}-{}'.format(bytes_from, bytes_to - 1) 142 | s3obj = s3client.get_object(Bucket=bucket, Key=key, Range=rng) 143 | b = s3obj['Body'].read() 144 | return b 145 | -------------------------------------------------------------------------------- /tests/test_data/maxwell-metadata.old.json: -------------------------------------------------------------------------------- 1 | { 2 | "uuid": "2023-08-12-e-GABA_v0", 3 | "timestamp": "", 4 | "maxwell_chip_id": "20325, 20247, 20402", 5 | "notes": { 6 | "purpose": "GABA dose response sweep. Gaba added X:55-X+1:~15 for x=2,4,... Started at .5uM, doubled each time (last was 14 not 16)", 7 | "biology": { 8 | "sample_type": "organoid", 9 | "aggregation_date": "23-06-22", 10 | "plating_date": "23-07-17", 11 | "species": "mouse", 12 | "cell_line": "e14", 13 | "genotype": "", 14 | "culture_media": "gfcdm", 15 | "organoid_tracker_sr_no": "", 16 | "organoid modification": "", 17 | "organoid modification date": "" 18 | }, 19 | "stimulation": { 20 | "optogenetic": false, 21 | "electrical": false, 22 | "pharmacological": "GABA" 23 | }, 24 | "automated_feeding": false, 25 | "automated_imaging": false, 26 | "comments": "" 27 | }, 28 | "ephys_experiments": { 29 | "data_GABA_BL_20325": { 30 | "name": "data_GABA_BL_20325", 31 | "hardware": "Maxwell", 32 | "maxwell_chip_number": "N/A", 33 | "maxwell_tracker_sr_no": "xxxxxx", 34 | "channels": [], 35 | "notes": "", 36 | "num_channels": 819, 37 | "num_current_input_channels": 0, 38 | "num_voltage_channels": 32, 39 | "offset": 0, 40 | "sample_rate": 20000, 41 | "voltage_scaling_factor": 1, 42 | "timestamp": "2023-08-12 T15:01:19;", 43 | "units": "\u00b5V", 44 | "version": 20190530, 45 | "blocks": [ 46 | { 47 | "num_frames": 6001200, 48 | "path": "original/data/data_GABA_BL_20325.raw.h5", 49 | "timestamp": "2023-08-12 T15:01:19;" 50 | } 51 | ] 52 | }, 53 | "Trace_20230812_17_00_14_GABA_BL_20247": { 54 | "name": "Trace_20230812_17_00_14_GABA_BL_20247", 55 | "hardware": "Maxwell", 56 | "maxwell_chip_number": "N/A", 57 | "maxwell_tracker_sr_no": "xxxxxx", 58 | "channels": [], 59 | "notes": "", 60 | "num_channels": 944, 61 | "num_current_input_channels": 0, 62 | "num_voltage_channels": 32, 63 | "offset": 0, 64 | "sample_rate": 20000, 65 | "voltage_scaling_factor": 1, 66 | "timestamp": "2023-08-12 T17:00:41;", 67 | "units": "\u00b5V", 68 | "version": 20190530, 69 | "blocks": [ 70 | { 71 | "num_frames": 6242800, 72 | "path": "original/data/Trace_20230812_17_00_14_GABA_BL_20247.raw.h5", 73 | "timestamp": "2023-08-12 T17:00:41;" 74 | } 75 | ] 76 | }, 77 | "Trace_20230812_15_55_43_GABA_BL_20402": { 78 | "name": "Trace_20230812_15_55_43_GABA_BL_20402", 79 | "hardware": "Maxwell", 80 | "maxwell_chip_number": "N/A", 81 | "maxwell_tracker_sr_no": "xxxxxx", 82 | "channels": [], 83 | "notes": "", 84 | "num_channels": 971, 85 | "num_current_input_channels": 0, 86 | "num_voltage_channels": 32, 87 | "offset": 0, 88 | "sample_rate": 20000, 89 | "voltage_scaling_factor": 1, 90 | "timestamp": "2023-08-12 T15:58:02;", 91 | "units": "\u00b5V", 92 | "version": 20190530, 93 | "blocks": [ 94 | { 95 | "num_frames": 6108800, 96 | "path": "original/data/Trace_20230812_15_55_43_GABA_BL_20402.raw.h5", 97 | "timestamp": "2023-08-12 T15:58:02;" 98 | } 99 | ] 100 | }, 101 | "Trace_20230812_16_03_20_GABA_DR2_20402": { 102 | "name": "Trace_20230812_16_03_20_GABA_DR2_20402", 103 | "hardware": "Maxwell", 104 | "maxwell_chip_number": "N/A", 105 | "maxwell_tracker_sr_no": "xxxxxx", 106 | "channels": [], 107 | "notes": "", 108 | "num_channels": 971, 109 | "num_current_input_channels": 0, 110 | "num_voltage_channels": 32, 111 | "offset": 0, 112 | "sample_rate": 20000, 113 | "voltage_scaling_factor": 1, 114 | "timestamp": "2023-08-12 T16:04:39;", 115 | "units": "\u00b5V", 116 | "version": 20190530, 117 | "blocks": [ 118 | { 119 | "num_frames": 16962400, 120 | "path": "original/data/Trace_20230812_16_03_20_GABA_DR2_20402.raw.h5", 121 | "timestamp": "2023-08-12 T16:04:39;" 122 | } 123 | ] 124 | }, 125 | "Trace_20230812_15_18_49_GABA_DR_20325": { 126 | "name": "Trace_20230812_15_18_49_GABA_DR_20325", 127 | "hardware": "Maxwell", 128 | "maxwell_chip_number": "N/A", 129 | "maxwell_tracker_sr_no": "xxxxxx", 130 | "channels": [], 131 | "notes": "", 132 | "num_channels": 819, 133 | "num_current_input_channels": 0, 134 | "num_voltage_channels": 32, 135 | "offset": 0, 136 | "sample_rate": 20000, 137 | "voltage_scaling_factor": 1, 138 | "timestamp": "2023-08-12 T15:19:25;", 139 | "units": "\u00b5V", 140 | "version": 20190530, 141 | "blocks": [ 142 | { 143 | "num_frames": 16888000, 144 | "path": "original/data/Trace_20230812_15_18_49_GABA_DR_20325.raw.h5", 145 | "timestamp": "2023-08-12 T15:19:25;" 146 | } 147 | ] 148 | }, 149 | "Trace_20230812_17_27_36_GABA_DR_20247": { 150 | "name": "Trace_20230812_17_27_36_GABA_DR_20247", 151 | "hardware": "Maxwell", 152 | "maxwell_chip_number": "N/A", 153 | "maxwell_tracker_sr_no": "xxxxxx", 154 | "channels": [], 155 | "notes": "", 156 | "num_channels": 944, 157 | "num_current_input_channels": 0, 158 | "num_voltage_channels": 32, 159 | "offset": 0, 160 | "sample_rate": 20000, 161 | "voltage_scaling_factor": 1, 162 | "timestamp": "2023-08-12 T17:27:59;", 163 | "units": "\u00b5V", 164 | "version": 20190530, 165 | "blocks": [ 166 | { 167 | "num_frames": 16890400, 168 | "path": "original/data/Trace_20230812_17_27_36_GABA_DR_20247.raw.h5", 169 | "timestamp": "2023-08-12 T17:27:59;" 170 | } 171 | ] 172 | } 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /tests/test_data/maxwell-metadata.expected.json: -------------------------------------------------------------------------------- 1 | { 2 | "uuid": "2023-08-12-e-GABA_v0", 3 | "timestamp": "2023-10-05T18:10:02", 4 | "hardware": "Maxwell", 5 | "maxwell_chip_id": "20325, 20247, 20402", 6 | "notes": { 7 | "purpose": "GABA dose response sweep. Gaba added X:55-X+1:~15 for x=2,4,... Started at .5uM, doubled each time (last was 14 not 16)", 8 | "biology": { 9 | "sample_type": "organoid", 10 | "aggregation_date": "23-06-22", 11 | "plating_date": "23-07-17", 12 | "species": "mouse", 13 | "cell_line": "e14", 14 | "genotype": "", 15 | "culture_media": "gfcdm", 16 | "organoid_tracker_sr_no": "", 17 | "organoid modification": "", 18 | "organoid modification date": "" 19 | }, 20 | "stimulation": { 21 | "optogenetic": false, 22 | "electrical": false, 23 | "pharmacological": "GABA" 24 | }, 25 | "automated_feeding": false, 26 | "automated_imaging": false, 27 | "comments": "" 28 | }, 29 | "ephys_experiments": { 30 | "data_GABA_BL_20325": { 31 | "name": "data_GABA_BL_20325", 32 | "maxwell_chip_number": "N/A", 33 | "maxwell_tracker_sr_no": "xxxxxx", 34 | "channels": [], 35 | "notes": "", 36 | "num_channels": 819, 37 | "num_current_input_channels": 0, 38 | "num_voltage_channels": 32, 39 | "offset": 0, 40 | "sample_rate": 20000, 41 | "voltage_scaling_factor": 1, 42 | "timestamp": "2023-08-12 T15:01:19;", 43 | "units": "\u00b5V", 44 | "version": 20190530, 45 | "blocks": [ 46 | { 47 | "num_frames": 6001200, 48 | "path": "shared/data_GABA_BL_20325.nwb", 49 | "timestamp": "2023-08-12 T15:01:19;" 50 | } 51 | ], 52 | "data_format": "NeurodataWithoutBorders" 53 | }, 54 | "Trace_20230812_17_00_14_GABA_BL_20247": { 55 | "name": "Trace_20230812_17_00_14_GABA_BL_20247", 56 | "maxwell_chip_number": "N/A", 57 | "maxwell_tracker_sr_no": "xxxxxx", 58 | "channels": [], 59 | "notes": "", 60 | "num_channels": 944, 61 | "num_current_input_channels": 0, 62 | "num_voltage_channels": 32, 63 | "offset": 0, 64 | "sample_rate": 20000, 65 | "voltage_scaling_factor": 1, 66 | "timestamp": "2023-08-12 T17:00:41;", 67 | "units": "\u00b5V", 68 | "version": 20190530, 69 | "blocks": [ 70 | { 71 | "num_frames": 6242800, 72 | "path": "shared/Trace_20230812_17_00_14_GABA_BL_20247.nwb", 73 | "timestamp": "2023-08-12 T17:00:41;" 74 | } 75 | ], 76 | "data_format": "NeurodataWithoutBorders" 77 | }, 78 | "Trace_20230812_15_55_43_GABA_BL_20402": { 79 | "name": "Trace_20230812_15_55_43_GABA_BL_20402", 80 | "maxwell_chip_number": "N/A", 81 | "maxwell_tracker_sr_no": "xxxxxx", 82 | "channels": [], 83 | "notes": "", 84 | "num_channels": 971, 85 | "num_current_input_channels": 0, 86 | "num_voltage_channels": 32, 87 | "offset": 0, 88 | "sample_rate": 20000, 89 | "voltage_scaling_factor": 1, 90 | "timestamp": "2023-08-12 T15:58:02;", 91 | "units": "\u00b5V", 92 | "version": 20190530, 93 | "blocks": [ 94 | { 95 | "num_frames": 6108800, 96 | "path": "shared/Trace_20230812_15_55_43_GABA_BL_20402.nwb", 97 | "timestamp": "2023-08-12 T15:58:02;" 98 | } 99 | ], 100 | "data_format": "NeurodataWithoutBorders" 101 | }, 102 | "Trace_20230812_16_03_20_GABA_DR2_20402": { 103 | "name": "Trace_20230812_16_03_20_GABA_DR2_20402", 104 | "maxwell_chip_number": "N/A", 105 | "maxwell_tracker_sr_no": "xxxxxx", 106 | "channels": [], 107 | "notes": "", 108 | "num_channels": 971, 109 | "num_current_input_channels": 0, 110 | "num_voltage_channels": 32, 111 | "offset": 0, 112 | "sample_rate": 20000, 113 | "voltage_scaling_factor": 1, 114 | "timestamp": "2023-08-12 T16:04:39;", 115 | "units": "\u00b5V", 116 | "version": 20190530, 117 | "blocks": [ 118 | { 119 | "num_frames": 16962400, 120 | "path": "shared/Trace_20230812_16_03_20_GABA_DR2_20402.nwb", 121 | "timestamp": "2023-08-12 T16:04:39;" 122 | } 123 | ], 124 | "data_format": "NeurodataWithoutBorders" 125 | }, 126 | "Trace_20230812_15_18_49_GABA_DR_20325": { 127 | "name": "Trace_20230812_15_18_49_GABA_DR_20325", 128 | "maxwell_chip_number": "N/A", 129 | "maxwell_tracker_sr_no": "xxxxxx", 130 | "channels": [], 131 | "notes": "", 132 | "num_channels": 819, 133 | "num_current_input_channels": 0, 134 | "num_voltage_channels": 32, 135 | "offset": 0, 136 | "sample_rate": 20000, 137 | "voltage_scaling_factor": 1, 138 | "timestamp": "2023-08-12 T15:19:25;", 139 | "units": "\u00b5V", 140 | "version": 20190530, 141 | "blocks": [ 142 | { 143 | "num_frames": 16888000, 144 | "path": "shared/Trace_20230812_15_18_49_GABA_DR_20325.nwb", 145 | "timestamp": "2023-08-12 T15:19:25;" 146 | } 147 | ], 148 | "data_format": "NeurodataWithoutBorders" 149 | }, 150 | "Trace_20230812_17_27_36_GABA_DR_20247": { 151 | "name": "Trace_20230812_17_27_36_GABA_DR_20247", 152 | "maxwell_chip_number": "N/A", 153 | "maxwell_tracker_sr_no": "xxxxxx", 154 | "channels": [], 155 | "notes": "", 156 | "num_channels": 944, 157 | "num_current_input_channels": 0, 158 | "num_voltage_channels": 32, 159 | "offset": 0, 160 | "sample_rate": 20000, 161 | "voltage_scaling_factor": 1, 162 | "timestamp": "2023-08-12 T17:27:59;", 163 | "units": "\u00b5V", 164 | "version": 20190530, 165 | "blocks": [ 166 | { 167 | "num_frames": 16890400, 168 | "path": "shared/Trace_20230812_17_27_36_GABA_DR_20247.nwb", 169 | "timestamp": "2023-08-12 T17:27:59;" 170 | } 171 | ], 172 | "data_format": "NeurodataWithoutBorders" 173 | } 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/braingeneers/utils/memoize_s3.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import partial 4 | from warnings import warn 5 | 6 | import awswrangler as wr 7 | import boto3 8 | from botocore.client import ClientError 9 | from botocore.exceptions import NoCredentialsError 10 | from smart_open.s3 import parse_uri 11 | 12 | from .smart_open_braingeneers import open 13 | 14 | try: 15 | from joblib import Memory, register_store_backend 16 | from joblib._store_backends import StoreBackendBase, StoreBackendMixin 17 | except ImportError: 18 | raise ImportError("joblib is required to use memoize_s3") 19 | 20 | 21 | def s3_isdir(path): 22 | """ 23 | S3 doesn't support directories, so to check whether some path "exists", 24 | instead check whether it is a prefix of at least one object. 25 | """ 26 | try: 27 | next(wr.s3.list_objects(glob.escape(path), chunked=True)) 28 | return True 29 | except StopIteration: 30 | return False 31 | 32 | 33 | def normalize_location(location: str): 34 | """ 35 | Normalize a location string to use forward slashes instead of backslashes. This is 36 | necessary on Windows because joblib uses `os.path.join` to construct paths, but S3 37 | always uses forward slashes. 38 | """ 39 | return location.replace("\\", "/") 40 | 41 | 42 | class S3StoreBackend(StoreBackendBase, StoreBackendMixin): 43 | def _open_item(self, f, mode): 44 | return open(normalize_location(f), mode) 45 | 46 | def _item_exists(self, location: str): 47 | location = normalize_location(location) 48 | return wr.s3.does_object_exist(location) or s3_isdir(location) 49 | 50 | def _move_item(self, src_uri, dst_uri): 51 | # awswrangler only includes a fancy move/rename method that actually 52 | # makes it pretty hard to just do a simple move. 53 | src, dst = [parse_uri(normalize_location(x)) for x in (src_uri, dst_uri)] 54 | self.client.copy_object( 55 | Bucket=dst["bucket_id"], 56 | Key=dst["key_id"], 57 | CopySource=f"{src['bucket_id']}/{src['key_id']}", 58 | ) 59 | self.client.delete_object(Bucket=src["bucket_id"], Key=src["key_id"]) 60 | 61 | def create_location(self, location): 62 | # Actually don't do anything. There are no locations on S3. 63 | pass 64 | 65 | def clear_location(self, location): 66 | location = normalize_location(location) 67 | # This should only ever be used for prefixes contained within a joblib cache 68 | # directory, so make sure that's actually happening before deleting. 69 | if not location.startswith(self.location): 70 | raise ValueError("can only clear locations within the cache directory") 71 | wr.s3.delete_objects(glob.escape(location)) 72 | 73 | def get_items(self): 74 | # This is only ever used to find cache items for deletion, which we can't 75 | # support because we don't have access times for S3 objects. Returning nothing 76 | # here means it will silently have no effect. 77 | return [] 78 | 79 | def configure(self, location, verbose, backend_options={}): 80 | # We don't do any logging yet, but `configure()` must accept this 81 | # argument, so store it for forwards compatibility. 82 | self.verbose = verbose 83 | 84 | # We have to save this on the backend because joblib queries it, but 85 | # default to True instead of joblib's usual False because S3 is not 86 | # local disk and compression can really help. 87 | self.compress = backend_options.get("compress", True) 88 | 89 | # This option is available by default but we can't accept it because 90 | # there's no reasonable way to make joblib use NumpyS3Memmap. 91 | self.mmap_mode = backend_options.get("mmap_mode") 92 | if self.mmap_mode is not None: 93 | raise ValueError("impossible to mmap on S3.") 94 | 95 | # Don't attempt to handle local files, just use the default backend 96 | # for that! 97 | if not location.startswith("s3://"): 98 | raise ValueError("location must be an s3:// URI") 99 | 100 | # We don't have to check that the bucket exists because joblib 101 | # performs a `list_objects()` in it, but note that this doesn't 102 | # actually check whether we can write to it! 103 | self.location = normalize_location(location) 104 | 105 | # We need a boto3 client, so create it using the endpoint which was 106 | # configured in awswrangler by importing smart_open_braingeneers. 107 | self.client = boto3.Session().client( 108 | "s3", endpoint_url=wr.config.s3_endpoint_url 109 | ) 110 | 111 | 112 | def memoize( 113 | location=None, backend="s3", ignore=None, cache_validation_callback=None, **kwargs 114 | ): 115 | """ 116 | Memoize a function to S3 using joblib.Memory. By default, saves to 117 | `s3://braingeneersdev/$S3_USER/cache`, where $S3_USER defaults to "common" if unset. 118 | Alternately, the cache directory can be provided explicitly. 119 | 120 | If the user has no access to the provided S3 bucket, fall back to local storage 121 | (but errors may occur if only partial permissions such as read-only are available). 122 | 123 | Accepts all the same keyword arguments as `joblib.Memory`, including `backend`, 124 | which can be set to "local" to recover default behavior. Also accepts the 125 | keyword arguments of `joblib.Memory.cache()` and passes them on. Usage: 126 | 127 | ``` 128 | from braingeneers.utils.memoize_s3 import memoize 129 | 130 | # Cache to the default location on NRP S3. 131 | @memoize 132 | def foo(x): 133 | return x 134 | 135 | # Cache to a different NRP S3 location. 136 | @memoize("s3://braingeneers/someplace/else/idk") 137 | def bar(x): 138 | return x 139 | 140 | # Ignore some parameters when deciding which cache entry to check. 141 | @memoize(ignore=["verbose"]) 142 | def plover(x, verbose): 143 | if verbose: ... 144 | return x 145 | ``` 146 | 147 | Another known issue is that size-based cache eviction is NOT supported, 148 | and will also silently fail. This is because there is no easy way to get 149 | access times out of S3, so we can't find LRU entries. 150 | """ 151 | if callable(location): 152 | # This case probably means the @memoize decorator was used without 153 | # arguments, but pass the kwargs on anyway just in case. 154 | return memoize( 155 | backend=backend, 156 | ignore=ignore, 157 | cache_validation_callback=cache_validation_callback, 158 | **kwargs, 159 | )(location) 160 | 161 | if backend == "s3": 162 | # Default to the braingeneersdev bucket, using the user's name. 163 | if location is None: 164 | user = os.environ.get("S3_USER") or "common" 165 | location = f"s3://braingeneersdev/{user}/cache" 166 | 167 | # Get the bucket name from the location URI. 168 | try: 169 | uri = parse_uri(normalize_location(location)) 170 | bucket = uri["bucket_id"] 171 | except AssertionError as e: 172 | raise ValueError(f"Invalid location: {e}") from None 173 | 174 | # Check the user's credentials for that bucket. 175 | try: 176 | boto3.Session().client( 177 | "s3", endpoint_url=wr.config.s3_endpoint_url 178 | ).list_objects_v2(Bucket=bucket, MaxKeys=1) 179 | except (NoCredentialsError, ClientError) as e: 180 | warn(f"Cannot access s3://{bucket} ({e}). Memoizing locally.") 181 | backend = "local" 182 | location = uri["key_id"] 183 | 184 | return partial( 185 | Memory(location, backend=backend, **kwargs).cache, 186 | ignore=ignore, 187 | cache_validation_callback=cache_validation_callback, 188 | ) 189 | 190 | 191 | register_store_backend("s3", S3StoreBackend) 192 | -------------------------------------------------------------------------------- /src/braingeneers/iot/simple.py: -------------------------------------------------------------------------------- 1 | """ A simple protocol for connecting devices to IoT and schedule experiments""" 2 | 3 | # import stuff 4 | from braingeneers.iot import messaging 5 | import uuid 6 | import schedule 7 | import time 8 | import warnings 9 | import builtins 10 | import inspect 11 | import logging 12 | import traceback 13 | from dateutil import tz 14 | 15 | 16 | def start_iot(device_name, device_type, experiment, commands=[]): #, #allowed_commands={"pump":["func_1","func_1"], "image":["func_1","func_1"]}): 17 | """Create a device and have it start listening for commands. This is intended for simple use cases""" 18 | 19 | # Run Helper Funcitons 20 | def _updateIot(device_name, mb ): 21 | """Updates the IoT device's publicly visible state according to the local schedule/status""" 22 | # Update schedule 23 | try: # Use try/catch so that error like no internet don't stop experiment 24 | jobs, jobs_str = [],[] 25 | for i,job in enumerate(schedule.get_jobs()): # debugging: print(i) #print(job.__str__()) #print(job.__repr__()) 26 | jobs.append( job.__str__() ) # append python command that wass used to create shceduled job 27 | if job.cancel_after: # if job must be cancelled at a certain time, add cancel time 28 | jobs[i]+= job.cancel_after.astimezone(tz.gettz('US/Pacific')).strftime("UNTIL-%Y-%m-%d %H:%M:%S") 29 | 30 | job_str = f"Job {i}: "+job.__repr__() 31 | if job.last_run: # if job was previously run 32 | last_time = job.last_run.astimezone(tz.gettz('US/Pacific')).strftime("%Y-%m-%d %H:%M:%S") # get time of last scheduled event 33 | job_str = job_str.split("last run: ")[0]+"last run: "+last_time+", next run: " 34 | next_time = job.next_run.astimezone(tz.gettz('US/Pacific')).strftime("%Y-%m-%d %H:%M:%S)") # get time of next scheduled event 35 | jobs_str.append( job_str.split("next run: ")[0]+"next run: "+next_time )# append 36 | 37 | # Update log 38 | with open("iot.log") as file: # Get history of iot commands and scheduel from the file iot.log 39 | lines = file.readlines() 40 | log = [line.rstrip() for line in lines] 41 | if len(log)>35: # If iot log is more than 25 lines, get just the last 25 lines 42 | log = log[-35:] 43 | mb.update_device_state( device_name, {"status":iot_status,"jobs":jobs,"schedule":jobs_str,"history":log} ) 44 | 45 | except: 46 | logger.error("\n"+traceback.format_exc()) # Write error to log 47 | 48 | 49 | # Create function for when IoT command is received 50 | def respondToCommand(topic: str, message: dict): # build function that runs when device receives command 51 | global last_mqtts 52 | try: #DEBUG: print("command Received") #print(message) 53 | if len(commands)>0: # if user inputed list of allowed commands 54 | if not "global iot_status; iot_status=" in message["command"] and not any(x in message["command"] for x in commands ): # check if sent command contains an allowed command, if not throw error 55 | raise Exception(f"User message-- {message['command']}-- doesn't contain required commands -- {commands}") 56 | if message["id"] in last_mqtts : 57 | raise Exception(f"User message-- {message['command']}-- received multiple times and was stopped from running again") 58 | last_mqtts = last_mqtts[1:] + [message["id"]] 59 | logger.debug(f"Run Command: {message['command']}") # log to history that the sent command was run 60 | exec(message["command"]) # run python command that was sent 61 | except Exception as e: 62 | logger.error("\n"+traceback.format_exc()) 63 | _updateIot(device_name, mb ) # update iot state, in case schedule/status changed 64 | 65 | 66 | # Initialize environment 67 | from braingeneers.iot import messaging; import uuid; import schedule; import time; import warnings; import logging; import traceback #requried packages 68 | global iot_status # states if iot is running. Other iot functions can change it asynchronously 69 | global last_mqtts 70 | warnings.filterwarnings(action='once') # stops same warning from appearing more than once 71 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 72 | last_mqtts = [""]*200 73 | 74 | # Set up Logging 75 | open("iot.log", "w").close() 76 | logging.basicConfig( level=logging.WARNING ) 77 | logger = logging.getLogger('schedule') 78 | logger.setLevel(level=logging.DEBUG) 79 | f_handler = logging.FileHandler('iot.log') 80 | f_handler.setFormatter( logging.Formatter("%(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S %Z") ) # %(levelname)s 81 | logger.addHandler(f_handler) 82 | global print 83 | print = logger.debug 84 | 85 | # Start IoT Device 86 | if device_name not in mb.list_devices_by_type(thing_type_name= device_type): # check if device already exists 87 | mb.create_device( device_name= device_name, device_type= device_type) # if not, create it 88 | else: # otherwise, check device is ok and isn't still running 89 | assert "status" in mb.get_device_state(device_name), f"{device_name} has corrupted data! Talk to data team." 90 | assert mb.get_device_state(device_name)["status"]=="shutdown", f"{device_name} already exists and isn't shutdown. Please shutdown with 'iot.shutdown({device_name})'" 91 | mb.update_device_state( device_name, {"experiment":experiment,"status":"run","jobs":[],"schedule":[],"history":[]} ) # initialize iot state 92 | mb.subscribe_message( f"devices/+/{device_name}", respondToCommand ) # start listening for new commands 93 | #mb.subscribe_message(f"devices/{device_type}/{device_name}",respondToCommand) 94 | 95 | # Perpetually listen for IoT commands 96 | iot_status = "run" # keep python running so that listener can do it's job 97 | while not iot_status=="shutdown": # when it's time to stop, iot makes iot_status='shutdown'{} 98 | if iot_status=="run": # if the device is in run mode, 99 | is_pending = sum([job.should_run for job in schedule.jobs]) # Get the number of pending jobs that should be run now 100 | if is_pending: # if there are any pending jobs to run 101 | schedule.run_pending() # run the pending jobs 102 | _updateIot(device_name, mb ) # schedule info has changed, so update IoT state 103 | time.sleep(.1) # wait a little to save cpu usage 104 | mb.shutdown() # shutdown iot at the end. 105 | 106 | 107 | def ready_iot(): 108 | """Save source code for start_iot function to a place where it can be executed by the user""" 109 | builtins.ready_iot = inspect.getsource(start_iot) 110 | 111 | -------------------------------------------------------------------------------- /src/braingeneers/iot/gui.py: -------------------------------------------------------------------------------- 1 | 2 | # import stuff 3 | import uuid 4 | import time 5 | import warnings 6 | from braingeneers.iot import messaging 7 | from datetime import date, datetime, timedelta 8 | from matplotlib.patches import Patch 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def send( device_name, command ): 13 | """Send a python script as a string which is then implemented by an IoT device. This is intended for simple use cases""" 14 | warnings.filterwarnings("ignore") 15 | my_id =str(uuid.uuid4() ) 16 | mb = messaging.MessageBroker(my_id) # spin up iot 17 | mb.publish_message( topic=f"devices/dummy/{device_name}", message={"command": command,"id":my_id} ) # send command to listening device 18 | mb.shutdown() # shutdown iot 19 | 20 | def get_schedule( device_name ): 21 | """Get a list of scheduled commands from a device's shadow. This is intended for simple use cases""" 22 | warnings.filterwarnings("ignore") 23 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 24 | my_schedule = mb.get_device_state( device_name )["schedule"] # get schedule for device 25 | mb.shutdown() # shutdown iot 26 | return my_schedule # return schedule to user 27 | 28 | 29 | def get_status( device_name ): 30 | """Get a list of scheduled commands from a device's shadow. This is intended for simple use cases""" 31 | warnings.filterwarnings("ignore") 32 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 33 | status = mb.get_device_state( device_name )["status"] # get schedule for device 34 | mb.shutdown() # shutdown iot 35 | return status # return schedule to user 36 | 37 | def get_history( device_name ): 38 | """Get a list of scheduled commands from a device's shadow. This is intended for simple use cases""" 39 | warnings.filterwarnings("ignore") 40 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 41 | status = mb.get_device_state( device_name )["history"] # get schedule for device 42 | mb.shutdown() # shutdown iot 43 | return status # return schedule to user 44 | 45 | def get_info( device_name ): 46 | """Get public device info from its shadow. This is intended for simple use cases""" 47 | warnings.filterwarnings("ignore") 48 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 49 | info = mb.get_device_state( device_name ) # get all info 50 | mb.shutdown() # shutdown iot 51 | return info # return info to user 52 | 53 | def draw_schedule( device_list ): 54 | """Draw a weekly schedule of all events that occure for a chose device or experiment""" 55 | # To Do: figure out how to find all device if given the experiment. 56 | fig, ax = plt.subplots( figsize=(18, 15) ) 57 | plt.title('Weekly Schedule', y=1, fontsize=16) # Give the figure a title 58 | ax.grid(axis='y', linestyle='--', linewidth=0.5) # Add horizonal grid lines to the plot 59 | 60 | DAYS = ['Monday','Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'] 61 | ax.set_xlim(0.5, len(DAYS) + 0.5) 62 | ax.set_xticks(range(1, len(DAYS) + 1)) 63 | ax.set_xticklabels(DAYS) 64 | ax.set_ylim( 24, 0) 65 | ax.set_yticks(range(0, 24)) 66 | ax.set_yticklabels(["{0}:00".format(h) for h in range(0, 24)]) 67 | 68 | # !!! This code should be able to handle experiment types 69 | # Locally: schedule_str= [f"Job {i}: "+x.__repr__() for i,x in enumerate(schedule.get_jobs())] #jobs= [job.__str__() for job in schedule.get_jobs()] 70 | if type(device_list)==type(""): # if single device was passed to the function as a string instead of list, 71 | device_list = [device_list] # turn it into a list 72 | colors= ["cornflowerblue","darkorange","mediumpurple","lightgreen"] 73 | legend_elements = [ Patch(colors[i],colors[i],alpha=0.3) for i in range(len(device_list)) ] 74 | plt.legend(legend_elements, device_list, bbox_to_anchor=(1.1,1), loc="upper right", prop={'size': 12}) # Add legend # Show plot 75 | 76 | for dev_num, device in enumerate(device_list): 77 | info = get_info( device ) 78 | jobs, schedule_str = info["jobs"], info["schedule"] 79 | 80 | for i in range(len(jobs)): # for each job, we get it's next run time, and run interval for IoT state information 81 | next_run= datetime.fromisoformat( schedule_str[i].split("next run: ")[1].split(")")[0] ) 82 | period = timedelta(**{ jobs[i].split("unit=")[1].split(",")[0] : int(jobs[i].split("interval=")[1].split(",")[0]) }) 83 | if "UNTIL-" in jobs[i]: # if there is a specified stop time, use it 84 | stop_time = datetime.fromisoformat( jobs[i].split("UNTIL-")[1] ) 85 | else: # otherwise stop at weekly cycle 86 | stop_time = datetime.now() - timedelta(hours=7) + timedelta(weeks=1) # Consider jobs that occur up to a week from now 87 | job_times = [] # create a list of all event times for a job 88 | while next_run < stop_time: 89 | job_times.append(next_run) 90 | next_run += period 91 | 92 | for event in job_times: 93 | d = event.weekday() + 0.52 # get day of week for event 94 | start = float(event.hour) + float(event.minute) / 60 # get start time of event 95 | end = float(event.hour) + (float(event.minute)+15) / 60 # Ends 15 minutes after start 96 | plt.fill_between([d, d + 0.96], [start, start], [end, end], color=colors[dev_num],alpha=0.3) 97 | plt.text(d + 0.02, start + 0.02, '{0}:{1:0>2}'.format(event.hour, event.minute), va='top', fontsize=8) 98 | plt.text(d + 0.48, start + 0.01, f"Job {i}", va='top', fontsize=8) #ha='center', va='center', fontsize=10) 99 | 100 | # Add a red line for when right now is 101 | now = datetime.now() - timedelta(hours=7) 102 | d = now.weekday() + 0.52 # get day of week for event 103 | start = float(now.hour) + float(now.minute) / 60 # get start time of event 104 | end = float(now.hour) + (float(now.minute)+5) / 60 # Ends 15 minutes after start 105 | plt.fill_between([d, d + 0.96 ], [start, start], [end, end], color='red' ) 106 | plt.show() 107 | 108 | 109 | def shutdown( device_name, hard=False ): 110 | """Stops iot listener on device by changing flag on shadow. This is intended for simple use cases""" 111 | send(device_name, "global iot_status; iot_status='shutdown'") 112 | if hard == True: 113 | mb = messaging.MessageBroker(str(uuid.uuid4)) # spin up iot 114 | mb.update_device_state( device_name, {"status":"shutdown"} ) # change status flag on device state to run 115 | mb.shutdown() 116 | 117 | def pause( device_name ): 118 | """Pauses iot listener on device by changing flag on shadow. This is intended for simple use cases""" 119 | send(device_name, "global iot_status; iot_status='pause'") 120 | 121 | def run( device_name ): 122 | """Resumes running of iot listener on device by changing flag on shadow. This is intended for simple use cases""" 123 | send(device_name, "global iot_status; iot_status='run'") 124 | -------------------------------------------------------------------------------- /tests/test_common_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import tempfile 4 | import unittest 5 | from unittest.mock import MagicMock, patch 6 | 7 | import braingeneers.utils.smart_open_braingeneers as smart_open 8 | from braingeneers.utils import common_utils 9 | from braingeneers.utils.common_utils import checkout, map2 10 | 11 | 12 | class TestFileListFunction(unittest.TestCase): 13 | @patch( 14 | "braingeneers.utils.common_utils._lazy_init_s3_client" 15 | ) # Updated to common_utils 16 | def test_s3_files_exist(self, mock_s3_client): 17 | # Mock S3 client response 18 | mock_response = { 19 | "Contents": [ 20 | {"Key": "file1.txt", "LastModified": "2023-01-01", "Size": 123}, 21 | {"Key": "file2.txt", "LastModified": "2023-01-02", "Size": 456}, 22 | ] 23 | } 24 | mock_s3_client.return_value.list_objects.return_value = mock_response 25 | 26 | result = common_utils.file_list("s3://test-bucket/") # Updated to common_utils 27 | expected = [("file2.txt", "2023-01-02", 456), ("file1.txt", "2023-01-01", 123)] 28 | self.assertEqual(result, expected) 29 | 30 | @patch( 31 | "braingeneers.utils.common_utils._lazy_init_s3_client" 32 | ) # Updated to common_utils 33 | def test_s3_no_files(self, mock_s3_client): 34 | # Mock S3 client response for no files 35 | mock_s3_client.return_value.list_objects.return_value = {} 36 | result = common_utils.file_list("s3://test-bucket/") # Updated to common_utils 37 | self.assertEqual(result, []) 38 | 39 | def test_local_files_exist(self): 40 | with tempfile.TemporaryDirectory() as temp_dir: 41 | for f in ["tempfile1.txt", "tempfile2.txt"]: 42 | with open(os.path.join(temp_dir, f), "w") as w: 43 | w.write("nothing") 44 | 45 | result = common_utils.file_list(temp_dir) # Updated to common_utils 46 | # The result should contain two files with their details 47 | self.assertEqual(len(result), 2) 48 | 49 | def test_local_no_files(self): 50 | with tempfile.TemporaryDirectory() as temp_dir: 51 | result = common_utils.file_list(temp_dir) # Updated to common_utils 52 | self.assertEqual(result, []) 53 | 54 | 55 | class TestCheckout(unittest.TestCase): 56 | def setUp(self): 57 | # Setup mock for smart_open and MessageBroker 58 | self.message_broker_patch = patch("braingeneers.iot.messaging.MessageBroker") 59 | 60 | # Start the patches 61 | self.mock_message_broker = self.message_broker_patch.start() 62 | 63 | # Mock the message broker's get_lock and delete_lock methods 64 | self.mock_message_broker.return_value.get_lock.return_value = MagicMock() 65 | self.mock_message_broker.return_value.delete_lock = MagicMock() 66 | 67 | self.mock_file = MagicMock(spec=io.StringIO) 68 | self.mock_file.read.return_value = ( 69 | "Test data" # Ensure this is correctly setting the return value for read 70 | ) 71 | self.mock_file.__enter__.return_value = self.mock_file 72 | self.mock_file.__exit__.return_value = None 73 | self.smart_open_mock = MagicMock(spec=smart_open) 74 | self.smart_open_mock.open.return_value = self.mock_file 75 | 76 | common_utils.smart_open = self.smart_open_mock 77 | 78 | def tearDown(self): 79 | # Stop all patches 80 | self.message_broker_patch.stop() 81 | 82 | def test_checkout_context_manager_read(self): 83 | # Test the reading functionality 84 | with checkout("s3://test-bucket/test-file.txt", isbinary=False) as locked_obj: 85 | data = locked_obj.get_value() 86 | self.assertEqual(data, "Test data") 87 | 88 | def test_checkout_context_manager_write_text(self): 89 | # Test the writing functionality for text mode 90 | test_data = "New test data" 91 | self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test 92 | with checkout("s3://test-bucket/test-file.txt", isbinary=False) as locked_obj: 93 | locked_obj.checkin(test_data) 94 | self.mock_file.write.assert_called_once_with(test_data) 95 | 96 | def test_checkout_context_manager_write_binary(self): 97 | # Test the writing functionality for binary mode 98 | test_data = b"New binary data" 99 | self.mock_file.write.reset_mock() # Reset mock to ensure clean state for the test 100 | with checkout("s3://test-bucket/test-file.bin", isbinary=True) as locked_obj: 101 | locked_obj.checkin(test_data) 102 | self.mock_file.write.assert_called_once_with(test_data) 103 | 104 | 105 | class TestMap2Function(unittest.TestCase): 106 | def test_with_pass_through_kwargs_handling(self): 107 | """Test map2 with a function accepting dynamic kwargs, specifically to check the handling of 'experiment_name' 108 | passed through **kwargs, using the original signature for f_with_kwargs.""" 109 | 110 | def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs): 111 | # Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs 112 | self.assertTrue(isinstance(kwargs, dict), "kwargs should be a dict") 113 | self.assertFalse("kwargs" in kwargs) 114 | return "some data" 115 | 116 | experiments = [ 117 | {"experiment": "exp1"}, 118 | {"experiment": "exp2"}, 119 | ] # List of experiment names to be passed as individual kwargs 120 | fixed_values = { 121 | "cache_path": "/tmp/ephys_cache", 122 | "max_size_gb": 50, 123 | "metadata": {"some": "metadata"}, 124 | "channels": ["channel1"], 125 | "length": -1, 126 | } 127 | 128 | # Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly 129 | map2( 130 | f_with_kwargs, 131 | kwargs=experiments, 132 | fixed_values=fixed_values, 133 | parallelism=False, 134 | ) 135 | self.assertTrue(True) # If the test reaches this point, it has passed 136 | 137 | def test_with_kwargs_function_parallelism_false(self): 138 | # Define a test function that takes a positional argument and arbitrary kwargs 139 | def test_func(a, **kwargs): 140 | return a + kwargs.get("increment", 0) 141 | 142 | # Define the arguments and kwargs to pass to map2 143 | args = [(1,), (2,), (3,)] # positional arguments 144 | kwargs = [ 145 | {"increment": 10}, 146 | {"increment": 20}, 147 | {"increment": 30}, 148 | ] # kwargs for each call 149 | 150 | # Call map2 with the test function, args, kwargs, and parallelism=False 151 | result = map2(func=test_func, args=args, kwargs=kwargs, parallelism=False) 152 | 153 | # Expected results after applying the function with the given args and kwargs 154 | expected_results = [11, 22, 33] 155 | 156 | # Assert that the actual result matches the expected result 157 | self.assertEqual(result, expected_results) 158 | 159 | def test_with_fixed_values_and_variable_kwargs_parallelism_false(self): 160 | # Define a test function that takes fixed positional argument and arbitrary kwargs 161 | def test_func(a, **kwargs): 162 | return a + kwargs.get("increment", 0) 163 | 164 | # Define the kwargs to pass to map2, each dict represents kwargs for one call 165 | kwargs = [{"increment": 10}, {"increment": 20}, {"increment": 30}] 166 | 167 | # Call map2 with the test function, no args, variable kwargs, fixed_values containing 'a', and parallelism=False 168 | result = map2( 169 | func=test_func, 170 | kwargs=kwargs, 171 | fixed_values={"a": 1}, # 'a' is fixed for all calls 172 | parallelism=False, 173 | ) 174 | 175 | # Expected results after applying the function with the fixed 'a' and given kwargs 176 | expected_results = [11, 21, 31] 177 | 178 | # Assert that the actual result matches the expected result 179 | self.assertEqual(result, expected_results) 180 | 181 | def test_with_no_kwargs(self): 182 | # Define a test function that takes a positional argument and no kwargs 183 | def test_func(a): 184 | return a + 1 185 | 186 | # While we're at it, also test the pathway that normalizes the args. 187 | args = range(1, 4) 188 | result = map2( 189 | func=test_func, 190 | args=args, 191 | parallelism=False, 192 | ) 193 | 194 | self.assertEqual(result, [2, 3, 4]) 195 | 196 | 197 | if __name__ == "__main__": 198 | unittest.main() 199 | -------------------------------------------------------------------------------- /tests/test_messaging.py: -------------------------------------------------------------------------------- 1 | """ Unit test for BraingeneersMqttClient, assumes Braingeneers ~/.aws/credentials file exists """ 2 | import queue 3 | import threading 4 | import time 5 | import unittest 6 | import uuid 7 | import warnings 8 | 9 | from unittest.mock import MagicMock 10 | from tenacity import retry, stop_after_attempt 11 | 12 | import braingeneers.iot.messaging as messaging 13 | 14 | 15 | def _make_message_broker(): 16 | try: 17 | return messaging.MessageBroker(f"unittest-{uuid.uuid4()}") 18 | except PermissionError: 19 | raise unittest.SkipTest("Can't test messaging without access token.") from None 20 | 21 | 22 | class TestBraingeneersMessageBroker(unittest.TestCase): 23 | def setUp(self) -> None: 24 | warnings.filterwarnings( 25 | "ignore", category=ResourceWarning, message="unclosed.*" 26 | ) 27 | self.mb = _make_message_broker() 28 | self.mb_test_device = messaging.MessageBroker("unittest") 29 | self.mb.create_device("test", "Other") 30 | 31 | 32 | def tearDown(self) -> None: 33 | self.mb.shutdown() 34 | self.mb_test_device.shutdown() 35 | 36 | def test_publish_message_error(self): 37 | self.mb._mqtt_connection = MagicMock() 38 | 39 | # Mock a failed publish_message attempt 40 | self.mb._mqtt_connection.publish.return_value.rc = 1 41 | 42 | with self.assertRaises(messaging.MQTTError): 43 | self.mb.publish_message("test", "message") 44 | 45 | def test_subscribe_system_messages(self): 46 | q = self.mb.subscribe_message("$SYS/#", callback=None) 47 | self.mb.publish_message("test/unittest", message={"test": "true"}) 48 | 49 | t0 = time.time() 50 | while time.time() - t0 < 5: 51 | topic, message = q.get(timeout=5) 52 | print(f"DEBUG TEST> {topic}") 53 | if topic.startswith("$SYS"): 54 | self.assertTrue(True) 55 | break 56 | 57 | def test_two_message_broker_objects(self): 58 | """Tests that two message broker objects can successfully publish and subscribe messages""" 59 | mb1 = messaging.MessageBroker() 60 | mb2 = messaging.MessageBroker() 61 | q1 = messaging.CallableQueue() 62 | q2 = messaging.CallableQueue() 63 | mb1.subscribe_message("test/unittest1", q1) 64 | mb2.subscribe_message("test/unittest2", q2) 65 | mb1.publish_message("test/unittest1", message={"test": "true"}) 66 | mb2.publish_message("test/unittest2", message={"test": "true"}) 67 | topic, message = q1.get() 68 | self.assertEqual(topic, "test/unittest1") 69 | self.assertEqual(message, {"test": "true"}) 70 | topic, message = q2.get() 71 | self.assertEqual(topic, "test/unittest2") 72 | self.assertEqual(message, {"test": "true"}) 73 | mb1.shutdown() 74 | mb2.shutdown() 75 | 76 | def test_publish_subscribe_message(self): 77 | """Uses a custom callback to test publish subscribe messages""" 78 | message_received_barrier = threading.Barrier(2, timeout=30) 79 | 80 | def unittest_subscriber(topic, message): 81 | print(f"DEBUG> {topic}: {message}") 82 | self.assertEqual(topic, "test/unittest") 83 | self.assertEqual(message, {"test": "true"}) 84 | message_received_barrier.wait() # synchronize between threads 85 | 86 | self.mb.subscribe_message("test/unittest", unittest_subscriber) 87 | self.mb.publish_message("test/unittest", message={"test": "true"}) 88 | 89 | message_received_barrier.wait() # will throw BrokenBarrierError if timeout 90 | 91 | def test_publish_subscribe_message_with_confirm_receipt(self): 92 | q = messaging.CallableQueue() 93 | self.mb.subscribe_message("test/unittest", q) 94 | self.mb.publish_message( 95 | "test/unittest", message={"test": "true"}, confirm_receipt=True 96 | ) 97 | topic, message = q.get() 98 | self.assertEqual(topic, "test/unittest") 99 | self.assertEqual(message, {"test": "true"}) 100 | 101 | def test_publish_subscribe_data_stream(self): 102 | """Uses queue method to test publish/subscribe data streams""" 103 | q = messaging.CallableQueue(1) 104 | self.mb.subscribe_data_stream(stream_name="unittest", callback=q) 105 | self.mb.publish_data_stream( 106 | stream_name="unittest", data={b"x": b"42"}, stream_size=1 107 | ) 108 | result_stream_name, result_data = q.get(timeout=15) 109 | self.assertEqual(result_stream_name, "unittest") 110 | self.assertDictEqual(result_data, {b"x": b"42"}) 111 | 112 | def test_publish_subscribe_multiple_data_streams(self): 113 | self.mb.redis_client.delete("unittest1", "unittest2") 114 | q = messaging.CallableQueue() 115 | self.mb.subscribe_data_stream( 116 | stream_name=["unittest1", "unittest2"], callback=q 117 | ) 118 | self.mb.publish_data_stream( 119 | stream_name="unittest1", data={b"x": b"42"}, stream_size=1 120 | ) 121 | self.mb.publish_data_stream( 122 | stream_name="unittest2", data={b"x": b"43"}, stream_size=1 123 | ) 124 | self.mb.publish_data_stream( 125 | stream_name="unittest2", data={b"x": b"44"}, stream_size=1 126 | ) 127 | 128 | result_stream_name, result_data = q.get(timeout=15) 129 | self.assertEqual(result_stream_name, "unittest1") 130 | self.assertDictEqual(result_data, {b"x": b"42"}) 131 | 132 | result_stream_name, result_data = q.get(timeout=15) 133 | self.assertEqual(result_stream_name, "unittest2") 134 | self.assertDictEqual(result_data, {b"x": b"43"}) 135 | 136 | result_stream_name, result_data = q.get(timeout=15) 137 | self.assertEqual(result_stream_name, "unittest2") 138 | self.assertDictEqual(result_data, {b"x": b"44"}) 139 | 140 | @retry(stop=stop_after_attempt(3)) 141 | def test_poll_data_stream(self): 142 | """Uses more advanced poll_data_stream function""" 143 | self.mb.redis_client.delete("unittest") 144 | 145 | self.mb.publish_data_stream( 146 | stream_name="unittest", data={b"x": b"42"}, stream_size=1 147 | ) 148 | self.mb.publish_data_stream( 149 | stream_name="unittest", data={b"x": b"43"}, stream_size=1 150 | ) 151 | self.mb.publish_data_stream( 152 | stream_name="unittest", data={b"x": b"44"}, stream_size=1 153 | ) 154 | 155 | result1 = self.mb.poll_data_streams({"unittest": "-"}, count=1) 156 | self.assertEqual(len(result1[0][1]), 1) 157 | self.assertDictEqual(result1[0][1][0][1], {b"x": b"42"}) 158 | 159 | result2 = self.mb.poll_data_streams({"unittest": result1[0][1][0][0]}, count=2) 160 | self.assertEqual(len(result2[0][1]), 2) 161 | self.assertDictEqual(result2[0][1][0][1], {b"x": b"43"}) 162 | self.assertDictEqual(result2[0][1][1][1], {b"x": b"44"}) 163 | 164 | result3 = self.mb.poll_data_streams({"unittest": "-"}) 165 | self.assertEqual(len(result3[0][1]), 3) 166 | self.assertDictEqual(result3[0][1][0][1], {b"x": b"42"}) 167 | self.assertDictEqual(result3[0][1][1][1], {b"x": b"43"}) 168 | self.assertDictEqual(result3[0][1][2][1], {b"x": b"44"}) 169 | 170 | # TypeError: 'NoneType' object is not subscriptable 171 | @unittest.expectedFailure 172 | def test_delete_device_state(self): 173 | self.mb.delete_device_state("test") 174 | self.mb.update_device_state("test", {"x": 42, "y": 24}) 175 | state = self.mb.get_device_state("test") 176 | self.assertTrue("x" in state) 177 | self.assertTrue(state["x"] == 42) 178 | self.assertTrue("y" in state) 179 | self.assertTrue(state["y"] == 24) 180 | self.mb.delete_device_state("test", ["x"]) 181 | state_after_del = self.mb.get_device_state("test") 182 | self.assertTrue("x" not in state_after_del) 183 | self.assertTrue("y" in state) 184 | self.assertTrue(state["y"] == 24) 185 | 186 | # TypeError: 'NoneType' object is not subscriptable 187 | @unittest.expectedFailure 188 | def test_get_update_device_state(self): 189 | self.mb_test_device.delete_device_state("test") 190 | self.mb_test_device.update_device_state("test", {"x": 42}) 191 | state = self.mb_test_device.get_device_state("test") 192 | self.assertTrue("x" in state) 193 | self.assertEqual(state["x"], 42) 194 | self.mb_test_device.delete_device_state("test") 195 | 196 | def test_lock(self): 197 | with self.mb.get_lock("unittest"): 198 | print("lock granted") 199 | 200 | def test_unsubscribe(self): 201 | q = messaging.CallableQueue() 202 | self.mb.subscribe_message("test/unittest", callback=q) 203 | self.mb.unsubscribe_message("test/unittest") 204 | self.mb.publish_message("test/unittest", message={"test": 1}) 205 | with self.assertRaises(queue.Empty): 206 | q.get(timeout=3) 207 | 208 | def test_two_subscribers(self): 209 | q1 = messaging.CallableQueue() 210 | q2 = messaging.CallableQueue() 211 | self.mb.subscribe_message("test/unittest1", callback=q1) 212 | self.mb.subscribe_message("test/unittest2", callback=q2) 213 | self.mb.publish_message("test/unittest1", message={"test": 1}) 214 | self.mb.publish_message("test/unittest2", message={"test": 2}) 215 | topic1, message1 = q1.get(timeout=5) 216 | topic2, message2 = q2.get(timeout=5) 217 | self.assertDictEqual(message1, {"test": 1}) 218 | self.assertDictEqual(message2, {"test": 2}) 219 | 220 | 221 | class TestInterprocessQueue(unittest.TestCase): 222 | def setUp(self) -> None: 223 | self.mb = _make_message_broker() 224 | self.mb.delete_queue("unittest") 225 | 226 | @retry(stop=stop_after_attempt(3)) 227 | def test_get_put_defaults(self): 228 | q = self.mb.get_queue("unittest") 229 | q.put("some-value") 230 | result = q.get("some-value") 231 | self.assertEqual(result, "some-value") 232 | 233 | def test_get_put_nonblocking_without_maxsize(self): 234 | q = self.mb.get_queue("unittest") 235 | q.put("some-value", block=False) 236 | result = q.get(block=False) 237 | self.assertEqual(result, "some-value") 238 | 239 | @retry(stop=stop_after_attempt(3)) 240 | def test_maxsize(self): 241 | q = self.mb.get_queue("unittest", maxsize=1) 242 | q.put("some-value") 243 | result = q.get() 244 | self.assertEqual(result, "some-value") 245 | 246 | @retry(stop=stop_after_attempt(3)) 247 | def test_timeout_put(self): 248 | q = self.mb.get_queue("unittest", maxsize=1) 249 | q.put("some-value-1") 250 | with self.assertRaises(queue.Full): 251 | q.put("some-value-2", timeout=0.1) 252 | time.sleep(1) 253 | self.fail( 254 | "Queue failed to throw an expected exception after 0.1s timeout period." 255 | ) 256 | 257 | def test_timeout_get(self): 258 | q = self.mb.get_queue("unittest", maxsize=1) 259 | with self.assertRaises(queue.Empty): 260 | q.get(timeout=0.1) 261 | time.sleep(1) 262 | self.fail( 263 | "Queue failed to throw an expected exception after 0.1s timeout period." 264 | ) 265 | 266 | def test_task_done_join(self): 267 | """Test that task_done and join work as expected.""" 268 | 269 | def f(ql, jl, bl): 270 | t0 = time.time() 271 | ql.join() 272 | jl["join_time"] = time.time() - t0 273 | b.wait() 274 | 275 | b = threading.Barrier(2) 276 | join_time = {"join_time": 0} # a mutable datastructure 277 | 278 | q = self.mb.get_queue("unittest") 279 | q.put("some-value") 280 | threading.Thread(target=f, args=(q, join_time, b)).start() 281 | time.sleep(0.1) 282 | q.get() 283 | q.task_done() 284 | q.join() 285 | b.wait() 286 | 287 | t = join_time["join_time"] 288 | self.assertTrue(t >= 0.1, msg=f"Join time {t} less than expected 0.1 sec.") 289 | 290 | 291 | class TestNamedLock(unittest.TestCase): 292 | def setUp(self) -> None: 293 | self.mb = _make_message_broker() 294 | self.mb.delete_lock("unittest") 295 | 296 | def tearDown(self) -> None: 297 | self.mb.delete_lock("unittest") 298 | 299 | def test_enter_exit(self): 300 | with self.mb.get_lock("unittest"): 301 | self.assertTrue(True) 302 | 303 | def test_acquire_release(self): 304 | lock = self.mb.get_lock("unittest") 305 | lock.acquire() 306 | lock.release() 307 | 308 | 309 | if __name__ == "__main__": 310 | unittest.main() 311 | -------------------------------------------------------------------------------- /src/braingeneers/analysis/visualize_maxwell.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from neuraltoolkit import ntk_filters as ntk 5 | import pdb 6 | import braingeneers 7 | from braingeneers.utils import s3wrangler as wr 8 | from braingeneers.utils import smart_open_braingeneers as smart_open 9 | import numpy as np 10 | from typing import List, Union 11 | from braingeneers.data import datasets_electrophysiology as de 12 | from scipy import signal as ssig 13 | import matplotlib.pyplot as plt 14 | import gc 15 | 16 | 17 | def int_or_str(value): 18 | """ 19 | This function is passed as type to accept the Union of two data types. 20 | :param value: value to consider 21 | :return: either int or string 22 | """ 23 | try: 24 | return int(value) 25 | except ValueError: 26 | return value 27 | 28 | 29 | def parse_args(): 30 | """ 31 | This function parses the arguments passed in via CLI 32 | :return: Dictionary of parsed arguments 33 | """ 34 | parser = argparse.ArgumentParser(description='Convert a single neural data file using a specified filter') 35 | parser.add_argument('--uuid', '-u', type=str, required=True, 36 | help='UUID for desired experiment batch') 37 | 38 | parser.add_argument('--experiment', '-e', type=int_or_str, required=True, 39 | help='Experiment number. Can be passed as index (int) or experiment# (string)' 40 | 'e.g. 1 or \'experiment2\'') 41 | 42 | parser.add_argument('--outputLocation', '-o', type=str, default='local', choices=['local', 's3'], 43 | help='Where to store the output. Either specify \'local\' or \'s3\', or leave blank' 44 | 'to have it saved locally. ') 45 | 46 | parser.add_argument('--details', '-d', nargs='+', required=True, 47 | help='CSList indicating where and how much data to ' 48 | 'take. Usage: -a offset length channels ' 49 | 'where offset is an int, length is an int, ' 50 | 'and channels is a string of values separated by slashes') 51 | parser.add_argument( 52 | '--apply', action='append', required=True, 53 | help='Filter type + arguments, --apply-filter is specified 1 or more times for each filter. Usage options:\n' 54 | '--apply highpass=750 (highpass filter @ 750 hz)\n' 55 | '--apply lowpass=8 (lowpass filter @ 8 hz)\n' 56 | '--apply bandpass=low,high (bandpass values for between the low and high arguments)' 57 | '--apply downsample=200 (downsample to 200 hz)' 58 | ) 59 | 60 | return vars(parser.parse_args()) 61 | 62 | 63 | def highpass(data: np.ndarray, hz: int, fs: int): 64 | data_highpass = np.vstack([ 65 | ntk.butter_highpass(channel_data, highpass=hz, fs=fs, order=3) 66 | for channel_data in data 67 | ]) 68 | print( 69 | f'Running highpass filter with parameters: hz={hz}, fs={fs}, input shape: {data.shape}, output shape: {data_highpass.shape}\n') 70 | return data_highpass 71 | 72 | 73 | def lowpass(data: np.ndarray, hz: int, fs: int): 74 | data_lowpass = np.vstack([ 75 | ntk.butter_lowpass(channel_data, lowpass=hz, fs=fs, order=3) 76 | for channel_data in data 77 | ]) 78 | print( 79 | f'Running lowpass filter with parameters: hz={hz}, fs={fs}, input shape: {data.shape}, output shape: {data_lowpass.shape}\n') 80 | return data_lowpass 81 | 82 | 83 | def bandpass(data: np.ndarray, hz_high: int, hz_low: int, fs: int): 84 | data_bandpass = np.vstack([ 85 | ntk.butter_bandpass(channel_data, highpass=hz_high, lowpass=hz_low, fs=fs, order=3) 86 | for channel_data in data 87 | ]) 88 | print( 89 | f'Running bandpass filter with parameters: hz_low={hz_low}, hz_high={hz_high} fs={fs}, input shape: {data.shape}, output shape: {data_bandpass.shape}\n') 90 | return data_bandpass 91 | 92 | 93 | def main(uuid: str, experiment: Union[str, int], outputLocation: str, details: List[str], apply: List[str]): 94 | # set local endpoint for faster loading of data 95 | # braingeneers.set_default_endpoint(f'{os.getcwd()}') 96 | spectro = False 97 | fsemg = 1 98 | # TODO: Make more robust - if length is not specified, need to still make it work 99 | # TODO: rip out spectrogram code and make it separate file 100 | # load metadata 101 | offset = int(details[0]) 102 | datalen = int(details[1]) 103 | chans = details[2] 104 | metad = de.load_metadata(uuid) 105 | e = f'experiment{experiment + 1}' if isinstance(experiment, int) else experiment 106 | fs = metad['ephys_experiments'][e]['sample_rate'] 107 | # then, load the data 108 | chans = [int(i) for i in chans.split('-')] 109 | dataset = de.load_data(metad, experiment=experiment, offset=offset, length=datalen, channels=chans) 110 | dataset = np.vstack(dataset) 111 | print(dataset.shape) 112 | # if the data is shorter than 3 min, no point in making spectrogram 113 | if dataset.shape[1] >= 3600000: 114 | spectro = True 115 | # parse out apply list 116 | for item in apply: 117 | filt, arg = item.split('=') 118 | if filt == 'highpass': 119 | filt_dataset = highpass(dataset, int(arg), fs) 120 | filt_dataset = np.vstack(filt_dataset) 121 | elif filt == 'lowpass': 122 | filt_dataset = lowpass(dataset, int(arg), fs) 123 | filt_dataset = np.vstack(filt_dataset) 124 | elif filt == 'bandpass': 125 | # 7/29/22 switched low and high since the arguments should be different 126 | hi_rate, low_rate = arg.split(',') 127 | filt_dataset = bandpass(dataset, int(hi_rate), int(low_rate), fs) 128 | filt_dataset = np.vstack(filt_dataset) 129 | 130 | # here the data should be ready to viz. 131 | # try putting data thru spectrogram 132 | # fig, axs = plt.subplots(nrows=len(chans), ncols=3, figsize=(16,8)) 133 | datafig = plt.figure(figsize=(16, 2 * len(chans) + 2)) 134 | 135 | # create chan x 1 subfigs 136 | subfigs = datafig.subfigures(nrows=len(chans), ncols=1) 137 | if not spectro: 138 | for index in range(len(chans)): 139 | subfigs[index].suptitle(f'Channel {chans[index]}') 140 | 141 | # create 1x2 subplots per subfig 142 | axs = subfigs[index].subplots(nrows=1, ncols=2, subplot_kw={'anchor': 'SW'}, 143 | gridspec_kw={'wspace': 0.15}) 144 | for ax in axs: 145 | ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True, right=False, labelright=False, 146 | top=False, labeltop=False) 147 | 148 | raw_plot = dataset[index] 149 | filt_plot = filt_dataset[index] 150 | realtime = np.arange(np.size(raw_plot)) / fsemg 151 | axs[0].plot(realtime, (raw_plot - np.nanmean(raw_plot)) / np.nanstd(raw_plot)) 152 | axs[0].set_xlim(0, datalen) 153 | # plot filtered data in middle 154 | axs[1].plot(realtime, (filt_plot - np.nanmean(filt_plot)) / np.nanstd(filt_plot)) 155 | axs[1].set_xlim(0, datalen) 156 | 157 | # here, assume that the data is long enough to do spectrogram with 158 | else: 159 | specfig = plt.figure(figsize=(16, 2 * len(chans) + 2)) 160 | spectro_subfigs = specfig.subfigures(nrows=len(chans), ncols=1) 161 | for index in range(len(chans)): 162 | subfigs[index].suptitle(f'Channel {chans[index]}') 163 | spectro_subfigs[index].suptitle(f'Channel {chans[index]}') 164 | # axs now needs 4 columns, but axs2 still is just 1x1 165 | axs = subfigs[index].subplots(nrows=1, ncols=4) 166 | axs2 = spectro_subfigs[index].subplots(nrows=1, ncols=1) 167 | 168 | for ax in axs: 169 | ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True, right=False, labelright=False, 170 | top=False, labeltop=False) 171 | axs2[0].tick_params(bottom=True, labelbottom=True, left=True, labelleft=True, right=False, 172 | labelright=False, 173 | top=False, labeltop=False) 174 | 175 | raw_plot_1 = dataset[index][:20000] # 1 second 176 | raw_plot_10 = dataset[index][:200000] # 10 seconds 177 | filt_plot = filt_dataset[index] 178 | filt_plot_1 = filt_plot[:20000] 179 | filt_plot_10 = filt_plot[:200000] 180 | 181 | realtime_1 = np.arange(np.size(raw_plot_1)) / fsemg 182 | realtime_10 = np.arange(np.size(raw_plot_10)) / fsemg 183 | axs[0].plot(realtime_1, (raw_plot_1 - np.nanmean(raw_plot_1)) / np.nanstd(raw_plot_1)) 184 | axs[0].set_xlim(0, 20000) 185 | # plot filtered data in middle 186 | axs[1].plot(realtime_1, (filt_plot_1 - np.nanmean(filt_plot_1)) / np.nanstd(filt_plot_1)) 187 | axs[1].set_xlim(0, 20000) 188 | 189 | axs[2].plot(realtime_10, (raw_plot_10 - np.nanmean(raw_plot_10)) / np.nanstd(raw_plot_10)) 190 | axs[2].set_xlim(0, 200000) 191 | 192 | axs[3].plot(realtime_10, (filt_plot_10 - np.nanmean(filt_plot_10)) / np.nanstd(filt_plot_10)) 193 | axs[3].set_xlim(0, 200000) 194 | # plot 195 | # for spectrogram 196 | freq, times, spec = ssig.spectrogram(filt_plot, fs, window='hamming', nperseg=1000, noverlap=1000 - 1, 197 | mode='psd') 198 | fmax = 64 199 | fmin = 1 200 | x_mesh, y_mesh = np.meshgrid(times, freq[(freq <= fmax) & (freq >= fmin)]) 201 | axs2[0].pcolormesh(x_mesh, y_mesh, np.log10(spec[(freq <= fmax) & (freq >= fmin)]), cmap='jet', 202 | shading='auto') 203 | 204 | datapoints_filename = f'Raw_and_filtered_data_{uuid}_{experiment}_chan_{chans}.png' 205 | spectro_filename = f'Spectrogram_{uuid}_{experiment}_chan_{chans}.png' 206 | with open(datapoints_filename, 'wb') as dfig: 207 | plt.savefig(dfig, format='png') 208 | with open(spectro_filename, 'wb') as sfig: 209 | plt.savefig(sfig, format='png') 210 | # then, if it's meant to be on s3, awswrangle it up there. 211 | if outputLocation == 's3': 212 | # Check if file exists 213 | try: 214 | with smart_open.open(f's3://braingeneersdev/ephys/{uuid}/derived/{e}_visualization_metadata.json'): 215 | file_exists = True 216 | except OSError: 217 | file_exists = False 218 | # pdb.set_trace() 219 | if not file_exists: 220 | # make metadata 221 | new_meta = { 222 | 'notes': f'Raw data, filtered data, and spectrogram for {uuid} {e} ', 223 | 'hardware': 'Maxwell', 224 | 'channels': chans 225 | # , 226 | # 'drug_condition': 'fill', # TODO: (not immediate) FIND WAY TO PARSE DRUG CONDITION FROM DATA 227 | # 'age_of_culture': 'age', 228 | # 'duration_on_chip': 'len', 229 | # 'cell_type': 'H9', 230 | # 'stimulation':'None', 231 | # 'external_considerations': 'microfluidics' 232 | 233 | } 234 | with smart_open.open(f'{uuid}_viz_metadata.json', 'w') as f: 235 | json.dump(new_meta, f) 236 | # then, make the new bucket and put the metadata in. awswrangler for that. 237 | # pdb.set_trace() 238 | # print(wr.config.s3_endpoint_url) 239 | wr.upload(local_file=f'{os.getcwd()}/{uuid}_viz_metadata.json', 240 | path=f's3://braingeneersdev/ephys/{uuid}/derived/{e}_visualization_metadata.json') 241 | 242 | else: 243 | # if the metadata exists, need to add this metadata onto the existing one 244 | with smart_open.open( 245 | f's3://braingeneersdev/ephys/{uuid}/derived/{e}_visualization_metadata.json') as old_json: 246 | fixed_meta = json.load(old_json) 247 | old_chans = fixed_meta['channels'] 248 | new_chans = sorted(list(set(old_chans + chans))) 249 | fixed_meta['channels'] = new_chans 250 | with smart_open.open(f's3://braingeneersdev/ephys/{uuid}/derived/{e}_visualization_metadata.json', 251 | 'w') as out_metadata: 252 | json.dump(fixed_meta, out_metadata, indent=2) 253 | 254 | wr.upload(local_file=f'{os.getcwd()}/Raw_and_filtered_data_{uuid}_{experiment}_chan_{chans}.png', 255 | path=f's3://braingeneersdev/ephys/{uuid}/derived/Raw_and_filtered_data_{uuid}_{experiment}_chan_{chans}.png') 256 | wr.upload(local_file=f'{os.getcwd()}/Spectrogram_{uuid}_{experiment}_chan_{chans}.png', 257 | path=f's3://braingeneersdev/ephys/{uuid}/derived/Spectrogram_{uuid}_{experiment}_chan_{chans}.png') 258 | 259 | plt.close() 260 | gc.collect() 261 | 262 | 263 | if __name__ == '__main__': 264 | # args = parse_args() 265 | # # pdb.set_trace() 266 | main(**parse_args()) 267 | # main('2022-07-27-e-Chris_BCC_APV', 'experiment3', [6000, 10000, '0,1,2,3'], ['bandpass=500,9000']) 268 | -------------------------------------------------------------------------------- /src/braingeneers/ml/ephys_dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from braingeneers.data import datasets_electrophysiology as de 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | # HERE set numpy seed 8 | # todo: Need to have a good default for experiment_num and sample_size so can be used w/ spectrogram easily 9 | 10 | class EphysDataset(Dataset): 11 | 12 | def __init__(self, batch_uuid, experiment_num, sample_size, attr_csv, align='center', bounds='exception', 13 | length=-1, 14 | channels=None): 15 | """ 16 | :param batch_uuid: str 17 | String indicating which batch to take from 18 | :param experiment_num: int 19 | Number of desired experiment 20 | :param sample_size: int 21 | This value should be passed in every time; this determines the size of the samples the dataloader 22 | should be returning. This is NOT the size of the dataset being loaded in by load_data. 23 | :param attr_csv: str of filepath or name of readable file 24 | CSV file specifying offset, length, and channels. Offset MUST be specified. 25 | :param align: str 26 | Indicates how to use idx: indicating center of data datachunk, left of datachunk, or right of datachunk 27 | :param bounds: str 28 | Indicates how to treat cases where there's OutOfBounds issues (e.g. center idx @ 0 would go below 0). options are 29 | 'exception', 'pad', 'flush'. 30 | """ 31 | self.attr_df = pd.read_csv(attr_csv) 32 | # store as self variables to do 33 | self.UUID = batch_uuid 34 | self.exp_num = experiment_num 35 | self.datalen = length 36 | self.channels = channels 37 | self.align = align 38 | self.bounds = bounds 39 | self.sample_size = sample_size 40 | # self.x are all the channels in the dataset, use torch.from_numpy to make a tensor 41 | # since pytorch is channels first, dataset doesn't need to do .transpose to work with from_numpy 42 | # self.x = torch.from_numpy(dataset) 43 | 44 | def __getitem__(self, idx): 45 | # should return a datachunk of frames, not each channel 46 | # empty array is rows = channels, col = sample size 47 | 48 | # indexing row of csv 49 | self.offset = self.attr_df.iloc[idx][0] 50 | data_length = self.attr_df.iloc[idx][1] 51 | if data_length is not None: 52 | self.datalen = data_length 53 | channels = self.attr_df.iloc[idx][2] 54 | # checking if 'all' option was given. should be true if 'all'. 55 | if any(char.isalpha() for char in channels): 56 | self.channels = [i for i in range(0, 1028)] 57 | # otherwise, use number passed in 58 | else: 59 | self.channels = [int(i) for i in channels.split('/')] 60 | 61 | dataset = de.load_data(de.load_metadata(self.UUID), self.exp_num, self.offset, self.datalen, self.channels) 62 | datachunk = np.empty((dataset.shape[0], self.sample_size)) 63 | # if 'center', idx should point to the CENTER of a set of data, and get the frames from halfway in front and behind. 64 | # 4/4/22 replacing idx with offset. idx now points to a certain row in the csv. 65 | # offset now refers to the point in data where we're interested in sampling from. 66 | if self.align == 'center': 67 | left_bound = int(self.offset - (self.sample_size / 2)) 68 | right_bound = int(self.offset + (self.sample_size / 2)) 69 | # checking bounds handling 70 | if self.bounds == 'exception': 71 | # exception means throw an exception regarding usage 72 | if left_bound < 0 or right_bound > dataset.shape[1]: 73 | raise IndexError('Sample size too large in one direction. Please use valid index/sample size pair.') 74 | else: 75 | datachunk = dataset[:, left_bound:right_bound] 76 | elif self.bounds == 'pad': 77 | # pad means we fill the empty space with zeros 78 | if left_bound < 0 and right_bound > dataset.shape[1]: 79 | # perform both adjustments 80 | left_pad = 0 - left_bound 81 | left_bound = 0 82 | right_pad = right_bound - dataset.shape[1] 83 | # pad both sides 84 | datachunk[:, :left_pad] = 0 85 | datachunk[:, right_pad:] = 0 86 | datachunk[:, left_pad:right_pad] = dataset[:, left_bound:right_pad - left_pad] 87 | elif left_bound < 0: 88 | # find the dimensions of padding space and reset left_bound, pad with 0s and fill rest of datachunk 89 | left_pad = 0 - left_bound 90 | left_bound = 0 91 | datachunk[:, :left_pad] = 0 92 | datachunk[:, left_pad:] = dataset[:, left_bound:right_bound] 93 | elif right_bound > dataset.shape[1]: 94 | # calc overshoot and reset right_bound, fill with datachunk first and pad with 0s 95 | right_bound = dataset.shape[1] 96 | datalimit = right_bound - left_bound 97 | datachunk[:, :datalimit] = dataset[:, left_bound:right_bound] 98 | datachunk[:, datalimit:] = 0 99 | else: 100 | # if none of these apply, then the bounds are valid and should be used. 101 | datachunk = dataset[:, left_bound:right_bound] 102 | elif self.bounds == 'flush': 103 | # if it's flush, then make adjustment to the nearest viable sample. 104 | # if the sample size cannot be accommodated, shrink the bounds arbitrarily. 105 | if left_bound < 0 and right_bound > dataset.shape[1]: 106 | left_bound = 0 107 | right_bound = dataset.shape[1] 108 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 109 | elif left_bound < 0: 110 | left_bound = 0 111 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 112 | elif right_bound > dataset.shape[1]: 113 | right_bound = dataset.shape[1] 114 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 115 | else: 116 | # if none of these apply, then the bounds are valid and should be used. 117 | datachunk = dataset[:, left_bound:right_bound] 118 | # if 'left', idx should be on left side of data, sampling forward 119 | elif self.align == 'left': 120 | # get bounds 121 | left_bound = self.offset 122 | right_bound = self.offset + self.sample_size 123 | if self.bounds == 'exception': 124 | # exception means throw an exception regarding usage 125 | if left_bound < 0 or right_bound > dataset.shape[1]: 126 | raise IndexError('Sample size too large in one direction. Please use valid index/sample size pair.') 127 | else: 128 | datachunk = dataset[:, left_bound:right_bound] 129 | elif self.bounds == 'pad': 130 | # pad means we fill the empty space with zeros 131 | if left_bound < 0 and right_bound > dataset.shape[1]: 132 | # perform both adjustments 133 | left_pad = 0 - left_bound 134 | left_bound = 0 135 | right_pad = right_bound - dataset.shape[1] 136 | # pad both sides 137 | datachunk[:, :left_pad] = 0 138 | datachunk[:, right_pad:] = 0 139 | datachunk[:, left_pad:right_pad] = dataset[:, left_bound:right_pad - left_pad] 140 | elif left_bound < 0: 141 | # find the dimensions of padding space and reset left_bound, pad with 0s and fill rest of datachunk 142 | left_pad = 0 - left_bound 143 | left_bound = 0 144 | datachunk[:, :left_pad] = 0 145 | datachunk[:, left_pad:] = dataset[:, left_bound:right_bound] 146 | elif right_bound > dataset.shape[1]: 147 | # calc overshoot and reset right_bound, fill with datachunk first and pad with 0ss 148 | right_pad = right_bound - dataset.shape[1] 149 | right_bound = dataset.shape[1] 150 | datachunk[:, :right_pad] = dataset[:, left_bound:right_bound] 151 | datachunk[:, right_pad:] = 0 152 | else: 153 | # if none of these apply, then the bounds are valid and should be used. 154 | datachunk = dataset[:, left_bound:right_bound] 155 | elif self.bounds == 'flush': 156 | # if it's flush, then make adjustment to the nearest viable sample. 157 | # if the sample size cannot be accommodated, shrink the bounds arbitrarily. 158 | if left_bound < 0 and right_bound > dataset.shape[1]: 159 | left_bound = 0 160 | right_bound = dataset.shape[1] 161 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 162 | elif left_bound < 0: 163 | left_bound = 0 164 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 165 | elif right_bound > dataset.shape[1]: 166 | right_bound = dataset.shape[1] 167 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 168 | else: 169 | # if none of these apply, then the bounds are valid and should be used. 170 | datachunk = dataset[:, left_bound:right_bound] 171 | #datachunk = dataset[:, idx:idx + self.sample_size] 172 | # if 'right', idx is on the right end of the data, sampling backwards 173 | elif self.align == 'right': 174 | left_bound = self.offset - self.sample_size 175 | right_bound = self.offset 176 | if self.bounds == 'exception': 177 | # exception means throw an exception regarding usage 178 | if left_bound < 0 or right_bound > dataset.shape[1]: 179 | raise IndexError('Sample size too large in one direction. Please use valid index/sample size pair.') 180 | else: 181 | datachunk = dataset[:, left_bound:right_bound] 182 | elif self.bounds == 'pad': 183 | # pad means we fill the empty space with zeros 184 | if left_bound < 0 and right_bound > dataset.shape[1]: 185 | # perform both adjustments 186 | left_pad = 0 - left_bound 187 | left_bound = 0 188 | right_pad = right_bound - dataset.shape[1] 189 | # pad both sides 190 | datachunk[:, :left_pad] = 0 191 | datachunk[:, right_pad:] = 0 192 | datachunk[:, left_pad:right_pad] = dataset[:, left_bound:right_pad - left_pad] 193 | elif left_bound < 0: 194 | # find the dimensions of padding space and reset left_bound, pad with 0s and fill rest of datachunk 195 | left_pad = 0 - left_bound 196 | left_bound = 0 197 | datachunk[:, :left_pad] = 0 198 | datachunk[:, left_pad:] = dataset[:, left_bound:right_bound] 199 | elif right_bound > dataset.shape[1]: 200 | # calc overshoot and reset right_bound, fill with datachunk first and pad with 0ss 201 | right_pad = right_bound - dataset.shape[1] 202 | right_bound = dataset.shape[1] 203 | datachunk[:, :right_pad] = dataset[:, left_bound:right_bound] 204 | datachunk[:, right_pad:] = 0 205 | else: 206 | # if none of these apply, then the bounds are valid and should be used. 207 | datachunk = dataset[:, left_bound:right_bound] 208 | elif self.bounds == 'flush': 209 | # if it's flush, then make adjustment to the nearest viable sample. 210 | # if the sample size cannot be accommodated, shrink the bounds arbitrarily. 211 | if left_bound < 0 and right_bound > dataset.shape[1]: 212 | left_bound = 0 213 | right_bound = dataset.shape[1] 214 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 215 | elif left_bound < 0: 216 | left_bound = 0 217 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 218 | elif right_bound > dataset.shape[1]: 219 | right_bound = dataset.shape[1] 220 | datachunk[:, left_bound:right_bound] = dataset[:, left_bound:right_bound] 221 | else: 222 | # if none of these apply, then the bounds are valid and should be used. 223 | datachunk = dataset[:, left_bound:right_bound] 224 | return datachunk 225 | 226 | def __len__(self): 227 | # change to reflect length of the csv, NOT the datapoints 228 | return len(self.attr_df.index) 229 | -------------------------------------------------------------------------------- /src/braingeneers/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | """ Common utility functions """ 2 | import io 3 | import urllib 4 | import boto3 5 | from botocore.exceptions import ClientError 6 | import os 7 | import braingeneers 8 | import braingeneers.utils.smart_open_braingeneers as smart_open 9 | from typing import Callable, Iterable, Union, List, Tuple, Dict, Any 10 | import inspect 11 | import multiprocessing 12 | import posixpath 13 | import pathlib 14 | 15 | _s3_client = None # S3 client for boto3, lazy initialization performed in _lazy_init_s3_client() 16 | _message_broker = None # Lazy initialization of the message broker 17 | _named_locks = {} # Named locks for checkout and checkin 18 | 19 | 20 | def _lazy_init_s3_client(): 21 | """ 22 | This function lazy inits the s3 client, it can be called repeatedly and will work with multiprocessing. 23 | This function is for internal use and is automatically called by functions in this class that use the boto3 client. 24 | """ 25 | global _s3_client 26 | if _s3_client is None: 27 | _s3_client = boto3.client('s3', endpoint_url=braingeneers.get_default_endpoint()) 28 | return _s3_client 29 | 30 | 31 | def get_basepath() -> str: 32 | """ 33 | Returns a local or S3 path so URLs or local paths can be appropriate prefixed. 34 | """ 35 | if braingeneers.get_default_endpoint().startswith('http'): 36 | return 's3://braingeneers/' 37 | else: 38 | return braingeneers.get_default_endpoint() 39 | 40 | 41 | def path_join(*args) -> str: 42 | """ 43 | Joins the basepath from get_basepath() to a set of paths. Example: 44 | 45 | path_join('2020-01-01-e-test', 'original', experiment_name, data_file) 46 | 47 | That would produce one of the following depending on the user: 48 | /some/local/path/2020-01-01-e-test/original/experiment1/data.bin 49 | s3://braingeneers/ephys/2020-01-01-e-test/original/experiment1/data.bin 50 | """ 51 | return posixpath.join(get_basepath(), *args) 52 | 53 | 54 | def file_exists(filename: str) -> bool: 55 | """ 56 | Simple test for whether a file exists or not, supports local and S3. 57 | This is implemented as a utility function because supporting multiple platforms (s3 and local) is not trivial. 58 | Issue history: 59 | - Using tensorflow for this functionality failed when libcurl rejected an expired cert. 60 | - Using PyFilesystem is a bad choice because it doesn't support streaming 61 | - Using smart_open supports streaming but not basic file operations like size and exists 62 | 63 | :param filename: file path + name, local or S3 64 | :return: boolean exists|not_exists 65 | """ 66 | if filename.startswith('s3://'): 67 | s3_client = _lazy_init_s3_client() 68 | o = urllib.parse.urlparse(filename) 69 | try: 70 | s3_client.head_object(Bucket=o.netloc, Key=o.path[1:]) 71 | exists = True 72 | except ClientError: 73 | exists = False 74 | else: 75 | exists = os.path.isfile(filename) 76 | 77 | return exists 78 | 79 | 80 | def file_size(filename: str) -> int: 81 | """ 82 | Gets file size, supports local and S3 files, same issues as documented in function file_exists 83 | :param filename: file path + name, local or S3 84 | :return: int file size in bytes 85 | """ 86 | if filename.startswith('s3://'): 87 | s3_client = _lazy_init_s3_client() 88 | o = urllib.parse.urlparse(filename) 89 | try: 90 | sz = s3_client.head_object(Bucket=o.netloc, Key=o.path[1:])['ContentLength'] 91 | except ClientError as e: 92 | # noinspection PyProtectedMember 93 | raise Exception(f'S3 ClientError using endpoint {s3_client._endpoint} for file {filename}.') from e 94 | else: 95 | sz = os.path.getsize(filename) 96 | 97 | return sz 98 | 99 | 100 | def file_list(filepath: str) -> List[Tuple[str, str, int]]: 101 | """ 102 | Returns a list of files, last modified time, and size on local or S3 in descending order of last modified time 103 | 104 | :param filepath: Local or S3 file path to list, example: "local/dir/" or "s3://bucket/prefix/" 105 | :return: A list of tuples of [('fileA', 'last_modified_A', size), ('fileB', 'last_modified_B', size), ...] 106 | """ 107 | files_and_details = [] 108 | 109 | if filepath.startswith('s3://'): 110 | s3_client = _lazy_init_s3_client() 111 | o = urllib.parse.urlparse(filepath) 112 | response = s3_client.list_objects(Bucket=o.netloc, Prefix=o.path[1:]) 113 | 114 | if 'Contents' in response: 115 | files_and_details = [ 116 | (f['Key'].split('/')[-1], str(f['LastModified']), int(f['Size'])) 117 | for f in sorted(response['Contents'], key=lambda x: x['LastModified'], reverse=True) 118 | ] 119 | elif os.path.exists(filepath): 120 | files = sorted(pathlib.Path(filepath).iterdir(), key=os.path.getmtime, reverse=True) 121 | files_and_details = [(f.name, str(f.stat().st_mtime), f.stat().st_size) for f in files] 122 | 123 | return files_and_details 124 | 125 | 126 | # Define the wrapper function as a top-level function 127 | def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple, func_kwargs: Dict[str, Any]) -> Any: 128 | """Internal wrapper function for map2 to handle fixed values and dynamic arguments, including kwargs.""" 129 | # Merge fixed_values with provided arguments, aligning provided args with required_params 130 | call_args = {**fixed_values, **dict(zip(required_params, args))} 131 | return func(**call_args, **func_kwargs) 132 | 133 | 134 | def map2(func: Callable, 135 | args: Iterable[Tuple[Any, ...]] = None, 136 | kwargs: Iterable[Dict[str, Any]] = None, 137 | fixed_values: dict = None, 138 | parallelism: Union[bool, int] = True, 139 | use_multithreading: bool = False) -> List[object]: 140 | """ 141 | A universal multiprocessing version of the map function to simplify parallelizing code. 142 | The function provides a simple method for parallelizing operations while making debugging easy. 143 | 144 | Combines functionality of: map, itertools.starmap, and pool.map 145 | 146 | Eliminates the need to understand functools, multiprocessing.Pool, and 147 | argument unpacking operators which should be unnecessary to accomplish a simple 148 | multi-threaded function mapping operation. 149 | 150 | Usage example: 151 | def f(x, y): 152 | print(x, y) 153 | 154 | common_utils.map2( 155 | func=f, 156 | args=[(1, 'yellow'), (2, 'yarn'), (3, 'yack')], # (x, y) arguments 157 | parallelism=3, # use a 3 process multiprocessing pool 158 | ) 159 | 160 | common_utils.map2( 161 | func=f, 162 | args=[1, 2, 3], # x arguments has multiple values to run 163 | fixed_values=dict(y='yellow'), # y always is 'yellow' 164 | parallelism=False, # Runs without parallelism which makes debugging exceptions easier 165 | ) 166 | 167 | Usage example incorporating kwargs: 168 | def myfunc(a, b, **kwargs): 169 | print(a, b, kwargs.get('c')) 170 | 171 | common_utils.map2( 172 | func=myfunc, 173 | args=[(1, 2), (3, 4)], 174 | kwargs=[{'c': 50}, {'c': 100}], 175 | ) 176 | 177 | :param func: a callable function 178 | :param args: a list of arguments (if only 1 argument is left after fixed_values) or a list of tuples 179 | (if multiple arguments are left after fixed_values) 180 | :param kwargs: an iterable of dictionaries where each dictionary represents the keyword arguments to pass 181 | to the function for each call. This parameter allows passing dynamic keyword arguments to the function. 182 | the length of args and kwargs must be equal. 183 | :param fixed_values: a dictionary with parameters that will stay the same for each call to func 184 | :param parallelism: number of processes to use or boolean, default is # of CPU cores. 185 | When parallelism==False or 1, this maps to itertools.starmap and does not use multiprocessing. 186 | If parallelism >1 then multiprocessing.pool.starmap will be used with this number of worker processes. 187 | If parallelism == True then multiprocessing.pool will be used with multiprocessing.cpu_count() processes. 188 | :param use_multithreading: advanced option, use the default (False) in most cases. Parallelizes using 189 | threads instead of multiprocessing. Multiprocessing should be used if more than one CPU core is needed 190 | due to the GIL, threads are lighter weight than processes for some non cpu-intensive tasks. 191 | :return: a list of the return values of func 192 | """ 193 | if args is not None and kwargs is not None: 194 | assert len(args) == len(kwargs), \ 195 | f"args and kwargs must have the same length, found lengths: len(args)={len(args)} and len(kwargs)={len(kwargs)}" 196 | assert isinstance(fixed_values, (dict, type(None))) 197 | assert isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer" 198 | parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism 199 | assert isinstance(parallelism, int), "parallelism must be resolved to an integer" 200 | 201 | fixed_values = fixed_values or {} 202 | func_signature = inspect.signature(func) 203 | required_params = [p.name for p in func_signature.parameters.values() if 204 | p.default == inspect.Parameter.empty and p.name not in fixed_values] 205 | 206 | if not args: 207 | args = [()] * len(kwargs or []) 208 | if not kwargs: 209 | kwargs = [{}] * len(args) 210 | if not all(isinstance(a, tuple) for a in args): 211 | args = [(a,) for a in args] 212 | call_parameters = list(zip(args, kwargs)) 213 | 214 | if parallelism == 1: 215 | result_iterator = map(lambda params: _map2_wrapper(fixed_values, required_params, func, params[0], params[1]), 216 | call_parameters) 217 | else: 218 | ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool 219 | with ProcessOrThreadPool(parallelism) as pool: 220 | result_iterator = pool.starmap( 221 | _map2_wrapper, 222 | [(fixed_values, required_params, func, args, kw) for args, kw in call_parameters] 223 | ) 224 | 225 | return list(result_iterator) 226 | 227 | 228 | class checkout: 229 | """ 230 | A context manager for atomically checking out a file from S3 for reading or writing. 231 | 232 | Example usage: 233 | 234 | # Read-then-update metadata.json (or any text based file on S3) 235 | with checkout('s3://braingeneers/ephys/9999-0-0-e-test/metadata.json', isbinary=False) as locked_obj: 236 | metadata_dict = json.loads(locked_obj.get_value()) 237 | metadata_dict['new_key'] = 'new_value' 238 | metadata_updated_str = json.dumps(metadata_dict, indent=2) 239 | locked_obj.checkin(metadata_updated_str) 240 | 241 | # Read-then-update data.npy (or any binary file on S3) 242 | with checkout('s3://braingeneersdev/test/data.npy', isbinary=True) as locked_obj: 243 | file_obj = locked_obj.get_file() 244 | ndarray = np.load(file_obj) 245 | ndarray[3, 3] = 42 246 | locked_obj.checkin(ndarray.tobytes()) 247 | 248 | # Edit a file in place, note checkin is not needed, the file is updated when the context manager exits 249 | with checkout('s3://braingeneersdev/test/test_file.bin', isbinary=True) as locked_obj: 250 | with zipfile.ZipFile(locked_obj.get_file(), 'a') as z: 251 | z.writestr('new_file.txt', 'new file contents') 252 | 253 | locked_obj functions: 254 | get_value() # returns a string or bytes object (depending on isbinary) 255 | get_file() # returns a file-like object akin to open() 256 | checkin() # updates the file, accepts string, bytes, or file like objects 257 | """ 258 | class LockedObject: 259 | def __init__(self, s3_file_object: io.IOBase, s3_path_str: str, isbinary: bool): 260 | self.s3_path_str = s3_path_str 261 | self.s3_file_object = s3_file_object # underlying file object 262 | self.isbinary = isbinary # binary or text mode 263 | self.modified = False # Track if the file has been modified 264 | 265 | def get_value(self): 266 | # Read file object from outer class s3_file_object 267 | self.s3_file_object.seek(0) 268 | return self.s3_file_object.read() 269 | 270 | def get_file(self): 271 | # Mark file as potentially modified when accessed 272 | self.modified = True 273 | # Return file object from outer class s3_file_object 274 | self.s3_file_object.seek(0) 275 | return self.s3_file_object 276 | 277 | def checkin(self, update_file: Union[str, bytes, io.IOBase]): 278 | # Validate input 279 | if not isinstance(update_file, (str, bytes, io.IOBase)): 280 | raise TypeError('File must be a string, bytes, or file object.') 281 | if isinstance(update_file, str) or isinstance(update_file, io.StringIO): 282 | if self.isbinary: 283 | raise ValueError('Cannot check in a string or text file when checkout is specified for binary mode.') 284 | if isinstance(update_file, bytes) or isinstance(update_file, io.BytesIO): 285 | if not self.isbinary: 286 | raise ValueError('Cannot check in bytes or a binary file when checkout is specified for text mode.') 287 | 288 | if isinstance(update_file, io.IOBase): 289 | update_file.seek(0) 290 | update_str_or_bytes = update_file if not isinstance(update_file, io.IOBase) else update_file.read() 291 | mode = 'w' if not self.isbinary else 'wb' 292 | with smart_open.open(self.s3_path_str, mode=mode) as f: 293 | f.write(update_str_or_bytes) 294 | 295 | def __init__(self, s3_path_str: str, isbinary: bool = False): 296 | # TODO: avoid circular import 297 | from braingeneers.iot.messaging import MessageBroker 298 | 299 | self.s3_path_str = s3_path_str 300 | self.isbinary = isbinary 301 | self.mb = MessageBroker() 302 | self.named_lock = None # message broker lock 303 | self.locked_obj = None # user facing locked object 304 | 305 | def __enter__(self): 306 | lock_str = f'common-utils-checkout-{self.s3_path_str}' 307 | named_lock = self.mb.get_lock(lock_str) 308 | named_lock.acquire() 309 | self.named_lock = named_lock 310 | f = smart_open.open(self.s3_path_str, 'rb' if self.isbinary else 'r') 311 | self.locked_obj = checkout.LockedObject(f, self.s3_path_str, self.isbinary) 312 | return self.locked_obj 313 | 314 | def __exit__(self, exc_type, exc_val, exc_tb): 315 | if self.locked_obj.modified: 316 | # If the file was modified, automatically check in the changes 317 | self.locked_obj.checkin(self.locked_obj.get_file()) 318 | self.named_lock.release() 319 | 320 | 321 | def force_release_checkout(s3_file: str): 322 | """ 323 | Force release the lock on a file that was checked out with checkout. 324 | """ 325 | # TODO: avoid circular import 326 | from braingeneers.iot.messaging import MessageBroker 327 | 328 | global _message_broker 329 | if _message_broker is None: 330 | _message_broker = MessageBroker() 331 | 332 | _message_broker.delete_lock(f'common-utils-checkout-{s3_file}') 333 | 334 | 335 | def pretty_print(data, n=10, indent=0): 336 | """ 337 | Custom pretty print function that uniformly truncates any collection (list or dictionary) 338 | longer than `n` values, showing the first `n` values and a summary of omitted items. 339 | Ensures mapping sections and similar are displayed compactly. 340 | 341 | Example usage (to display metadata.json): 342 | 343 | from braingeneers.utils.common_utils import pretty_print 344 | from braingeneers.data import datasets_electrophysiology as de 345 | 346 | metadata = de.load_metadata('2023-04-17-e-connectoid16235_CCH') 347 | pretty_print(metadata) 348 | 349 | Parameters: 350 | - data: The data to pretty print, either a list or a dictionary. 351 | - n: Maximum number of elements or items to display before truncation. 352 | - indent: Don't use this. Current indentation level for formatting, used during recursion. 353 | """ 354 | indent_space = ' ' * indent 355 | if isinstance(data, dict): 356 | keys = list(data.keys()) 357 | if len(keys) > n: 358 | truncated_keys = keys[:n] 359 | omitted_keys = len(keys) - n 360 | else: 361 | truncated_keys = keys 362 | omitted_keys = None 363 | 364 | print('{') 365 | for key in truncated_keys: 366 | value = data[key] 367 | print(f"{indent_space} '{key}': ", end='') 368 | if isinstance(value, dict): 369 | pretty_print(value, n, indent + 4) 370 | print() 371 | elif isinstance(value, list) and all(isinstance(x, (list, tuple)) and len(x) == 4 for x in value): 372 | # Compact display for lists of tuples/lists of length 4. 373 | print('[', end='') 374 | if len(value) > n: 375 | for item in value[:n]: 376 | print(f"{item}, ", end='') 377 | print(f"... (+{len(value) - n} more items)", end='') 378 | else: 379 | print(', '.join(map(str, value)), end='') 380 | print('],') 381 | else: 382 | print(f"{value},") 383 | if omitted_keys: 384 | print(f"{indent_space} ... (+{omitted_keys} more items)") 385 | print(f"{indent_space}}}", end='') 386 | elif isinstance(data, list): 387 | print('[') 388 | for item in data[:n]: 389 | pretty_print(item, n, indent + 4) 390 | print(',') 391 | if len(data) > n: 392 | print(f"{indent_space} ... (+{len(data) - n} more items)") 393 | print(f"{indent_space}]", end='') 394 | -------------------------------------------------------------------------------- /src/braingeneers/analysis/analysis.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import io 3 | import posixpath 4 | import zipfile 5 | from dataclasses import dataclass 6 | from logging import getLogger 7 | from typing import List, Tuple 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from deprecated import deprecated 12 | from scipy import signal 13 | from spikedata import SpikeData 14 | 15 | import braingeneers.utils.smart_open_braingeneers as smart_open 16 | from braingeneers.utils import s3wrangler 17 | from braingeneers.utils.common_utils import get_basepath 18 | 19 | __all__ = [ 20 | "read_phy_files", 21 | "filter", 22 | "NeuronAttributes", 23 | "load_spike_data", 24 | ] 25 | 26 | logger = getLogger("braingeneers.analysis") 27 | 28 | 29 | @dataclass 30 | class NeuronAttributes: 31 | cluster_id: int 32 | channel: np.ndarray 33 | position: Tuple[float, float] 34 | amplitudes: List[float] 35 | template: np.ndarray 36 | templates: np.ndarray 37 | label: str 38 | 39 | # These lists are the same length and correspond to each other 40 | neighbor_channels: np.ndarray 41 | neighbor_positions: List[Tuple[float, float]] 42 | neighbor_templates: List[np.ndarray] 43 | 44 | def __init__(self, *args, **kwargs): 45 | self.cluster_id = kwargs.pop("cluster_id") 46 | self.channel = kwargs.pop("channel") 47 | self.position = kwargs.pop("position") 48 | self.amplitudes = kwargs.pop("amplitudes") 49 | self.template = kwargs.pop("template") 50 | self.templates = kwargs.pop("templates") 51 | self.label = kwargs.pop("label") 52 | self.neighbor_channels = kwargs.pop("neighbor_channels") 53 | self.neighbor_positions = kwargs.pop("neighbor_positions") 54 | self.neighbor_templates = kwargs.pop("neighbor_templates") 55 | for key, value in kwargs.items(): 56 | setattr(self, key, value) 57 | 58 | def add_attribute(self, key, value): 59 | setattr(self, key, value) 60 | 61 | def list_attributes(self): 62 | return [ 63 | attr 64 | for attr in dir(self) 65 | if not attr.startswith("__") and not callable(getattr(self, attr)) 66 | ] 67 | 68 | 69 | def list_sorted_files(uuid, basepath=None): 70 | """ 71 | Lists files in a directory. 72 | 73 | :param path: the path to the directory. 74 | :param pattern: the pattern to match. 75 | :return: a list of files. 76 | """ 77 | if basepath is None: 78 | basepath = get_basepath() 79 | if "s3://" in basepath: 80 | return s3wrangler.list_objects( 81 | basepath + "ephys/" + uuid + "/derived/kilosort2/" 82 | ) 83 | else: 84 | # return glob.glob(os.path.join(basepath, f'ephys/{uuid}/derived/kilosort2/*')) 85 | return glob.glob(basepath + f"ephys/{uuid}/derived/kilosort2/*") 86 | 87 | 88 | def load_spike_data( 89 | uuid, 90 | experiment=None, 91 | basepath=None, 92 | full_path=None, 93 | fs=20000.0, 94 | groups_to_load=["good", "mua", "", np.nan, "unsorted"], 95 | sorter="kilosort2", 96 | ): 97 | """ 98 | Loads spike data from a dataset. 99 | 100 | :param uuid: the UUID for a specific dataset. 101 | :param experiment: an optional string to specify a particular experiment in the dataset. 102 | :param basepath: an optional string to specify a basepath for the dataset. 103 | :return: SpikeData class with a list of spike time lists and a list of NeuronAttributes. 104 | """ 105 | if basepath is None: 106 | basepath = get_basepath() 107 | 108 | if experiment is None: 109 | experiment = "" 110 | prefix = f"ephys/{uuid}/derived/{sorter}/{experiment}" 111 | logger.info("prefix: %s", prefix) 112 | path = posixpath.join(basepath, prefix) 113 | 114 | if full_path is not None: 115 | experiment = full_path.split("/")[-1].split(".")[0] 116 | logger.info("Using full path, experiment: %s", experiment) 117 | path = full_path 118 | else: 119 | if path.startswith("s3://"): 120 | logger.info("Using s3 path for experiment: %s", experiment) 121 | # If path is an s3 path, use wrangler 122 | file_list = s3wrangler.list_objects(path) 123 | 124 | zip_files = [file for file in file_list if file.endswith(".zip")] 125 | 126 | if not zip_files: 127 | raise ValueError("No zip files found in specified location.") 128 | elif len(zip_files) > 1: 129 | logger.warning("Multiple zip files found. Using the first one.") 130 | 131 | path = zip_files[0] 132 | 133 | else: 134 | logger.info("Using local path for experiment: %s", experiment) 135 | # If path is a local path, check locally 136 | file_list = glob.glob(path + "*.zip") 137 | 138 | zip_files = [file for file in file_list if file.endswith(".zip")] 139 | 140 | if not zip_files: 141 | raise ValueError("No zip files found in specified location.") 142 | elif len(zip_files) > 1: 143 | logger.warning("Multiple zip files found. Using the first one.") 144 | 145 | path = zip_files[0] 146 | 147 | with smart_open.open(path, "rb") as f0: 148 | f = io.BytesIO(f0.read()) 149 | logger.debug("Opening zip file...") 150 | with zipfile.ZipFile(f, "r") as f_zip: 151 | assert "params.py" in f_zip.namelist(), "Wrong spike sorting output." 152 | logger.debug("Reading params.py...") 153 | with io.TextIOWrapper(f_zip.open("params.py"), encoding="utf-8") as params: 154 | for line in params: 155 | if "sample_rate" in line: 156 | fs = float(line.split()[-1]) 157 | logger.debug("Reading spike data...") 158 | clusters = np.load(f_zip.open("spike_clusters.npy")).squeeze() 159 | templates_w = np.load(f_zip.open("templates.npy")) 160 | wmi = np.load(f_zip.open("whitening_mat_inv.npy")) 161 | channels = np.load(f_zip.open("channel_map.npy")).squeeze() 162 | spike_templates = np.load(f_zip.open("spike_templates.npy")).squeeze() 163 | spike_times = np.load(f_zip.open("spike_times.npy")).squeeze() / fs * 1e3 164 | positions = np.load(f_zip.open("channel_positions.npy")) 165 | amplitudes = np.load(f_zip.open("amplitudes.npy")).squeeze() 166 | 167 | # Load cluster info from the first detected of several possible filenames. 168 | tsv_names = {"cluster_info.tsv", "cluster_group.tsv", "cluster_KSLabel.tsv"} 169 | for tsv in tsv_names & set(f_zip.namelist()): 170 | cluster_info = pd.read_csv(f_zip.open(tsv), sep="\t") 171 | cluster_id = cluster_info.cluster_id.values 172 | # Sometimes this file has the column "KSLabel" instead of "group". 173 | if "KSLabel" in cluster_info: 174 | cluster_info.rename(columns=dict(KSLabel="group"), inplace=True) 175 | labeled_clusters = cluster_id[cluster_info.group.isin(groups_to_load)] 176 | # Delete labeled clusters that were not assigned to any spike. 177 | labeled_clusters = np.intersect1d(labeled_clusters, clusters) 178 | break 179 | 180 | # If no file is detected, print a warning, but continue with filler labels. 181 | else: 182 | logger.warning( 183 | "No cluster assignment TSV file found. Generating blank labels." 184 | ) 185 | labeled_clusters = np.unique(clusters) 186 | cluster_info = pd.DataFrame( 187 | { 188 | "cluster_id": labeled_clusters, 189 | "group": [""] * len(labeled_clusters), 190 | } 191 | ) 192 | 193 | assert len(labeled_clusters) > 0, "No clusters found." 194 | logger.debug("Reorganizing data...") 195 | df = pd.DataFrame( 196 | {"clusters": clusters, "spikeTimes": spike_times, "amplitudes": amplitudes} 197 | ) 198 | cluster_agg = df.groupby("clusters").agg( 199 | {"spikeTimes": lambda x: list(x), "amplitudes": lambda x: list(x)} 200 | ) 201 | cluster_agg = cluster_agg[cluster_agg.index.isin(labeled_clusters)] 202 | cls_temp = dict(zip(clusters, spike_templates)) 203 | 204 | logger.debug("Creating neuron attributes...") 205 | neuron_attributes = [] 206 | 207 | # un-whiten the templates before finding the best channel 208 | templates = np.dot(templates_w, wmi) 209 | 210 | for i in range(len(labeled_clusters)): 211 | c = labeled_clusters[i] 212 | temp = templates[cls_temp[c]].T 213 | amp = np.max(temp, axis=1) - np.min(temp, axis=1) 214 | sorted_idx = [ind for _, ind in sorted(zip(amp, np.arange(len(amp))))] 215 | nbgh_chan_idx = sorted_idx[::-1][:12] 216 | nbgh_temps = temp[nbgh_chan_idx] 217 | nbgh_channels = channels[nbgh_chan_idx] 218 | nbgh_postions = [tuple(positions[idx]) for idx in nbgh_chan_idx] 219 | neuron_attributes.append( 220 | NeuronAttributes( 221 | cluster_id=c, 222 | channel=nbgh_channels[0], 223 | position=nbgh_postions[0], 224 | amplitudes=cluster_agg["amplitudes"][c], 225 | template=nbgh_temps[0], 226 | templates=templates[cls_temp[c]].T, 227 | label=cluster_info["group"][cluster_info["cluster_id"] == c].values[0], 228 | neighbor_channels=nbgh_channels, 229 | neighbor_positions=nbgh_postions, 230 | neighbor_templates=nbgh_temps, 231 | ) 232 | ) 233 | 234 | logger.debug("Creating spike data...") 235 | 236 | metadata = {"experiment": experiment} 237 | spike_data = SpikeData( 238 | cluster_agg["spikeTimes"].to_list(), 239 | neuron_attributes=neuron_attributes, 240 | metadata=metadata, 241 | ) 242 | 243 | logger.debug("Done.") 244 | return spike_data 245 | 246 | 247 | @deprecated("Prefer load_spike_data()", version="0.1.13") 248 | def read_phy_files(path: str, fs=20000.0): 249 | """ 250 | :param path: a s3 or local path to a zip of phy files. 251 | :return: SpikeData class with a list of spike time lists and neuron_data. 252 | neuron_data = {0: neuron_dict, 1: config_dict} 253 | neuron_dict = {"new_cluster_id": {"channel": c, "position": (x, y), 254 | "amplitudes": [a0, a1, an], "template": [t0, t1, tn], 255 | "neighbor_channels": [c0, c1, cn], 256 | "neighbor_positions": [(x0, y0), (x1, y1), (xn,yn)], 257 | "neighbor_templates": [[t00, t01, t0n], [tn0, tn1, tnn]}} 258 | config_dict = {chn: pos} 259 | """ 260 | assert path[-3:] == "zip", "Only zip files supported!" 261 | import braingeneers.utils.smart_open_braingeneers as smart_open 262 | 263 | with smart_open.open(path, "rb") as f0: 264 | f = io.BytesIO(f0.read()) 265 | 266 | with zipfile.ZipFile(f, "r") as f_zip: 267 | assert "params.py" in f_zip.namelist(), "Wrong spike sorting output." 268 | with io.TextIOWrapper(f_zip.open("params.py"), encoding="utf-8") as params: 269 | for line in params: 270 | if "sample_rate" in line: 271 | fs = float(line.split()[-1]) 272 | clusters = np.load(f_zip.open("spike_clusters.npy")).squeeze() 273 | templates = np.load( 274 | f_zip.open("templates.npy") 275 | ) # (cluster_id, samples, channel_id) 276 | channels = np.load(f_zip.open("channel_map.npy")).squeeze() 277 | templates_w = np.load(f_zip.open("templates.npy")) 278 | wmi = np.load(f_zip.open("whitening_mat_inv.npy")) 279 | spike_templates = np.load(f_zip.open("spike_templates.npy")).squeeze() 280 | spike_times = ( 281 | np.load(f_zip.open("spike_times.npy")).squeeze() / fs * 1e3 282 | ) # in ms 283 | positions = np.load(f_zip.open("channel_positions.npy")) 284 | amplitudes = np.load(f_zip.open("amplitudes.npy")).squeeze() 285 | 286 | if "cluster_KSLabel.tsv" in f_zip.namelist(): 287 | cluster_info = pd.read_csv(f_zip.open("cluster_KSLabel.tsv"), sep="\t") 288 | cluster_id = np.array(cluster_info["cluster_id"]) 289 | labeled_clusters = cluster_id[ 290 | cluster_info["group"].isin(groups_to_load) 291 | ] 292 | 293 | elif "cluster_info.tsv" in f_zip.namelist(): 294 | cluster_info = pd.read_csv(f_zip.open("cluster_info.tsv"), sep="\t") 295 | cluster_id = np.array(cluster_info["cluster_id"]) 296 | # select clusters using curation label, remove units labeled as "noise" 297 | # find the best channel by amplitude 298 | labeled_clusters = cluster_id[cluster_info["group"] != "noise"] 299 | else: 300 | labeled_clusters = np.unique(clusters) 301 | 302 | df = pd.DataFrame( 303 | {"clusters": clusters, "spikeTimes": spike_times, "amplitudes": amplitudes} 304 | ) 305 | cluster_agg = df.groupby("clusters").agg( 306 | {"spikeTimes": lambda x: list(x), "amplitudes": lambda x: list(x)} 307 | ) 308 | cluster_agg = cluster_agg[cluster_agg.index.isin(labeled_clusters)] 309 | 310 | cls_temp = dict(zip(clusters, spike_templates)) 311 | neuron_dict = dict.fromkeys(np.arange(len(labeled_clusters)), None) 312 | 313 | # un-whitten the templates before finding the best channel 314 | templates = np.dot(templates_w, wmi) 315 | 316 | neuron_attributes = [] 317 | for i in range(len(labeled_clusters)): 318 | c = labeled_clusters[i] 319 | temp = templates[cls_temp[c]] 320 | amp = np.max(temp, axis=0) - np.min(temp, axis=0) 321 | sorted_idx = [ind for _, ind in sorted(zip(amp, np.arange(len(amp))))] 322 | nbgh_chan_idx = sorted_idx[::-1][:12] 323 | nbgh_temps = temp.transpose()[nbgh_chan_idx] 324 | best_chan_temp = nbgh_temps[0] 325 | nbgh_channels = channels[nbgh_chan_idx] 326 | nbgh_postions = [tuple(positions[idx]) for idx in nbgh_chan_idx] 327 | best_channel = nbgh_channels[0] 328 | best_position = nbgh_postions[0] 329 | # neighbor_templates = dict(zip(nbgh_postions, nbgh_temps)) 330 | cls_amp = cluster_agg["amplitudes"][c] 331 | neuron_dict[i] = { 332 | "cluster_id": c, 333 | "channel": best_channel, 334 | "position": best_position, 335 | "amplitudes": cls_amp, 336 | "template": best_chan_temp, 337 | "neighbor_channels": nbgh_channels, 338 | "neighbor_positions": nbgh_postions, 339 | "neighbor_templates": nbgh_temps, 340 | } 341 | neuron_attributes.append( 342 | NeuronAttributes( 343 | cluster_id=c, 344 | channel=best_channel, 345 | position=best_position, 346 | amplitudes=cluster_agg["amplitudes"][c], 347 | template=best_chan_temp, 348 | templates=templates[cls_temp[c]].T, 349 | label=cluster_info["group"][cluster_info["cluster_id"] == c].values[0], 350 | neighbor_channels=channels[nbgh_chan_idx], 351 | neighbor_positions=[tuple(positions[idx]) for idx in nbgh_chan_idx], 352 | neighbor_templates=[templates[cls_temp[c]].T[n] for n in nbgh_chan_idx], 353 | ) 354 | ) 355 | 356 | config_dict = dict(zip(channels, positions)) 357 | neuron_data = {0: neuron_dict} 358 | metadata = {0: config_dict} 359 | spikedata = SpikeData( 360 | list(cluster_agg["spikeTimes"]), 361 | neuron_data=neuron_data, 362 | metadata=metadata, 363 | neuron_attributes=neuron_attributes, 364 | ) 365 | return spikedata 366 | 367 | 368 | @deprecated("Prefer analysis.butter_filter()", version="0.1.14") 369 | def filter( 370 | raw_data, 371 | fs_Hz=20000, 372 | filter_order=3, 373 | filter_lo_Hz=300, 374 | filter_hi_Hz=6000, 375 | time_step_size_s=10, 376 | channel_step_size=100, 377 | verbose=0, 378 | zi=None, 379 | return_zi=False, 380 | ): 381 | """ 382 | Filter the raw data using a bandpass filter. 383 | 384 | :param raw_data: [channels, time] array of raw ephys data 385 | :param fs_Hz: sampling frequency of raw data in Hz 386 | :param filter_order: order of the filter 387 | :param filter_lo_Hz: low frequency cutoff in Hz 388 | :param filter_hi_Hz: high frequency cutoff in Hz 389 | :param filter_step_size_s: size of chunks to filter in seconds 390 | :param channel_step_size: number of channels to filter at once 391 | :param verbose: verbosity level 392 | :param zi: initial conditions for the filter 393 | :param return_zi: whether to return the final filter conditions 394 | 395 | :return: filtered data 396 | """ 397 | 398 | time_step_size = int(time_step_size_s * fs_Hz) 399 | data = np.zeros_like(raw_data) 400 | 401 | # Get filter params 402 | b, a = signal.butter( 403 | fs=fs_Hz, btype="bandpass", N=filter_order, Wn=[filter_lo_Hz, filter_hi_Hz] 404 | ) 405 | 406 | if zi is None: 407 | # Filter initial state 408 | zi = signal.lfilter_zi(b, a) 409 | zi = np.vstack( 410 | [zi * np.mean(raw_data[ch, :5]) for ch in range(raw_data.shape[0])] 411 | ) 412 | 413 | # Step through the data in chunks and filter it 414 | for ch_start in range(0, raw_data.shape[0], channel_step_size): 415 | ch_end = min(ch_start + channel_step_size, raw_data.shape[0]) 416 | 417 | logger.debug(f"Filtering channels {ch_start} to {ch_end}") 418 | 419 | for t_start in range(0, raw_data.shape[1], time_step_size): 420 | t_end = min(t_start + time_step_size, raw_data.shape[1]) 421 | 422 | ( 423 | data[ch_start:ch_end, t_start:t_end], 424 | zi[ch_start:ch_end, :], 425 | ) = signal.lfilter( 426 | b, 427 | a, 428 | raw_data[ch_start:ch_end, t_start:t_end], 429 | axis=1, 430 | zi=zi[ch_start:ch_end, :], 431 | ) 432 | 433 | return data if not return_zi else (data, zi) 434 | -------------------------------------------------------------------------------- /src/braingeneers/data/datasets_neuron.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import re 5 | import subprocess 6 | import sys 7 | import warnings 8 | from pathlib import Path 9 | 10 | import ipywidgets as widgets 11 | import musclebeachtools as mbt 12 | import numpy as np 13 | from .utils import s3wrangler as wr 14 | from .utils.numpy_s3_memmap import NumpyS3Memmap 15 | from IPython.display import display, clear_output 16 | 17 | 18 | warnings.filterwarnings("ignore") 19 | logging.disable(sys.maxsize) 20 | 21 | 22 | 23 | s3 = 'aws --endpoint https://s3.nautilus.optiputer.net s3' 24 | res_path = 'mgk_results/' 25 | s3_path = 'braingeneers/ephys/' 26 | 27 | 28 | class NeuralAid: 29 | 30 | 31 | def __init__(self,data_path='data/'): 32 | self.data_path = data_path 33 | self.exp = None 34 | self.select_exp = None 35 | self.ratings_dict = get_ratings_dict() 36 | 37 | return 38 | 39 | 40 | def set_data_path(self,new_path): 41 | 42 | self.data_path = str(Path(new_path).resolve()) + '/' 43 | print("Data path changed to:",self.data_path) 44 | return 45 | 46 | 47 | def gen_exp_dropdown(self,exps = None,disp=True): 48 | ''' 49 | Generates dropdown list for experiment selection 50 | 51 | Parameters: 52 | ----------- 53 | exps: list 54 | list of experiments returned fro get_exp_list 55 | if None, generates list 56 | disp: bool 57 | if True, displays dropdown immediately 58 | ''' 59 | if exps == None: 60 | exps = get_exp_list() 61 | 62 | self.select_exp = widgets.Dropdown( options=exps, description='Experiment:') 63 | 64 | if disp: 65 | display(self.select_exp) 66 | return 67 | 68 | 69 | def set_exp(self,exp=None): 70 | ''' 71 | Sets exp 72 | If experiment is none, try to pull from dropdown menu 73 | ''' 74 | if exp == None: 75 | self.exp = self.select_exp.value 76 | else: 77 | if self.select_exp != None: 78 | self.select_exp.value = exp 79 | self.exp = exp 80 | 81 | return 82 | 83 | 84 | def load_experiment_b(self,b): 85 | ''' 86 | Button interfacing, set and load experiment 87 | 88 | ''' 89 | self.set_exp() 90 | print('Loading exp') 91 | 92 | self.load_experiment() 93 | 94 | return 95 | 96 | def choose_experiment(self): 97 | self.gen_exp_dropdown() 98 | self.gen_load_experiment_b() 99 | return 100 | 101 | def choose_well(self): 102 | self.gen_well_dropdown() 103 | self.gen_load_well_b() 104 | 105 | 106 | def gen_load_experiment_b(self): 107 | ''' 108 | Generates button for loading experiment 109 | ''' 110 | self.load_exp_btn = widgets.Button(description="Load") 111 | self.load_exp_btn.on_click(self.load_experiment_b) 112 | 113 | display(self.load_exp_btn) 114 | return 115 | 116 | def set_well_dict(self): 117 | ''' 118 | Load well dict from s3 experiment, used for pathing of exp and well to data location 119 | 120 | 121 | Parameters: 122 | ----------- 123 | self.exp: str 124 | uuid of experiment directory followed by '/' 125 | self.data_path: str 126 | location to path where local data is stored. 127 | 128 | Sets: 129 | -------- 130 | self.well_dict: dict of dicts 131 | A dict which indexes based on experiments. The values are also dicts, which index based on well, with the 132 | value being the path of the experiment-well''' 133 | 134 | #Should error out here? 135 | if self.exp == None: 136 | self.set_exp() 137 | 138 | # Sort 139 | loc = f'{self.data_path}{self.exp}root/data/{self.exp}results/*' 140 | chs = sorted(glob.glob(loc)) 141 | 142 | # Remove rasters 143 | chs = [ch for ch in chs if 'rasters' not in ch] 144 | self.well_dict = {} 145 | seen_ch = [] 146 | 147 | for ch in chs: 148 | 149 | well_grp = ch.split('/')[-1] 150 | well, grp = well_grp.split('chgroup') 151 | 152 | # Maps group to full name 153 | if well not in seen_ch: 154 | seen_ch.append(well) 155 | self.well_dict[well] = {} 156 | 157 | self.well_dict[well][int(grp)] = ch 158 | 159 | return 160 | 161 | 162 | def get_ratings_dict(self): 163 | objs = wr.list_objects('s3://braingeneers/ephys/*/dataset/*.npy') 164 | 165 | ratings_dict = {} 166 | 167 | for o in objs: 168 | #This is dirty 169 | key = o.split('/')[4] + '/' 170 | ratings_dict[key] = o 171 | 172 | self.ratings_dict = ratings_dict 173 | 174 | return ratings_dict 175 | 176 | 177 | 178 | def load_experiment(self,exp=None): 179 | '''Load experiment from s3 to local, return well dict. 180 | 181 | Parameters: 182 | ----------- 183 | self.exp: str 184 | name of experiment followed by '/' (ex. test1/) 185 | self.data_path: str 186 | path where data will be downloaded 187 | 188 | Sets: 189 | -------- 190 | self.well_dict: dict 191 | A dict which indexes based on experiments. The values are also dicts, 192 | which index based on well, with the value being the 193 | path of the experiment-well''' 194 | if exp != None: 195 | self.set_exp(exp) 196 | 197 | #Check if experiment already has been downloaded 198 | if os.path.isdir(self.data_path + self.exp): 199 | self.set_well_dict() 200 | 201 | #TODO: Check if this is going to cause an issue 202 | #!{s3} cp s3://{s3_path + exp + 'dataset/'} {data_path}{exp} 203 | print('Selected! Experiment already exists locally:',self.exp) 204 | return 205 | 206 | print('Loading experiment: ' + self.exp[:-1]) 207 | 208 | # Download files from exp 209 | cmd = f'{s3} cp s3://{s3_path + self.exp + res_path} {self.data_path + self.exp}. --recursive --exclude="*" --include="*.zip"' 210 | process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) 211 | output, error = process.communicate() 212 | # !{s3} cp s3://{s3_path + exp + res_path} data/{exp}. --recursive --exclude="*" --include="*.zip" 213 | 214 | #Unzip those files 215 | cmd = f"unzip -qn {self.data_path + self.exp}*.zip -d {self.data_path + self.exp[:-1]}" 216 | process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE,stderr=subprocess.PIPE) 217 | output, error = process.communicate() 218 | print(error) 219 | # !unzip -qn 'data/{exp}*.zip' -d data/{exp[:-1]} 220 | 221 | cmd = f"rm {self.data_path + self.exp}*.zip" 222 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,stderr=subprocess.PIPE,shell=True) 223 | output, error = process.communicate() 224 | # !rm data/{exp}*.zip 225 | print(error) 226 | 227 | 228 | self.set_well_dict() 229 | 230 | print('Finished loading:',exp) 231 | return 232 | 233 | def load_well(self,well=None): 234 | '''Loads corresponding wells neurons and ratings 235 | 236 | Arguments: 237 | well -- location of well (ex. 'A1') 238 | 239 | self.exp -- name of experiment followed by '/' (ex. test1/) 240 | 241 | Global vars: 242 | data_path -- path where data will be downloaded 243 | 244 | ''' 245 | neurons = None 246 | 247 | #Sort by actual number 248 | well_data = {k: v for k,v in sorted(self.well_dict[well].items(),key=lambda x: x[0])} 249 | 250 | #Load and append each group to the data list, accumulating potential spikes 251 | for group in well_data.values(): 252 | nf = glob.glob(group + '/spikeintf/outputs/neurons*') 253 | n_temp = np.load(nf[0],allow_pickle=True) 254 | n_prb = open(glob.glob(group+'/spikeintf/inputs/*probefile.prb')[0]) 255 | mbt.load_spike_amplitudes(n_temp, group+'/spikeintf/outputs/amplitudes0.npy') 256 | 257 | lines = n_prb.readlines() 258 | real_chans = [] 259 | s = lines[5] 260 | n = s.split() 261 | 262 | for chan in range(1,len(n)): 263 | result = re.search('c_(.*)\'', n[chan]) 264 | real_chans.append(int(result.group(1))) 265 | 266 | for i in range(len(n_temp)): 267 | chan = n_temp[i].peak_channel 268 | n_temp[i].peak_channel = real_chans[chan] 269 | 270 | if type(neurons) != np.ndarray: 271 | neurons = n_temp 272 | else: 273 | neurons = np.append(neurons,n_temp) 274 | 275 | n_prb.close() 276 | 277 | print('Well {} selected, {} potential neurons!'.format(well,len(neurons)),end='') 278 | 279 | #Load ratings and assign 280 | ratings_path = f'{self.data_path}{self.exp}dataset/' 281 | ratings_path = glob.glob(ratings_path + well + '*.npy') 282 | 283 | if self.exp in self.ratings_dict: 284 | ratings = load_ratings(self.ratings_dict[self.exp][well]) 285 | # ratings = np.load(ratings_path[0]) 286 | print('{} ratings loaded.'.format(len(ratings))) 287 | 288 | else: 289 | ratings = np.zeros(len(neurons)) 290 | print('Ratings NOT loaded.') 291 | 292 | self.neurons = neurons 293 | self.ratings = ratings 294 | return (neurons,ratings) 295 | 296 | 297 | 298 | def gen_well_dropdown(self): 299 | ''' 300 | Generates dropdown list for experiment selection 301 | 302 | Parameters: 303 | ----------- 304 | exps: list 305 | list of experiments returned fro get_exp_list 306 | if None, generates list 307 | disp: bool 308 | if True, displays dropdown immediately 309 | ''' 310 | self.select_well = widgets.Dropdown( options=self.well_dict.keys(), description='Well:') 311 | 312 | 313 | display(self.select_well) 314 | return 315 | 316 | def gen_load_well_b(self): 317 | ''' 318 | Generates load well button 319 | ''' 320 | self.load_well_btn = widgets.Button(description="Load Well") 321 | self.load_well_btn.on_click(self.load_well_b) 322 | 323 | display(self.load_well_btn) 324 | return 325 | 326 | 327 | def load_well_b(self,b): 328 | '''Loads all channel groups from the specified well selected in the drop down menu''' 329 | 330 | #Load from dropdowns 331 | self.set_well() 332 | 333 | self.load_well(self.well) 334 | 335 | return 336 | 337 | 338 | def set_well(self,well=None): 339 | ''' 340 | Sets well 341 | If well is none, try to pull from dropdown menu 342 | ''' 343 | if well == None: 344 | self.well = self.select_well.value 345 | else: 346 | if self.select_well != None: 347 | self.select_well.value = well 348 | self.well = well 349 | 350 | return 351 | 352 | 353 | def rate_neuron_b(self,b): 354 | ''' 355 | Rates current neuron, goes to next neuron/finishes 356 | ''' 357 | 358 | #Rate neuron of i-1 359 | if type(b.description) != str: 360 | rate = int(b.description) 361 | self.ratings[self.ind_neurons]=rate 362 | 363 | 364 | #Show neuron i 365 | # clear_output() 366 | clear_output(wait=True) 367 | self.ind_neurons = self.ind_neurons + 1 368 | 369 | #Finish if no more neurons 370 | if self.ind_neurons >= len(self.ratings): 371 | print('Finished!') 372 | return 373 | 374 | 375 | 376 | self.neurons[self.ind_neurons].qual_prob = [0,0,0,0] 377 | self.neurons[self.ind_neurons].checkqual(binsz=60) 378 | 379 | print('Neuron:{}/{}'.format(self.ind_neurons+1,self.n_neurons)) 380 | print('Current Rating:',self.ratings[self.ind_neurons]) 381 | display(self.btn_1,self.btn_2,self.btn_3,self.btn_4,self.btn_5) 382 | 383 | return 384 | 385 | def gen_rate_neuron(self): 386 | ''' 387 | Iterates through list with buttons, rating neurons 388 | ''' 389 | 390 | 391 | self.btn_1 = widgets.Button(description="1") 392 | self.btn_2 = widgets.Button(description="2") 393 | self.btn_3 = widgets.Button(description="3") 394 | self.btn_4 = widgets.Button(description="4") 395 | self.btn_5 = widgets.Button(description="Keep Current") 396 | 397 | self.btn_1.on_click(self.rate_neuron_b) 398 | self.btn_2.on_click(self.rate_neuron_b) 399 | self.btn_3.on_click(self.rate_neuron_b) 400 | self.btn_4.on_click(self.rate_neuron_b) 401 | self.btn_5.on_click(self.rate_neuron_b) 402 | 403 | self.n_neurons = len(self.neurons) 404 | self.ind_neurons = 0 405 | print('Starting rating for ratings/num_neurons:{}/{}'.format(len(self.ratings),self.n_neurons)) 406 | 407 | self.neurons[self.ind_neurons].qual_prob = [0,0,0,0] 408 | self.neurons[self.ind_neurons].checkqual(binsz=10) 409 | print('Neuron:{}/{}'.format(self.ind_neurons+1,self.n_neurons)) 410 | print('Current Rating:',self.ratings[self.ind_neurons]) 411 | display(self.btn_1,self.btn_2,self.btn_3,self.btn_4, self.btn_5) 412 | 413 | 414 | def save_ratings(self,s3_upload=True): 415 | '''Saves the well ratings as a csv from the text input field prepended by the well. 416 | Uses selected experiment to save the well back into its corresponding dataset folder 417 | UPLOADS to s3 also''' 418 | 419 | save_dir = f'{self.data_path}{self.exp}dataset/' 420 | 421 | os.makedirs(save_dir, exist_ok=True) 422 | save_path = save_dir + self.well + '_' + self.name_field.value 423 | 424 | if not (os.path.isfile(save_path + '.npy')): 425 | np.save(save_path,np.array(self.ratings)) 426 | print('Saved successfully in:',save_path) 427 | else: 428 | print("File already exists") 429 | 430 | # if b.description == "Save & Upload": 431 | if s3_upload: 432 | s3_save_path = self.exp + "dataset/" + self.well + '_' + self.name_field.value + '.npy' 433 | 434 | cmd = f"{s3} cp {save_path + '.npy'} s3://{s3_path + s3_save_path} " 435 | process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) 436 | output, error = process.communicate() 437 | print('Uploaded Successfully') 438 | # !{s3} cp {save_path + '.npy'} s3://{s3_path + s3_save_path} 439 | 440 | def save_ratings_b(self,b): 441 | 442 | if b.description == "Save & Upload": 443 | self.save_ratings() 444 | else: 445 | self.save_ratings(s3_upload=False) 446 | 447 | 448 | def gen_save_b(self): 449 | ''' 450 | Generate save buttons and widgets for s3 451 | ''' 452 | 453 | self.save_well_btn = widgets.Button(description="Save") 454 | self.save_well_s3_btn = widgets.Button(description="Save & Upload") 455 | self.name_field = widgets.Text( 456 | value='', 457 | placeholder='title here', 458 | description='Filename prepended by well (ex: "A1_YourTextHere")') 459 | 460 | self.save_well_btn.on_click(self.save_ratings_b) 461 | self.save_well_s3_btn.on_click(self.save_ratings_b) 462 | display(self.name_field,self.save_well_btn,self.save_well_s3_btn) 463 | return 464 | 465 | def get_exp_list(): 466 | ''' 467 | Get experiment list from s3 location. 468 | 469 | ''' 470 | cmd = f'{s3} ls s3://{s3_path}' 471 | process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) 472 | output, error = process.communicate() 473 | output = str(output).split('\\n') 474 | exps = [t.split('PRE ')[1] for t in output if len(t.split('PRE ')) >1] 475 | return exps 476 | 477 | def get_ratings_list(): 478 | ''' 479 | Get list of uuids which have datasets 480 | ''' 481 | objs = wr.list_objects('s3://braingeneers/ephys/*/dataset/*.npy') 482 | 483 | datasets = [] 484 | for o in objs: 485 | #This is dirty 486 | datasets.append(o.split('/')[4] + '/') 487 | 488 | return datasets 489 | 490 | 491 | def get_ratings_dict(): 492 | ''' 493 | Returns dict of dict mapping UUID -> Well -> s3 path 494 | 495 | Usage of dict looks like d[uuid]['A1'] 496 | 497 | ''' 498 | objs = wr.list_objects('s3://braingeneers/ephys/*/dataset/*.npy') 499 | 500 | ratings_dict = {} 501 | 502 | for o in objs: 503 | #This is dirty 504 | key = o.split('/')[4] + '/' 505 | well = o.split('/')[6][:2] 506 | if ratings_dict.get(key) == None: 507 | ratings_dict[key] = {well:o} 508 | else: 509 | ratings_dict[key][well] = o 510 | 511 | return ratings_dict 512 | 513 | 514 | def load_ratings(fname): 515 | ''' 516 | Loads array of ratings from s3 filename 517 | 518 | Use get_ratings_dict to generate an easy method to index through the s3 filenames 519 | 520 | ''' 521 | return NumpyS3Memmap(fname)[:] 522 | 523 | 524 | 525 | 526 | def load_all_rated(): 527 | ''' 528 | Loads all neurons(outputted from the sorter) that have been rated 529 | ''' 530 | na = NeuralAid() 531 | #Local storage 532 | na.set_data_path('./data/') 533 | 534 | ratings_dict = get_ratings_dict() 535 | 536 | neurons = [] 537 | ratings = [] 538 | for exp in ratings_dict.keys(): 539 | for well in ratings_dict[exp].keys(): 540 | 541 | na.load_experiment(exp) 542 | n_temp,r_temp = na.load_well(well) 543 | assert len(n_temp) == len(r_temp), "Number of neurons in the data and number of ratings must be the same, failure in {}".format(exp) 544 | neurons = np.append(neurons,n_temp) 545 | ratings = np.append(ratings,r_temp) 546 | 547 | print('Loaded {} neurons and ratings'.format(len(neurons))) 548 | return (neurons,ratings) 549 | --------------------------------------------------------------------------------