├── .github └── workflows │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples ├── baselines │ ├── autolfads │ │ ├── README.md │ │ ├── config │ │ │ ├── area2_bump.yaml │ │ │ ├── area2_bump_20.yaml │ │ │ ├── dmfc_rsg.yaml │ │ │ ├── dmfc_rsg_20.yaml │ │ │ ├── mc_maze.yaml │ │ │ ├── mc_maze_20.yaml │ │ │ ├── mc_maze_large.yaml │ │ │ ├── mc_maze_large_20.yaml │ │ │ ├── mc_maze_medium.yaml │ │ │ ├── mc_maze_medium_20.yaml │ │ │ ├── mc_maze_small.yaml │ │ │ ├── mc_maze_small_20.yaml │ │ │ ├── mc_rtt.yaml │ │ │ └── mc_rtt_20.yaml │ │ ├── lfads_data_prep.py │ │ ├── post_lfads_prep.py │ │ └── run_lfads.py │ ├── gpfa │ │ ├── README.md │ │ ├── gpfa_cv_sweep.py │ │ └── run_gpfa.py │ ├── ndt │ │ └── README.md │ ├── slds │ │ ├── README.md │ │ ├── run_slds.py │ │ └── run_slds_randsearch.py │ └── smoothing │ │ ├── README.md │ │ ├── run_smoothing.py │ │ └── smoothing_cv_sweep.py └── tutorials │ ├── basic_example.ipynb │ ├── gpfa_example.ipynb │ ├── img │ ├── pipeline.png │ └── split.png │ └── slds_example.ipynb ├── nlb_tools ├── __init__.py ├── chop.py ├── evaluation.py ├── make_tensors.py └── nwb_interface.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_evaluate.py ├── test_make_tensors.py ├── test_nlb.py └── test_nwb_interface.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ${{ matrix.os }} ubuntu-latest 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | os: 18 | - windows-2019 19 | - ubuntu-latest 20 | - macos-latest 21 | python-version: 22 | - "3.7" 23 | - "3.8" 24 | - "3.9" 25 | # - "3.10" 26 | # - "3.11" 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v3 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | python -m pip install flake8 pytest 38 | python -m pip install dandi==0.46.3 nwbinspector==0.4.14 39 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 40 | - name: Lint with flake8 41 | run: | 42 | # stop the build if there are Python syntax errors or undefined names 43 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 44 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 45 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 46 | - name: Test with pytest 47 | run: | 48 | pytest 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # test data 132 | tests/temp_data/ 133 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.2.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.8.0 10 | hooks: 11 | - id: black 12 | exclude: ^docs/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Neural Latents 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NLB Codepack (nlb_tools) 2 | Python tools for participating in Neural Latents Benchmark '21. 3 | 4 | ## Overview 5 | Neural Latents Benchmark '21 (NLB'21) is a benchmark suite for unsupervised modeling of neural population activity. 6 | The suite includes four datasets spanning a variety of brain areas and experiments. 7 | The primary task in the benchmark is co-smoothing, or inference of firing rates of unseen neurons in the population. 8 | 9 | This repo contains code to facilitate participation in NLB'21: 10 | * `nlb_tools/` has code to load and preprocess our dataset files, format data for modeling, and locally evaluate results 11 | * `examples/tutorials/` contains tutorial notebooks demonstrating basic usage of `nlb_tools` 12 | * `examples/baselines/` holds the code we used to run our baseline methods. They may serve as helpful references on more extensive usage of `nlb_tools` 13 | 14 | ## Installation 15 | The package can be installed with the following command: 16 | ``` 17 | pip install nlb-tools 18 | ``` 19 | However, to run the tutorial notebooks locally or make any modifications to the code, you should clone the repo. The package can then be installed with the following commands: 20 | ``` 21 | git clone https://github.com/neurallatents/nlb_tools.git 22 | cd nlb_tools 23 | pip install -e . 24 | ``` 25 | This package requires Python 3.7+ and was developed in Python 3.7, which is the Python version we recommend you use. 26 | 27 | ## Getting started 28 | We recommend reading/running through `examples/tutorials/basic_example.ipynb` to learn how to use `nlb_tools` to load and 29 | format data for our benchmark. You can also find Jupyter notebooks demonstrating running GPFA and SLDS for the benchmark in 30 | `examples/tutorials/`. 31 | 32 | ## Other resources 33 | For more information on the benchmark: 34 | * our [main webpage](https://neurallatents.github.io) contains general information on our benchmark pipeline and introduces the datasets 35 | * our [EvalAI challenge](https://eval.ai/web/challenges/challenge-page/1256/overview) is where submissions are evaluated and displayed on the leaderboard 36 | * our datasets are available on DANDI: [MC_Maze](https://dandiarchive.org/#/dandiset/000128), [MC_RTT](https://dandiarchive.org/#/dandiset/000129), [Area2_Bump](https://dandiarchive.org/#/dandiset/000127), [DMFC_RSG](https://dandiarchive.org/#/dandiset/000130), [MC_Maze_Large](https://dandiarchive.org/#/dandiset/000138), [MC_Maze_Medium](https://dandiarchive.org/#/dandiset/000139), [MC_Maze_Small](https://dandiarchive.org/#/dandiset/000140) 37 | * our [paper](http://arxiv.org/abs/2109.04463) describes our motivations behind this benchmarking effort as well as various technical details and explanations of design choices made in preparing NLB'21 38 | * our [Slack workspace](https://neurallatents.slack.com) lets you interact directly with the developers and other participants. Please email `fpei6 [at] gatech [dot] edu` for an invite link 39 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/README.md: -------------------------------------------------------------------------------- 1 | # AutoLFADS 2 | 3 | Latent Factor Analysis via Dynamical Systems (LFADS) is a deep learning method to infer latent dynamics from single-trial neural spiking data. 4 | AutoLFADS utilizes Population Based Training (PBT) to optimize LFADS hyperparameters efficiently. 5 | You can read more about LFADS in [Pandarinath et al. 2018](https://www.nature.com/articles/s41592-018-0109-9) and AutoLFADS in [Keshtkaran et al. 2021](https://www.biorxiv.org/content/10.1101/2021.01.13.426570v1) 6 | 7 | This directory contains files used to run AutoLFADS for NLB'21: 8 | * `lfads_data_prep.py` saves input data in the expected format for LFADS. 9 | * `run_lfads.py` trains AutoLFADS on the training data and performs inference on test data. 10 | * `post_lfads_prep.py` takes LFADS output and reformats it in the expected submission format for NLB'21. 11 | * `config/` contains the run config YAML files used to run AutoLFADS 12 | 13 | ## Dependencies 14 | * [nlb_tools](https://github.com/neurallatents/nlb_tools) 15 | * [autolfads-tf2](https://github.com/snel-repo/autolfads-tf2) 16 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/area2_bump.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 65 4 | CS_DIM: 16 5 | SEQ_LEN: 160 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: area2_bump_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 180 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/area2_bump_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 65 4 | CS_DIM: 16 5 | SEQ_LEN: 40 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: area2_bump_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 180 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/dmfc_rsg.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | DATA_DIM: 60 3 | CS_DIM: 15 4 | SEQ_LEN: 340 5 | FP_LEN: 40 6 | IC_ENC_DIM: 100 7 | CI_ENC_DIM: 64 8 | CON_DIM: 64 9 | CO_DIM: 4 10 | IC_DIM: 100 11 | GEN_DIM: 100 12 | FAC_DIM: 40 13 | DROPOUT_RATE: 0.05 # OVERWRITTEN 14 | CD_RATE: 0.5 # OVERWRITTEN 15 | CD_PASS_RATE: 0.0 16 | CO_PRIOR_TAU: 10.0 17 | CO_PRIOR_NVAR: 0.1 18 | IC_PRIOR_VAR: 0.1 19 | IC_POST_VAR_MIN: 1.0e-4 20 | TRAIN: 21 | DATA: 22 | DIR: ~/data/lfads_input 23 | PREFIX: dmfc_rsg_train 24 | PATIENCE: 5000 # Early stopping should be off 25 | BATCH_SIZE: 200 # 100 for the 230-sample dataset 26 | MAX_EPOCHS: 10000 27 | MAX_GRAD_NORM: 200.0 28 | LOSS_SCALE: 10000.0 29 | LR: 30 | INIT: 5.0e-3 # OVERWRITTEN 31 | STOP: 1.0e-10 32 | DECAY: 1.0 33 | PATIENCE: 0 34 | ADAM_EPSILON: 1.0e-8 35 | L2: 36 | START_EPOCH: 0 37 | INCREASE_EPOCH: 80 38 | IC_ENC_SCALE: 0.0 # UNUSED 39 | CI_ENC_SCALE: 0.0 # UNUSED 40 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 41 | CON_SCALE: 1.0e-4 # OVERWRITTEN 42 | KL: 43 | START_EPOCH: 0 44 | INCREASE_EPOCH: 80 45 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 46 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | PBT_MODE: True 48 | TUNE_MODE: True 49 | LOG_HPS: True 50 | USE_TB: False 51 | MODEL_DIR: '' # OVERWRITTEN 52 | OVERWRITE: True -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/dmfc_rsg_20.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | DATA_DIM: 60 3 | CS_DIM: 15 4 | SEQ_LEN: 85 5 | FP_LEN: 10 6 | IC_ENC_DIM: 100 7 | CI_ENC_DIM: 64 8 | CON_DIM: 64 9 | CO_DIM: 4 10 | IC_DIM: 100 11 | GEN_DIM: 100 12 | FAC_DIM: 40 13 | DROPOUT_RATE: 0.05 # OVERWRITTEN 14 | CD_RATE: 0.5 # OVERWRITTEN 15 | CD_PASS_RATE: 0.0 16 | CO_PRIOR_TAU: 10.0 17 | CO_PRIOR_NVAR: 0.1 18 | IC_PRIOR_VAR: 0.1 19 | IC_POST_VAR_MIN: 1.0e-4 20 | TRAIN: 21 | DATA: 22 | DIR: ~/data/lfads_input 23 | PREFIX: dmfc_rsg_20_train 24 | PATIENCE: 5000 # Early stopping should be off 25 | BATCH_SIZE: 300 # 100 for the 230-sample dataset 26 | MAX_EPOCHS: 10000 27 | MAX_GRAD_NORM: 200.0 28 | LOSS_SCALE: 10000.0 29 | LR: 30 | INIT: 5.0e-3 # OVERWRITTEN 31 | STOP: 1.0e-10 32 | DECAY: 1.0 33 | PATIENCE: 0 34 | ADAM_EPSILON: 1.0e-8 35 | L2: 36 | START_EPOCH: 0 37 | INCREASE_EPOCH: 80 38 | IC_ENC_SCALE: 0.0 # UNUSED 39 | CI_ENC_SCALE: 0.0 # UNUSED 40 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 41 | CON_SCALE: 1.0e-4 # OVERWRITTEN 42 | KL: 43 | START_EPOCH: 0 44 | INCREASE_EPOCH: 80 45 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 46 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | PBT_MODE: True 48 | TUNE_MODE: True 49 | LOG_HPS: True 50 | USE_TB: False 51 | MODEL_DIR: '' # OVERWRITTEN 52 | OVERWRITE: True -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 182 # 146 4 | CS_DIM: 45 5 | SEQ_LEN: 180 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 182 4 | CS_DIM: 45 5 | SEQ_LEN: 45 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_large.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 162 # 180 4 | CS_DIM: 40 5 | SEQ_LEN: 180 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_large_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_large_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 162 # 180 4 | CS_DIM: 40 5 | SEQ_LEN: 45 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_large_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_medium.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 152 # 156 4 | CS_DIM: 38 5 | SEQ_LEN: 180 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_medium_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_medium_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 152 # 156 4 | CS_DIM: 38 5 | SEQ_LEN: 45 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_medium_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_small.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 142 # 130 4 | CS_DIM: 35 5 | SEQ_LEN: 180 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_small_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_maze_small_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 142 # 130 4 | CS_DIM: 35 5 | SEQ_LEN: 45 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_maze_small_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_rtt.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 130 4 | CS_DIM: 32 5 | SEQ_LEN: 160 6 | FP_LEN: 40 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_rtt_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/config/mc_rtt_20.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | MODEL: 3 | DATA_DIM: 130 4 | CS_DIM: 32 5 | SEQ_LEN: 40 6 | FP_LEN: 10 7 | IC_ENC_DIM: 100 8 | CI_ENC_DIM: 80 9 | CON_DIM: 80 10 | CO_DIM: 4 11 | IC_DIM: 100 12 | GEN_DIM: 100 13 | FAC_DIM: 40 14 | DROPOUT_RATE: 0.05 # OVERWRITTEN 15 | CD_RATE: 0.5 # OVERWRITTEN 16 | CD_PASS_RATE: 0.0 17 | CO_PRIOR_TAU: 10.0 18 | CO_PRIOR_NVAR: 0.1 19 | IC_PRIOR_VAR: 0.1 20 | IC_POST_VAR_MIN: 1.0e-4 21 | TRAIN: 22 | DATA: 23 | DIR: ~/data/lfads_input 24 | PREFIX: mc_rtt_20_train 25 | PATIENCE: 5000 # Early stopping should be off 26 | BATCH_SIZE: 100 # 100 for the 230-sample dataset 27 | MAX_EPOCHS: 10000 28 | MAX_GRAD_NORM: 200.0 29 | LOSS_SCALE: 10000.0 30 | LR: 31 | INIT: 0.004 # OVERWRITTEN 32 | STOP: 1.0e-10 33 | DECAY: 1.0 34 | PATIENCE: 0 35 | ADAM_EPSILON: 1.0e-8 36 | L2: 37 | START_EPOCH: 0 38 | INCREASE_EPOCH: 80 39 | IC_ENC_SCALE: 0.0 # UNUSED 40 | CI_ENC_SCALE: 0.0 # UNUSED 41 | GEN_SCALE: 1.0e-4 # OVERWRITTEN 42 | CON_SCALE: 1.0e-4 # OVERWRITTEN 43 | KL: 44 | START_EPOCH: 0 45 | INCREASE_EPOCH: 80 46 | IC_WEIGHT: 1.0e-4 # OVERWRITTEN 47 | CO_WEIGHT: 1.0e-4 # OVERWRITTEN 48 | PBT_MODE: True 49 | TUNE_MODE: True 50 | LOG_HPS: True 51 | USE_TB: False 52 | MODEL_DIR: '' # OVERWRITTEN 53 | OVERWRITE: True 54 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/lfads_data_prep.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ---- # 2 | from nlb_tools.nwb_interface import NWBDataset 3 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors 4 | import numpy as np 5 | import h5py 6 | import sys 7 | 8 | # ---- Run params ---- # 9 | dataset_name = 'mc_rtt' 10 | valid_ratio = 0.2 11 | bin_size_ms = 5 12 | suf = '' if bin_size_ms == 5 else '_20' 13 | 14 | # ---- Data locations ---- # 15 | datapath_dict = { 16 | 'mc_maze': '~/data/000128/sub-Jenkins/', 17 | 'mc_rtt': '~/data/000129/sub-Indy/', 18 | 'area2_bump': '~/data/000127/sub-Han/', 19 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 20 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 21 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 22 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 23 | } 24 | prefix_dict = { 25 | 'mc_maze': '*full', 26 | 'mc_maze_large': '*large', 27 | 'mc_maze_medium': '*medium', 28 | 'mc_maze_small': '*small', 29 | } 30 | datapath = datapath_dict[dataset_name] 31 | prefix = prefix_dict.get(dataset_name, '') 32 | savepath_train = f'~/data/lfads_input/{dataset_name}{suf}_train_lfads.h5' 33 | savepath_test = f'~/data/lfads_input/{dataset_name}{suf}_test_lfads.h5' 34 | 35 | # ---- Load data ---- # 36 | dataset = NWBDataset(datapath, prefix) 37 | dataset.resample(bin_size_ms) 38 | 39 | # ---- Extract train data ---- # 40 | data_dict = make_train_input_tensors(dataset, dataset_name, ['train', 'val'], save_file=False, include_forward_pred=True) 41 | 42 | tlen = data_dict['train_spikes_heldin'].shape[1] 43 | num_heldin = data_dict['train_spikes_heldin'].shape[2] 44 | num_heldout = data_dict['train_spikes_heldout'].shape[2] 45 | fp_steps = data_dict['train_spikes_heldin_forward'].shape[1] 46 | spikes = np.hstack([ 47 | np.dstack([data_dict['train_spikes_heldin'], data_dict['train_spikes_heldout']]), 48 | np.dstack([data_dict['train_spikes_heldin_forward'], data_dict['train_spikes_heldout_forward']]), 49 | ]) 50 | 51 | num_trials = len(spikes) 52 | valid_inds = np.arange(0, num_trials, int(1./valid_ratio)) 53 | train_inds = np.delete(np.arange(num_trials), valid_inds) 54 | 55 | with h5py.File(savepath_train, 'w') as h5file: 56 | h5file.create_dataset('train_inds', data=train_inds) 57 | h5file.create_dataset('valid_inds', data=valid_inds) 58 | h5file.create_dataset('train_data', data=spikes[train_inds]) 59 | h5file.create_dataset('valid_data', data=spikes[valid_inds]) 60 | 61 | # ---- Extract test data ---- # 62 | data_dict = make_eval_input_tensors(dataset, dataset_name, 'test', save_file=False, scramble_trials=('maze' not in dataset_name)) 63 | num_trials = len(data_dict['eval_spikes_heldin']) 64 | spikes = np.hstack([ 65 | np.dstack([data_dict['eval_spikes_heldin'], np.full((num_trials, tlen, num_heldout), 0.0)]), 66 | np.full((num_trials, fp_steps, num_heldin + num_heldout), 0.0), 67 | ]) 68 | valid_inds = np.arange(0, num_trials, int(1./valid_ratio)) 69 | train_inds = np.delete(np.arange(num_trials), valid_inds) 70 | 71 | with h5py.File(savepath_test, 'w') as h5file: 72 | h5file.create_dataset('train_inds', data=train_inds) 73 | h5file.create_dataset('valid_inds', data=valid_inds) 74 | h5file.create_dataset('train_data', data=spikes[train_inds]) 75 | h5file.create_dataset('valid_data', data=spikes[valid_inds]) 76 | 77 | # ---- Print summary ---- # 78 | print(f'heldin: {num_heldin}') 79 | print(f'heldout: {num_heldout}') 80 | print(f'tlen: {tlen}') 81 | print(f'fp_steps: {fp_steps}') 82 | -------------------------------------------------------------------------------- /examples/baselines/autolfads/post_lfads_prep.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ---- # 2 | import numpy as np 3 | import pandas as pd 4 | import h5py 5 | import pickle 6 | from lfads_tf2.utils import load_posterior_averages 7 | from nlb_tools.nwb_interface import NWBDataset 8 | from nlb_tools.make_tensors import make_train_input_tensors, save_to_h5 9 | import sys 10 | 11 | # ---- Run params ---- # 12 | dataset_name = 'mc_rtt' 13 | bin_size = 5 14 | suf = '' if bin_size == 5 else '_20' 15 | 16 | # ---- Data locations ---- # 17 | datapath_dict = { 18 | 'mc_maze': '~/data/000128/sub-Jenkins/', 19 | 'mc_rtt': '~/data/000129/sub-Indy/', 20 | 'area2_bump': '~/data/000127/sub-Han/', 21 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 22 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 23 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 24 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 25 | } 26 | prefix_dict = { 27 | 'mc_maze': '*full', 28 | 'mc_maze_large': '*large', 29 | 'mc_maze_medium': '*medium', 30 | 'mc_maze_small': '*small', 31 | } 32 | datapath = datapath_dict[dataset_name] 33 | prefix = prefix_dict.get(dataset_name, '') 34 | savepath = f'{dataset_name}{suf}_autolfads_submission.h5' 35 | 36 | # ---- Load LFADS output ---- # 37 | model_dir = f'~/autolfads/runs/{dataset_name}{suf}/best_model/' 38 | train_rates, train_factors, *_ = load_posterior_averages(model_dir, merge_tv=True, ps_filename='posterior_samples.h5') 39 | test_rates, test_factors, *_ = load_posterior_averages(model_dir, merge_tv=True, ps_filename='posterior_samples_test.h5') 40 | 41 | # ---- Load data ---- # 42 | dataset = NWBDataset(datapath, prefix) 43 | dataset.resample(bin_size) 44 | 45 | # ---- Find data shapes ---- # 46 | data_dict = make_train_input_tensors(dataset, dataset_name, 'train', return_dict=True, save_file=False) 47 | train_spikes_heldin = data_dict['train_spikes_heldin'] 48 | train_spikes_heldout = data_dict['train_spikes_heldout'] 49 | num_heldin = train_spikes_heldin.shape[2] 50 | tlen = train_spikes_heldin.shape[1] 51 | 52 | # ---- Split LFADS output ---- # 53 | train_rates_heldin = train_rates[:, :tlen, :num_heldin] 54 | train_rates_heldout = train_rates[:, :tlen, num_heldin:] 55 | train_rates_heldin_forward = train_rates[:, tlen:, :num_heldin] 56 | train_rates_heldout_forward = train_rates[:, tlen:, num_heldin:] 57 | eval_rates_heldin = test_rates[:, :tlen, :num_heldin] 58 | eval_rates_heldout = test_rates[:, :tlen, num_heldin:] 59 | eval_rates_heldin_forward = test_rates[:, tlen:, :num_heldin] 60 | eval_rates_heldout_forward = test_rates[:, tlen:, num_heldin:] 61 | 62 | # ---- Save output ---- # 63 | output_dict = { 64 | dataset_name + suf: { 65 | 'train_rates_heldin': train_rates_heldin, 66 | 'train_rates_heldout': train_rates_heldout, 67 | 'eval_rates_heldin': eval_rates_heldin, 68 | 'eval_rates_heldout': eval_rates_heldout, 69 | 'eval_rates_heldin_forward': eval_rates_heldin_forward, 70 | 'eval_rates_heldout_forward': eval_rates_heldout_forward, 71 | } 72 | } 73 | save_to_h5(output_dict, savepath, overwrite=True) -------------------------------------------------------------------------------- /examples/baselines/autolfads/run_lfads.py: -------------------------------------------------------------------------------- 1 | import ray, yaml, shutil 2 | from ray import tune 3 | from os import path 4 | 5 | from tune_tf2.models import create_trainable_class 6 | from tune_tf2.pbt.hps import HyperParam 7 | from tune_tf2.pbt.schedulers import MultiStrategyPBT 8 | from tune_tf2.pbt.trial_executor import SoftPauseExecutor 9 | from lfads_tf2.utils import flatten 10 | from lfads_tf2.tuples import LoadableData 11 | import h5py 12 | import numpy as np 13 | 14 | import sys 15 | import time 16 | from datetime import datetime 17 | 18 | dataset_name = 'mc_maze' 19 | bin_size_ms = 5 20 | binsuf = "" if bin_size_ms == 5 else "_20" 21 | 22 | # ---------- PBT I/O CONFIGURATION ---------- 23 | # the default configuration file for the LFADS model 24 | CFG_PATH = f"./config/{dataset_name}{binsuf}.yaml" 25 | # the directory to save PBT runs (usually '~/ray_results') 26 | PBT_HOME = "~/autolfads/runs/" 27 | # the name of this PBT run (run will be stored at {PBT_HOME}/{PBT_NAME}) 28 | RUN_NAME = f'{dataset_name}' # the name of the PBT run 29 | # the dataset to train the PBT model on 30 | DATA_DIR = '~/data/lfads_input/' 31 | DATA_PREFIX = f'{dataset_name}{binsuf}_train_' 32 | 33 | # ---------- PBT RUN CONFIGURATION ---------- 34 | # whether to use single machine or cluster 35 | SINGLE_MACHINE = True 36 | # the number of workers to use - make sure machine can handle all 37 | NUM_WORKERS = 20 38 | # the resources to allocate per model 39 | RESOURCES_PER_TRIAL = {"cpu": 2, "gpu": 0.5} 40 | # the hyperparameter space to search 41 | HYPERPARAM_SPACE = { 42 | 'TRAIN.LR.INIT': HyperParam(1e-5, 5e-3, explore_wt=0.3, 43 | enforce_limits=True, init=4e-3), 44 | 'MODEL.DROPOUT_RATE': HyperParam(0.0, 0.6, explore_wt=0.3, 45 | enforce_limits=True, sample_fn='uniform'), 46 | 'MODEL.CD_RATE': HyperParam(0.01, 0.7, explore_wt=0.3, 47 | enforce_limits=True, init=0.5, sample_fn='uniform'), 48 | 'TRAIN.L2.GEN_SCALE': HyperParam(1e-4, 1e-0, explore_wt=0.8), 49 | 'TRAIN.L2.CON_SCALE': HyperParam(1e-4, 1e-0, explore_wt=0.8), 50 | 'TRAIN.KL.CO_WEIGHT': HyperParam(1e-6, 1e-4, explore_wt=0.8), 51 | 'TRAIN.KL.IC_WEIGHT': HyperParam(1e-6, 1e-3, explore_wt=0.8), 52 | } 53 | # override if necessary 54 | if dataset_name == 'area2_bump': 55 | HYPERPARAM_SPACE['TRAIN.KL.IC_WEIGHT'] = HyperParam(1e-6, 1e-4, explore_wt=0.8) 56 | elif dataset_name == 'dmfc_rsg': 57 | HYPERPARAM_SPACE['TRAIN.LR.INIT'] = HyperParam(1e-5, 7e-3, explore_wt=0.3, 58 | enforce_limits=True, init=5e-3) 59 | HYPERPARAM_SPACE['MODEL.CD_RATE'] HyperParam(0.01, 0.99, explore_wt=0.3, 60 | enforce_limits=True, init=0.5, sample_fn='uniform') 61 | HYPERPARAM_SPACE['TRAIN.L2.GEN_SCALE'] HyperParam(1e-6, 1e-1, explore_wt=0.8) 62 | HYPERPARAM_SPACE['TRAIN.L2.CON_SCALE'] HyperParam(1e-6, 1e-1, explore_wt=0.8) 63 | HYPERPARAM_SPACE['TRAIN.KL.CO_WEIGHT'] HyperParam(1e-7, 1e-4, explore_wt=0.8) 64 | HYPERPARAM_SPACE['TRAIN.KL.IC_WEIGHT'] HyperParam(1e-7, 1e-3, explore_wt=0.8) 65 | PBT_METRIC='smth_val_nll_heldin' 66 | EPOCHS_PER_GENERATION = 25 67 | # --------------------------------------------- 68 | 69 | # setup the data hyperparameters 70 | dataset_info = { 71 | 'TRAIN.DATA.DIR': DATA_DIR, 72 | 'TRAIN.DATA.PREFIX': DATA_PREFIX} 73 | # setup initialization of search hyperparameters 74 | init_space = {name: tune.sample_from(hp.init) 75 | for name, hp in HYPERPARAM_SPACE.items()} 76 | # load the configuration as a dictionary and update for this run 77 | flat_cfg_dict = flatten(yaml.full_load(open(CFG_PATH))) 78 | flat_cfg_dict.update(dataset_info) 79 | flat_cfg_dict.update(init_space) 80 | # Set the number of epochs per generation 81 | tuneLFADS = create_trainable_class(EPOCHS_PER_GENERATION) 82 | # connect to Ray cluster or start on single machine 83 | address = None if SINGLE_MACHINE else 'localhost:10000' 84 | ray.init(address=address) 85 | # create the PBT scheduler 86 | scheduler = MultiStrategyPBT( 87 | HYPERPARAM_SPACE, 88 | metric=PBT_METRIC) 89 | # Create the trial executor 90 | executor = SoftPauseExecutor(reuse_actors=True) 91 | # Create the command-line display table 92 | reporter = tune.CLIReporter(metric_columns=['epoch', PBT_METRIC]) 93 | try: 94 | # run the tune job, excepting errors 95 | tune.run( 96 | tuneLFADS, 97 | name=RUN_NAME, 98 | local_dir=PBT_HOME, 99 | config=flat_cfg_dict, 100 | resources_per_trial=RESOURCES_PER_TRIAL, 101 | num_samples=NUM_WORKERS, 102 | sync_to_driver='# {source} {target}', # prevents rsync 103 | scheduler=scheduler, 104 | progress_reporter=reporter, 105 | trial_executor=executor, 106 | verbose=1, 107 | reuse_actors=True, 108 | ) 109 | except tune.error.TuneError: 110 | print("tune error!??!?") 111 | pass 112 | 113 | # load the results dataframe for this run 114 | pbt_dir = path.join(PBT_HOME, RUN_NAME) 115 | df = tune.Analysis(pbt_dir).dataframe() 116 | df = df[df.logdir.apply(lambda path: not 'best_model' in path)] 117 | # find the best model 118 | best_model_logdir = df.loc[df[PBT_METRIC].idxmin()].logdir 119 | best_model_src = path.join(best_model_logdir, 'model_dir') 120 | # copy the best model somewhere easy to find 121 | best_model_dest = path.join(pbt_dir, 'best_model') 122 | shutil.copytree(best_model_src, best_model_dest) 123 | # perform posterior sampling 124 | from lfads_tf2.models import LFADS 125 | model = LFADS(model_dir=best_model_dest) 126 | model.sample_and_average() 127 | 128 | loadpath = f'~/data/lfads_input/{dataset_name}{binsuf}_test_lfads.h5' 129 | h5file = h5py.File(loadpath, 'r') 130 | test_data = LoadableData( 131 | train_data=h5file['train_data'][()].astype(np.float32), 132 | valid_data=h5file['valid_data'][()].astype(np.float32), 133 | train_ext_input=None, 134 | valid_ext_input=None, 135 | train_inds=h5file['train_inds'][()].astype(np.float32), 136 | valid_inds=h5file['valid_inds'][()].astype(np.float32), 137 | ) 138 | h5file.close() 139 | 140 | model.sample_and_average(loadable_data=test_data, ps_filename='posterior_samples_test.h5', merge_tv=True) -------------------------------------------------------------------------------- /examples/baselines/gpfa/README.md: -------------------------------------------------------------------------------- 1 | # GPFA 2 | 3 | Gaussian Process Factor Analysis (GPFA) is a classic method of extracting low-dimensional neural trajectories from spiking activity. 4 | You can read more about it in [Yu et al. 2009](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2712272/). 5 | 6 | This directory contains files used to optimize GPFA for NLB'21: 7 | * `gpfa_cv_sweep.py` runs a 3-fold cross-validated grid search over certain parameter values. 8 | * `run_gpfa.py` runs GPFA and generates a submission for NLB'21. The best parameters found by `gpfa_cv_sweep.py` are stored in `default_dict` in the file. 9 | 10 | ## Dependencies 11 | * [nlb_tools](https://github.com/neurallatents/nlb_tools) 12 | * [elephant](https://github.com/NeuralEnsemble/elephant) 13 | * sklearn>=0.23 -------------------------------------------------------------------------------- /examples/baselines/gpfa/gpfa_cv_sweep.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import quantities as pq 4 | import h5py 5 | import neo 6 | import gc 7 | from elephant.gpfa import GPFA 8 | from sklearn.linear_model import PoissonRegressor, Ridge 9 | from itertools import product 10 | 11 | from nlb_tools.nwb_interface import NWBDataset 12 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 13 | from nlb_tools.evaluation import evaluate 14 | 15 | # ---- Default sweep ranges ---- # 16 | latent_dim_dict = { 17 | 'mc_maze': np.linspace(32, 52, 6), 18 | 'mc_maze_large': np.linspace(32, 52, 6), 19 | 'mc_maze_medium': np.linspace(16, 36, 6), 20 | 'mc_maze_small': np.linspace(12, 22, 6), 21 | 'mc_rtt': np.linspace(24, 44, 6), 22 | 'area2_bump': np.linspace(14, 26, 7), 23 | 'dmfc_rsg': np.linspace(20, 40, 6), 24 | } 25 | 26 | # ---- Run Params ---- # 27 | dataset_name = "mc_rtt" # one of {'area2_bump', 'dmfc_rsg', 'mc_maze', 'mc_rtt', 28 | # 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small'} 29 | bin_size_ms = 5 30 | cv_fold = 3 31 | latent_dims = latent_dim_dict[dataset_name] 32 | alpha1s = [0.0, 0.0001, 0.001, 0.01] 33 | alpha2s = [0.0, 0.0001, 0.001, 0.01] 34 | 35 | # ---- Useful variables ---- # 36 | binsuf = '' if bin_size_ms == 5 else f'_{bin_size_ms}' 37 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else (dataset_name + binsuf + "_split") 38 | pref_dict = {'mc_maze_small': '[100] ', 'mc_maze_medium': '[250] ', 'mc_maze_large': '[500] '} 39 | bpskey = pref_dict.get(dataset_name, '') + 'co-bps' 40 | 41 | # ---- Data locations ----# 42 | datapath_dict = { 43 | 'mc_maze': '~/data/000128/sub-Jenkins/', 44 | 'mc_rtt': '~/data/000129/sub-Indy/', 45 | 'area2_bump': '~/data/000127/sub-Han/', 46 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 47 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 48 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 49 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 50 | } 51 | prefix_dict = { 52 | 'mc_maze': '*full', 53 | 'mc_maze_large': '*large', 54 | 'mc_maze_medium': '*medium', 55 | 'mc_maze_small': '*small', 56 | } 57 | datapath = datapath_dict[dataset_name] 58 | prefix = prefix_dict.get(dataset_name, '') 59 | 60 | # ---- Load data ---- # 61 | dataset = NWBDataset(datapath, prefix, skip_fields=['hand_pos', 'cursor_pos', 'force', 'eye_pos', 'muscle_vel', 'muscle_len', 'joint_vel', 'joint_ang']) 62 | dataset.resample(bin_size_ms) 63 | 64 | # ---- Prepare n folds ---- # 65 | all_mask = np.isin(dataset.trial_info.split, ['train', 'val']) 66 | all_idx = np.arange(all_mask.shape[0])[all_mask] 67 | train_masks = [] 68 | eval_masks = [] 69 | for i in range(cv_fold): 70 | eval_idx = all_idx[i::cv_fold] # take every n samples for each fold 71 | train_idx = all_idx[~np.isin(all_idx, eval_idx)] 72 | train_masks.append(np.isin(np.arange(all_mask.shape[0]), train_idx)) 73 | eval_masks.append(np.isin(np.arange(all_mask.shape[0]), eval_idx)) 74 | 75 | # ---- Conversion helper ---- # 76 | def array_to_spiketrains(array, bin_size): 77 | """Convert B x T x N spiking array to list of list of SpikeTrains""" 78 | stList = [] 79 | for trial in range(len(array)): 80 | trialList = [] 81 | for channel in range(array.shape[2]): 82 | times = np.nonzero(array[trial, :, channel])[0] 83 | counts = array[trial, times, channel].astype(int) 84 | times = np.repeat(times, counts) 85 | st = neo.SpikeTrain(times*bin_size*pq.ms, t_stop=array.shape[1]*bin_size*pq.ms) 86 | trialList.append(st) 87 | stList.append(trialList) 88 | return stList 89 | 90 | # ---- Extract data for each fold ---- # 91 | fold_data = [] 92 | for i in range(cv_fold): 93 | train_dict = make_train_input_tensors(dataset, dataset_name, train_masks[i], save_file=False) 94 | eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_masks[i], save_file=False) 95 | 96 | train_spikes_heldin = train_dict['train_spikes_heldin'] 97 | train_spikes_heldout = train_dict['train_spikes_heldout'] 98 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 99 | 100 | train_st_heldin = array_to_spiketrains(train_spikes_heldin, bin_size=bin_size_ms) 101 | eval_st_heldin = array_to_spiketrains(eval_spikes_heldin, bin_size=bin_size_ms) 102 | 103 | target_dict = make_eval_target_tensors(dataset, dataset_name, train_masks[i], eval_masks[i], save_file=False, include_psth=True) 104 | fold_data.append((train_spikes_heldin, train_spikes_heldout, eval_spikes_heldin, train_st_heldin, eval_st_heldin, target_dict)) 105 | del dataset 106 | gc.collect() 107 | 108 | # ---- Define helpers ---- # 109 | flatten2d = lambda x: x.reshape(-1, x.shape[2]) 110 | 111 | def fit_poisson(train_factors_s, eval_factors_s, train_spikes_s, eval_spikes_s=None, alpha=0.0): 112 | """Fit Poisson GLM from factors to spikes and return rate predictions""" 113 | train_in = train_factors_s if eval_spikes_s is None else np.vstack([train_factors_s, eval_factors_s]) 114 | train_out = train_spikes_s if eval_spikes_s is None else np.vstack([train_spikes_s, eval_spikes_s]) 115 | train_pred = [] 116 | eval_pred = [] 117 | for chan in range(train_out.shape[1]): 118 | pr = PoissonRegressor(alpha=alpha, max_iter=500) 119 | pr.fit(train_in, train_out[:, chan]) 120 | while pr.n_iter_ == pr.max_iter and pr.max_iter < 10000: 121 | print(f"didn't converge - retraining {chan} with max_iter={pr.max_iter * 5}") 122 | oldmax = pr.max_iter 123 | del pr 124 | pr = PoissonRegressor(alpha=alpha, max_iter=oldmax * 5) 125 | pr.fit(train_in, train_out[:, chan]) 126 | train_pred.append(pr.predict(train_factors_s)) 127 | eval_pred.append(pr.predict(eval_factors_s)) 128 | train_rates_s = np.vstack(train_pred).T 129 | eval_rates_s = np.vstack(eval_pred).T 130 | return train_rates_s, eval_rates_s 131 | 132 | def fit_rectlin(train_factors_s, eval_factors_s, train_spikes_s, eval_spikes_s=None, alpha=0.0, thresh=1e-10): 133 | """Fit linear regression from factors to spikes, rectify, and return rate predictions""" 134 | train_in = train_factors_s if eval_spikes_s is None else np.vstack([train_factors_s, eval_factors_s]) 135 | train_out = train_spikes_s if eval_spikes_s is None else np.vstack([train_spikes_s, eval_spikes_s]) 136 | ridge = Ridge(alpha=alpha) 137 | ridge.fit(train_in, train_out) 138 | train_rates_s = ridge.predict(train_factors_s) 139 | eval_rates_s = ridge.predict(eval_factors_s) 140 | rect_min = np.min([np.min(train_rates_s[train_rates_s > 0]), np.min(eval_rates_s[eval_rates_s > 0])]) 141 | true_min = np.min([np.min(train_rates_s), np.min(eval_rates_s)]) 142 | train_rates_s[train_rates_s < thresh] = thresh 143 | eval_rates_s[eval_rates_s < thresh] = thresh 144 | return train_rates_s, eval_rates_s 145 | 146 | # ---- Sweep latent dims ---- # 147 | results = [] 148 | for latent_dim in latent_dims: 149 | print(f"Evaluating latent_dim={latent_dim}") 150 | fold_gpfa = [] 151 | # ---- n-fold gpfa ---- # 152 | for n, data in enumerate(fold_data): 153 | _, _, _, train_st_heldin, eval_st_heldin, _ = data 154 | gpfa = GPFA(bin_size=(bin_size_ms * pq.ms), x_dim=int(latent_dim)) 155 | train_factors = gpfa.fit_transform(train_st_heldin) 156 | eval_factors = gpfa.transform(eval_st_heldin) 157 | train_factors = np.stack([train_factors[i].T for i in range(len(train_factors))]) 158 | eval_factors = np.stack([eval_factors[i].T for i in range(len(eval_factors))]) 159 | fold_gpfa.append((train_factors, eval_factors)) 160 | 161 | # ---- Sweep alphas ---- # 162 | for alpha1, alpha2 in product(alpha1s, alpha2s): 163 | print(f"Evaluating alpha1={alpha1}, alpha2={alpha2}") 164 | res_list = [] 165 | for n, (data, gpfa_res) in enumerate(zip(fold_data, fold_gpfa)): 166 | train_spikes_heldin, train_spikes_heldout, eval_spikes_heldin, train_st_heldin, eval_st_heldin, target_dict = data 167 | train_factors, eval_factors = gpfa_res 168 | 169 | train_spikes_heldin_s = flatten2d(train_spikes_heldin) 170 | train_spikes_heldout_s = flatten2d(train_spikes_heldout) 171 | eval_spikes_heldin_s = flatten2d(eval_spikes_heldin) 172 | train_factors_s = flatten2d(train_factors) 173 | eval_factors_s = flatten2d(eval_factors) 174 | 175 | train_rates_heldin_s, eval_rates_heldin_s = fit_rectlin(train_factors_s, eval_factors_s, train_spikes_heldin_s, eval_spikes_heldin_s, alpha=alpha1) 176 | train_rates_heldout_s, eval_rates_heldout_s = fit_poisson(train_rates_heldin_s, eval_rates_heldin_s, train_spikes_heldout_s, alpha=alpha2) 177 | 178 | train_rates_heldin = train_rates_heldin_s.reshape(train_spikes_heldin.shape) 179 | train_rates_heldout = train_rates_heldout_s.reshape(train_spikes_heldout.shape) 180 | eval_rates_heldin = eval_rates_heldin_s.reshape(eval_spikes_heldin.shape) 181 | eval_rates_heldout = eval_rates_heldout_s.reshape((eval_spikes_heldin.shape[0], eval_spikes_heldin.shape[1], train_spikes_heldout.shape[2])) 182 | 183 | submission = { 184 | dataset_name + binsuf: { 185 | 'train_rates_heldin': train_rates_heldin, 186 | 'train_rates_heldout': train_rates_heldout, 187 | 'eval_rates_heldin': eval_rates_heldin, 188 | 'eval_rates_heldout': eval_rates_heldout 189 | } 190 | } 191 | 192 | res = evaluate(target_dict, submission)[0][dskey] 193 | res_list.append(res) 194 | print(f" Fold {n}: " + str(res)) 195 | res = pd.DataFrame(res_list).mean().to_dict() 196 | print(" Mean: " + str(res)) 197 | res['latent_dim'] = latent_dim 198 | res['alpha1'] = alpha1 199 | res['alpha2'] = alpha2 200 | results.append(res) 201 | 202 | # ---- Save results ---- # 203 | results = pd.DataFrame(results) 204 | results.to_csv(f'{dataset_name}{binsuf}_gpfa_cv_sweep.csv') 205 | 206 | # ---- Find best parameters ---- # 207 | best_combo = results[bpskey].argmax() 208 | best_latent_dim = results.iloc[best_combo].latent_dim 209 | best_alpha1 = results.iloc[best_combo].alpha1 210 | best_alpha2 = results.iloc[best_combo].alpha2 211 | print(f'Best params: latent_dim={best_latent_dim}, alpha1={alpha1}, alpha2={alpha2}') 212 | -------------------------------------------------------------------------------- /examples/baselines/gpfa/run_gpfa.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ---- # 2 | import numpy as np 3 | import pandas as pd 4 | import h5py 5 | import neo 6 | import quantities as pq 7 | from elephant.gpfa import GPFA 8 | from sklearn.linear_model import LinearRegression, PoissonRegressor, Ridge 9 | 10 | from nlb_tools.nwb_interface import NWBDataset 11 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 12 | from nlb_tools.evaluation import evaluate 13 | 14 | # ---- Default params ---- # 15 | default_dict = { # [latent_dim, alpha1, alpha2] 16 | 'mc_maze': [52, 0.01, 0.0], 17 | 'mc_rtt': [36, 0.0, 0.0], 18 | 'area2_bump': [22, 0.0001, 0.0], 19 | 'dmfc_rsg': [32, 0.0001, 0.0], 20 | 'mc_maze_large': [44, 0.01, 0.0], 21 | 'mc_maze_medium': [28, 0.0, 0.0], 22 | 'mc_maze_small': [18, 0.01, 0.0], 23 | } 24 | 25 | # ---- Run Params ---- # 26 | dataset_name = "area2_bump" # one of {'area2_bump', 'dmfc_rsg', 'mc_maze', 'mc_rtt', 27 | # 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small'} 28 | bin_size_ms = 5 29 | # replace defaults with other values if desired 30 | latent_dim = default_dict[dataset_name][0] 31 | alpha1 = default_dict[dataset_name][1] 32 | alpha2 = default_dict[dataset_name][2] 33 | phase = 'test' # one of {'test', 'val'} 34 | 35 | # ---- Useful variables ---- # 36 | binsuf = '' if bin_size_ms == 5 else f'_{bin_size_ms}' 37 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else (dataset_name + binsuf + "_split") 38 | pref_dict = {'mc_maze_small': '[100] ', 'mc_maze_medium': '[250] ', 'mc_maze_large': '[500] '} 39 | bpskey = pref_dict.get(dataset_name, '') + 'co-bps' 40 | 41 | # ---- Data locations ---- # 42 | datapath_dict = { 43 | 'mc_maze': '~/data/000128/sub-Jenkins/', 44 | 'mc_rtt': '~/data/000129/sub-Indy/', 45 | 'area2_bump': '~/data/000127/sub-Han/', 46 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 47 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 48 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 49 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 50 | } 51 | prefix_dict = { 52 | 'mc_maze': '*full', 53 | 'mc_maze_large': '*large', 54 | 'mc_maze_medium': '*medium', 55 | 'mc_maze_small': '*small', 56 | } 57 | datapath = datapath_dict[dataset_name] 58 | prefix = prefix_dict.get(dataset_name, '') 59 | savepath = f'{dataset_name}{"" if bin_size_ms == 5 else f"_{bin_size_ms}"}_smoothing_output_{phase}.h5' 60 | 61 | # ---- Load data ---- # 62 | dataset = NWBDataset(datapath, prefix, 63 | skip_fields=['hand_pos', 'cursor_pos', 'eye_pos', 'muscle_vel', 'muscle_len', 'joint_vel', 'joint_ang', 'force']) 64 | dataset.resample(bin_size_ms) 65 | 66 | # ---- Extract data ---- # 67 | if phase == 'val': 68 | train_split = 'train' 69 | eval_split = 'val' 70 | else: 71 | train_split = ['train', 'val'] 72 | eval_split = 'test' 73 | train_dict = make_train_input_tensors(dataset, dataset_name, train_split, save_file=False) 74 | train_spikes_heldin = train_dict['train_spikes_heldin'] 75 | train_spikes_heldout = train_dict['train_spikes_heldout'] 76 | eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_split, save_file=False) 77 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 78 | 79 | # ---- Convert to neo.SpikeTrains ---- # 80 | def array_to_spiketrains(array, bin_size): 81 | """Convert B x T x N spiking array to list of list of SpikeTrains""" 82 | stList = [] 83 | for trial in range(len(array)): 84 | trialList = [] 85 | for channel in range(array.shape[2]): 86 | times = np.nonzero(array[trial, :, channel])[0] 87 | counts = array[trial, times, channel].astype(int) 88 | times = np.repeat(times, counts) 89 | st = neo.SpikeTrain(times*bin_size*pq.ms, t_stop=array.shape[1]*bin_size*pq.ms) 90 | trialList.append(st) 91 | stList.append(trialList) 92 | return stList 93 | train_st_heldin = array_to_spiketrains(train_spikes_heldin, bin_size_ms) 94 | eval_st_heldin = array_to_spiketrains(eval_spikes_heldin, bin_size_ms) 95 | 96 | # ---- Run GPFA ---- # 97 | gpfa = GPFA(bin_size=(bin_size_ms * pq.ms), x_dim=latent_dim) 98 | train_factors = gpfa.fit_transform(train_st_heldin) 99 | eval_factors = gpfa.transform(eval_st_heldin) 100 | 101 | # ---- Reshape factors ---- # 102 | train_factors_s = np.vstack([train_factors[i].T for i in range(len(train_factors))]) 103 | eval_factors_s = np.vstack([eval_factors[i].T for i in range(len(eval_factors))]) 104 | 105 | # ---- Useful variables ---- # 106 | hi_chan = train_spikes_heldin.shape[2] 107 | ho_chan = train_spikes_heldout.shape[2] 108 | tlength = train_spikes_heldin.shape[1] 109 | num_train = len(train_st_heldin) 110 | num_eval = len(eval_st_heldin) 111 | 112 | # ---- Prepare data for regression ---- # 113 | train_spikes_heldin_s = train_spikes_heldin.reshape(-1, train_spikes_heldin.shape[2]) 114 | train_spikes_heldout_s = train_spikes_heldout.reshape(-1, train_spikes_heldout.shape[2]) 115 | eval_spikes_heldin_s = eval_spikes_heldin.reshape(-1, eval_spikes_heldin.shape[2]) 116 | 117 | # ---- Define helpers ---- # 118 | flatten2d = lambda x: x.reshape(-1, x.shape[2]) 119 | 120 | def fit_poisson(train_factors_s, eval_factors_s, train_spikes_s, eval_spikes_s=None, alpha=0.0): 121 | """Fit Poisson GLM from factors to spikes and return rate predictions""" 122 | train_in = train_factors_s if eval_spikes_s is None else np.vstack([train_factors_s, eval_factors_s]) 123 | train_out = train_spikes_s if eval_spikes_s is None else np.vstack([train_spikes_s, eval_spikes_s]) 124 | train_pred = [] 125 | eval_pred = [] 126 | for chan in range(train_out.shape[1]): 127 | pr = PoissonRegressor(alpha=alpha, max_iter=500) 128 | pr.fit(train_in, train_out[:, chan]) 129 | while pr.n_iter_ == pr.max_iter and pr.max_iter < 10000: 130 | print(f"didn't converge - retraining {chan} with max_iter={pr.max_iter * 5}") 131 | oldmax = pr.max_iter 132 | del pr 133 | pr = PoissonRegressor(alpha=alpha, max_iter=oldmax * 5) 134 | pr.fit(train_in, train_out[:, chan]) 135 | train_pred.append(pr.predict(train_factors_s)) 136 | eval_pred.append(pr.predict(eval_factors_s)) 137 | train_rates_s = np.vstack(train_pred).T 138 | eval_rates_s = np.vstack(eval_pred).T 139 | return train_rates_s, eval_rates_s 140 | 141 | def fit_rectlin(train_factors_s, eval_factors_s, train_spikes_s, eval_spikes_s=None, alpha=0.0, thresh=1e-10): 142 | """Fit linear regression from factors to spikes, rectify, and return rate predictions""" 143 | train_in = train_factors_s if eval_spikes_s is None else np.vstack([train_factors_s, eval_factors_s]) 144 | train_out = train_spikes_s if eval_spikes_s is None else np.vstack([train_spikes_s, eval_spikes_s]) 145 | ridge = Ridge(alpha=alpha) 146 | ridge.fit(train_in, train_out) 147 | train_rates_s = ridge.predict(train_factors_s) 148 | eval_rates_s = ridge.predict(eval_factors_s) 149 | rect_min = np.min([np.min(train_rates_s[train_rates_s > 0]), np.min(eval_rates_s[eval_rates_s > 0])]) 150 | true_min = np.min([np.min(train_rates_s), np.min(eval_rates_s)]) 151 | train_rates_s[train_rates_s < thresh] = thresh 152 | eval_rates_s[eval_rates_s < thresh] = thresh 153 | return train_rates_s, eval_rates_s 154 | 155 | # ---- Rate prediction ---- # 156 | train_rates_heldin_s, eval_rates_heldin_s = fit_rectlin(train_factors_s, eval_factors_s, train_spikes_heldin_s, eval_spikes_heldin_s, alpha=alpha1) 157 | train_rates_heldout_s, eval_rates_heldout_s = fit_poisson(train_rates_heldin_s, eval_rates_heldin_s, train_spikes_heldout_s, alpha=alpha2) 158 | 159 | train_rates_heldin = train_rates_heldin_s.reshape(num_train, tlength, hi_chan) 160 | train_rates_heldout = train_rates_heldout_s.reshape(num_train, tlength, ho_chan) 161 | eval_rates_heldin = eval_rates_heldin_s.reshape(num_eval, tlength, hi_chan) 162 | eval_rates_heldout = eval_rates_heldout_s.reshape(num_eval, tlength, ho_chan) 163 | 164 | # ---- Save output ---- # 165 | output_dict = { 166 | dataset_name + binsuf: { 167 | 'train_rates_heldin': train_rates_heldin, 168 | 'train_rates_heldout': train_rates_heldout, 169 | 'eval_rates_heldin': eval_rates_heldin, 170 | 'eval_rates_heldout': eval_rates_heldout, 171 | } 172 | } 173 | save_to_h5(output_dict, savepath, overwrite=True) 174 | 175 | if phase == 'val': 176 | target_dict = make_eval_target_tensors(dataset, dataset_name, train_split, eval_split, save_file=False, include_psth=True) 177 | print(evaluate(target_dict, output_dict)) 178 | -------------------------------------------------------------------------------- /examples/baselines/ndt/README.md: -------------------------------------------------------------------------------- 1 | # NDT 2 | 3 | The Neural Data Transformer (NDT) uses an attention mechanism to model neural population activity without recurrence, enabling much faster inference than RNN-based models. 4 | You can read more about it in [Ye et al. 2021](https://www.biorxiv.org/content/10.1101/2021.01.16.426955v2). 5 | 6 | The code for running NDT for NLB'21 can be found in the [neural-data-transformers repo](https://github.com/snel-repo/neural-data-transformers). Config files for each dataset can be found in `configs/` and the random searches can be run with `python ray_random.py -e `, as mentioned in the repo's README. 7 | 8 | ## Dependencies 9 | * [nlb_tools](https://github.com/neurallatents/nlb_tools) 10 | * [neural-data-transformers](https://github.com/snel-repo/neural-data-transformers) -------------------------------------------------------------------------------- /examples/baselines/slds/README.md: -------------------------------------------------------------------------------- 1 | # SLDS 2 | 3 | SLDS is a method that infers latent states that evolve according to multiple distinct linear dynamical systems that switch with each other over time, allowing for the approximation of complex non-linear dynamics. 4 | You can read more about it in [Linderman et al. 2016](https://arxiv.org/abs/1610.08466). 5 | 6 | This directory contains files used to optimize SLDS for NLB'21: 7 | * `run_slds_randsearch.py` runs a random search over certain parameter values using a portion of the training data. 8 | * `run_slds.py` runs SLDS and generates a submission for NLB'21. The parameters in `default_dict` in the file were found by a combination of the random search and manual tuning. 9 | 10 | ## Dependencies 11 | * [nlb_tools](https://github.com/neurallatents/nlb_tools) 12 | * [ssm](https://github.com/felixp8/ssm) -------------------------------------------------------------------------------- /examples/baselines/slds/run_slds.py: -------------------------------------------------------------------------------- 1 | 2 | from ssm.lds import SLDS 3 | import numpy as np 4 | import h5py 5 | from sklearn.linear_model import PoissonRegressor 6 | from datetime import datetime 7 | import gc 8 | 9 | from nlb_tools.nwb_interface import NWBDataset 10 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 11 | from nlb_tools.evaluation import evaluate 12 | 13 | # ---- Default params ---- # 14 | default_dict = { # [states, factors, dynamics_kwargs] 15 | 'mc_maze': [6, 38, {'l2_penalty_A': 30000.0, 'l2_penalty_b': 6.264734351046042e-05}], 16 | 'mc_rtt': [8, 20, {'l2_penalty_A': 5088.303423769022, 'l2_penalty_b': 2.0595034155496943e-07}], 17 | 'area2_bump': [4, 15, {'l2_penalty_A': 10582.770724811768, 'l2_penalty_b': 3.982037833098992e-05}], 18 | 'dmfc_rsg': [10, 30, {'l2_penalty_A': 30000.0, 'l2_penalty_b': 1e-05}], 19 | 'mc_maze_large': [4, 28, {'l2_penalty_A': 5462.032425984561, 'l2_penalty_b': 2.1670446099229413e-05}], 20 | 'mc_maze_medium': [3, 20, {'l2_penalty_A': 2391.229442269956, 'l2_penalty_b': 1.258022745020434e-05}], 21 | 'mc_maze_small': [5, 15, {'l2_penalty_A': 5837.898552701826, 'l2_penalty_b': 1.2150060110686535e-08}], 22 | } 23 | 24 | # ---- Run Params ---- # 25 | dataset_name = "area2_bump" # one of {'area2_bump', 'dmfc_rsg', 'mc_maze', 'mc_rtt', 26 | # 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small'} 27 | bin_size_ms = 5 28 | # replace defaults with other values if desired 29 | # defaults are not optimal for 20 ms resolution 30 | states = default_dict[dataset_name][0] 31 | factors = default_dict[dataset_name][1] 32 | dynamics_kwargs = default_dict[dataset_name][2] 33 | alpha = 0.2 34 | num_iters = 50 35 | phase = 'test' # one of {'test', 'val'} 36 | 37 | # ---- Useful variables ---- # 38 | binsuf = '' if bin_size_ms == 5 else f'_{bin_size_ms}' 39 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else (dataset_name + binsuf + "_split") 40 | pref_dict = {'mc_maze_small': '[100] ', 'mc_maze_medium': '[250] ', 'mc_maze_large': '[500] '} 41 | bpskey = pref_dict.get(dataset_name, '') + 'co-bps' 42 | 43 | # ---- Data locations ---- # 44 | datapath_dict = { 45 | 'mc_maze': '~/data/000128/sub-Jenkins/', 46 | 'mc_rtt': '~/data/000129/sub-Indy/', 47 | 'area2_bump': '~/data/000127/sub-Han/', 48 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 49 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 50 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 51 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 52 | } 53 | prefix_dict = { 54 | 'mc_maze': '*full', 55 | 'mc_maze_large': '*large', 56 | 'mc_maze_medium': '*medium', 57 | 'mc_maze_small': '*small', 58 | } 59 | datapath = datapath_dict[dataset_name] 60 | prefix = prefix_dict.get(dataset_name, '') 61 | 62 | # ---- Load data ---- # 63 | dataset = NWBDataset(datapath, prefix, 64 | skip_fields=['hand_pos', 'cursor_pos', 'eye_pos', 'muscle_vel', 'muscle_len', 'joint_vel', 'joint_ang', 'force']) 65 | dataset.resample(bin_size_ms) 66 | 67 | # ---- Extract data ---- # 68 | if phase == 'val': 69 | train_split = 'train' 70 | eval_split = 'val' 71 | else: 72 | train_split = ['train', 'val'] 73 | eval_split = 'test' 74 | train_dict = make_train_input_tensors(dataset, dataset_name, train_split, save_file=False, include_forward_pred=True) 75 | train_spikes_heldin = train_dict['train_spikes_heldin'] 76 | train_spikes_heldout = train_dict['train_spikes_heldout'] 77 | eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_split, save_file=False) 78 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 79 | 80 | train_spikes_heldin = train_dict['train_spikes_heldin'] 81 | train_spikes_heldout = train_dict['train_spikes_heldout'] 82 | train_spikes_heldin_fp = train_dict['train_spikes_heldin_forward'] 83 | train_spikes_heldout_fp = train_dict['train_spikes_heldout_forward'] 84 | 85 | train_spikes = np.concatenate([ 86 | np.concatenate([train_spikes_heldin, train_spikes_heldin_fp], axis=1), 87 | np.concatenate([train_spikes_heldout, train_spikes_heldout_fp], axis=1), 88 | ], axis=2) 89 | 90 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 91 | eval_spikes = np.full((eval_spikes_heldin.shape[0], train_spikes_heldin.shape[1] + train_spikes_heldin_fp.shape[1], train_spikes.shape[2]), 0.0) 92 | masks = np.full((eval_spikes_heldin.shape[0], train_spikes_heldin.shape[1] + train_spikes_heldin_fp.shape[1], train_spikes.shape[2]), False) 93 | eval_spikes[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = eval_spikes_heldin 94 | masks[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = True 95 | 96 | numheldin = train_spikes_heldin.shape[2] 97 | tlen = train_spikes_heldin.shape[1] 98 | 99 | # ---- Prepare run ---- # 100 | T = train_spikes.shape[1] 101 | K = states 102 | D = factors 103 | N = train_spikes.shape[2] 104 | transitions = "standard" 105 | emissions = "poisson" 106 | 107 | train_datas = [train_spikes[i, :, :].astype(int) for i in range(len(train_spikes))] 108 | eval_datas = [eval_spikes[i, :, :].astype(int) for i in range(len(eval_spikes))] 109 | train_masks = [np.full(masks[0, :, :].shape, True) for _ in range(len(train_datas))] 110 | eval_masks = [masks[i, :, :] for i in range(len(masks))] 111 | 112 | numtrain = len(train_datas) 113 | numeval = len(eval_datas) 114 | 115 | # ---- Run SLDS ---- # 116 | slds = SLDS(N, K, D, 117 | transitions=transitions, 118 | emissions=emissions, 119 | emission_kwargs=dict(link="softplus"), 120 | dynamics_kwargs=dynamics_kwargs, 121 | ) 122 | 123 | train_datas, train_inputs, train_masks, train_tags = slds.prep_inputs(datas=train_datas) 124 | eval_datas, eval_inputs, eval_masks, eval_tags = slds.prep_inputs(datas=eval_datas, masks=eval_masks) 125 | gc.collect() 126 | 127 | q_elbos_lem_train, q_lem_train, *_ = slds.fit( 128 | datas=train_datas, 129 | inputs=train_inputs, 130 | masks=train_masks, 131 | tags=train_tags, 132 | method="laplace_em", 133 | variational_posterior="structured_meanfield", 134 | initialize=True, 135 | num_init_iters=50, num_iters=num_iters, alpha=alpha 136 | ) 137 | 138 | q_elbos_lem_eval, q_lem_eval, *_ = slds.approximate_posterior( 139 | datas=eval_datas, 140 | inputs=eval_inputs, 141 | masks=eval_masks, 142 | tags=eval_tags, 143 | method="laplace_em", 144 | variational_posterior="structured_meanfield", 145 | num_iters=num_iters, alpha=alpha, 146 | ) 147 | 148 | train_rates = slds.smooth_3d(q_lem_train.mean_continuous_states, train_datas, train_inputs, train_masks, train_tags).cpu().numpy() 149 | eval_rates = slds.smooth_3d(q_lem_eval.mean_continuous_states, eval_datas, eval_inputs, eval_masks, eval_tags).cpu().numpy() 150 | 151 | # ---- Format output ---- # 152 | train_rates_heldin = train_rates[:, :tlen, :num_heldin] 153 | train_rates_heldout = train_rates[:, :tlen, num_heldin:] 154 | eval_rates_heldin = eval_rates[:, :tlen, :numheldin] 155 | eval_rates_heldout = eval_rates[:, :tlen, numheldin:] 156 | eval_rates_heldin_forward = eval_rates[:, tlen:, :numheldin] 157 | eval_rates_heldout_forward = eval_rates[:, tlen:, numheldin:] 158 | 159 | # ---- Save output ---- # 160 | output_dict = { 161 | dataset_name + binsuf: { 162 | 'train_rates_heldin': train_rates_heldin, 163 | 'train_rates_heldout': train_rates_heldout, 164 | 'eval_rates_heldin': eval_rates_heldin, 165 | 'eval_rates_heldout': eval_rates_heldout, 166 | 'eval_rates_heldin_forward': eval_rates_heldin_forward, 167 | 'eval_rates_heldout_forward': eval_rates_heldout_forward, 168 | } 169 | } 170 | save_to_h5(output_dict, f'slds_output_{dataset_name}{binsuf}.h5') 171 | 172 | # ---- Evaluate ---- # 173 | if phase == 'val': 174 | target_dict = make_eval_target_tensors(dataset, dataset_name, train_split, eval_split, save_file=False, include_psth=True) 175 | print(evaluate(target_dict, output_dict)) 176 | -------------------------------------------------------------------------------- /examples/baselines/slds/run_slds_randsearch.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ---- # 2 | from ssm.lds import SLDS 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import h5py 7 | import os 8 | from sklearn.linear_model import PoissonRegressor 9 | from itertools import product 10 | from datetime import datetime 11 | import time 12 | import json 13 | import traceback 14 | import gc 15 | import pickle 16 | 17 | from nlb_tools.nwb_interface import NWBDataset 18 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 19 | from nlb_tools.evaluation import evaluate 20 | 21 | np.random.seed(1234) 22 | 23 | # ---- Default ranges ---- # 24 | default_dict = { # [K range, D range, l2_A range, l2_b range] 25 | 'mc_maze': [('int', 3, 7), ('int', 30, 45), ('log', 3.25, 4.25), ('log', -8, -4)], 26 | 'mc_rtt': [('int', 4, 10), ('int', 20, 32), ('log', 3.25, 4.5), ('log', -8, -4)], 27 | 'area2_bump': [('int', 3, 7), ('int', 12, 24), ('log', 3.25, 4.25), ('log', -8, -4)], 28 | 'dmfc_rsg': [('int', 5, 10), ('int', 20, 35), ('log', 3.75, 5.0), ('log', -8, -4)], 29 | 'mc_maze_large': [('int', 3, 7), ('int', 25, 40), ('log', 3.25, 4.25), ('log', -8, -4)], 30 | 'mc_maze_medium': [('int', 3, 7), ('int', 20, 35), ('log', 3.25, 4.25), ('log', -8, -4)], 31 | 'mc_maze_small': [('int', 3, 7), ('int', 15, 25), ('log', 3.25, 4.25), ('log', -8, -4)], 32 | } 33 | 34 | # ---- Constants ---- # 35 | dynamics = "gaussian" 36 | transitions = "standard" 37 | emissions = "poisson" 38 | emission_kwargs = dict(link="softplus") 39 | 40 | # ---- Run Params ---- # 41 | dataset_name = 'mc_rtt' 42 | bin_size_ms = 5 43 | 44 | n_runs = 20 45 | 46 | train_subset_size = 100 47 | eval_subset_size = 100 48 | num_init_iters = 50 49 | num_train_iters = 50 50 | num_repeats = 2 51 | 52 | init_param_values = { 53 | 'K': default_dict[dataset_name][0], # num states 54 | 'D': default_dict[dataset_name][1], # num factors 55 | 'dynamics_kwargs.l2_penalty_A': default_dict[dataset_name][2], 56 | 'dynamics_kwargs.l2_penalty_b': default_dict[dataset_name][3], 57 | } 58 | 59 | fit_param_values = { 60 | 'alpha': ('float', 0.2, 0.2), 61 | } 62 | 63 | # ---- Full sweep matrix ---- # 64 | def unpack_nested(param_dict): 65 | new_dict = param_dict.copy() 66 | for k, v in param_dict.items(): 67 | if isinstance(v, dict): 68 | for nk, nv in v.items(): 69 | new_dict[k + '.' + nk] = nv 70 | new_dict.pop(k) 71 | return new_dict 72 | 73 | def dict_sample(param_dict, num_samples): 74 | keys = list(param_dict.keys()) 75 | vals = list(param_dict.values()) 76 | 77 | def sample(tup): 78 | if tup[0] == 'float': 79 | return np.random.uniform(low=tup[1], high=tup[2], size=num_samples).tolist() 80 | elif tup[0] == 'int': 81 | return np.random.randint(low=tup[1], high=tup[2], size=num_samples).tolist() 82 | elif tup[0] == 'log': 83 | return np.power(10, np.random.uniform(low=tup[1], high=tup[2], size=num_samples)).tolist() 84 | else: 85 | raise ValueError("Unsupported sampling method") 86 | 87 | vals = [sample(val) for val in vals] 88 | combs = zip(*vals) 89 | 90 | def make_dict(keys, vals): 91 | d = {} 92 | for k, v in zip(keys, vals): 93 | if '.' in k: 94 | assert k.count('.') == 1, 'cannot handle nesting of depth 2' 95 | k_p, k_c = k.split('.') 96 | if k_p not in d: 97 | d[k_p] = {} 98 | d[k_p][k_c] = v 99 | else: 100 | d[k] = v 101 | return d 102 | 103 | prod = [make_dict(keys, vals) for vals in combs] 104 | return prod 105 | 106 | init_param_list = dict_sample(init_param_values, n_runs) 107 | fit_param_list = dict_sample(fit_param_values, n_runs) 108 | 109 | # ---- Load data ---- # 110 | print("Loading dataset...") 111 | datapath_dict = { 112 | 'mc_maze': '~/data/000128/sub-Jenkins/', 113 | 'mc_rtt': '~/data/000129/sub-Indy/', 114 | 'area2_bump': '~/data/000127/sub-Han/', 115 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 116 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 117 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 118 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 119 | } 120 | prefix_dict = { 121 | 'mc_maze': '*full', 122 | 'mc_maze_large': '*large', 123 | 'mc_maze_medium': '*medium', 124 | 'mc_maze_small': '*small', 125 | } 126 | datapath = datapath_dict[dataset_name] 127 | prefix = prefix_dict.get(dataset_name, '') 128 | 129 | dataset = NWBDataset(datapath, prefix) 130 | dataset.resample(bin_size_ms) 131 | 132 | binsuf = '' if bin_size_ms == 5 else '_20' 133 | scaling_tdict = { 134 | 'mc_maze_small': '[100] ', 135 | 'mc_maze_medium': '[250] ', 136 | 'mc_maze_large': '[500] ', 137 | } 138 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else dataset_name + binsuf + "_split" 139 | bpskey = scaling_tdict[dataset_name] + 'co-bps' if 'maze_' in dataset_name else 'co-bps' 140 | deckey = scaling_tdict[dataset_name] + 'vel R2' if 'maze_' in dataset_name else 'tp Corr' if 'dmfc' in dataset_name else 'vel R2' 141 | 142 | # ---- Prep Input ---- # 143 | print("Preparing input...") 144 | 145 | valid_mask = (dataset.trial_info.split != 'none').to_numpy() 146 | good_trials = valid_mask.nonzero()[0] 147 | 148 | trial_sels = [np.random.choice(good_trials, train_subset_size + eval_subset_size, replace=False) for _ in range(num_repeats)] 149 | train_splits = [np.isin(np.arange(len(valid_mask)), ts[:train_subset_size]) for ts in trial_sels] 150 | eval_splits = [np.isin(np.arange(len(valid_mask)), ts[train_subset_size:]) for ts in trial_sels] 151 | 152 | train_datas = [] 153 | eval_datas = [] 154 | target_datas = [] 155 | for ts, es in zip(train_splits, eval_splits): 156 | train_dict = make_train_input_tensors(dataset, dataset_name, ts, save_file=False, include_forward_pred=True) 157 | eval_dict = make_eval_input_tensors(dataset, dataset_name, es, save_file=False) 158 | target_dict = make_eval_target_tensors(dataset, dataset_name, ts, es, save_file=False, include_psth=('rtt' not in dataset_name)) 159 | 160 | train_spikes_heldin = train_dict['train_spikes_heldin'] 161 | train_spikes_heldout = train_dict['train_spikes_heldout'] 162 | train_spikes_heldin_fp = train_dict['train_spikes_heldin_forward'] 163 | train_spikes_heldout_fp = train_dict['train_spikes_heldout_forward'] 164 | train_spikes = np.concatenate([ 165 | np.concatenate([train_spikes_heldin, train_spikes_heldin_fp], axis=1), 166 | np.concatenate([train_spikes_heldout, train_spikes_heldout_fp], axis=1), 167 | ], axis=2) 168 | 169 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 170 | eval_spikes = np.full((eval_spikes_heldin.shape[0], train_spikes.shape[1], train_spikes.shape[2]), 0.0) 171 | masks = np.full((eval_spikes_heldin.shape[0], train_spikes.shape[1], train_spikes.shape[2]), False) 172 | eval_spikes[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = eval_spikes_heldin 173 | masks[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = True 174 | 175 | train_spklist = [train_spikes[i, :, :].astype(int) for i in range(len(train_spikes))] 176 | eval_spklist = [eval_spikes[i, :, :].astype(int) for i in range(len(eval_spikes))] 177 | eval_masks = [masks[i, :, :] for i in range(len(masks))] 178 | 179 | train_datas.append((train_spklist, None)) 180 | eval_datas.append((eval_spklist, eval_masks)) 181 | target_datas.append(target_dict) 182 | 183 | numheldin = train_spikes_heldin.shape[2] 184 | tlen = train_spikes_heldin.shape[1] 185 | 186 | def make_inputs(slds, datas): 187 | datas, inputs, masks, tags = slds.prep_inputs(datas=datas[0], masks=datas[1]) 188 | tensors = { 189 | 'datas': datas, 190 | 'inputs': inputs, 191 | 'masks': masks, 192 | 'tags': tags 193 | } 194 | return tensors 195 | 196 | temp_slds = SLDS(2, 1, 1) 197 | 198 | train_tensors = [make_inputs(temp_slds, d) for d in train_datas] 199 | eval_tensors = [make_inputs(temp_slds, d) for d in eval_datas] 200 | 201 | del dataset, temp_slds 202 | del train_dict, train_spikes_heldin, train_spikes_heldout, train_spikes_heldin_fp, train_spikes_heldout_fp 203 | del eval_dict, eval_spikes_heldin, eval_spikes 204 | del train_datas, eval_datas 205 | del masks, eval_masks, train_spklist, eval_spklist 206 | gc.collect() 207 | 208 | # ---- Define slds Wrapper ---- # 209 | def run_slds(init_params, fit_params, train_datas, eval_datas): 210 | N = train_datas['datas'].shape[2] 211 | slds = SLDS(N=N, 212 | transitions=transitions, 213 | emissions=emissions, 214 | emission_kwargs=emission_kwargs, 215 | **init_params, 216 | ) 217 | 218 | slds.initialize( 219 | verbose=2, 220 | num_init_iters=num_init_iters, 221 | **train_datas, 222 | ) 223 | 224 | q_elbos_lem_train, q_lem_train, *_ = slds.fit( 225 | method="laplace_em", 226 | variational_posterior="structured_meanfield", 227 | initialize=False, 228 | num_iters=num_train_iters, # score=True, 229 | **train_datas, 230 | **fit_params, 231 | ) 232 | 233 | q_elbos_lem_eval, q_lem_eval, *_ = slds.approximate_posterior( 234 | method="laplace_em", 235 | variational_posterior="structured_meanfield", 236 | num_iters=num_train_iters, 237 | **eval_datas, 238 | **fit_params, 239 | ) 240 | 241 | train_rates = slds.smooth_3d(q_lem_train.mean_continuous_states, **train_datas).cpu().numpy() 242 | eval_rates = slds.smooth_3d(q_lem_eval.mean_continuous_states, **eval_datas).cpu().numpy() 243 | 244 | train_factors = q_lem_train.mean_continuous_states.cpu().numpy() 245 | eval_factors = q_lem_eval.mean_continuous_states.cpu().numpy() 246 | 247 | del slds 248 | del q_lem_train 249 | del q_lem_eval 250 | gc.collect() 251 | 252 | return (train_rates, eval_rates) 253 | 254 | def dict_mean(dict_list): 255 | num_dicts = len(dict_list) 256 | if num_dicts == 0: 257 | return [] 258 | if num_dicts == 1: 259 | return dict_list[0] 260 | mean_dict = {} 261 | for d in dict_list: 262 | for key, val in d.items(): 263 | prev = mean_dict.get(key, 0) 264 | mean_dict[key] = prev + val / num_dicts 265 | return mean_dict 266 | 267 | # ---- Run Sweep ---- # 268 | res_list = [] 269 | 270 | search_name = f"./{dataset_name}_runs/search_{datetime.now().strftime('%Y%m%d_%H%M%S')}" 271 | 272 | print(f'Starting {n_runs} runs...') 273 | 274 | i = 0 275 | num_restarts = 0 276 | best_slds = None 277 | best_bps = 0 278 | 279 | while i < n_runs: 280 | init_params = dict_sample(init_param_values, 1)[0] 281 | fit_params = dict_sample(fit_param_values, 1)[0] 282 | print(f"Run {i}:\n init_params: {init_params}\n fit_params: {fit_params}") 283 | sub_list = [] 284 | for n in range(num_repeats): 285 | try: 286 | (train_rates, eval_rates) = run_slds(init_params, fit_params, train_tensors[n], eval_tensors[n]) 287 | except: 288 | print('Run failed!') 289 | continue 290 | 291 | # Reshape output 292 | train_rates_heldin = train_rates[:, :tlen, :numheldin] 293 | train_rates_heldout = train_rates[:, :tlen, numheldin:] 294 | eval_rates_heldin = eval_rates[:, :tlen, :numheldin] 295 | eval_rates_heldout = eval_rates[:, :tlen, numheldin:] 296 | eval_rates_heldin_forward = eval_rates[:, tlen:, :numheldin] 297 | eval_rates_heldout_forward = eval_rates[:, tlen:, numheldin:] 298 | 299 | submission_dict = { 300 | dataset_name + binsuf: { 301 | 'train_rates_heldin': train_rates_heldin, 302 | 'train_rates_heldout': train_rates_heldout, 303 | 'eval_rates_heldin': eval_rates_heldin, 304 | 'eval_rates_heldout': eval_rates_heldout, 305 | 'eval_rates_heldin_forward': eval_rates_heldin_forward, 306 | 'eval_rates_heldout_forward': eval_rates_heldout_forward, 307 | } 308 | } 309 | 310 | res = evaluate(target_datas[n], submission_dict)[0][dskey] 311 | sub_list.append(res) 312 | 313 | if not sub_list: 314 | i += 1 315 | continue 316 | res = dict_mean(sub_list) 317 | res['run_idx'] = i 318 | res.update(unpack_nested(fit_params)) 319 | res.update(unpack_nested(init_params)) 320 | res_list.append(res) 321 | 322 | del submission_dict, train_rates, eval_rates 323 | del train_rates_heldin, train_rates_heldout, eval_rates_heldin, eval_rates_heldout, eval_rates_heldin_forward, eval_rates_heldout_forward 324 | gc.collect() 325 | 326 | print('') 327 | time.sleep(10) # rest between models 328 | i += 1 329 | 330 | del train_tensors, eval_tensors 331 | 332 | # ---- Save results ---- # 333 | results = pd.DataFrame(res_list) 334 | results.to_csv(search_name + '_results.csv') 335 | 336 | -------------------------------------------------------------------------------- /examples/baselines/smoothing/README.md: -------------------------------------------------------------------------------- 1 | # Spike smoothing 2 | 3 | Spike smoothing is a simple approach to denoising firing rates by convolving spikes with a Gaussian kernel. 4 | 5 | This directory contains files used to optimize SLDS for NLB'21: 6 | * `smoothing_cv_sweep.py` runs a 5-fold cross-validated grid search over certain parameter values. 7 | * `run_smoothing.py` runs smoothing and generates a submission for NLB'21. The best parameters found by `smoothing_cv_sweep.py` are stored in `default_dict` in the file. 8 | 9 | ## Dependencies 10 | * [nlb_tools](https://github.com/neurallatents/nlb_tools) 11 | * sklearn>=0.23 12 | -------------------------------------------------------------------------------- /examples/baselines/smoothing/run_smoothing.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ---- # 2 | from nlb_tools.nwb_interface import NWBDataset 3 | from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 4 | from nlb_tools.evaluation import evaluate 5 | import numpy as np 6 | import h5py 7 | import scipy.signal as signal 8 | from sklearn.linear_model import PoissonRegressor 9 | 10 | # ---- Default params ---- # 11 | default_dict = { # [kern_sd, alpha] 12 | 'mc_maze': [50, 0.01], 13 | 'mc_rtt': [30, 0.1], 14 | 'area2_bump': [30, 0.01], 15 | 'dmfc_rsg': [60, 0.001], 16 | 'mc_maze_large': [40, 0.1], 17 | 'mc_maze_medium': [60, 0.1], 18 | 'mc_maze_small': [60, 0.1], 19 | } 20 | 21 | # ---- Run Params ---- # 22 | dataset_name = "mc_rtt" # one of {'area2_bump', 'dmfc_rsg', 'mc_maze', 'mc_rtt', 23 | # 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small'} 24 | bin_size_ms = 5 25 | kern_sd = default_dict[dataset_name][0] 26 | alpha = default_dict[dataset_name][1] 27 | phase = 'test' # one of {'test', 'val'} 28 | log_offset = 1e-4 # amount to add before taking log to prevent log(0) error 29 | 30 | # ---- Useful variables ---- # 31 | binsuf = '' if bin_size_ms == 5 else f'_{bin_size_ms}' 32 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else (dataset_name + binsuf + "_split") 33 | pref_dict = {'mc_maze_small': '[100] ', 'mc_maze_medium': '[250] ', 'mc_maze_large': '[500] '} 34 | bpskey = pref_dict.get(dataset_name, '') + 'co-bps' 35 | 36 | # ---- Data locations ---- # 37 | datapath_dict = { 38 | 'mc_maze': '~/data/000128/sub-Jenkins/', 39 | 'mc_rtt': '~/data/000129/sub-Indy/', 40 | 'area2_bump': '~/data/000127/sub-Han/', 41 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 42 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 43 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 44 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 45 | } 46 | prefix_dict = { 47 | 'mc_maze': '*full', 48 | 'mc_maze_large': '*large', 49 | 'mc_maze_medium': '*medium', 50 | 'mc_maze_small': '*small', 51 | } 52 | datapath = datapath_dict[dataset_name] 53 | prefix = prefix_dict.get(dataset_name, '') 54 | savepath = f'{dataset_name}{"" if bin_size_ms == 5 else f"_{bin_size_ms}"}_smoothing_output_{phase}.h5' 55 | 56 | # ---- Load data ---- # 57 | dataset = NWBDataset(datapath, prefix, skip_fields=['hand_pos', 'cursor_pos', 'eye_pos', 'muscle_vel', 'muscle_len', 'joint_vel', 'joint_ang', 'force']) 58 | dataset.resample(bin_size_ms) 59 | 60 | # ---- Extract data ---- # 61 | if phase == 'val': 62 | train_split = 'train' 63 | eval_split = 'val' 64 | else: 65 | train_split = ['train', 'val'] 66 | eval_split = 'test' 67 | train_dict = make_train_input_tensors(dataset, dataset_name, train_split, save_file=False) 68 | train_spikes_heldin = train_dict['train_spikes_heldin'] 69 | train_spikes_heldout = train_dict['train_spikes_heldout'] 70 | eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_split, save_file=False) 71 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 72 | 73 | # ---- Useful shape info ---- # 74 | tlen = train_spikes_heldin.shape[1] 75 | num_heldin = train_spikes_heldin.shape[2] 76 | num_heldout = train_spikes_heldout.shape[2] 77 | 78 | # ---- Define helpers ---- # 79 | def fit_poisson(train_factors_s, eval_factors_s, train_spikes_s, eval_spikes_s=None, alpha=0.0): 80 | """ 81 | Fit Poisson GLM from factors to spikes and return rate predictions 82 | """ 83 | train_in = train_factors_s if eval_spikes_s is None else np.vstack([train_factors_s, eval_factors_s]) 84 | train_out = train_spikes_s if eval_spikes_s is None else np.vstack([train_spikes_s, eval_spikes_s]) 85 | train_pred = [] 86 | eval_pred = [] 87 | for chan in range(train_out.shape[1]): 88 | pr = PoissonRegressor(alpha=alpha, max_iter=500) 89 | pr.fit(train_in, train_out[:, chan]) 90 | while pr.n_iter_ == pr.max_iter and pr.max_iter < 10000: 91 | print(f"didn't converge - retraining {chan} with max_iter={pr.max_iter * 5}") 92 | oldmax = pr.max_iter 93 | del pr 94 | pr = PoissonRegressor(alpha=alpha, max_iter=oldmax * 5) 95 | pr.fit(train_in, train_out[:, chan]) 96 | train_pred.append(pr.predict(train_factors_s)) 97 | eval_pred.append(pr.predict(eval_factors_s)) 98 | train_rates_s = np.vstack(train_pred).T 99 | eval_rates_s = np.vstack(eval_pred).T 100 | return np.clip(train_rates_s, 1e-9, 1e20), np.clip(eval_rates_s, 1e-9, 1e20) 101 | 102 | # ---- Smooth spikes ---- # 103 | window = signal.gaussian(int(6 * kern_sd / bin_size_ms), int(kern_sd / bin_size_ms), sym=True) 104 | window /= np.sum(window) 105 | def filt(x): 106 | return np.convolve(x, window, 'same') 107 | train_spksmth_heldin = np.apply_along_axis(filt, 1, train_spikes_heldin) 108 | eval_spksmth_heldin = np.apply_along_axis(filt, 1, eval_spikes_heldin) 109 | 110 | # ---- Reshape for regression ---- # 111 | flatten2d = lambda x: x.reshape(-1, x.shape[2]) 112 | train_spksmth_heldin_s = flatten2d(train_spksmth_heldin) 113 | train_spikes_heldin_s = flatten2d(train_spikes_heldin) 114 | train_spikes_heldout_s = flatten2d(train_spikes_heldout) 115 | eval_spikes_heldin_s = flatten2d(eval_spikes_heldin) 116 | eval_spksmth_heldin_s = flatten2d(eval_spksmth_heldin) 117 | 118 | # Taking log of smoothed spikes gives better results 119 | train_lograte_heldin_s = np.log(train_spksmth_heldin_s + log_offset) 120 | eval_lograte_heldin_s = np.log(eval_spksmth_heldin_s + log_offset) 121 | 122 | # ---- Predict rates ---- # 123 | train_spksmth_heldout_s, eval_spksmth_heldout_s = fit_poisson(train_lograte_heldin_s, eval_lograte_heldin_s, train_spikes_heldout_s, alpha=alpha) 124 | train_spksmth_heldout = train_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 125 | eval_spksmth_heldout = eval_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 126 | 127 | # OPTIONAL: Also use smoothed spikes + GLM for held-in rate predictions 128 | # train_spksmth_heldin_s, eval_spksmth_heldin_s = fit_poisson(train_lograte_heldin_s, eval_lograte_heldin_s, train_spikes_heldin_s, eval_spikes_heldin_s, alpha=alpha) 129 | # train_spksmth_heldin = train_spksmth_heldin_s.reshape((-1, tlen, num_heldin)) 130 | # eval_spksmth_heldin = eval_spksmth_heldin_s.reshape((-1, tlen, num_heldin)) 131 | 132 | # ---- Prepare/save output ---- # 133 | output_dict = { 134 | dataset_name + binsuf: { 135 | 'train_rates_heldin': train_spksmth_heldin, 136 | 'train_rates_heldout': train_spksmth_heldout, 137 | 'eval_rates_heldin': eval_spksmth_heldin, 138 | 'eval_rates_heldout': eval_spksmth_heldout, 139 | } 140 | } 141 | save_to_h5(output_dict, savepath, overwrite=True) 142 | 143 | # ---- Evaluate locally ---- # 144 | if phase == 'val': 145 | target_dict = make_eval_target_tensors(dataset, dataset_name, train_split, eval_split, save_file=False, include_psth=True) 146 | print(evaluate(target_dict, output_dict)) 147 | -------------------------------------------------------------------------------- /examples/baselines/smoothing/smoothing_cv_sweep.py: -------------------------------------------------------------------------------- 1 | # ---- Imports ----- # 2 | from nlb_tools.nwb_interface import NWBDataset 3 | from nlb_tools.make_tensors import make_train_input_tensors, \ 4 | make_eval_input_tensors, make_eval_target_tensors, save_to_h5 5 | from nlb_tools.evaluation import evaluate 6 | import h5py 7 | import sys, gc 8 | import numpy as np 9 | import pandas as pd 10 | import scipy.signal as signal 11 | from sklearn.linear_model import PoissonRegressor 12 | from datetime import datetime 13 | 14 | # ---- Run Params ---- # 15 | dataset_name = "area2_bump" # one of {'area2_bump', 'dmfc_rsg', 'mc_maze', 'mc_rtt', 16 | # 'mc_maze_large', 'mc_maze_medium', 'mc_maze_small'} 17 | bin_size_ms = 5 18 | kern_sds = np.linspace(30, 60, 4) 19 | alphas = np.logspace(-3, 0, 4) 20 | cv_fold = 5 21 | log_offset = 1e-4 # amount to add before taking log to prevent log(0) error 22 | 23 | # ---- Useful variables ---- # 24 | binsuf = '' if bin_size_ms == 5 else f'_{bin_size_ms}' 25 | dskey = f'mc_maze_scaling{binsuf}_split' if 'maze_' in dataset_name else (dataset_name + binsuf + "_split") 26 | pref_dict = {'mc_maze_small': '[100] ', 'mc_maze_medium': '[250] ', 'mc_maze_large': '[500] '} 27 | bpskey = pref_dict.get(dataset_name, '') + 'co-bps' 28 | 29 | # ---- Data locations ----# 30 | datapath_dict = { 31 | 'mc_maze': '~/data/000128/sub-Jenkins/', 32 | 'mc_rtt': '~/data/000129/sub-Indy/', 33 | 'area2_bump': '~/data/000127/sub-Han/', 34 | 'dmfc_rsg': '~/data/000130/sub-Haydn/', 35 | 'mc_maze_large': '~/data/000138/sub-Jenkins/', 36 | 'mc_maze_medium': '~/data/000139/sub-Jenkins/', 37 | 'mc_maze_small': '~/data/000140/sub-Jenkins/', 38 | } 39 | prefix_dict = { 40 | 'mc_maze': '*full', 41 | 'mc_maze_large': '*large', 42 | 'mc_maze_medium': '*medium', 43 | 'mc_maze_small': '*small', 44 | } 45 | datapath = datapath_dict[dataset_name] 46 | prefix = prefix_dict.get(dataset_name, '') 47 | 48 | # ---- Load data ---- # 49 | dataset = NWBDataset(datapath, prefix, 50 | skip_fields=['hand_pos', 'cursor_pos', 'eye_pos', 'force', 'muscle_vel', 'muscle_len', 'joint_vel', 'joint_ang']) 51 | dataset.resample(bin_size_ms) 52 | 53 | # ---- Prepare n folds ---- # 54 | all_mask = np.isin(dataset.trial_info.split, ['train', 'val']) 55 | all_idx = np.arange(all_mask.shape[0])[all_mask] 56 | train_masks = [] 57 | eval_masks = [] 58 | for i in range(cv_fold): 59 | eval_idx = all_idx[i::cv_fold] # take every n samples for each fold 60 | train_idx = all_idx[~np.isin(all_idx, eval_idx)] 61 | train_masks.append(np.isin(np.arange(all_mask.shape[0]), train_idx)) 62 | eval_masks.append(np.isin(np.arange(all_mask.shape[0]), eval_idx)) 63 | 64 | # ---- Extract data for each fold ---- # 65 | fold_data = [] 66 | for i in range(cv_fold): 67 | train_dict = make_train_input_tensors(dataset, dataset_name, train_masks[i], save_file=False) 68 | eval_dict = make_eval_input_tensors(dataset, dataset_name, eval_masks[i], save_file=False) 69 | 70 | train_spikes_heldin = train_dict['train_spikes_heldin'] 71 | train_spikes_heldout = train_dict['train_spikes_heldout'] 72 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 73 | 74 | target_dict = make_eval_target_tensors(dataset, dataset_name, train_masks[i], eval_masks[i], include_psth=True, save_file=False) 75 | fold_data.append((train_spikes_heldin, train_spikes_heldout, eval_spikes_heldin, target_dict)) 76 | del dataset 77 | gc.collect() 78 | 79 | # ---- Useful shape info ---- # 80 | tlen = fold_data[0][0].shape[1] 81 | num_heldin = fold_data[0][0].shape[2] 82 | num_heldout = fold_data[0][1].shape[2] 83 | results = [] 84 | 85 | # ---- Define helpers ---- # 86 | flatten2d = lambda x: x.reshape(-1, x.shape[2]) # flattens 3d -> 2d array 87 | 88 | def fit_poisson(train_factors_s, test_factors_s, train_spikes_s, test_spikes_s=None, alpha=0.0): 89 | """Fit Poisson GLM from factors to spikes and return rate predictions""" 90 | train_in = train_factors_s if test_spikes_s is None else np.vstack([train_factors_s, test_factors_s]) 91 | train_out = train_spikes_s if test_spikes_s is None else np.vstack([train_spikes_s, test_spikes_s]) 92 | train_pred = [] 93 | test_pred = [] 94 | for chan in range(train_out.shape[1]): 95 | pr = PoissonRegressor(alpha=alpha, max_iter=500) 96 | pr.fit(train_in, train_out[:, chan]) 97 | while pr.n_iter_ == pr.max_iter and pr.max_iter < 10000: 98 | print(f"didn't converge - retraining {chan} with max_iter={pr.max_iter * 5}") 99 | oldmax = pr.max_iter 100 | del pr 101 | pr = PoissonRegressor(alpha=alpha, max_iter=oldmax * 5) 102 | pr.fit(train_in, train_out[:, chan]) 103 | train_pred.append(pr.predict(train_factors_s)) 104 | test_pred.append(pr.predict(test_factors_s)) 105 | train_rates_s = np.vstack(train_pred).T 106 | test_rates_s = np.vstack(test_pred).T 107 | return np.clip(train_rates_s, 1e-9, 1e20), np.clip(test_rates_s, 1e-9, 1e20) 108 | 109 | # ---- Sweep kernel std ---- # 110 | for ks in kern_sds: 111 | print(f"Evaluating kern_sd = {ks}") 112 | 113 | # ---- Prepare smoothing kernel ---- # 114 | window = signal.gaussian(int(6 * ks / bin_size_ms), int(ks / bin_size_ms), sym=True) 115 | window /= np.sum(window) 116 | def filt(x): 117 | return np.convolve(x, window, 'same') 118 | 119 | # ---- Sweep GLM alpha ---- # 120 | for a in alphas: 121 | print(f" Evaluating alpha = {a}") 122 | res_list = [] 123 | 124 | # ---- Evaluate each fold ---- # 125 | for n, data in enumerate(fold_data): 126 | 127 | # ---- Smooth spikes ---- # 128 | train_spikes_heldin, train_spikes_heldout, eval_spikes_heldin, target_dict = data 129 | train_spksmth_heldin = np.apply_along_axis(filt, 1, train_spikes_heldin) 130 | eval_spksmth_heldin = np.apply_along_axis(filt, 1, eval_spikes_heldin) 131 | 132 | # ---- Reshape for regression ---- # 133 | train_spikes_heldin_s = flatten2d(train_spikes_heldin) 134 | train_spikes_heldout_s = flatten2d(train_spikes_heldout) 135 | train_spksmth_heldin_s = flatten2d(train_spksmth_heldin) 136 | eval_spikes_heldin_s = flatten2d(eval_spikes_heldin) 137 | eval_spksmth_heldin_s = flatten2d(eval_spksmth_heldin) 138 | 139 | # Taking log of smoothed spikes gives better results 140 | train_lograte_heldin_s = np.log(train_spksmth_heldin_s + log_offset) 141 | eval_lograte_heldin_s = np.log(eval_spksmth_heldin_s + log_offset) 142 | 143 | # ---- Predict rates ---- # 144 | train_spksmth_heldout_s, eval_spksmth_heldout_s = fit_poisson(train_lograte_heldin_s, eval_lograte_heldin_s, train_spikes_heldout_s, alpha=a) 145 | train_spksmth_heldout = train_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 146 | eval_spksmth_heldout = eval_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 147 | 148 | # OPTIONAL: Also use smoothed spikes for held-in rate predictions 149 | # train_spksmth_heldin_s, eval_spksmth_heldin_s = fit_poisson(train_lograte_heldin_s, eval_lograte_heldin_s, train_spikes_heldin_s, eval_spikes_heldin_s, alpha=0.0) 150 | # train_spksmth_heldin = train_spksmth_heldin_s.reshape((-1, tlen, num_heldin)) 151 | # eval_spksmth_heldin = eval_spksmth_heldin_s.reshape((-1, tlen, num_heldin)) 152 | 153 | # ---- Prepare output ---- # 154 | output_dict = { 155 | dataset_name + binsuf: { 156 | 'train_rates_heldin': train_spksmth_heldin, 157 | 'train_rates_heldout': train_spksmth_heldout, 158 | 'eval_rates_heldin': eval_spksmth_heldin, 159 | 'eval_rates_heldout': eval_spksmth_heldout 160 | } 161 | } 162 | 163 | # ---- Evaluate output ---- # 164 | res = evaluate(target_dict, output_dict)[0][dskey] 165 | res_list.append(res) 166 | print(f" Fold {n}: " + str(res)) 167 | 168 | # ---- Average across folds ---- # 169 | res = pd.DataFrame(res_list).mean().to_dict() 170 | print(" Mean: " + str(res)) 171 | res['kern_sd'] = ks 172 | res['alpha'] = a 173 | results.append(res) 174 | 175 | # ---- Save results ---- # 176 | results = pd.DataFrame(results) 177 | results.to_csv(f'{dataset_name}{binsuf}_smoothing_cv_sweep.csv') 178 | 179 | # ---- Find best parameters ---- # 180 | best_combo = results[bpskey].argmax() 181 | best_kern_sd = results.iloc[best_combo].kern_sd 182 | best_alpha = results.iloc[best_combo].alpha 183 | print(f'Best params: kern_sd={best_kern_sd}, alpha={best_alpha}') -------------------------------------------------------------------------------- /examples/tutorials/gpfa_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# GPFA Demo\n", 15 | "\n", 16 | "While `basic_example.ipynb` used a smoothing implementation to generate rate predictions for the benchmark, this notebook will run GPFA, a better modeling method, using the Python package [`elephant`](https://github.com/NeuralEnsemble/elephant), which should produce far better results. We recommend first viewing `basic_example.ipynb` for more explanation of the `nlb_tools` functions we use here." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## 1. Setup\n", 24 | "\n", 25 | "Below, we import the necessary functions from `nlb_tools` and additional standard packages. Note that you will need to install `elephant`, which should install with it `neo`, and `quantities` if you don't already have them. Additionally, you'll need `scikit-learn>=0.23` for the Poisson GLM used in this notebook." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "## Install packages if necessary\n", 35 | "# !pip install elephant\n", 36 | "# !pip install -U scikit-learn\n", 37 | "# !pip install nlb-tools" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "## Imports\n", 47 | "\n", 48 | "from nlb_tools.nwb_interface import NWBDataset\n", 49 | "from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5\n", 50 | "from nlb_tools.evaluation import evaluate\n", 51 | "\n", 52 | "import numpy as np\n", 53 | "import pandas as pd\n", 54 | "import h5py\n", 55 | "import neo\n", 56 | "import quantities as pq\n", 57 | "from elephant.gpfa import GPFA\n", 58 | "from sklearn.linear_model import PoissonRegressor, Ridge" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "## If necessary, download dataset from DANDI\n", 68 | "# !pip install dandi\n", 69 | "# !dandi download https://dandiarchive.org/dandiset/000138 # replace URL with URL for dataset you want\n", 70 | "# # URLS are:\n", 71 | "# # - MC_Maze: https://dandiarchive.org/dandiset/000128\n", 72 | "# # - MC_RTT: https://dandiarchive.org/dandiset/000129\n", 73 | "# # - Area2_Bump: https://dandiarchive.org/dandiset/000127\n", 74 | "# # - DMFC_RSG: https://dandiarchive.org/dandiset/000130\n", 75 | "# # - MC_Maze_Large: https://dandiarchive.org/dandiset/000138\n", 76 | "# # - MC_Maze_Medium: https://dandiarchive.org/dandiset/000139\n", 77 | "# # - MC_Maze_Small: https://dandiarchive.org/dandiset/000140" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## 2. Loading data\n", 85 | "\n", 86 | "Below, we enter the name of the dataset, the path to the dataset files, as well as a prefix to filter out specific files, in order to load the data." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "## Load dataset\n", 96 | "\n", 97 | "dataset_name = 'mc_maze_large'\n", 98 | "datapath = './000138/sub-Jenkins/'\n", 99 | "prefix = f'*ses-large'\n", 100 | "dataset = NWBDataset(datapath, prefix)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## 3. Input prep\n", 108 | "\n", 109 | "`elephant`'s implementation of GPFA takes its input in the form of lists of `neo.SpikeTrain`s. Here, we'll use `make_train_input_tensor` or `make_eval_input_tensor` to extract the data we want to model before converting it into the desired format." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "## Dataset preparation\n", 119 | "\n", 120 | "# Choose the phase here, either 'val' or 'test'\n", 121 | "phase = 'val'\n", 122 | "\n", 123 | "# Choose bin width and resample\n", 124 | "bin_width = 5\n", 125 | "dataset.resample(bin_width)\n", 126 | "\n", 127 | "# Create suffix for group naming later\n", 128 | "suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "## Make train input data\n", 138 | "\n", 139 | "# Generate input tensors\n", 140 | "train_trial_split = 'train' if (phase == 'val') else ['train', 'val']\n", 141 | "train_dict = make_train_input_tensors(dataset, dataset_name=dataset_name, trial_split=train_trial_split, save_file=False)\n", 142 | "\n", 143 | "# Unpack input data\n", 144 | "train_spikes_heldin = train_dict['train_spikes_heldin']\n", 145 | "train_spikes_heldout = train_dict['train_spikes_heldout']" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "## Make eval input data\n", 155 | "\n", 156 | "# Generate input tensors\n", 157 | "eval_split = phase\n", 158 | "eval_dict = make_eval_input_tensors(dataset, dataset_name=dataset_name, trial_split=eval_split, save_file=False)\n", 159 | "\n", 160 | "# Unpack data\n", 161 | "eval_spikes_heldin = eval_dict['eval_spikes_heldin']" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 8, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "## Convert spiking array to SpikeTrains\n", 171 | "\n", 172 | "def array_to_spiketrains(array):\n", 173 | " \"\"\"Converts trial x time x channel spiking arrays to list of list of neo.SpikeTrain\"\"\"\n", 174 | " stList = []\n", 175 | " # Loop through trials\n", 176 | " for trial in range(len(array)):\n", 177 | " trialList = []\n", 178 | " # Loop through channels\n", 179 | " for channel in range(array.shape[2]):\n", 180 | " # Get spike times and counts\n", 181 | " times = np.where(array[trial, :, channel])[0]\n", 182 | " counts = array[trial, times, channel].astype(int)\n", 183 | " train = np.repeat(times, counts)\n", 184 | " # Create neo.SpikeTrain\n", 185 | " st = neo.SpikeTrain(times*bin_width*pq.ms, t_stop=array.shape[1]*bin_width*pq.ms)\n", 186 | " trialList.append(st)\n", 187 | " stList.append(trialList)\n", 188 | " return stList\n", 189 | "\n", 190 | "# Run conversion\n", 191 | "train_st_heldin = array_to_spiketrains(train_spikes_heldin)\n", 192 | "eval_st_heldin = array_to_spiketrains(eval_spikes_heldin)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "## 4. Running GPFA\n", 200 | "\n", 201 | "Now that we have properly formatted data, we'll run GPFA. This step may take quite a while, depending on your machine and the chosen parameters." 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 9, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "Initializing parameters using factor analysis...\n", 214 | "\n", 215 | "Fitting GPFA model...\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "## Run GPFA\n", 221 | "\n", 222 | "# Set parameters\n", 223 | "bin_size = bin_width * pq.ms\n", 224 | "latent_dim = 20\n", 225 | "\n", 226 | "# Train GPFA on train data and apply on test data\n", 227 | "gpfa = GPFA(bin_size=bin_size, x_dim=latent_dim)\n", 228 | "train_factors = gpfa.fit_transform(train_st_heldin)\n", 229 | "eval_factors = gpfa.transform(eval_st_heldin)\n", 230 | "\n", 231 | "# Extract and reshape factors to 3d array\n", 232 | "train_factors = np.stack([train_factors[i].T for i in range(len(train_factors))])\n", 233 | "eval_factors = np.stack([eval_factors[i].T for i in range(len(eval_factors))])\n" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "## 5. Generating rate predictions\n", 241 | "\n", 242 | "Now that we have our latent factors at the specified resolution, we can map these factors to the spiking data." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 10, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "## Basic data prep\n", 252 | "\n", 253 | "# Get input arrays\n", 254 | "train_spikes_heldin = train_dict['train_spikes_heldin']\n", 255 | "train_spikes_heldout = train_dict['train_spikes_heldout']\n", 256 | "\n", 257 | "# Assign variables\n", 258 | "tlength = train_spikes_heldin.shape[1]\n", 259 | "numtrain = train_spikes_heldin.shape[0]\n", 260 | "numeval = eval_spikes_heldin.shape[0]\n", 261 | "numheldin = train_spikes_heldin.shape[2]\n", 262 | "numheldout = train_spikes_heldout.shape[2]\n", 263 | "\n", 264 | "# Reshape data to 2d for regression\n", 265 | "flatten3d = lambda x: x.reshape(-1, x.shape[2])\n", 266 | "train_spikes_heldin_s = flatten3d(train_spikes_heldin)\n", 267 | "train_spikes_heldout_s = flatten3d(train_spikes_heldout)\n", 268 | "train_factors_s = flatten3d(train_factors)\n", 269 | "eval_factors_s = flatten3d(eval_factors)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 11, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "## Define fitting functions\n", 279 | "\n", 280 | "def fit_rectlin(train_input, eval_input, train_output, alpha=0.0):\n", 281 | " # Fit linear regression\n", 282 | " lr = Ridge(alpha=alpha)\n", 283 | " lr.fit(train_input, train_output)\n", 284 | " train_pred = lr.predict(train_input)\n", 285 | " eval_pred = lr.predict(eval_input)\n", 286 | " # Rectify to prevent negative or 0 rate predictions\n", 287 | " train_pred[train_pred < 1e-10] = 1e-10\n", 288 | " eval_pred[eval_pred < 1e-10] = 1e-10\n", 289 | " return train_pred, eval_pred\n", 290 | "\n", 291 | "def fit_poisson(train_input, eval_input, train_output, alpha=0.0):\n", 292 | " train_pred = []\n", 293 | " eval_pred = []\n", 294 | " # train Poisson GLM for each output column\n", 295 | " for chan in range(train_output.shape[1]):\n", 296 | " pr = PoissonRegressor(alpha=alpha, max_iter=500)\n", 297 | " pr.fit(train_input, train_output[:, chan])\n", 298 | " train_pred.append(pr.predict(train_input))\n", 299 | " eval_pred.append(pr.predict(eval_input))\n", 300 | " train_pred = np.vstack(train_pred).T\n", 301 | " eval_pred = np.vstack(eval_pred).T\n", 302 | " return train_pred, eval_pred" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 12, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "## Make rate predictions\n", 312 | "\n", 313 | "# fit GLMs for rate predictions\n", 314 | "train_rates_heldin_s, eval_rates_heldin_s = fit_rectlin(train_factors_s, eval_factors_s, train_spikes_heldin_s)\n", 315 | "train_rates_heldout_s, eval_rates_heldout_s = fit_poisson(train_rates_heldin_s, eval_rates_heldin_s, train_spikes_heldout_s)\n", 316 | "\n", 317 | "# reshape output back to 3d\n", 318 | "train_rates_heldin = train_rates_heldin_s.reshape((numtrain, tlength, numheldin))\n", 319 | "train_rates_heldout = train_rates_heldout_s.reshape((numtrain, tlength, numheldout))\n", 320 | "eval_rates_heldin = eval_rates_heldin_s.reshape((numeval, tlength, numheldin))\n", 321 | "eval_rates_heldout = eval_rates_heldout_s.reshape((numeval, tlength, numheldout))" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "## 6. Making the submission\n", 329 | "\n", 330 | "Now, we'll make the submission dict manually. As described in `basic_example.ipynb`, you can also use the function `save_to_h5` from `make_tensors.py` to save the output as an h5 file for submission on EvalAI." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 13, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "## Prepare submission data\n", 340 | "\n", 341 | "output_dict = {\n", 342 | " dataset_name + suffix: {\n", 343 | " 'train_rates_heldin': train_rates_heldin,\n", 344 | " 'train_rates_heldout': train_rates_heldout,\n", 345 | " 'eval_rates_heldin': eval_rates_heldin,\n", 346 | " 'eval_rates_heldout': eval_rates_heldout\n", 347 | " }\n", 348 | "}\n", 349 | "\n", 350 | "# To save as an h5 file:\n", 351 | "# save_to_h5(output_dict, 'submission.h5')" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## 7. Evaluation\n", 359 | "\n", 360 | "Finally, we will create the target data with `make_eval_target_tensors` and evaluate our model if we ran on the 'val' phase. If the notebook was run on the 'test' phase, you would need to submit to the EvalAI challenge to get results." 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 14, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "[{'mc_maze_scaling_split': {'[500] co-bps': 0.22863305534919023, '[500] vel R2': 0.5843404497000466, '[500] psth R2': 0.18814023505537117}}]\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "## Make data to test predictions with and evaluate\n", 378 | "\n", 379 | "if phase == 'val':\n", 380 | " target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False)\n", 381 | "\n", 382 | " print(evaluate(target_dict, output_dict))" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "## Summary\n", 390 | "\n", 391 | "In this notebook, we used `nlb_tools` and `elephant` to run GPFA on a dataset for the Neural Latents Benchmark." 392 | ] 393 | } 394 | ], 395 | "metadata": { 396 | "interpreter": { 397 | "hash": "d876f2d84ebe613ecb987c3cdf86da35455b4fa2dba2ba72805210f4933655ff" 398 | }, 399 | "kernelspec": { 400 | "display_name": "Python 3.7.6 64-bit ('tf2-gpu': conda)", 401 | "name": "python3" 402 | }, 403 | "language_info": { 404 | "codemirror_mode": { 405 | "name": "ipython", 406 | "version": 3 407 | }, 408 | "file_extension": ".py", 409 | "mimetype": "text/x-python", 410 | "name": "python", 411 | "nbconvert_exporter": "python", 412 | "pygments_lexer": "ipython3", 413 | "version": "3.7.6" 414 | }, 415 | "metadata": { 416 | "interpreter": { 417 | "hash": "d876f2d84ebe613ecb987c3cdf86da35455b4fa2dba2ba72805210f4933655ff" 418 | } 419 | } 420 | }, 421 | "nbformat": 4, 422 | "nbformat_minor": 2 423 | } 424 | -------------------------------------------------------------------------------- /examples/tutorials/img/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neurallatents/nlb_tools/1ddc15f45b56388ff093d1396b7b87b36fa32a68/examples/tutorials/img/pipeline.png -------------------------------------------------------------------------------- /examples/tutorials/img/split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neurallatents/nlb_tools/1ddc15f45b56388ff093d1396b7b87b36fa32a68/examples/tutorials/img/split.png -------------------------------------------------------------------------------- /examples/tutorials/slds_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# SLDS Demo\n", 15 | "\n", 16 | "In this notebook, we will use a switching linear dynamical system (SLDS) to model the neural data. We will use the Linderman Lab's [`ssm` package](https://github.com/lindermanlab/ssm), which you should install before running this demo. We recommend first viewing `basic_example.ipynb` for more explanation of the `nlb_tools` functions we use here." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## 1. Setup\n", 24 | "\n", 25 | "Below, we import the necessary functions from `nlb_tools` and additional standard packages." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "## Install packages if necessary\n", 35 | "# !pip install git+https://github.com/lindermanlab/ssm\n", 36 | "# !pip install -U scikit-learn\n", 37 | "# !pip install nlb-tools" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "## Imports\n", 47 | "\n", 48 | "from nlb_tools.nwb_interface import NWBDataset\n", 49 | "from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5\n", 50 | "from nlb_tools.evaluation import evaluate\n", 51 | "\n", 52 | "import ssm\n", 53 | "import numpy as np\n", 54 | "import h5py\n", 55 | "import sys" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "## If necessary, download dataset from DANDI\n", 65 | "# !pip install dandi\n", 66 | "# !dandi download https://dandiarchive.org/dandiset/000138 # replace URL with URL for dataset you want\n", 67 | "# # URLS are:\n", 68 | "# # - MC_Maze: https://dandiarchive.org/dandiset/000128\n", 69 | "# # - MC_RTT: https://dandiarchive.org/dandiset/000129\n", 70 | "# # - Area2_Bump: https://dandiarchive.org/dandiset/000127\n", 71 | "# # - DMFC_RSG: https://dandiarchive.org/dandiset/000130\n", 72 | "# # - MC_Maze_Large: https://dandiarchive.org/dandiset/000138\n", 73 | "# # - MC_Maze_Medium: https://dandiarchive.org/dandiset/000139\n", 74 | "# # - MC_Maze_Small: https://dandiarchive.org/dandiset/000140" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## 2. Loading data\n", 82 | "\n", 83 | "Below, please enter the path to the dataset, as well as the name of the dataset, to load the data. In addition, you can choose a bin size (0.005 or 0.02 s) to run the notebook at." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "## Load dataset\n", 93 | "\n", 94 | "dataset_name = 'mc_maze_small'\n", 95 | "datapath = './000140/sub-Jenkins/'\n", 96 | "prefix = f'*ses-small'\n", 97 | "dataset = NWBDataset(datapath, prefix)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## 3. Input prep\n", 105 | "\n", 106 | "`ssm` expects inputs as a list of 2d arrays of type `int`, so we will use functions from `make_tensors` to create 3d arrays, and split the arrays along the trial axis to get our list. Note that since SLDS can perform forward prediction, we indicate `include_forward_pred=True` in `make_train_input_tensors`, which includes the next 200 ms of spiking activity after the required window for each trial in separate tensors called `'train_spikes_heldin_forward'` and `'train_spikes_heldout_forward'`." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "## Dataset preparation\n", 116 | "\n", 117 | "# Choose the phase here, either 'val' or 'test'\n", 118 | "phase = 'val'\n", 119 | "\n", 120 | "# Choose bin width and resample\n", 121 | "bin_width = 5\n", 122 | "dataset.resample(bin_width)\n", 123 | "\n", 124 | "# Create suffix for group naming later\n", 125 | "suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "## Make train input data\n", 135 | "\n", 136 | "# Generate input tensors\n", 137 | "train_trial_split = 'train' if (phase == 'val') else ['train', 'val']\n", 138 | "train_dict = make_train_input_tensors(dataset, dataset_name=dataset_name, trial_split=train_trial_split, save_file=False, include_forward_pred=True)\n", 139 | "\n", 140 | "# Unpack input data\n", 141 | "train_spikes_heldin = train_dict['train_spikes_heldin']\n", 142 | "train_spikes_heldout = train_dict['train_spikes_heldout']" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "## Make eval input data\n", 152 | "\n", 153 | "# Generate input tensors\n", 154 | "eval_trial_split = phase\n", 155 | "eval_dict = make_eval_input_tensors(dataset, dataset_name=dataset_name, trial_split=eval_trial_split, save_file=False)\n", 156 | "\n", 157 | "# Unpack data\n", 158 | "eval_spikes_heldin = eval_dict['eval_spikes_heldin']" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "## Prep input\n", 168 | "\n", 169 | "# Combine train spiking data into one array\n", 170 | "train_spikes_heldin = train_dict['train_spikes_heldin']\n", 171 | "train_spikes_heldout = train_dict['train_spikes_heldout']\n", 172 | "train_spikes_heldin_fp = train_dict['train_spikes_heldin_forward']\n", 173 | "train_spikes_heldout_fp = train_dict['train_spikes_heldout_forward']\n", 174 | "train_spikes = np.concatenate([\n", 175 | " np.concatenate([train_spikes_heldin, train_spikes_heldin_fp], axis=1),\n", 176 | " np.concatenate([train_spikes_heldout, train_spikes_heldout_fp], axis=1),\n", 177 | "], axis=2)\n", 178 | "\n", 179 | "# Fill missing test spiking data with zeros and make masks\n", 180 | "eval_spikes_heldin = eval_dict['eval_spikes_heldin']\n", 181 | "eval_spikes = np.full((eval_spikes_heldin.shape[0], train_spikes.shape[1], train_spikes.shape[2]), 0.0)\n", 182 | "masks = np.full((eval_spikes_heldin.shape[0], train_spikes.shape[1], train_spikes.shape[2]), False)\n", 183 | "eval_spikes[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = eval_spikes_heldin\n", 184 | "masks[:, :eval_spikes_heldin.shape[1], :eval_spikes_heldin.shape[2]] = True\n", 185 | "\n", 186 | "# Make lists of arrays\n", 187 | "train_datas = [train_spikes[i, :, :].astype(int) for i in range(len(train_spikes))]\n", 188 | "eval_datas = [eval_spikes[i, :, :].astype(int) for i in range(len(eval_spikes))]\n", 189 | "eval_masks = [masks[i, :, :].astype(bool) for i in range(len(masks))]\n", 190 | "\n", 191 | "num_heldin = train_spikes_heldin.shape[2]\n", 192 | "tlen = train_spikes_heldin.shape[1]\n", 193 | "num_train = len(train_datas)\n", 194 | "num_eval = len(eval_datas)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "## 4. Running SLDS\n", 202 | "\n", 203 | "Now that we have our input data prepared, we can fit an SLDS to it. Feel free to vary the parameters as you see fit" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 9, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "application/vnd.jupyter.widget-view+json": { 214 | "model_id": "0f56b9137d754192821a85315bf33de4", 215 | "version_major": 2, 216 | "version_minor": 0 217 | }, 218 | "text/plain": [ 219 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | }, 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "Initializing with an ARHMM using 25 steps of EM.\n" 230 | ] 231 | }, 232 | { 233 | "data": { 234 | "application/vnd.jupyter.widget-view+json": { 235 | "model_id": "99254f1a302c4fb9a37595dbd4bde07d", 236 | "version_major": 2, 237 | "version_minor": 0 238 | }, 239 | "text/plain": [ 240 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))" 241 | ] 242 | }, 243 | "metadata": {}, 244 | "output_type": "display_data" 245 | }, 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "\n", 251 | "\n" 252 | ] 253 | }, 254 | { 255 | "data": { 256 | "application/vnd.jupyter.widget-view+json": { 257 | "model_id": "3f8a30354b2449ce9aa374a62f466542", 258 | "version_major": 2, 259 | "version_minor": 0 260 | }, 261 | "text/plain": [ 262 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))" 263 | ] 264 | }, 265 | "metadata": {}, 266 | "output_type": "display_data" 267 | }, 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | "\n" 273 | ] 274 | }, 275 | { 276 | "data": { 277 | "application/vnd.jupyter.widget-view+json": { 278 | "model_id": "e212c9b881d746fd9899a991536064d7", 279 | "version_major": 2, 280 | "version_minor": 0 281 | }, 282 | "text/plain": [ 283 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=25.0), HTML(value='')))" 284 | ] 285 | }, 286 | "metadata": {}, 287 | "output_type": "display_data" 288 | }, 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "## Run SLDS\n", 299 | "\n", 300 | "# Set parameters\n", 301 | "T = train_datas[0].shape[0] # trial length\n", 302 | "K = 3 # number of discrete states\n", 303 | "D = 15 # dimensionality of latent states\n", 304 | "N = train_datas[0].shape[1] # input dimensionality\n", 305 | "\n", 306 | "slds = ssm.SLDS(N, K, D,\n", 307 | " transitions='standard',\n", 308 | " emissions='poisson',\n", 309 | " emission_kwargs=dict(link=\"log\"),\n", 310 | " dynamics_kwargs={\n", 311 | " 'l2_penalty_A': 3000.0,\n", 312 | " }\n", 313 | ")\n", 314 | "\n", 315 | "# Train\n", 316 | "q_elbos_lem_train, q_lem_train = slds.fit(\n", 317 | " datas=train_datas,\n", 318 | " method=\"laplace_em\",\n", 319 | " variational_posterior=\"structured_meanfield\",\n", 320 | " num_init_iters=25, num_iters=25, alpha=0.2,\n", 321 | ")\n", 322 | "\n", 323 | "# Pass eval data\n", 324 | "q_elbos_lem_eval, q_lem_eval = slds.approximate_posterior(\n", 325 | " datas=eval_datas,\n", 326 | " masks=eval_masks,\n", 327 | " method=\"laplace_em\",\n", 328 | " variational_posterior=\"structured_meanfield\",\n", 329 | " num_iters=25, alpha=0.2,\n", 330 | ")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "## 5. Generating rate predictions\n", 338 | "\n", 339 | "We now have our estimates of continuous neural population state, so we'll now use them to predict neuron firing rates. `SLDS` does this by smoothing the input data." 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 10, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "## Generate rate predictions\n", 349 | "\n", 350 | "# Smooth observations using inferred states\n", 351 | "train_rates = [slds.smooth(q_lem_train.mean_continuous_states[i], train_datas[i]) for i in range(num_train)]\n", 352 | "eval_rates = [slds.smooth(q_lem_eval.mean_continuous_states[i], eval_datas[i], mask=eval_masks[i]) for i in range(num_eval)]\n", 353 | "\n", 354 | "# Reshape output\n", 355 | "train_rates = np.stack(train_rates)\n", 356 | "eval_rates = np.stack(eval_rates)\n", 357 | "\n", 358 | "train_rates_heldin = train_rates[:, :tlen, :num_heldin]\n", 359 | "train_rates_heldout = train_rates[:, :tlen, num_heldin:]\n", 360 | "eval_rates_heldin = eval_rates[:, :tlen, :num_heldin]\n", 361 | "eval_rates_heldout = eval_rates[:, :tlen, num_heldin:]\n", 362 | "eval_rates_heldin_forward = eval_rates[:, tlen:, :num_heldin]\n", 363 | "eval_rates_heldout_forward = eval_rates[:, tlen:, num_heldin:]" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "## 6. Making the submission\n", 371 | "\n", 372 | "Now, we'll make the submission dict manually. As described in `basic_example.ipynb`, you can also use the function `save_to_h5` from `make_tensors.py` to save the output as an h5 file for submission on EvalAI." 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 11, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "## Prepare submission data\n", 382 | "\n", 383 | "output_dict = {\n", 384 | " dataset_name + suffix: {\n", 385 | " 'train_rates_heldin': train_rates_heldin,\n", 386 | " 'train_rates_heldout': train_rates_heldout,\n", 387 | " 'eval_rates_heldin': eval_rates_heldin,\n", 388 | " 'eval_rates_heldout': eval_rates_heldout,\n", 389 | " 'eval_rates_heldin_forward': eval_rates_heldin_forward,\n", 390 | " 'eval_rates_heldout_forward': eval_rates_heldout_forward,\n", 391 | " }\n", 392 | "}\n", 393 | "\n", 394 | "# To save as an h5 file:\n", 395 | "# save_to_h5(output_dict, 'submission.h5')" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": {}, 401 | "source": [ 402 | "## 7. Evaluation\n", 403 | "\n", 404 | "Finally, we will create the test data with make_test_tensor and evaluate our model." 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 12, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "name": "stderr", 414 | "output_type": "stream", 415 | "text": [ 416 | "Zero rate predictions found. Replacing zeros with 1e-9\n" 417 | ] 418 | }, 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "[{'mc_maze_scaling_split': {'[100] co-bps': 0.15331211388512617, '[100] vel R2': 0.6514445468827901, '[100] psth R2': 0.21864536599656217, '[100] fp-bps': -2.7198655704068924}}]\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "## Make data to test predictions with and evaluate\n", 429 | "\n", 430 | "if phase == 'val':\n", 431 | " target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=('mc_rtt' not in dataset_name), save_file=False)\n", 432 | "\n", 433 | " print(evaluate(target_dict, output_dict))" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "## Summary\n", 441 | "\n", 442 | "In this notebook, we used `nlb_tools` and `ssm` to run and evaluate SLDS on our benchmark." 443 | ] 444 | } 445 | ], 446 | "metadata": { 447 | "interpreter": { 448 | "hash": "d876f2d84ebe613ecb987c3cdf86da35455b4fa2dba2ba72805210f4933655ff" 449 | }, 450 | "kernelspec": { 451 | "display_name": "Python 3.7.6 64-bit ('tf2-gpu': conda)", 452 | "name": "python3" 453 | }, 454 | "language_info": { 455 | "codemirror_mode": { 456 | "name": "ipython", 457 | "version": 3 458 | }, 459 | "file_extension": ".py", 460 | "mimetype": "text/x-python", 461 | "name": "python", 462 | "nbconvert_exporter": "python", 463 | "pygments_lexer": "ipython3", 464 | "version": "3.7.6" 465 | }, 466 | "metadata": { 467 | "interpreter": { 468 | "hash": "d876f2d84ebe613ecb987c3cdf86da35455b4fa2dba2ba72805210f4933655ff" 469 | } 470 | }, 471 | "orig_nbformat": 2 472 | }, 473 | "nbformat": 4, 474 | "nbformat_minor": 2 475 | } 476 | -------------------------------------------------------------------------------- /nlb_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neurallatents/nlb_tools/1ddc15f45b56388ff093d1396b7b87b36fa32a68/nlb_tools/__init__.py -------------------------------------------------------------------------------- /nlb_tools/chop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import logging 4 | from os import path 5 | import pandas as pd 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class SegmentRecord: 12 | """Stores information needed to reconstruct a segment from chops""" 13 | def __init__(self, seg_id, clock_time, offset, n_chops, overlap): 14 | """Stores the information needed to reconstruct a segment 15 | 16 | Parameters 17 | ---------- 18 | seg_id : int 19 | The ID of this segment. 20 | clock_time : pd.Series 21 | The TimeDeltaIndex of the original data from this segment 22 | offset : int 23 | The offset of the chops from the start of the segment 24 | n_chops : int 25 | The number of chops that make up this segment 26 | overlap : int 27 | The number of bins of overlap between adjacent chops 28 | """ 29 | self.seg_id = seg_id 30 | self.clock_time = clock_time 31 | self.offset = offset 32 | self.n_chops = n_chops 33 | self.overlap = overlap 34 | 35 | def rebuild_segment(self, chops, smooth_pwr=2): 36 | """Reassembles a segment from its chops 37 | 38 | Parameters 39 | ---------- 40 | chops : np.ndarray 41 | A 3D numpy array of shape n_chops x seg_len x data_dim that 42 | holds the data from all of the chops in this segment 43 | smooth_pwr : float, optional 44 | The power to use for smoothing. See `merge_chops` 45 | function for more details, by default 2 46 | 47 | Returns 48 | ------- 49 | pd.DataFrame 50 | A DataFrame of reconstructed segment data, indexed by the 51 | clock_time of the original segment 52 | """ 53 | # Merge the chops for this segment 54 | merged_array = merge_chops( 55 | chops, 56 | overlap=self.overlap, 57 | orig_len=len(self.clock_time) - self.offset, 58 | smooth_pwr=smooth_pwr) 59 | # Add NaNs for points that were not modeled due to offset 60 | data_dim = merged_array.shape[1] 61 | offset_nans = np.full((self.offset, data_dim), np.nan) 62 | merged_array = np.concatenate([offset_nans, merged_array]) 63 | # Recreate segment DataFrame with the appropriate `clock_time`s 64 | try: 65 | segment_df = pd.DataFrame(merged_array, index=self.clock_time) 66 | except: 67 | import pdb; pdb.set_trace() 68 | segment_df = None 69 | return segment_df 70 | 71 | class ChopInterface: 72 | """Chops data from NWBDatasets into segments with fixed overlap""" 73 | def __init__(self, 74 | window, 75 | overlap, 76 | max_offset=0, 77 | chop_margins=0, 78 | random_seed=None): 79 | """Initializes a ChopInterface 80 | 81 | Parameters 82 | ---------- 83 | window : int 84 | The length of chopped segments in ms 85 | overlap : int 86 | The overlap between chopped segments in ms 87 | max_offset : int, optional 88 | The maximum offset of the first chop from the beginning of 89 | each segment in ms. The actual offset will be chose 90 | randomly. By default, 0 adds no offset 91 | chop_margins : int, optional 92 | The size of extra margins to add to either end of each chop 93 | in bins, designed for use with the temporal_shift operation, 94 | by default 0 95 | random_seed : int, optional 96 | The random seed for generating the dataset, by default None 97 | does not use a random seed 98 | """ 99 | def to_timedelta(ms): 100 | return pd.to_timedelta(ms, unit='ms') 101 | 102 | self.window = to_timedelta(window) 103 | self.overlap = to_timedelta(overlap) 104 | self.max_offset = to_timedelta(max_offset) 105 | self.chop_margins = chop_margins 106 | self.random_seed = random_seed 107 | 108 | def chop(self, neural_df, chop_fields): 109 | """Chops a trialized or continuous RDS DataFrame. 110 | 111 | Parameters 112 | ---------- 113 | neural_df : pd.DataFrame 114 | A continuous or trialized DataFrame from RDS. 115 | chop_fields : str or list of str 116 | `signal_type` or list of `signal_type`s in neural_df to chop 117 | 118 | Returns 119 | ------- 120 | dict of np.array 121 | A data_dict of the chopped data. Consists of a dictionary 122 | with data tags mapping to 3D numpy arrays with dimensions 123 | corresponding to samples x time x features. 124 | """ 125 | # Set the random seed for the offset 126 | if self.random_seed is not None: 127 | np.random.seed(self.random_seed) 128 | 129 | if type(chop_fields) != list: 130 | chop_fields = [chop_fields] 131 | 132 | # Get info about the column groups to be chopped 133 | data_fields = sorted(chop_fields) 134 | get_field_dim = lambda field: getattr(neural_df, field).shape[1] if len(getattr(neural_df, field).shape) > 1 else 1 135 | data_dims = [get_field_dim(f) for f in data_fields] 136 | data_splits = np.cumsum(data_dims[:-1]) 137 | # Report information about the fields that are being chopped 138 | logger.info(f'Chopping data field(s) {data_fields} with dimension(s) {data_dims}.') 139 | 140 | # Calculate bin widths and set up segments for chopping 141 | if 'trial_id' in neural_df: 142 | # Trialized data 143 | bin_width = neural_df.clock_time[1] - neural_df.clock_time[0] 144 | segments = neural_df.groupby('trial_id') 145 | else: 146 | # Continuous data 147 | bin_width = neural_df.index[1] - neural_df.index[0] 148 | if np.any(np.isnan(neural_df[chop_fields])): 149 | splits = np.where(neural_df[chop_fields].sum(axis=1, min_count=1).isna().diff())[0].tolist() + [len(neural_df)] 150 | segments = {n: neural_df.iloc[splits[n]:splits[n+1]].reset_index() for n in range(0, len(splits) - 1, 2)} 151 | else: 152 | segments = {1: neural_df.reset_index()}.items() 153 | 154 | # Calculate the number of bins to use for chopping parameters 155 | window = int(self.window / bin_width) 156 | overlap = int(self.overlap / bin_width) 157 | chop_margins_td = pd.to_timedelta( 158 | self.chop_margins * bin_width, unit='ms') 159 | 160 | # Get correct offset based on data type 161 | if 'trial_id' in neural_df: 162 | # Trialized data 163 | max_offset = int(self.max_offset / bin_width) 164 | max_offset_td = self.max_offset 165 | get_offset = lambda: np.random.randint(max_offset+1) 166 | else: 167 | # Continuous data 168 | max_offset = 0 169 | max_offset_td = pd.to_timedelta(max_offset) 170 | get_offset = lambda: 0 171 | if self.max_offset > pd.to_timedelta(0): 172 | # Doesn't make sense to use offset on continuous data 173 | logger.info("Ignoring offset for continuous data.") 174 | 175 | def to_ms(timedelta): 176 | return int(timedelta.total_seconds() * 1000) 177 | 178 | # Log information about the chopping to be performed 179 | chop_message = ' - '.join([ 180 | 'Chopping data', 181 | f'Window: {window} bins, {to_ms(self.window)} ms', 182 | f'Overlap: {overlap} bins, {to_ms(self.overlap)} ms', 183 | f'Max offset: {max_offset} bins, {to_ms(max_offset_td)} ms', 184 | f'Chop margins: {self.chop_margins} bins, {to_ms(chop_margins_td)} ms', 185 | ]) 186 | logger.info(chop_message) 187 | 188 | # Iterate through segments, which can be trials or continuous data 189 | data_dict = defaultdict(list) 190 | segment_records = [] 191 | for segment_id, segment_df in segments: 192 | # Get the data from all of the column groups to extract 193 | data_arrays = [getattr(segment_df, f).to_numpy() if len(getattr(segment_df, f).shape) > 1 194 | else getattr(segment_df, f).to_numpy()[:, None] for f in data_fields] 195 | # Concatenate all data types into a single segment array 196 | segment_array = np.concatenate(data_arrays, axis=1) 197 | if self.chop_margins > 0: 198 | # Add padding to segment if we are using chop margins 199 | seg_dim = segment_array.shape[1] 200 | pad = np.full((self.chop_margins, seg_dim), 0.0001) 201 | segment_array = np.concatenate([pad, segment_array, pad]) 202 | # Sample an offset for this segment 203 | offset = get_offset() 204 | # Chop all of the data in this segment 205 | chops = chop_data( 206 | segment_array, 207 | overlap + 2*self.chop_margins, 208 | window + 2*self.chop_margins, 209 | offset) 210 | # Split the chops back up into the original fields 211 | data_chops = np.split(chops, data_splits, axis=2) 212 | # Create the data_dict with LFADS input names 213 | for field, data_chop in zip(data_fields, data_chops): 214 | data_dict[field].append(data_chop) 215 | # Keep a record to represent each original segment 216 | seg_rec = SegmentRecord( 217 | segment_id, 218 | segment_df.clock_time, 219 | offset, 220 | len(chops), 221 | overlap) 222 | segment_records.append(seg_rec) 223 | # Store the information for reassembling segments 224 | self.segment_records = segment_records 225 | # Consolidate data from all segments into a single array 226 | data_dict = {name: np.concatenate(c) for name, c in data_dict.items()} 227 | # Report diagnostic info 228 | dict_key = list(data_dict.keys())[0] 229 | n_chops = len(data_dict[dict_key]) 230 | n_segments = len(segment_records) 231 | logger.info(f'Created {n_chops} chops from {n_segments} segment(s).') 232 | 233 | return data_dict 234 | 235 | def merge(self, chopped_data, smooth_pwr=2): 236 | """Merges chopped data to reconstruct the original input 237 | sequence 238 | 239 | Parameters 240 | ---------- 241 | chopped_data : dict 242 | Dict mapping the keys to chopped 3d numpy arrays 243 | smooth_pwr : float, optional 244 | The power to use for smoothing. See `merge_chops` 245 | function for more details, by default 2 246 | 247 | Returns 248 | ------- 249 | pd.DataFrame 250 | A merged DataFrame indexed by the clock time of the original 251 | chops. Columns are multiindexed using `fields_map`. 252 | Unmodeled data is indicated by NaNs. 253 | """ 254 | # Get the desired arrays from the output 255 | output_fields = sorted(chopped_data.keys()) 256 | output_arrays = [chopped_data[f] for f in output_fields] 257 | # Keep track of boundaries between the different signals 258 | output_dims = [a.shape[-1] for a in output_arrays] 259 | # Concatenate the output arrays for more efficient merging 260 | output_full = np.concatenate(output_arrays, axis=-1) 261 | # Get info for separating the chops related to each segment 262 | seg_splits = np.cumsum([s.n_chops for s in self.segment_records])[:-1] 263 | # Separate out the chops for each segment 264 | seg_chops = np.split(output_full, seg_splits, axis=0) 265 | # Reconstruct the segment DataFrames 266 | segment_dfs = [record.rebuild_segment(chops, smooth_pwr) \ 267 | for record, chops in zip(self.segment_records, seg_chops)] 268 | # Concatenate the segments with clock_time indices 269 | merged_df = pd.concat(segment_dfs) 270 | # Add multiindexed columns 271 | midx_tuples = [(sig, f'{i:04}') \ 272 | for sig, dim in zip(output_fields, output_dims) \ 273 | for i in range(dim)] 274 | merged_df.columns = pd.MultiIndex.from_tuples(midx_tuples) 275 | 276 | return merged_df 277 | 278 | 279 | def chop_data(data, overlap, window, offset=0): 280 | """Rearranges an array of continuous data into overlapping segments. 281 | 282 | This low-level function takes a 2-D array of features measured 283 | continuously through time and breaks it up into a 3-D array of 284 | partially overlapping time segments. 285 | 286 | Parameters 287 | ---------- 288 | data : np.ndarray 289 | A TxN numpy array of N features measured across T time points. 290 | overlap : int 291 | The number of points to overlap between subsequent segments. 292 | window : int 293 | The number of time points in each segment. 294 | Returns 295 | ------- 296 | np.ndarray 297 | An SxTxN numpy array of S overlapping segments spanning 298 | T time points with N features. 299 | 300 | See Also 301 | -------- 302 | chop.merge_chops : Performs the opposite of this operation. 303 | """ 304 | # Random offset breaks temporal connection between trials and chops 305 | offset_data = data[offset:] 306 | shape = ( 307 | int((offset_data.shape[0] - window)/(window - overlap)) + 1, 308 | window, 309 | offset_data.shape[-1], 310 | ) 311 | strides = ( 312 | offset_data.strides[0]*(window - overlap), 313 | offset_data.strides[0], 314 | offset_data.strides[1], 315 | ) 316 | chopped = np.lib.stride_tricks.as_strided( 317 | offset_data, shape=shape, strides=strides).copy().astype('f') 318 | return chopped 319 | 320 | 321 | def merge_chops(data, overlap, orig_len=None, smooth_pwr=2): 322 | """Merges an array of overlapping segments back into continuous data. 323 | This low-level function takes a 3-D array of partially overlapping 324 | time segments and merges it back into a 2-D array of features measured 325 | continuously through time. 326 | 327 | Parameters 328 | ---------- 329 | data : np.ndarray 330 | An SxTxN numpy array of S overlapping segments spanning 331 | T time points with N features. 332 | overlap : int 333 | The number of overlapping points between subsequent segments. 334 | orig_len : int, optional 335 | The original length of the continuous data, by default None 336 | will cause the length to depend on the input data. 337 | smooth_pwr : float, optional 338 | The power of smoothing. To keep only the ends of chops and 339 | discard the beginnings, use np.inf. To linearly blend the 340 | chops, use 1. Raising above 1 will increasingly prefer the 341 | ends of chops and lowering towards 0 will increasingly 342 | prefer the beginnings of chops (not recommended). To use 343 | only the beginnings of chops, use 0 (not recommended). By 344 | default, 2 slightly prefers the ends of segments. 345 | Returns 346 | ------- 347 | np.ndarray 348 | A TxN numpy array of N features measured across T time points. 349 | 350 | See Also 351 | -------- 352 | chop.chop_data : Performs the opposite of this operation. 353 | """ 354 | if smooth_pwr < 1: 355 | logger.warning('Using `smooth_pwr` < 1 for merging ' 356 | 'chops is not recommended.') 357 | 358 | merged = [] 359 | full_weight_len = data.shape[1] - 2*overlap 360 | # Create x-values for the ramp 361 | x = np.linspace(1/overlap, 1 - 1/overlap, overlap) \ 362 | if overlap != 0 else np.array([]) 363 | # Compute a power-function ramp to transition 364 | ramp = 1 - x ** smooth_pwr 365 | ramp = np.expand_dims(ramp, axis=-1) 366 | # Compute the indices to split up each chop 367 | split_ixs = np.cumsum([overlap, full_weight_len]) 368 | for i in range(len(data)): 369 | # Split the chop into overlapping and non-overlapping 370 | first, middle, last = np.split(data[i], split_ixs) 371 | # Ramp each chop and combine it with the previous chop 372 | if i == 0: 373 | last = last * ramp 374 | elif i == len(data) - 1: 375 | first = first * (1-ramp) + merged.pop(-1) 376 | else: 377 | first = first * (1-ramp) + merged.pop(-1) 378 | last = last * ramp 379 | # Track all of the chops in a list 380 | merged.extend([first, middle, last]) 381 | 382 | merged = np.concatenate(merged) 383 | # Indicate unmodeled data with NaNs 384 | if orig_len is not None and len(merged) < orig_len: 385 | nans = np.full((orig_len-len(merged), merged.shape[1]), np.nan) 386 | merged = np.concatenate([merged, nans]) 387 | 388 | return merged -------------------------------------------------------------------------------- /nlb_tools/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | from sklearn.linear_model import Ridge 4 | from scipy.stats import pearsonr 5 | from scipy.special import gammaln 6 | from sklearn.metrics import r2_score, explained_variance_score 7 | from sklearn.model_selection import GridSearchCV 8 | 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def evaluate(test_annotation_file, user_submission_file): 15 | """ 16 | Runs evaluation as it would be run on EvalAI servers 17 | 18 | Parameters 19 | ---------- 20 | test_annotation_file : str or dict 21 | Path to the eval target .h5 file or dict of eval target 22 | data to evaluate against 23 | user_submission_file : str 24 | Path to the .h5 file or dict with user 25 | rate predictions 26 | 27 | Returns 28 | ------- 29 | list 30 | List containing a dict for each dataset that was 31 | evaluated. Each dict contains the calculated metrics 32 | """ 33 | logger.info("Starting Evaluation.....") 34 | 35 | # define prefixes for scaling metrics 36 | scaling_tcount = { 37 | "mc_maze_large": "[500]", 38 | "mc_maze_medium": "[250]", 39 | "mc_maze_small": "[100]", 40 | } 41 | 42 | # read data from files 43 | if type(test_annotation_file) == str: 44 | target_data = h5py.File(test_annotation_file, "r") 45 | else: 46 | target_data = test_annotation_file 47 | if type(user_submission_file) == str: 48 | user_data = h5py.File(user_submission_file, "r") 49 | else: 50 | user_data = user_submission_file 51 | 52 | result_list = [] 53 | scaling_dict = {} 54 | scaling_dict_20 = {} 55 | # evaluate on datasets that are included in both submission and evaluation data 56 | for dataset in [ 57 | "mc_maze", 58 | "mc_rtt", 59 | "area2_bump", 60 | "dmfc_rsg", 61 | "mc_maze_large", 62 | "mc_maze_medium", 63 | "mc_maze_small", 64 | ]: 65 | for bin_size_ms, suf in zip([5, 20], ["", "_20"]): 66 | if (dataset + suf) not in user_data.keys(): 67 | continue 68 | dataset_name = dataset + suf 69 | logger.info(f"Evaluating {dataset_name}") 70 | result_dict = {} 71 | # check that both submission and evaluation dicts have data for this dataset 72 | if "eval_rates_heldout" not in user_data[dataset_name].keys(): 73 | continue 74 | elif ( 75 | dataset_name 76 | ) not in target_data.keys() or "eval_spikes_heldout" not in target_data[ 77 | dataset_name 78 | ].keys(): 79 | logger.warning(f"Evaluation data for {dataset_name} not found") 80 | continue 81 | 82 | # extract evaluation data 83 | eval_spikes_heldout = target_data[dataset_name]["eval_spikes_heldout"][ 84 | () 85 | ].astype("float") 86 | train_behavior = target_data[dataset_name]["train_behavior"][()].astype( 87 | "float" 88 | ) 89 | eval_behavior = target_data[dataset_name]["eval_behavior"][()].astype( 90 | "float" 91 | ) 92 | 93 | # extract submitted data 94 | eval_rates_heldin = user_data[dataset_name]["eval_rates_heldin"][()].astype( 95 | "float" 96 | ) 97 | eval_rates_heldout = user_data[dataset_name]["eval_rates_heldout"][ 98 | () 99 | ].astype("float") 100 | eval_rates = np.concatenate( 101 | [eval_rates_heldin, eval_rates_heldout], axis=-1 102 | ) 103 | 104 | # calculate co-smoothing bits per spike 105 | result_dict["co-bps"] = float( 106 | bits_per_spike(eval_rates_heldout, eval_spikes_heldout) 107 | ) 108 | 109 | if dataset == "dmfc_rsg": 110 | # Compute Pearson's r for the correlation between neural speed and tp 111 | result_dict["tp corr"] = speed_tp_correlation( 112 | eval_spikes_heldout, eval_rates, eval_behavior 113 | ) 114 | else: 115 | # extract train rates for regression 116 | train_rates_heldin = user_data[dataset_name]["train_rates_heldin"][ 117 | () 118 | ].astype("float") 119 | train_rates_heldout = user_data[dataset_name]["train_rates_heldout"][ 120 | () 121 | ].astype("float") 122 | train_rates = np.concatenate( 123 | [train_rates_heldin, train_rates_heldout], axis=-1 124 | ) 125 | # make decode mask if not provided 126 | if "train_decode_mask" in target_data[dataset_name].keys(): 127 | train_decode_mask = target_data[dataset_name]["train_decode_mask"][ 128 | () 129 | ] 130 | eval_decode_mask = target_data[dataset_name]["eval_decode_mask"][()] 131 | else: 132 | train_decode_mask = np.full(train_rates.shape[0], True)[:, None] 133 | eval_decode_mask = np.full(eval_rates.shape[0], True)[:, None] 134 | result_dict["vel R2"] = velocity_decoding( 135 | train_rates, 136 | train_behavior, 137 | train_decode_mask, 138 | eval_rates, 139 | eval_behavior, 140 | eval_decode_mask, 141 | ) 142 | if "psth" in target_data[dataset_name].keys(): 143 | # get PSTH information and evaluate 144 | psth = target_data[dataset_name]["psth"][()].astype("float") 145 | eval_cond_idx = target_data[dataset_name]["eval_cond_idx"][()] 146 | if "eval_jitter" in target_data[dataset_name].keys(): 147 | jitter = target_data[dataset_name]["eval_jitter"][()] 148 | else: 149 | jitter = np.zeros(eval_rates.shape[0]).astype(int) 150 | psth_r2 = eval_psth(psth, eval_rates, eval_cond_idx, jitter=jitter) 151 | result_dict["psth R2"] = float(psth_r2) 152 | 153 | if ( 154 | "eval_rates_heldin_forward" in user_data[dataset_name].keys() 155 | and "eval_spikes_heldin_forward" in target_data[dataset_name].keys() 156 | ): 157 | # extract forward prediction data 158 | eval_spikes_heldin_forward = target_data[dataset_name][ 159 | "eval_spikes_heldin_forward" 160 | ][()].astype("float") 161 | eval_spikes_heldout_forward = target_data[dataset_name][ 162 | "eval_spikes_heldout_forward" 163 | ][()].astype("float") 164 | eval_rates_heldin_forward = user_data[dataset_name][ 165 | "eval_rates_heldin_forward" 166 | ][()].astype("float") 167 | eval_rates_heldout_forward = user_data[dataset_name][ 168 | "eval_rates_heldout_forward" 169 | ][()].astype("float") 170 | # combine held-in and held-out 171 | eval_spikes_forward = np.dstack( 172 | [eval_spikes_heldin_forward, eval_spikes_heldout_forward] 173 | ) 174 | eval_rates_forward = np.dstack( 175 | [eval_rates_heldin_forward, eval_rates_heldout_forward] 176 | ) 177 | # calculate forward prediction bits per spike 178 | result_dict["fp-bps"] = float( 179 | bits_per_spike(eval_rates_forward, eval_spikes_forward) 180 | ) 181 | 182 | if dataset in ["mc_maze_large", "mc_maze_medium", "mc_maze_small"]: 183 | sd = scaling_dict if suf == "" else scaling_dict_20 184 | for key, val in result_dict.items(): 185 | sd[scaling_tcount[dataset] + " " + key] = val 186 | elif dataset in ["mc_maze", "mc_rtt", "area2_bump", "dmfc_rsg"]: 187 | result_list.append({f"{dataset_name}_split": result_dict}) 188 | 189 | # put scaling data in proper split 190 | if len(scaling_dict) > 0: 191 | result_list.append({"mc_maze_scaling_split": scaling_dict}) 192 | if len(scaling_dict_20) > 0: 193 | result_list.append({"mc_maze_scaling_20_split": scaling_dict_20}) 194 | 195 | logger.info("Completed evaluation") 196 | 197 | try: 198 | target_data.close() 199 | except: 200 | pass 201 | try: 202 | user_data.close() 203 | except: 204 | pass 205 | 206 | return result_list 207 | 208 | 209 | def neg_log_likelihood(rates, spikes, zero_warning=True): 210 | """Calculates Poisson negative log likelihood given rates and spikes. 211 | formula: -log(e^(-r) / n! * r^n) 212 | = r - n*log(r) + log(n!) 213 | 214 | Parameters 215 | ---------- 216 | rates : np.ndarray 217 | numpy array containing rate predictions 218 | spikes : np.ndarray 219 | numpy array containing true spike counts 220 | zero_warning : bool, optional 221 | Whether to print out warning about 0 rate 222 | predictions or not 223 | 224 | Returns 225 | ------- 226 | float 227 | Total negative log-likelihood of the data 228 | """ 229 | assert ( 230 | spikes.shape == rates.shape 231 | ), f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}" 232 | 233 | if np.any(np.isnan(spikes)): 234 | mask = np.isnan(spikes) 235 | rates = rates[~mask] 236 | spikes = spikes[~mask] 237 | 238 | assert not np.any(np.isnan(rates)), "neg_log_likelihood: NaN rate predictions found" 239 | 240 | assert np.all(rates >= 0), "neg_log_likelihood: Negative rate predictions found" 241 | if np.any(rates == 0): 242 | if zero_warning: 243 | logger.warning( 244 | "neg_log_likelihood: Zero rate predictions found. Replacing zeros with 1e-9" 245 | ) 246 | rates[rates == 0] = 1e-9 247 | 248 | result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0) 249 | return np.sum(result) 250 | 251 | 252 | def bits_per_spike(rates, spikes): 253 | """Computes bits per spike of rate predictions given spikes. 254 | Bits per spike is equal to the difference between the log-likelihoods (in base 2) 255 | of the rate predictions and the null model (i.e. predicting mean firing rate of each neuron) 256 | divided by the total number of spikes. 257 | 258 | Parameters 259 | ---------- 260 | rates : np.ndarray 261 | 3d numpy array containing rate predictions 262 | spikes : np.ndarray 263 | 3d numpy array containing true spike counts 264 | 265 | Returns 266 | ------- 267 | float 268 | Bits per spike of rate predictions 269 | """ 270 | nll_model = neg_log_likelihood(rates, spikes) 271 | null_rates = np.tile( 272 | np.nanmean(spikes, axis=tuple(range(spikes.ndim - 1)), keepdims=True), 273 | spikes.shape[:-1] + (1,), 274 | ) 275 | nll_null = neg_log_likelihood(null_rates, spikes, zero_warning=False) 276 | return (nll_null - nll_model) / np.nansum(spikes) / np.log(2) 277 | 278 | 279 | def fit_and_eval_decoder( 280 | train_rates, 281 | train_behavior, 282 | eval_rates, 283 | eval_behavior, 284 | grid_search=True, 285 | ): 286 | """Fits ridge regression on train data passed 287 | in and evaluates on eval data 288 | 289 | Parameters 290 | ---------- 291 | train_rates : np.ndarray 292 | 2d array with 1st dimension being samples (time) and 293 | 2nd dimension being input variables (units). 294 | Used to train regressor 295 | train_behavior : np.ndarray 296 | 2d array with 1st dimension being samples (time) and 297 | 2nd dimension being output variables (channels). 298 | Used to train regressor 299 | eval_rates : np.ndarray 300 | 2d array with same dimension ordering as train_rates. 301 | Used to evaluate regressor 302 | eval_behavior : np.ndarray 303 | 2d array with same dimension ordering as train_behavior. 304 | Used to evaluate regressor 305 | grid_search : bool 306 | Whether to perform a cross-validated grid search to find 307 | the best regularization hyperparameters. 308 | 309 | Returns 310 | ------- 311 | float 312 | R2 score on eval data 313 | """ 314 | if np.any(np.isnan(train_behavior)): 315 | train_rates = train_rates[~np.isnan(train_behavior)[:, 0]] 316 | train_behavior = train_behavior[~np.isnan(train_behavior)[:, 0]] 317 | if np.any(np.isnan(eval_behavior)): 318 | eval_rates = eval_rates[~np.isnan(eval_behavior)[:, 0]] 319 | eval_behavior = eval_behavior[~np.isnan(eval_behavior)[:, 0]] 320 | assert not np.any(np.isnan(train_rates)) and not np.any( 321 | np.isnan(eval_rates) 322 | ), "fit_and_eval_decoder: NaNs found in rate predictions within required trial times" 323 | 324 | if grid_search: 325 | decoder = GridSearchCV(Ridge(), {"alpha": np.logspace(-4, 0, 9)}) 326 | else: 327 | decoder = Ridge(alpha=1e-2) 328 | decoder.fit(train_rates, train_behavior) 329 | return decoder.score(eval_rates, eval_behavior) 330 | 331 | 332 | def eval_psth(psth, eval_rates, eval_cond_idx, jitter=None): 333 | """Evaluates match to PSTH across conditions 334 | Parameters 335 | ---------- 336 | psth : np.ndarray 337 | 3d array, with dimensions condition x time x neuron, 338 | containing PSTHs for each unit in each condition 339 | eval_rates : np.ndarray 340 | 3d array, with dimensions trial x time x neuron, 341 | containing rate predictions for all test split trials 342 | eval_cond_idx : list of np.array 343 | List of arrays containing indices of test trials 344 | corresponding to conditions in `psth` 345 | jitter : np.ndarray, optional 346 | 1d array containing jitter applied to each eval trial 347 | 348 | Returns 349 | ------- 350 | float 351 | R2 of PSTHs computed from rate predictions 352 | to true PSTHs across all conditions, averaged 353 | across neurons 354 | """ 355 | jitter_trial = ( 356 | lambda x: x[0] 357 | if x[1] == 0 358 | else np.vstack([np.full((x[1], x[0].shape[1]), np.nan), x[0][: -x[1]]]) 359 | if x[1] > 0 360 | else np.vstack([x[0][-x[1] :], np.full((-x[1], x[0].shape[1]), np.nan)]) 361 | ) 362 | if jitter is None: 363 | jitter = np.zeros(eval_rates.shape[0]).astype(int) 364 | true_list = [] 365 | pred_list = [] 366 | for i in range(len(eval_cond_idx)): 367 | if eval_cond_idx[i].size == 0: 368 | continue 369 | pred_psth = np.mean( 370 | [jitter_trial((eval_rates[idx], jitter[idx])) for idx in eval_cond_idx[i]], 371 | axis=0, 372 | ) 373 | true_psth = psth[i, :, :][~np.isnan(psth[i, :, 0])] 374 | pred_psth = pred_psth[~np.isnan(psth[i, :, 0])] 375 | assert not np.any( 376 | np.isnan(pred_psth) 377 | ), "eval_psth: NaNs found in rate predictions within required trial times" 378 | true_list.append(true_psth) 379 | pred_list.append(pred_psth) 380 | 381 | true_psth = np.vstack(true_list) 382 | pred_psth = np.vstack(pred_list) 383 | return r2_score(true_psth, pred_psth) 384 | 385 | 386 | def speed_tp_correlation(eval_spikes_heldout, eval_rates, eval_behavior): 387 | """Computes speed-tp correlation for DMFC datasets. 388 | 389 | Parameters 390 | ---------- 391 | eval_spikes_heldout : np.ndarray 392 | 3d array, with dimensions trial x time x neuron, 393 | containing heldout spikes for all test split trials. Contains NaNs 394 | during non set-go time points. 395 | eval_rates : np.ndarray 396 | 3d array, with dimensions trial x time x neuron, 397 | containing rate predictions for all test split trials 398 | eval_behavior : np.ndarray 399 | 2d array with same dimension ordering as train_behavior. 400 | Used to determine conditions and evaluate correlation coefficient 401 | 402 | Returns 403 | ------- 404 | float 405 | Pearson's r between neural speed computed from rates and actual 406 | produced time interval tp, averaged across conditions. 407 | """ 408 | 409 | # Find NaNs that indicate data outside of the set-go period 410 | masks = ~np.isnan(eval_spikes_heldout[..., 0]) 411 | # Compute neural speed during the set-go period for each trial 412 | def compute_speed(trial): 413 | return np.mean(np.linalg.norm(np.diff(trial, axis=0), axis=1)) 414 | 415 | eval_speeds = [compute_speed(trial[mask]) for trial, mask in zip(eval_rates, masks)] 416 | eval_speeds = np.array(eval_speeds) 417 | # Compute correlation within each condition 418 | decoding_rs = [] 419 | # conditions based only off prior, response modality, and direction 420 | # because there aren't many trials for each t_s in the test split 421 | # (behavior columns are in `make_tensors.py`) 422 | cond_cols = [0, 1, 2] 423 | # Get unique conditions and ignore NaNs 424 | conds = np.unique(eval_behavior[:, cond_cols], axis=0) 425 | conds = conds[~np.all(np.isnan(conds), axis=1)] 426 | for cond in conds: 427 | cmask = np.all(eval_behavior[:, cond_cols] == cond, axis=1) 428 | cond_eval_speeds = eval_speeds[cmask] 429 | cond_eval_tp = eval_behavior[cmask][:, -1] 430 | cond_r, _ = pearsonr(cond_eval_speeds, cond_eval_tp) 431 | decoding_rs.append(cond_r) 432 | decoding_r = np.mean(decoding_rs) 433 | 434 | return decoding_r 435 | 436 | 437 | def velocity_decoding( 438 | train_rates, 439 | train_behavior, 440 | train_decode_mask, 441 | eval_rates, 442 | eval_behavior, 443 | eval_decode_mask, 444 | grid_search=True, 445 | ): 446 | """Computes hand velocity decoding performance for mc_maze, area2_bump, and mc_rtt. 447 | 448 | Parameters 449 | ---------- 450 | train_rates : np.ndarray 451 | 3d array, with dimensions trial x time x neuron, 452 | containing rate predictions for all train split trials. 453 | train_behavior : np.ndarray 454 | 3d array, with dimensions trial x time x 2, containing x and y hand velocity 455 | for all train split trials. 456 | train_decode_mask : np.ndarray 457 | 2d array, with dimensions trial x n_masks, containing masks that group trials 458 | with the same decoder for all train split trials. 459 | eval_rates : np.ndarray 460 | 3d array, with dimensions trial x time x neuron, 461 | containing rate predictions for all test split trials. 462 | eval_behavior : np.ndarray 463 | 3d array, with dimensions trial x time x 2, containing x and y hand velocity 464 | for all test split trials. 465 | eval_decode_mask : np.ndarray 466 | 2d array, with dimensions trial x n_masks, containing masks that group trials 467 | with the same decoder for all test split trials. 468 | grid_search : bool, optional 469 | Whether to use a cross-validated grid search over the ridge regularization 470 | penalty, by default True 471 | 472 | Returns 473 | ------- 474 | float 475 | Average coefficient of determination for hand velocity decoding across masked 476 | groups. 477 | """ 478 | flatten3d = lambda x: x.reshape(-1, x.shape[2]) if (len(x.shape) > 2) else x 479 | decoding_r2s = [] 480 | # train/evaluate regression for each mask 481 | for i in range(train_decode_mask.shape[1]): 482 | decoding_r2 = fit_and_eval_decoder( 483 | flatten3d(train_rates[train_decode_mask[:, i]]), 484 | flatten3d(train_behavior[train_decode_mask[:, i]]), 485 | flatten3d(eval_rates[eval_decode_mask[:, i]]), 486 | flatten3d(eval_behavior[eval_decode_mask[:, i]]), 487 | grid_search=grid_search, 488 | ) 489 | decoding_r2s.append(decoding_r2) 490 | # average R2s across masks 491 | return np.mean(decoding_r2s) 492 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=1.0.0,<=1.3.4 2 | scipy>=1.1.0 3 | numpy 4 | scikit-learn 5 | h5py<4,>=2.9 6 | pynwb 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md") as f: 4 | long_description = f.read() 5 | 6 | with open("requirements.txt") as f: 7 | requirements = f.readlines() 8 | 9 | setup( 10 | name="nlb_tools", 11 | version="0.0.3", 12 | description="Python tools for participating in Neural Latents Benchmark '21", 13 | packages=find_packages(), 14 | install_requires=requirements, 15 | author="Felix Pei", 16 | classifiers=[ 17 | "Intended Audience :: Science/Research", 18 | "Operating System :: Microsoft :: Windows", 19 | "Operating System :: MacOS", 20 | "Operating System :: Unix", 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3.7", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | ], 26 | extras_require={ 27 | "dev": ["pytest", "dandi"], 28 | }, 29 | license="MIT", 30 | long_description=long_description, 31 | long_description_content_type="text/markdown", 32 | python_requires=">=3.7", 33 | setup_requires=["setuptools>=61.0.0", "wheel"], 34 | url="https://github.com/neurallatents/nlb_tools", 35 | ) 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neurallatents/nlb_tools/1ddc15f45b56388ff093d1396b7b87b36fa32a68/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | import pytest 4 | import numpy as np 5 | from pynwb import NWBHDF5IO, NWBFile 6 | from datetime import datetime, timezone 7 | from dandi.download import download 8 | 9 | from nlb_tools.nwb_interface import NWBDataset 10 | 11 | SEED = 0 12 | 13 | DATA_DIR = Path(os.path.dirname(__file__), "temp_data") 14 | DUMMY_RAW_FILE_NAME = "dummy.npz" 15 | DUMMY_NWB_TRAIN_FILE_NAME = "dummy_train.nwb" 16 | DUMMY_NWB_TEST_FILE_NAME = "dummy_test.nwb" 17 | DANDISET_URL = "https://dandiarchive.org/dandiset/000138" 18 | NLB_FILE_NAME = "000138/sub-Jenkins/sub-Jenkins_ses-large_desc-train_behavior+ecephys.nwb" 19 | 20 | DUMMY_BIN_SIZE = 0.001 21 | DUMMY_N_NEUR = 50 22 | DUMMY_TRIAL_LEN = 700 # to match MC_Maze 23 | DUMMY_FP_LEN = 200 # to match MC_Maze 24 | DUMMY_ALIGN_OFFSET = 250 # to match MC_Maze 25 | DUMMY_N_TRIALS = 10 26 | DUMMY_ITI = (150, 400) 27 | DUMMY_TRIAL_INFO = ['start_time', 'end_time', 'onset', 'split'] 28 | DUMMY_BEHAVIOR_LAG = 120 # to match MC_Maze 29 | DUMMY_SPIKE_THRESH = 0.98 30 | 31 | np.random.seed(SEED) 32 | 33 | # TODO: generate some random data to more thoroughly test all operations 34 | # @pytest.fixture(scope="session") 35 | # def dummy_true_filepath(): 36 | # # Make sure dir exists 37 | # if not os.path.exists(DATA_DIR): 38 | # os.mkdir(DATA_DIR) 39 | # # Choose trial info 40 | # t = 0 41 | # trial_data = np.empty((DUMMY_N_TRIALS, len(DUMMY_TRIAL_INFO))) 42 | # val_idx = np.random.choice(np.arange(DUMMY_N_TRIALS * 0.8)) 43 | # for i in range(DUMMY_N_TRIALS): 44 | # trial_data[i, 0] = t 45 | # trial_data[i, 1] = t + DUMMY_TRIAL_LEN + DUMMY_FP_LEN 46 | # trial_data[i, 2] = t + DUMMY_ALIGN_OFFSET 47 | # trial_data[i, 3] = 2 if (i > DUMMY_N_TRIALS * 0.8) else 1 if (i in val_idx) else 0 48 | # t += np.random.randint(*DUMMY_ITI) 49 | # spikes = np.random.random( 50 | # (t, DUMMY_N_NEUR)) > DUMMY_SPIKE_THRESH 51 | # spikes = spikes.astype(float) 52 | # behavior = np.stack([ 53 | # np.sin(np.arange(t) * 0.01 * np.pi), # pos x 54 | # np.cos(np.arange(t) * 0.01 * np.pi), # pos y 55 | # np.cos(np.arange(t) * 0.01 * np.pi), # vel x 56 | # -np.sin(np.arange(t) * 0.01 * np.pi), # vel y 57 | # ], axis=-1) 58 | # save_path = Path(DATA_DIR, DUMMY_RAW_FILE_NAME) 59 | # np.savez( 60 | # save_path, 61 | # spikes=spikes, 62 | # behavior=behavior, 63 | # trial_data=trial_data, 64 | # ) 65 | # return save_path 66 | 67 | # @pytest.fixture(scope="session") 68 | # def dummy_nwb_filepath(dummy_true_filepath): 69 | # if not os.path.exists(DATA_DIR): 70 | # os.mkdir(DATA_DIR) 71 | 72 | # nwb_train = NWBFile( 73 | # session_description='dummy train data for testing', 74 | # identifier='train', 75 | # session_start_time=datetime.now(timezone.utc), 76 | # ) 77 | # nwb_test = NWBFile( 78 | # session_description='dummy test data for testing', 79 | # identifier='test', 80 | # session_start_time=datetime.now(timezone.utc), 81 | # ) 82 | # raw_data = np.load(dummy_true_filepath) 83 | 84 | # spikes = raw_data['spikes'] 85 | # behavior = raw_data['behavior'] 86 | # trial_data = raw_data['trial_data'] 87 | 88 | # first_test = np.nonzero(trial_data[:, 3] == 2)[0][0] 89 | # test_start = trial_data[0, first_test] 90 | 91 | # train_spikes, test_spikes = np.split(spikes, [test_start]) 92 | # train_behavior, test_behavior = np.split(spikes, [test_start]) 93 | # train_spikes, test_spikes = np.split(spikes, [test_start]) 94 | 95 | @pytest.fixture(scope="session") 96 | def nlb_filepath(): 97 | if not os.path.exists(DATA_DIR): 98 | os.mkdir(DATA_DIR) 99 | download(urls=[DANDISET_URL], output_dir=DATA_DIR, existing='refresh') 100 | return Path(DATA_DIR, NLB_FILE_NAME) 101 | 102 | @pytest.fixture 103 | def nlb_dataset(nlb_filepath): 104 | dataset = NWBDataset(nlb_filepath) 105 | return dataset 106 | -------------------------------------------------------------------------------- /tests/test_evaluate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from scipy.special import gammaln 4 | 5 | from nlb_tools.evaluation import ( 6 | evaluate, 7 | neg_log_likelihood, 8 | bits_per_spike, 9 | fit_and_eval_decoder, 10 | eval_psth, 11 | speed_tp_correlation, 12 | velocity_decoding, 13 | ) 14 | 15 | 16 | # -- NLL and bits/spike ---------- 17 | 18 | 19 | def test_neg_log_likelihood(): 20 | """Test that NLL computation is correct""" 21 | # randomized test 22 | for _ in range(20): 23 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) 24 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 25 | 26 | expected_nll = np.sum(rates - spikes * np.log(rates) + gammaln(spikes + 1.0)) 27 | actual_nll = neg_log_likelihood(rates, spikes) 28 | assert np.isclose(expected_nll, actual_nll) 29 | 30 | 31 | def test_neg_log_likelihood_mismatched_shapes(): 32 | """Test that NLL computation fails when shapes don't match""" 33 | # randomized test 34 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 8)).astype(float) 35 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 36 | 37 | with pytest.raises(AssertionError): 38 | neg_log_likelihood(rates, spikes) 39 | 40 | 41 | def test_neg_log_likelihood_negative_rates(): 42 | """Test that NLL computation fials when rates are negative""" 43 | # randomized test 44 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 8)).astype(float) 45 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 46 | rates -= np.min(rates) + 5 # guarantee negative rates 47 | 48 | with pytest.raises(AssertionError): 49 | neg_log_likelihood(rates, spikes) 50 | 51 | 52 | def test_neg_log_likelihood_drop_nans(): 53 | """Test that NLL computation is correct when there are nans in either rates or spikes""" 54 | # randomized test 55 | for _ in range(20): 56 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) 57 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 58 | mask = np.random.rand(10, 100, 10) > 0.9 59 | spikes[mask] = np.nan 60 | if np.random.rand() > 0.5: # rates does not have to have nans 61 | rates[mask] = np.nan 62 | 63 | expected_nll = np.sum( 64 | rates[~mask] 65 | - spikes[~mask] * np.log(rates[~mask]) 66 | + gammaln(spikes[~mask] + 1.0) 67 | ) 68 | actual_nll = neg_log_likelihood(rates, spikes) 69 | assert np.isclose(expected_nll, actual_nll) 70 | 71 | 72 | def test_neg_log_likelihood_mismatched_nans(): 73 | """Test that NLL computation is correct""" 74 | # randomized test 75 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) 76 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 77 | mask = np.random.rand(10, 100, 10) 78 | # make sure spikes and rates have different nans 79 | spikes[mask < 0.1] = np.nan 80 | rates[mask > 0.9] = np.nan 81 | 82 | with pytest.raises(AssertionError): 83 | neg_log_likelihood(rates, spikes) 84 | 85 | 86 | def test_bits_per_spike(): 87 | for _ in range(20): 88 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) 89 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 90 | null_rates = np.tile( 91 | spikes.mean(axis=(0, 1), keepdims=True), 92 | (spikes.shape[0], spikes.shape[1], 1), 93 | ).squeeze() 94 | 95 | expected_rate_nll = np.sum( 96 | rates - spikes * np.log(rates) + gammaln(spikes + 1.0) 97 | ) 98 | expected_null_nll = np.sum( 99 | null_rates - spikes * np.log(null_rates) + gammaln(spikes + 1.0) 100 | ) 101 | expected_bps = ( 102 | (expected_null_nll - expected_rate_nll) / np.sum(spikes) / np.log(2) 103 | ) 104 | actual_bps = bits_per_spike(rates, spikes) 105 | assert np.isclose(expected_bps, actual_bps) 106 | 107 | 108 | def test_bits_per_spike_drop_nans(): 109 | for _ in range(20): 110 | spikes = np.random.randint(low=0, high=5, size=(10, 100, 10)).astype(float) 111 | rates = np.random.exponential(scale=1.0, size=(10, 100, 10)) 112 | mask = np.random.rand(10, 100, 10) > 0.9 113 | spikes[mask] = np.nan 114 | if np.random.rand() > 0.5: # rates does not have to have nans 115 | rates[mask] = np.nan 116 | null_rates = np.tile( 117 | np.nanmean(spikes, axis=(0, 1), keepdims=True), 118 | (spikes.shape[0], spikes.shape[1], 1), 119 | ).squeeze() 120 | 121 | expected_rate_nll = np.sum( 122 | rates[~mask] 123 | - spikes[~mask] * np.log(rates[~mask]) 124 | + gammaln(spikes[~mask] + 1.0) 125 | ) 126 | expected_null_nll = np.sum( 127 | null_rates[~mask] 128 | - spikes[~mask] * np.log(null_rates[~mask]) 129 | + gammaln(spikes[~mask] + 1.0) 130 | ) 131 | expected_bps = ( 132 | (expected_null_nll - expected_rate_nll) / np.nansum(spikes) / np.log(2) 133 | ) 134 | actual_bps = bits_per_spike(rates, spikes) 135 | assert np.isclose(expected_bps, actual_bps) 136 | 137 | 138 | # -- Ridge regression --------------- 139 | 140 | 141 | def test_fit_and_eval_decoder(): 142 | rng = np.random.default_rng(0) 143 | x = rng.standard_normal(size=(1000, 10)) 144 | y = x @ rng.standard_normal(size=(10, 2)) 145 | 146 | # noiseless should have high R^2 147 | score = fit_and_eval_decoder( 148 | train_rates=x[:800], 149 | train_behavior=y[:800], 150 | eval_rates=x[800:], 151 | eval_behavior=y[800:], 152 | ) 153 | assert score > 0.95 154 | 155 | # with noise should still have decent R^2 156 | y += rng.standard_normal(size=(1000, 2)) * 0.1 157 | score = fit_and_eval_decoder( 158 | train_rates=x[:800], 159 | train_behavior=y[:800], 160 | eval_rates=x[800:], 161 | eval_behavior=y[800:], 162 | ) 163 | assert score > 0.25 # arbitrary heuristic 164 | 165 | # regressing on noise should have poor R^2 166 | y = rng.standard_normal(size=(1000, 2)) 167 | score = fit_and_eval_decoder( 168 | train_rates=x[:800], 169 | train_behavior=y[:800], 170 | eval_rates=x[800:], 171 | eval_behavior=y[800:], 172 | ) 173 | assert score < 0.95 # arbitrary heuristic 174 | 175 | 176 | # -- PSTH evaluation 177 | 178 | # def test_eval_psth(): 179 | # return 180 | -------------------------------------------------------------------------------- /tests/test_make_tensors.py: -------------------------------------------------------------------------------- 1 | # TODO: write tests of main make_tensors functions -------------------------------------------------------------------------------- /tests/test_nlb.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import scipy.signal as signal 4 | from sklearn.linear_model import PoissonRegressor 5 | 6 | from nlb_tools.nwb_interface import NWBDataset 7 | from nlb_tools.make_tensors import ( 8 | make_train_input_tensors, 9 | make_eval_input_tensors, 10 | make_eval_target_tensors, 11 | ) 12 | from nlb_tools.evaluation import evaluate 13 | 14 | 15 | def fit_ridge_regression(train_input, test_input, train_target, alpha=0.0): 16 | # closed-form for better reproducibility 17 | train_input = np.concatenate( 18 | [train_input, np.ones((train_input.shape[0], 1))], axis=-1 19 | ) 20 | test_input = np.concatenate( 21 | [test_input, np.ones((test_input.shape[0], 1))], axis=-1 22 | ) 23 | w = np.linalg.pinv( 24 | train_input.T @ train_input + alpha * np.eye(train_input.shape[1]) 25 | ) @ (train_input.T @ train_target) 26 | train_pred = train_input @ w 27 | test_pred = test_input @ w 28 | return train_pred, test_pred 29 | 30 | 31 | def test_nlb_5ms(nlb_dataset): 32 | # Prepare data 33 | nlb_dataset.resample(5) 34 | # Make input tensors 35 | train_dict = make_train_input_tensors( 36 | nlb_dataset, "mc_maze_large", "train", save_file=False 37 | ) 38 | eval_dict = make_eval_input_tensors( 39 | nlb_dataset, "mc_maze_large", "val", save_file=False 40 | ) 41 | # Extract spikes 42 | train_spikes_heldin = train_dict["train_spikes_heldin"] 43 | train_spikes_heldout = train_dict["train_spikes_heldout"] 44 | eval_spikes_heldin = eval_dict["eval_spikes_heldin"] 45 | # Make target tensors 46 | target_dict = make_eval_target_tensors( 47 | nlb_dataset, "mc_maze_large", "train", "val", include_psth=True, save_file=False 48 | ) 49 | # 50ms std kernel 50 | window = signal.windows.gaussian(int(6 * 50 / 5), int(50 / 5), sym=True) 51 | window /= np.sum(window) 52 | 53 | def filt(x): 54 | return np.convolve(x, window, "same") 55 | 56 | # Prep useful things 57 | flatten2d = lambda x: x.reshape(-1, x.shape[-1]) 58 | log_offset = 1e-4 59 | tlen = train_spikes_heldin.shape[1] 60 | num_heldin = train_spikes_heldin.shape[2] 61 | num_heldout = train_spikes_heldout.shape[2] 62 | # Smooth spikes 63 | train_spksmth_heldin = np.apply_along_axis(filt, 1, train_spikes_heldin) 64 | eval_spksmth_heldin = np.apply_along_axis(filt, 1, eval_spikes_heldin) 65 | train_spksmth_heldout = np.apply_along_axis(filt, 1, train_spikes_heldout) 66 | # Prep for regression 67 | train_spksmth_heldin_s = flatten2d(train_spksmth_heldin) 68 | train_spksmth_heldout_s = flatten2d(train_spksmth_heldout) 69 | eval_spksmth_heldin_s = flatten2d(eval_spksmth_heldin) 70 | # Make lograte input 71 | train_lograte_heldin_s = np.log(train_spksmth_heldin_s + log_offset) 72 | eval_lograte_heldin_s = np.log(eval_spksmth_heldin_s + log_offset) 73 | # Regress 74 | def rectify(arr): 75 | return np.clip(arr, 1e-9, 1e20) 76 | 77 | train_spksmth_heldout_s, eval_spksmth_heldout_s = fit_ridge_regression( 78 | train_lograte_heldin_s, 79 | eval_lograte_heldin_s, 80 | train_spksmth_heldout_s, 81 | alpha=1e6, 82 | ) 83 | train_spksmth_heldout = rectify( 84 | train_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 85 | ) 86 | eval_spksmth_heldout = rectify( 87 | eval_spksmth_heldout_s.reshape((-1, tlen, num_heldout)) 88 | ) 89 | 90 | output_dict = { 91 | "mc_maze_large": { 92 | "train_rates_heldin": train_spksmth_heldin, 93 | "train_rates_heldout": train_spksmth_heldout, 94 | "eval_rates_heldin": eval_spksmth_heldin, 95 | "eval_rates_heldout": eval_spksmth_heldout, 96 | } 97 | } 98 | 99 | res = evaluate(target_dict, output_dict)[0]["mc_maze_scaling_split"] 100 | 101 | assert np.abs(res["[500] co-bps"] - 0.1528) < 1e-4, ( 102 | "Co-smoothing bits/spike does not match expected value. " 103 | + f"Expected: 0.1528. Result: {res['[500] co-bps']}" 104 | ) 105 | assert np.abs(res["[500] vel R2"] - 0.5224) < 1e-4, ( 106 | "Velocity decoding R^2 does not match expected value. " 107 | + f"Expected: 0.5224. Result: {res['[500] vel R2']}" 108 | ) 109 | assert np.abs(res["[500] psth R2"] - 0.3871) < 1e-4, ( 110 | "PSTH R^2 does not match expected value. " 111 | + f"Expected: 0.3871. Result: {res['[500] psth R2']}" 112 | ) 113 | -------------------------------------------------------------------------------- /tests/test_nwb_interface.py: -------------------------------------------------------------------------------- 1 | # TODO: write tests of main nwb_interface functions --------------------------------------------------------------------------------