├── .coveragerc ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── circle.yml ├── doc ├── Makefile ├── api │ ├── config.rst │ ├── io.rst │ ├── ml.rst │ ├── plot.rst │ └── utils.rst ├── conf.py ├── contributing.rst ├── getting-started.rst ├── images ├── index.rst ├── make.bat └── tutorial.rst ├── examples ├── README.txt ├── plot_crcns_dataset_example.py ├── plot_neural_coding_reward_example.py ├── plot_neuropixels_example.py ├── plot_neuropop_simul_example.py ├── plot_popvis_example.py └── plot_reaching_dataset_example.py ├── images ├── journal.pone.0160851.g002.PNG ├── nature14178-f2.jpg ├── psth_PMd_n91.png └── spykes-logo.png ├── setup.cfg ├── setup.py ├── spykes ├── __init__.py ├── config.py ├── io │ ├── __init__.py │ └── datasets.py ├── ml │ ├── __init__.py │ ├── neuropop.py │ ├── strf.py │ └── tensorflow │ │ ├── __init__.py │ │ ├── poisson_models.py │ │ └── sparse_filtering.py ├── plot │ ├── __init__.py │ ├── neurovis.py │ └── popvis.py └── utils.py └── tests ├── __init__.py ├── io └── __init__.py ├── ml ├── __init__.py ├── tensorflow │ ├── __init__.py │ ├── test_poisson_models.py │ └── test_sparse_filtering.py ├── test_neuropop.py └── test_strf.py ├── plot ├── __init__.py ├── test_neurovis.py └── test_popvis.py ├── test_config.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = spykes 4 | include = */spykes/* 5 | omit = 6 | */setup.py 7 | */spykes/io/datasets.py 8 | */__init__.py 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.pyc 3 | *.swp 4 | temp/* 5 | dist/* 6 | build/* 7 | 8 | # documentation 9 | doc/_build/ 10 | doc/auto_examples/ 11 | 12 | # development 13 | spykes.egg-info/ 14 | notebooks/ 15 | 16 | # test-related 17 | .coverage 18 | .cache 19 | 20 | # developer environments 21 | .idea 22 | .vscode 23 | 24 | # virtual environments 25 | venv/ 26 | venvs/ 27 | 28 | # data 29 | *.mat 30 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | env: 3 | - PYTHON=2.7 4 | - PYTHON=3.5 5 | install: 6 | - pip install -e .[develop] 7 | - pip install coveralls 8 | script: 9 | - python setup.py test 10 | - python setup.py flake 11 | after_success: 12 | - coveralls 13 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Guidelines 2 | 3 | - Check the [Gitter](https://gitter.im/KordingLab/spykes) to discuss the issue 4 | - Please follow [Google's Python style guide](https://google.github.io/styleguide/pyguide.html), mostly especially for docstrings. You will probably notice if you've made mistakes with docstrings because it won't render correctly when you build the documentation. 5 | 6 | # Testing 7 | 8 | To run testing and build documentation, install the development dependencies: 9 | 10 | ```bash 11 | pip install -e .[develop] 12 | ``` 13 | 14 | Make sure that tests are passing for both Python 2.7 and Python 3.6: 15 | 16 | ```bash 17 | python setup.py test # Unit tests 18 | python setup.py flake # Linting 19 | ``` 20 | 21 | Making sure these steps pass will help the continuous integration step go smoothly. 22 | 23 | # Building Documentation 24 | 25 | To build the documentation locally, run 26 | 27 | ```bash 28 | cd doc/ 29 | make html 30 | ``` 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Pavan Ramkumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spykes 2 | 3 | [![License](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](https://github.com/KordingLab/spykes/blob/master/LICENSE) [![Join the chat at https://gitter.im/KordingLab/spykes](https://badges.gitter.im/KordingLab/spykes.svg)](https://gitter.im/KordingLab/spykes?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 4 | [![Travis](https://api.travis-ci.org/pavanramkumar/pyglmnet.svg?branch=master "Travis")](https://travis-ci.org/KordingLab/spykes) 5 | [![Circle](https://circleci.com/gh/KordingLab/spykes/tree/master.svg?style=shield&circle-token=:circle-token)](https://circleci.com/gh/KordingLab/spykes/tree/master.svg?style=shield&circle-token=:circle-token) 6 | [![Coverage Status](https://coveralls.io/repos/github/KordingLab/spykes/badge.svg?branch=master)](https://coveralls.io/github/KordingLab/spykes?branch=master) 7 | 8 | ![Spykes!](images/spykes-logo.png) 9 | 10 | Almost any electrophysiology study of neural spiking data relies on a battery of standard analyses. Raster plots and peri-stimulus time histograms aligned to stimuli and behavior provide a snapshot visual description of neural activity. Similarly, tuning curves are the most standard way to characterize how neurons encode stimuli or behavioral preferences. With increasing popularity of population recordings, maximum-likelihood decoders based on tuning models are becoming part of this standard. 11 | 12 | Yet, virtually every lab relies on a set of in-house analysis scripts to go from raw data to summaries. We want to improve this status quo in order to enable easier sharing, better reproducibility and fewer bugs. 13 | 14 | Spykes is a collection of Python tools to make the visualization and analysis of neural data easy and reproducible. 15 | 16 | For more, see the [documentation](http://kordinglab.com/spykes/getting-started.html). 17 | 18 | ### Installation 19 | 20 | Spykes can be installed using 21 | 22 | ``` 23 | pip install spykes 24 | ``` 25 | 26 | For more detailed installation options, see the [documentation](http://kordinglab.com/spykes/getting-started.html#installing). 27 | 28 | ### Authors 29 | 30 | - [Pavan Ramkumar](http:/github.com/pavanramkumar) 31 | - [Hugo Fernandes](http:/github.com/hugoguh) 32 | 33 | ### Acknowledgments 34 | 35 | * [Konrad Kording](http://kordinglab.com) for funding and support 36 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | machine: 2 | environment: 3 | # We need to set this variable to let Anaconda take precedence 4 | PATH: "/home/ubuntu/miniconda/envs/circleenv/bin:/home/ubuntu/miniconda/bin:$PATH" 5 | DISPLAY: ":99.0" 6 | 7 | dependencies: 8 | cache_directories: 9 | - "~/spykes_data" 10 | - "~/miniconda" 11 | # Various dependencies 12 | pre: 13 | # Get a running Python 14 | - cd ~; 15 | # Disable pyenv (no cleaner way provided by CircleCI as it prepends pyenv version to PATH) 16 | - rm -rf ~/.pyenv; 17 | - rm -rf ~/virtualenvs; 18 | # Get Anaconda and conda-based requirements 19 | - > 20 | if [ ! -d "/home/ubuntu/miniconda" ]; then 21 | echo "Setting up conda"; 22 | wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O ~/miniconda.sh; 23 | chmod +x ~/miniconda.sh; 24 | ~/miniconda.sh -b -p /home/ubuntu/miniconda; 25 | conda update --yes --quiet conda; 26 | conda create -n circleenv --yes pip python=2.7 pip; 27 | sed -i "s/ENABLE_USER_SITE = .*/ENABLE_USER_SITE = False/g" /home/ubuntu/miniconda/envs/circleenv/lib/python2.7/site.py; 28 | else 29 | echo "Conda already set up."; 30 | fi 31 | - conda install -n circleenv --yes numpy scipy scikit-learn matplotlib sphinx pillow six IPython pandas; 32 | override: 33 | - pip install sphinx coverage 34 | - pip install sphinx-gallery m2r sphinx_bootstrap_theme sphinx_rtd_theme 35 | - pip install deepdish requests 36 | - pip install -e . 37 | # we need to do this here so the datasets will be cached 38 | # pipefail is necessary to propagate exit codes 39 | - set -o pipefail && cd doc && SPYKES_DATA=~/spykes_data/ make html_dev-pattern PATTERN=plot_\\\(?\\!crcns_dataset_example\\\) 2>&1 | tee ~/log.txt 40 | 41 | test: 42 | override: 43 | # workaround - make html returns 0 even if examples fail to build 44 | # (see https://github.com/sphinx-gallery/sphinx-gallery/issues/45) 45 | - cat ~/log.txt && if grep -q "Traceback (most recent call last):" ~/log.txt; then false; else true; fi 46 | 47 | general: 48 | artifacts: 49 | - "doc/_build/html" 50 | - "coverage" 51 | - "~/log.txt" 52 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don\'t have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " epub3 to make an epub3" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | @echo " dummy to check syntax errors of document sources" 51 | 52 | .PHONY: clean 53 | clean: 54 | rm -rf $(BUILDDIR)/* 55 | rm -rf $ auto_examples/ 56 | rm -rf $ modules/ 57 | 58 | 59 | .PHONY: html 60 | html: 61 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 62 | @echo 63 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 64 | 65 | html_dev-pattern: 66 | BUILD_DEV_HTML=1 $(SPHINXBUILD) -D plot_gallery=1 -D raise_gallery=1 -D abort_on_example_error=1 -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -b html $(ALLSPHINXOPTS) _build/html 67 | @echo 68 | @echo "Build finished. The HTML pages are in _build/html" 69 | 70 | .PHONY: dirhtml 71 | dirhtml: 72 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 73 | @echo 74 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 75 | 76 | .PHONY: singlehtml 77 | singlehtml: 78 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 79 | @echo 80 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 81 | 82 | .PHONY: pickle 83 | pickle: 84 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 85 | @echo 86 | @echo "Build finished; now you can process the pickle files." 87 | 88 | .PHONY: json 89 | json: 90 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 91 | @echo 92 | @echo "Build finished; now you can process the JSON files." 93 | 94 | .PHONY: htmlhelp 95 | htmlhelp: 96 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 97 | @echo 98 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 99 | ".hhp project file in $(BUILDDIR)/htmlhelp." 100 | 101 | .PHONY: qthelp 102 | qthelp: 103 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 104 | @echo 105 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 106 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 107 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/spykes.qhcp" 108 | @echo "To view the help file:" 109 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/spykes.qhc" 110 | 111 | .PHONY: applehelp 112 | applehelp: 113 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 114 | @echo 115 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 116 | @echo "N.B. You won't be able to view it unless you put it in" \ 117 | "~/Library/Documentation/Help or install it in your application" \ 118 | "bundle." 119 | 120 | .PHONY: devhelp 121 | devhelp: 122 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 123 | @echo 124 | @echo "Build finished." 125 | @echo "To view the help file:" 126 | @echo "# mkdir -p $$HOME/.local/share/devhelp/spykes" 127 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/spykes" 128 | @echo "# devhelp" 129 | 130 | .PHONY: epub 131 | epub: 132 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 133 | @echo 134 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 135 | 136 | .PHONY: epub3 137 | epub3: 138 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 139 | @echo 140 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." 141 | 142 | .PHONY: latex 143 | latex: 144 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 145 | @echo 146 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 147 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 148 | "(use \`make latexpdf' here to do that automatically)." 149 | 150 | .PHONY: latexpdf 151 | latexpdf: 152 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 153 | @echo "Running LaTeX files through pdflatex..." 154 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 155 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 156 | 157 | .PHONY: latexpdfja 158 | latexpdfja: 159 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 160 | @echo "Running LaTeX files through platex and dvipdfmx..." 161 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 162 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 163 | 164 | .PHONY: text 165 | text: 166 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 167 | @echo 168 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 169 | 170 | .PHONY: man 171 | man: 172 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 173 | @echo 174 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 175 | 176 | .PHONY: texinfo 177 | texinfo: 178 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 179 | @echo 180 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 181 | @echo "Run \`make' in that directory to run these through makeinfo" \ 182 | "(use \`make info' here to do that automatically)." 183 | 184 | .PHONY: info 185 | info: 186 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 187 | @echo "Running Texinfo files through makeinfo..." 188 | make -C $(BUILDDIR)/texinfo info 189 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 190 | 191 | .PHONY: gettext 192 | gettext: 193 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 194 | @echo 195 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 196 | 197 | .PHONY: changes 198 | changes: 199 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 200 | @echo 201 | @echo "The overview file is in $(BUILDDIR)/changes." 202 | 203 | .PHONY: linkcheck 204 | linkcheck: 205 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 206 | @echo 207 | @echo "Link check complete; look for any errors in the above output " \ 208 | "or in $(BUILDDIR)/linkcheck/output.txt." 209 | 210 | .PHONY: doctest 211 | doctest: 212 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 213 | @echo "Testing of doctests in the sources finished, look at the " \ 214 | "results in $(BUILDDIR)/doctest/output.txt." 215 | 216 | .PHONY: coverage 217 | coverage: 218 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 219 | @echo "Testing of coverage in the sources finished, look at the " \ 220 | "results in $(BUILDDIR)/coverage/python.txt." 221 | 222 | .PHONY: xml 223 | xml: 224 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 225 | @echo 226 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 227 | 228 | .PHONY: pseudoxml 229 | pseudoxml: 230 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 231 | @echo 232 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 233 | 234 | .PHONY: dummy 235 | dummy: 236 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy 237 | @echo 238 | @echo "Build finished. Dummy builder generates no files." 239 | 240 | install: 241 | rm -rf _build/doctrees _build/kordinglab.github.io 242 | # first clone the kordinglab/kordinglab.github.io repo because it may ask 243 | # for password and we don't want to delay this long build in 244 | # the middle of it 245 | # --no-checkout just fetches the root folder without content 246 | # --depth 1 is a speed optimization since we don't need the 247 | # history prior to the last commit 248 | # -b gh-pages fetches only the branch for the gh-pages 249 | git clone -b gh-pages --single-branch --no-checkout --depth 1 https://github.com/kordinglab/spykes _build/kordinglab.github.io 250 | touch _build/kordinglab.github.io/.nojekyll 251 | make html 252 | cd _build/ && \ 253 | cp -r html/* kordinglab.github.io && \ 254 | cd kordinglab.github.io && \ 255 | git add * && \ 256 | git add .nojekyll && \ 257 | git commit -a -m 'Make install' && \ 258 | git push 259 | -------------------------------------------------------------------------------- /doc/api/config.rst: -------------------------------------------------------------------------------- 1 | .. _config_documentation: 2 | 3 | Config 4 | ------ 5 | 6 | .. automodule:: spykes.config 7 | :members: 8 | -------------------------------------------------------------------------------- /doc/api/io.rst: -------------------------------------------------------------------------------- 1 | .. _io_documentation: 2 | 3 | ============== 4 | Input / Output 5 | ============== 6 | 7 | Datasets 8 | -------- 9 | 10 | This submodule includes functions for loading datasets. By default, datasets are downloaded and stored in the :data:`~/.spykes` directory. This can be overridden by setting the :data:`SPYKES_DATA` environment variable to point to your own directory. 11 | 12 | .. automodule:: spykes.io.datasets 13 | :members: 14 | -------------------------------------------------------------------------------- /doc/api/ml.rst: -------------------------------------------------------------------------------- 1 | .. _ml_documentation: 2 | 3 | ================ 4 | Machine Learning 5 | ================ 6 | 7 | NeuroPop 8 | ~~~~~~~~ 9 | 10 | .. automodule:: spykes.ml.neuropop 11 | :members: 12 | 13 | STRF 14 | ~~~~ 15 | 16 | .. automodule:: spykes.ml.strf 17 | :members: 18 | 19 | Sparse Filtering 20 | ~~~~~~~~~~~~~~~~ 21 | 22 | This module makes use of Tensorflow. Make sure your system is configured correctly before using it. 23 | 24 | .. automodule:: spykes.ml.tensorflow.sparse_filtering 25 | :members: 26 | 27 | Poisson Layers 28 | ~~~~~~~~~~~~~~ 29 | 30 | This module provides a TensorFlow implementation of the Poisson estimators used in the NeuroPop modules. 31 | 32 | .. automodule:: spykes.ml.tensorflow.poisson_models 33 | :members: 34 | 35 | -------------------------------------------------------------------------------- /doc/api/plot.rst: -------------------------------------------------------------------------------- 1 | .. _plot_documentation: 2 | 3 | ======== 4 | Plotting 5 | ======== 6 | 7 | NeuroVis 8 | ~~~~~~~~ 9 | 10 | .. automodule:: spykes.plot.neurovis 11 | :members: 12 | 13 | PopVis 14 | ~~~~~~ 15 | 16 | .. automodule:: spykes.plot.popvis 17 | :members: 18 | -------------------------------------------------------------------------------- /doc/api/utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils_documentation: 2 | 3 | Utils 4 | ----- 5 | 6 | .. automodule:: spykes.utils 7 | :members: 8 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # spykes documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Nov 1 20:30:09 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | sys.path.insert(0, os.path.abspath(os.pardir)) 22 | 23 | # -- General configuration ------------------------------------------------ 24 | 25 | # If your documentation needs a minimal Sphinx version, state it here. 26 | #needs_sphinx = '1.0' 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | 'sphinx.ext.autodoc', 33 | 'sphinx.ext.doctest', 34 | 'sphinx.ext.todo', 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.viewcode', 37 | 'sphinx.ext.githubpages', 38 | 'sphinx.ext.autosummary', 39 | 'sphinx.ext.napoleon', 40 | 'm2r', 41 | 'sphinx_gallery.gen_gallery' 42 | ] 43 | 44 | # generate autosummary even if no references 45 | autosummary_generate = True 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # The suffix(es) of source filenames. 51 | # You can specify multiple suffix as a list of string: 52 | source_suffix = ['.rst', '.md'] 53 | # source_suffix = '.rst' 54 | 55 | # The encoding of source files. 56 | #source_encoding = 'utf-8-sig' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # General information about the project. 62 | project = u'spykes' 63 | copyright = u'2016, KordingLab' 64 | author = u'KordingLab' 65 | 66 | # The version info for the project you're documenting, acts as replacement for 67 | # |version| and |release|, also used in various other places throughout the 68 | # built documents. 69 | # 70 | 71 | # The language for content autogenerated by Sphinx. Refer to documentation 72 | # for a list of supported languages. 73 | # 74 | # This is also used if you do content translation via gettext catalogs. 75 | # Usually you set "language" from the command line for these cases. 76 | language = None 77 | 78 | # There are two options for replacing |today|: either, you set today to some 79 | # non-false value, then it is used: 80 | #today = '' 81 | # Else, today_fmt is used as the format for a strftime call. 82 | #today_fmt = '%B %d, %Y' 83 | 84 | # List of patterns, relative to source directory, that match files and 85 | # directories to ignore when looking for source files. 86 | # This patterns also effect to html_static_path and html_extra_path 87 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 88 | 89 | # The reST default role (used for this markup: `text`) to use for all 90 | # documents. 91 | #default_role = None 92 | 93 | # If true, '()' will be appended to :func: etc. cross-reference text. 94 | #add_function_parentheses = True 95 | 96 | # If true, the current module name will be prepended to all description 97 | # unit titles (such as .. function::). 98 | #add_module_names = True 99 | 100 | # If true, sectionauthor and moduleauthor directives will be shown in the 101 | # output. They are ignored by default. 102 | #show_authors = False 103 | 104 | # The name of the Pygments (syntax highlighting) style to use. 105 | pygments_style = 'sphinx' 106 | 107 | # A list of ignored prefixes for module index sorting. 108 | #modindex_common_prefix = [] 109 | 110 | # If true, keep warnings as "system message" paragraphs in the built documents. 111 | #keep_warnings = False 112 | 113 | # If true, `todo` and `todoList` produce output, else they produce nothing. 114 | todo_include_todos = True 115 | 116 | 117 | # -- Options for HTML output ---------------------------------------------- 118 | 119 | # The theme to use for HTML and HTML Help pages. See the documentation for 120 | # a list of builtin themes. 121 | html_theme = 'sphinx_rtd_theme' 122 | 123 | # Theme options are theme-specific and customize the look and feel of a theme 124 | # further. For a list of options available for each theme, see the 125 | # documentation. 126 | html_theme_options = { 127 | 'collapse_navigation': False, 128 | 'navigation_depth': 4, 129 | 'display_version': True, 130 | } 131 | 132 | # Add any paths that contain custom themes here, relative to this directory. 133 | # html_theme_path = [] 134 | 135 | # The name for this set of Sphinx documents. 136 | # " v documentation" by default. 137 | #html_title = u'spykes v0.1' 138 | 139 | # A shorter title for the navigation bar. Default is the same as html_title. 140 | #html_short_title = None 141 | 142 | # The name of an image file (relative to this directory) to place at the top 143 | # of the sidebar. 144 | #html_logo = None 145 | 146 | # The name of an image file (relative to this directory) to use as a favicon of 147 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 148 | # pixels large. 149 | #html_favicon = None 150 | 151 | # Add any paths that contain custom static files (such as style sheets) here, 152 | # relative to this directory. They are copied after the builtin static files, 153 | # so a file named "default.css" will overwrite the builtin "default.css". 154 | html_static_path = ['_static'] 155 | 156 | # Add any extra paths that contain custom files (such as robots.txt or 157 | # .htaccess) here, relative to this directory. These files are copied 158 | # directly to the root of the documentation. 159 | #html_extra_path = [] 160 | 161 | # If not None, a 'Last updated on:' timestamp is inserted at every page 162 | # bottom, using the given strftime format. 163 | # The empty string is equivalent to '%b %d, %Y'. 164 | #html_last_updated_fmt = None 165 | 166 | # If true, SmartyPants will be used to convert quotes and dashes to 167 | # typographically correct entities. 168 | #html_use_smartypants = True 169 | 170 | # Custom sidebar templates, maps document names to template names. 171 | #html_sidebars = {} 172 | 173 | # Additional templates that should be rendered to pages, maps page names to 174 | # template names. 175 | #html_additional_pages = {} 176 | 177 | # If false, no module index is generated. 178 | #html_domain_indices = True 179 | 180 | # If false, no index is generated. 181 | #html_use_index = True 182 | 183 | # If true, the index is split into individual pages for each letter. 184 | #html_split_index = False 185 | 186 | # If true, links to the reST sources are added to the pages. 187 | #html_show_sourcelink = True 188 | 189 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 190 | #html_show_sphinx = True 191 | 192 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 193 | #html_show_copyright = True 194 | 195 | # If true, an OpenSearch description file will be output, and all pages will 196 | # contain a tag referring to it. The value of this option must be the 197 | # base URL from which the finished HTML is served. 198 | #html_use_opensearch = '' 199 | 200 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 201 | #html_file_suffix = None 202 | 203 | # Language to be used for generating the HTML full-text search index. 204 | # Sphinx supports the following languages: 205 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 206 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr', 'zh' 207 | #html_search_language = 'en' 208 | 209 | # A dictionary with options for the search language support, empty by default. 210 | # 'ja' uses this config value. 211 | # 'zh' user can custom change `jieba` dictionary path. 212 | #html_search_options = {'type': 'default'} 213 | 214 | # The name of a javascript file (relative to the configuration directory) that 215 | # implements a search results scorer. If empty, the default will be used. 216 | #html_search_scorer = 'scorer.js' 217 | 218 | # Output file base name for HTML help builder. 219 | htmlhelp_basename = 'spykesdoc' 220 | 221 | # -- Options for LaTeX output --------------------------------------------- 222 | 223 | latex_elements = { 224 | # The paper size ('letterpaper' or 'a4paper'). 225 | #'papersize': 'letterpaper', 226 | 227 | # The font size ('10pt', '11pt' or '12pt'). 228 | #'pointsize': '10pt', 229 | 230 | # Additional stuff for the LaTeX preamble. 231 | #'preamble': '', 232 | 233 | # Latex figure (float) alignment 234 | #'figure_align': 'htbp', 235 | } 236 | 237 | # Grouping the document tree into LaTeX files. List of tuples 238 | # (source start file, target name, title, 239 | # author, documentclass [howto, manual, or own class]). 240 | latex_documents = [ 241 | (master_doc, 'spykes.tex', u'spykes Documentation', 242 | u'KordingLab', 'manual'), 243 | ] 244 | 245 | # The name of an image file (relative to this directory) to place at the top of 246 | # the title page. 247 | #latex_logo = None 248 | 249 | # For "manual" documents, if this is true, then toplevel headings are parts, 250 | # not chapters. 251 | #latex_use_parts = False 252 | 253 | # If true, show page references after internal links. 254 | #latex_show_pagerefs = False 255 | 256 | # If true, show URL addresses after external links. 257 | #latex_show_urls = False 258 | 259 | # Documents to append as an appendix to all manuals. 260 | #latex_appendices = [] 261 | 262 | # If false, no module index is generated. 263 | #latex_domain_indices = True 264 | 265 | 266 | # -- Options for manual page output --------------------------------------- 267 | 268 | # One entry per manual page. List of tuples 269 | # (source start file, name, description, authors, manual section). 270 | man_pages = [ 271 | (master_doc, 'spykes', u'spykes Documentation', 272 | [author], 1) 273 | ] 274 | 275 | # If true, show URL addresses after external links. 276 | #man_show_urls = False 277 | 278 | 279 | # -- Options for Texinfo output ------------------------------------------- 280 | 281 | # Grouping the document tree into Texinfo files. List of tuples 282 | # (source start file, target name, title, author, 283 | # dir menu entry, description, category) 284 | texinfo_documents = [ 285 | ( 286 | master_doc, 287 | 'spykes', 288 | u'spykes Documentation', 289 | author, 290 | 'spykes', 291 | 'One line description of project.', 292 | 'Miscellaneous', 293 | ), 294 | ] 295 | 296 | # Documents to append as an appendix to all manuals. 297 | #texinfo_appendices = [] 298 | 299 | # If false, no module index is generated. 300 | #texinfo_domain_indices = True 301 | 302 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 303 | #texinfo_show_urls = 'footnote' 304 | 305 | # If true, do not generate a @detailmenu in the "Top" node's menu. 306 | #texinfo_no_detailmenu = False 307 | 308 | sphinx_gallery_conf = { 309 | 'examples_dirs': '../examples', 310 | 'gallery_dirs': 'auto_examples', 311 | 'backreferences_dir': False, 312 | } 313 | -------------------------------------------------------------------------------- /doc/contributing.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Spykes is under active development! See the contribution guidelines below for more info. 6 | 7 | .. mdinclude:: ../CONTRIBUTING.md 8 | -------------------------------------------------------------------------------- /doc/getting-started.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Getting Started 3 | =============== 4 | 5 | What is Spykes? 6 | --------------- 7 | 8 | Almost any electrophysiology study of neural spiking data relies on a battery of standard analyses. Raster plots and peri-stimulus time histograms aligned to stimuli and behavior provide a snapshot visual description of neural activity. Similarly, tuning curves are the most standard way to characterize how neurons encode stimuli or behavioral preferences. With increasing popularity of population recordings, maximum-likelihood decoders based on tuning models are becoming part of this standard. 9 | 10 | Yet, virtually every lab relies on a set of in-house analysis scripts to go from raw data to summaries. We want to improve this status quo in order to enable easier sharing, better reproducibility and fewer bugs. 11 | 12 | Spykes is a collection of Python tools to make the visualization and analysis of neural data easy and reproducible. 13 | 14 | At present, spykes comes with four classes: 15 | 16 | * :class:`NeuroVis` helps you plot beautiful spike rasters and peri-stimulus time histograms (PSTHs). 17 | * :class:`PopVis` helps you plot population summaries of PSTHs as normalized averages or heat maps. 18 | * :class:`NeuroPop` helps you estimate tuning curves of neural populations and decode stimuli from population vectors with maximum-likelihood decoding. 19 | * :class:`STRF` helps you estimate spatiotemporal receptive fields. 20 | 21 | Spykes deliberately does not aim to provide tools for spike sorting or file I/O with popular electrophysiology formats, but only aims to fill the missing niche for neural data analysis and easy visualization. For file I/O, see `Neo`_ and `OpenElectrophy`_. For spike sorting, see `Klusta`_. 22 | 23 | Installing 24 | ---------- 25 | 26 | For most cases (including following along with the examples) it is sufficient to just install the vanilla version. 27 | 28 | Vanilla 29 | ~~~~~~~ 30 | 31 | This installs the current version from PyPi. 32 | 33 | .. code-block:: bash 34 | 35 | pip install spykes 36 | 37 | Bleeding-Edge 38 | ~~~~~~~~~~~~~ 39 | 40 | This installs the most recent version from Github. 41 | 42 | .. code-block:: bash 43 | 44 | pip install git+git://github.com/KordingLab/spykes 45 | 46 | Local Version 47 | ~~~~~~~~~~~~~ 48 | 49 | This creates a local copy of the repo, where you can make changes to Spykes that get propagated to your project. 50 | 51 | .. code-block:: bash 52 | 53 | git clone http://github.com/KordingLab/spykes # Clone this somewhere useful 54 | pip install -e .[develop] 55 | 56 | Datasets 57 | -------- 58 | 59 | The examples use real datasets. Instructions for downloading these datasets are included in the notebooks. We recommend `deepdish`_ for reading the HDF5 datafile. 60 | 61 | .. _OpenElectrophy: http://neuralensemble.org/OpenElectrophy/ 62 | .. _Neo: http://neuralensemble.org/neo/ 63 | .. _Klusta: http://klusta.readthedocs.io/en/latest/ 64 | .. _deepdish: https://github.com/uchicago-cs/deepdish 65 | -------------------------------------------------------------------------------- /doc/images: -------------------------------------------------------------------------------- 1 | ../images -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../README.md 2 | 3 | Contents 4 | ======== 5 | 6 | .. toctree:: 7 | :maxdepth: -1 8 | 9 | getting-started 10 | tutorial 11 | auto_examples/index 12 | contributing 13 | 14 | .. toctree:: 15 | :maxdepth: -1 16 | :caption: API 17 | 18 | api/plot 19 | api/ml 20 | api/io 21 | api/config 22 | api/utils 23 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. epub3 to make an epub3 31 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 32 | echo. text to make text files 33 | echo. man to make manual pages 34 | echo. texinfo to make Texinfo files 35 | echo. gettext to make PO message catalogs 36 | echo. changes to make an overview over all changed/added/deprecated items 37 | echo. xml to make Docutils-native XML files 38 | echo. pseudoxml to make pseudoxml-XML files for display purposes 39 | echo. linkcheck to check all external links for integrity 40 | echo. doctest to run all doctests embedded in the documentation if enabled 41 | echo. coverage to run coverage check of the documentation if enabled 42 | echo. dummy to check syntax errors of document sources 43 | goto end 44 | ) 45 | 46 | if "%1" == "clean" ( 47 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 48 | del /q /s %BUILDDIR%\* 49 | goto end 50 | ) 51 | 52 | 53 | REM Check if sphinx-build is available and fallback to Python version if any 54 | %SPHINXBUILD% 1>NUL 2>NUL 55 | if errorlevel 9009 goto sphinx_python 56 | goto sphinx_ok 57 | 58 | :sphinx_python 59 | 60 | set SPHINXBUILD=python -m sphinx.__init__ 61 | %SPHINXBUILD% 2> nul 62 | if errorlevel 9009 ( 63 | echo. 64 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 65 | echo.installed, then set the SPHINXBUILD environment variable to point 66 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 67 | echo.may add the Sphinx directory to PATH. 68 | echo. 69 | echo.If you don't have Sphinx installed, grab it from 70 | echo.http://sphinx-doc.org/ 71 | exit /b 1 72 | ) 73 | 74 | :sphinx_ok 75 | 76 | 77 | if "%1" == "html" ( 78 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 79 | if errorlevel 1 exit /b 1 80 | echo. 81 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 82 | goto end 83 | ) 84 | 85 | if "%1" == "dirhtml" ( 86 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 87 | if errorlevel 1 exit /b 1 88 | echo. 89 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 90 | goto end 91 | ) 92 | 93 | if "%1" == "singlehtml" ( 94 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 95 | if errorlevel 1 exit /b 1 96 | echo. 97 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 98 | goto end 99 | ) 100 | 101 | if "%1" == "pickle" ( 102 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 103 | if errorlevel 1 exit /b 1 104 | echo. 105 | echo.Build finished; now you can process the pickle files. 106 | goto end 107 | ) 108 | 109 | if "%1" == "json" ( 110 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 111 | if errorlevel 1 exit /b 1 112 | echo. 113 | echo.Build finished; now you can process the JSON files. 114 | goto end 115 | ) 116 | 117 | if "%1" == "htmlhelp" ( 118 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 119 | if errorlevel 1 exit /b 1 120 | echo. 121 | echo.Build finished; now you can run HTML Help Workshop with the ^ 122 | .hhp project file in %BUILDDIR%/htmlhelp. 123 | goto end 124 | ) 125 | 126 | if "%1" == "qthelp" ( 127 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 128 | if errorlevel 1 exit /b 1 129 | echo. 130 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 131 | .qhcp project file in %BUILDDIR%/qthelp, like this: 132 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\spykes.qhcp 133 | echo.To view the help file: 134 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\spykes.ghc 135 | goto end 136 | ) 137 | 138 | if "%1" == "devhelp" ( 139 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 140 | if errorlevel 1 exit /b 1 141 | echo. 142 | echo.Build finished. 143 | goto end 144 | ) 145 | 146 | if "%1" == "epub" ( 147 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 148 | if errorlevel 1 exit /b 1 149 | echo. 150 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 151 | goto end 152 | ) 153 | 154 | if "%1" == "epub3" ( 155 | %SPHINXBUILD% -b epub3 %ALLSPHINXOPTS% %BUILDDIR%/epub3 156 | if errorlevel 1 exit /b 1 157 | echo. 158 | echo.Build finished. The epub3 file is in %BUILDDIR%/epub3. 159 | goto end 160 | ) 161 | 162 | if "%1" == "latex" ( 163 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 164 | if errorlevel 1 exit /b 1 165 | echo. 166 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdf" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "latexpdfja" ( 181 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 182 | cd %BUILDDIR%/latex 183 | make all-pdf-ja 184 | cd %~dp0 185 | echo. 186 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 187 | goto end 188 | ) 189 | 190 | if "%1" == "text" ( 191 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 192 | if errorlevel 1 exit /b 1 193 | echo. 194 | echo.Build finished. The text files are in %BUILDDIR%/text. 195 | goto end 196 | ) 197 | 198 | if "%1" == "man" ( 199 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 200 | if errorlevel 1 exit /b 1 201 | echo. 202 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 203 | goto end 204 | ) 205 | 206 | if "%1" == "texinfo" ( 207 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 208 | if errorlevel 1 exit /b 1 209 | echo. 210 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 211 | goto end 212 | ) 213 | 214 | if "%1" == "gettext" ( 215 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 216 | if errorlevel 1 exit /b 1 217 | echo. 218 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 219 | goto end 220 | ) 221 | 222 | if "%1" == "changes" ( 223 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 224 | if errorlevel 1 exit /b 1 225 | echo. 226 | echo.The overview file is in %BUILDDIR%/changes. 227 | goto end 228 | ) 229 | 230 | if "%1" == "linkcheck" ( 231 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 232 | if errorlevel 1 exit /b 1 233 | echo. 234 | echo.Link check complete; look for any errors in the above output ^ 235 | or in %BUILDDIR%/linkcheck/output.txt. 236 | goto end 237 | ) 238 | 239 | if "%1" == "doctest" ( 240 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 241 | if errorlevel 1 exit /b 1 242 | echo. 243 | echo.Testing of doctests in the sources finished, look at the ^ 244 | results in %BUILDDIR%/doctest/output.txt. 245 | goto end 246 | ) 247 | 248 | if "%1" == "coverage" ( 249 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 250 | if errorlevel 1 exit /b 1 251 | echo. 252 | echo.Testing of coverage in the sources finished, look at the ^ 253 | results in %BUILDDIR%/coverage/python.txt. 254 | goto end 255 | ) 256 | 257 | if "%1" == "xml" ( 258 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 259 | if errorlevel 1 exit /b 1 260 | echo. 261 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 262 | goto end 263 | ) 264 | 265 | if "%1" == "pseudoxml" ( 266 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 267 | if errorlevel 1 exit /b 1 268 | echo. 269 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 270 | goto end 271 | ) 272 | 273 | if "%1" == "dummy" ( 274 | %SPHINXBUILD% -b dummy %ALLSPHINXOPTS% %BUILDDIR%/dummy 275 | if errorlevel 1 exit /b 1 276 | echo. 277 | echo.Build finished. Dummy builder generates no files. 278 | goto end 279 | ) 280 | 281 | :end 282 | -------------------------------------------------------------------------------- /doc/tutorial.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Tutorials 3 | ========= 4 | 5 | Fitting Tuning Curves with Gradient Descent 6 | ------------------------------------------- 7 | 8 | The firing rates :math:`y_j` of neuron :math:`j` can be modeled as a 9 | Poisson random variable. 10 | 11 | .. math:: 12 | 13 | 14 | y_j = \text{Poisson}(\lambda_j) 15 | 16 | We will drop the subscript :math:`j` for convenience of notation and 17 | figure out how to fit the tuning curves of a given neuron :math:`j`. 18 | 19 | The mean :math:`\lambda` is given by the von Mises tuning model as 20 | follows. 21 | 22 | .. math:: 23 | 24 | 25 | \lambda = b + g\exp\Big(\kappa_0 + \kappa \cos(x - \mu)\Big) 26 | 27 | However, this formulation is non-convex in :math:`\mu`. Therefore, we 28 | re-parameterize it to be more tractable (still non-convex in :math:`b` 29 | and :math:`g`) as follows. 30 | 31 | .. math:: 32 | 33 | 34 | \lambda = b + g\exp\Big(\kappa_0 + \kappa_1 \cos(x) + \kappa_2 \sin(x) \Big), 35 | 36 | where :math:`\kappa_1 = \kappa \cos(\mu)` and 37 | :math:`\kappa_2 = \kappa \sin(\mu)`. 38 | 39 | Once we estimate :math:`\kappa_1` and :math:`\kappa_2`, we can back out 40 | :math:`\kappa` and :math:`\mu` as 41 | :math:`\kappa = \sqrt{\kappa_1^2 + \kappa_2^2}`, and 42 | :math:`\mu = \tan^{-1}\Big(\frac{\kappa_2}{\kappa_1}\Big)`. 43 | 44 | We estimate two special cases of this generalized von Mises model. 45 | 46 | Special Case 1: Poisson Generalized Linear Model (GLM) 47 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 48 | 49 | If we set :math:`b = 0` and :math:`g =1`, we get: 50 | 51 | .. math:: 52 | 53 | 54 | \lambda = \exp\Big(\kappa_0 + \kappa_1 \cos(x) + \kappa_2 \sin(x) \Big), 55 | 56 | This is identical to a Poisson GLM. 57 | 58 | The advantage of this formulation is that it is convex and the 59 | disadvantage is that all parameters are not straightforward to 60 | interpret, with :math:`\kappa_0` playing the role of both a baseline and 61 | a gain term. 62 | 63 | Special Case 2: Generalized von Mises Model (GVM) 64 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 65 | 66 | If we set :math:`\kappa_0 = 0`, we get: 67 | 68 | .. math:: 69 | 70 | 71 | \lambda = b + g\exp\Big(\kappa_1 \cos(x) + \kappa_2 \sin(x) \Big), 72 | 73 | This is identical to a Eq. (4) in `Amirikan & Georgopulos 74 | (2000) `__. 75 | 76 | The advantage of this formulation is is that although it is non-convex, 77 | we can easily interpret the parameters: - :math:`b`, as the baseline 78 | firing rate - :math:`g`, as the gain - :math:`\kappa`, as the width or 79 | shape 80 | 81 | Minimizing Negative Log Likelihood with Gradient Descent 82 | -------------------------------------------------------- 83 | 84 | Given a set of observations :math:`(x_i, y_i)`, to identify the 85 | parameters 86 | :math:`\Theta = \left\{\kappa_0, \kappa_1, \kappa_2, g, b\right\}` we 87 | use gradient descent on the loss function :math:`J`, specified by the 88 | negative Poisson log-likelihood, 89 | 90 | .. math:: 91 | 92 | 93 | J = -\log\mathcal{L} = \sum_{i} \lambda_i - y_i \log \lambda_i 94 | 95 | Taking the gradients, we get: 96 | 97 | .. math:: 98 | 99 | 100 | \frac{\partial J}{\partial \kappa_0} = \sum_{i} g \exp\Big(\kappa_0 + \kappa_1 \cos(x_i) + \kappa_2 \sin(x_i) \Big) \bigg(1 - \frac{y_i}{\lambda_i}\bigg) 101 | 102 | .. math:: 103 | 104 | 105 | \frac{\partial J}{\partial \kappa_1} = \sum_{i} g \exp\Big(\kappa_0 + \kappa_1 \cos(x_i) + \kappa_2 \sin(x_i) \Big) \cos(x_i) \bigg(1 - \frac{y_i}{\lambda_i}\bigg) 106 | 107 | .. math:: 108 | 109 | 110 | \frac{\partial J}{\partial \kappa_2} = \sum_{i} g \exp\Big(\kappa_0 + \kappa_1 \cos(x_i) + \kappa_2 \sin(x_i) \Big) \sin(x_i) \bigg(1 - \frac{y_i}{\lambda_i}\bigg) 111 | 112 | .. math:: 113 | 114 | 115 | \frac{\partial J}{\partial g} = \sum_{i} g \exp\Big(\kappa_0 + \kappa_1 \cos(x_i) + \kappa_2 \sin(x_i) \Big) \bigg(1 - \frac{y_i}{\lambda_i}\bigg) 116 | 117 | .. math:: 118 | 119 | 120 | \frac{\partial J}{\partial b} = \sum_{i} \bigg(1 - \frac{y_i}{\lambda_i}\bigg) 121 | 122 | Decoding Feature from Population Activity 123 | -------------------------------------------------------- 124 | 125 | Under the same Poisson firing rate model for each neuron, whose mean is 126 | specified by the von Mises tuning curve, as above, we can decode the 127 | stimulus :math:`\hat{x}` that is most likely to have produced the 128 | observed population activity 129 | :math:`Y = \left\{y_j, j = 1, 2, \dots \text{n_neurons}\right\}`. 130 | 131 | We will assume that the neurons are conditionally independent given the 132 | tuning parameters :math:`\Theta`. Thus the likelihood of observing the 133 | population activity :math:`Y` is given by 134 | 135 | .. math:: 136 | 137 | 138 | P(Y | \Theta) = \prod_j P(y_j | \Theta) 139 | 140 | As before, the loss function for the decoder is given by the negative 141 | Poisson log-likelihood: 142 | 143 | .. math:: 144 | 145 | 146 | J = -\log\mathcal{L} = \sum_j \lambda_j - y_j \log \lambda_j 147 | 148 | where 149 | 150 | .. math:: 151 | 152 | 153 | \lambda_j = b_j + g_j \exp\Big(\kappa_{0,j} + \kappa_{1,j} \cos(x) + \kappa_{1,j} \sin(x) \Big) 154 | 155 | To minimize this loss function with gradient descent, we need to take 156 | the gradient of :math:`J` with respect to :math:`x` 157 | 158 | .. math:: 159 | 160 | 161 | \frac{\partial J}{\partial x} = \sum_{j} g_j \exp\Big(\kappa_{0,j} + \kappa_{1,j} \cos(x) + \kappa_{2,j} \sin(x) \Big) \Big(\kappa_{2,j} \cos(x) - \kappa_{1,j} \sin(x)\Big) \bigg(1 - \frac{y_j}{\lambda_j}\bigg) 162 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples_benchmark: 2 | 3 | Examples Gallery 4 | ---------------- 5 | 6 | .. contents:: Contents 7 | :local: 8 | :depth: 2 -------------------------------------------------------------------------------- /examples/plot_crcns_dataset_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===================== 3 | CRCNS DataSet Example 4 | ===================== 5 | 6 | A demonstration of Spykes's functionality to reproduce a figure from Li et al's 7 | "A motor cortex circuit for motor planning and movement." 8 | 9 | """ 10 | # Authors: Eric Larson 11 | # 12 | # License: BSD (3-clause) 13 | 14 | ######################################################## 15 | 16 | import os 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import matplotlib.pyplot as plt 21 | import scipy.io 22 | 23 | from spykes.plot.neurovis import NeuroVis 24 | from spykes.config import get_data_directory 25 | 26 | plt.style.use('seaborn-ticks') 27 | 28 | ######################################################## 29 | # 0 Overview: Reproduce Figure 30 | # ----------------------------- 31 | # 32 | # 0.1 Article 33 | # ~~~~~~~~~~~~~ 34 | # 35 | # Li, Nuo, et al. "A motor cortex circuit for motor planning and movement." 36 | # Nature 519.7541 (2015): 51-56. 37 | # [`link to 38 | # paper 39 | # `__]. 40 | # We aim to reproduce 41 | # `this figure 42 | # `__ 43 | # 44 | # 45 | # 0.2 Dataset 46 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 47 | # 48 | # Nuo Li, Charles R Gerfen, Karel Svoboda (2014); Extracellular recordings 49 | # from anterior lateral motor cortex (ALM) neurons of adult mice 50 | # performing a tactile decision behavior. CRCNS.org [`link to 51 | # dataset `__] 52 | 53 | ######################################################## 54 | # 55 | # 1 Data 56 | # -------------------- 57 | # 58 | # 1.1 Download Data 59 | # ~~~~~~~~~~~~~~~~~ 60 | # 61 | # Register in ``CRCNS`` [`link to request account 62 | # `__] 63 | # 64 | # Download file ``data_structure_ANM218457.tar.gz`` [`link 65 | # `__] 66 | # 67 | # Unzip it and you will find file ``data_structure_ANM218457_20131006.mat`` 68 | # 69 | # Move this file to the right data directory. 70 | # 71 | # 1.2 Load Data 72 | # ~~~~~~~~~~~~~ 73 | 74 | fpath = os.path.join('data_structure_ANM218457_20131006.mat') 75 | mat = scipy.io.loadmat(fpath) 76 | 77 | ######################################################## 78 | # 79 | # 2 Get Spike Times 80 | # -------------------- 81 | 82 | neuron_n = 9 83 | spike_times = mat['obj']['eventSeriesHash'][0][ 84 | 0]['value'][0][0][0][neuron_n - 1][0][0][1] 85 | spike_times = [i[0] for i in spike_times] 86 | 87 | # instantiate neuron 88 | neuron = NeuroVis(spike_times, neuron_n) 89 | print('neuron %d has a firing rate of %0.2f spikes per second' % 90 | (neuron_n, neuron.firingrate)) 91 | 92 | 93 | ######################################################## 94 | # 95 | # Let's use all the goodness of ``pandas`` to define all our conditions. 96 | # Here, we will create a set of extra columns in the data frame that are 97 | # going to be useful to select and plot PSTHs for specific conditions. We 98 | # aim to follow the principles outlined in `Hadley Wickam's white paper on 99 | # Tidy Data `__. 100 | 101 | ######################################################## 102 | # 103 | # 3 Get Event Times 104 | # -------------------- 105 | 106 | 107 | data_df = pd.DataFrame() 108 | data_df['trialStart'] = mat['obj']['trialStartTimes'][0][0][0] 109 | data_df['cueTimes'] = np.squeeze( 110 | mat['obj']['trialPropertiesHash'][0][0][0][0][2][0][2]) 111 | data_df['RealCueTimes'] = data_df['trialStart'] + data_df['cueTimes'] 112 | 113 | # Collect all the events and display them 114 | events = ['trialStart', 'cueTimes', 'RealCueTimes'] 115 | data_df[events].head() 116 | 117 | ######################################################## 118 | # 119 | # 4 Get Features 120 | # -------------------- 121 | 122 | 123 | trialTypeMat = mat['obj']['trialTypeMat'][0][0].astype(np.bool_) 124 | trialTypeStr = np.squeeze( 125 | np.stack(np.squeeze(mat['obj']['trialTypeStr'][0][0]))) 126 | 127 | for ind, feat in enumerate(trialTypeStr): 128 | data_df[str(feat)] = trialTypeMat[ind] 129 | 130 | data_df['GoodTrials'] = np.squeeze( 131 | mat['obj']['trialPropertiesHash'][0][0][0][0][2][0][3]).astype(np.bool_) 132 | 133 | # Collect all features and display them 134 | features = ['HitR', 'HitL', 'ErrR', 'ErrL', 'NoLickR', 135 | 'NoLickL', 'LickEarly', 'StimTrials', 'GoodTrials'] 136 | data_df[features].head() 137 | 138 | ######################################################## 139 | # 140 | # 5 Define Features 141 | # -------------------- 142 | 143 | ######################################################## 144 | 145 | features_of_interst = ['HitR', 'HitL', 'ErrR', 'ErrL'] 146 | 147 | data_df['response'] = data_df[features_of_interst].apply( 148 | lambda row: row.argmax() if row.max() else '', axis=1) 149 | data_df['correct'] = data_df['response'].map( 150 | lambda s: {'Hit': True, 'Err': False, '': np.nan}[s[:3]]) 151 | data_df['response'] = data_df['response'].map( 152 | lambda s: {'L': 'Lick left', 'R': 'Lick right', '': ''}[s[-1:]]) 153 | 154 | data_df[['HitR', 'HitL', 'ErrR', 'ErrL', 'correct', 'response']].head() 155 | 156 | ######################################################## 157 | # 158 | # Let's put the events and the augmented features into a new data frame 159 | # which we will use everywhere below. 160 | 161 | ######################################################## 162 | 163 | # isolate trials of interest 164 | trials_df = data_df[((data_df['GoodTrials'] == True) & 165 | (data_df['correct'] == True))] 166 | 167 | ######################################################## 168 | # 169 | # 6 Plots 170 | # -------------------- 171 | # 172 | # 6.1 Rasters 173 | # ~~~~~~~~~~~~~~~~~~~ 174 | 175 | event = 'RealCueTimes' 176 | conditions = 'response' 177 | window = [-3000, 2000] 178 | 179 | rasters_fig2b1 = neuron.get_raster(event=event, 180 | conditions=conditions, 181 | df=trials_df, 182 | window=window, 183 | binsize=20, 184 | sortby='rate', 185 | sortorder='ascend') 186 | ######################################################## 187 | # 188 | # 6.2 PSTH 189 | # ~~~~~~~~~~~~~~~~~~~ 190 | 191 | plt.figure(figsize=(8, 5)) 192 | neuron.get_psth(event=event, 193 | df=trials_df, 194 | conditions=conditions, 195 | window=window, 196 | binsize=100) 197 | plt.title('') 198 | plt.show() 199 | 200 | ######################################################## 201 | # 202 | # 6.3 Reproduce Figure 203 | # ~~~~~~~~~~~~~~~~~~~~~~ 204 | 205 | plt.style.use('seaborn-ticks') 206 | 207 | cmap = [plt.get_cmap('Blues'), plt.get_cmap('Reds')] 208 | colors = ['#E82F3A', '#3B439A'] 209 | 210 | 211 | # get rasters------------------------------------------------------ 212 | rasters_fig2b1 = neuron.get_raster(event=event, 213 | conditions=conditions, 214 | df=trials_df, 215 | window=window, 216 | binsize=20, 217 | plot=False) 218 | 219 | 220 | plt.figure(figsize=(8, 10)) 221 | cond_ids = rasters_fig2b1['data'].keys()[::-1] 222 | 223 | # plot rasters------------------------------------------------------- 224 | for i, cond_id in enumerate(cond_ids): 225 | plt.subplot(4, 1, i + 1) 226 | neuron.plot_raster(rasters=rasters_fig2b1, 227 | cond_id=cond_id, 228 | cmap=cmap[i], 229 | sortby=None, 230 | has_title=False) 231 | plt.xlabel('') 232 | 233 | 234 | # plot psth------------------------------------------------------- 235 | plt.subplot(212) 236 | psth = neuron.get_psth(event=event, 237 | conditions=conditions, 238 | df=trials_df, 239 | window=window, 240 | binsize=100, 241 | plot=True, 242 | colors=colors) 243 | plt.title('') 244 | plt.show() 245 | 246 | 247 | ######################################################## 248 | # 249 | # 6.4 ggplot 250 | # ~~~~~~~~~~~~ 251 | 252 | plt.style.use('ggplot') 253 | 254 | plt.figure(figsize=(8, 5)) 255 | neuron.get_psth(event=event, 256 | conditions=conditions, 257 | df=trials_df, 258 | event_name='Cue On', 259 | window=window, 260 | conditions_names=['Lick Left', 'Lick Right'], 261 | binsize=100) 262 | plt.title('') 263 | plt.show() 264 | -------------------------------------------------------------------------------- /examples/plot_neural_coding_reward_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================ 3 | Neural Coding Reward Example 4 | ============================ 5 | 6 | A demonstration to use Spykes' functionality to reproduce Ramkumar et al's 7 | "Premotor and Motor Cortices Encode Reward." 8 | 9 | """ 10 | # Authors: Mayank Agrawal 11 | # 12 | # License: MIT 13 | 14 | ######################################################## 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import pandas as pd 19 | from spykes.plot.neurovis import NeuroVis 20 | from spykes.io.datasets import load_reward_data 21 | 22 | ######################################################## 23 | # 0 Overview: Reproduce Figure 24 | # ----------------------------- 25 | # 26 | # 0.1 Article 27 | # ~~~~~~~~~~~~~ 28 | # 29 | # Ramkumar, Pavan, et al. "Premotor and Motor Cortices Encode Reward." 30 | # PloS one 11.8 (2016) 31 | # 32 | # [`link to 33 | # paper 34 | # `__] 35 | # 36 | # 37 | # 0.2 Dataset 38 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | # 40 | # Download all files [`here 41 | # `__] 42 | # However, we'll only be looking at Mihili_07112013.mat (Monkey M, Session 1) 43 | # and Mihili_08062013.mat (Monkey M, Session 4) 44 | # 45 | # 0.3 Initialization 46 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 47 | # 48 | event = 'rewardTime' 49 | condition = 'rewardBool' 50 | window = [-500, 1500] 51 | binsize = 10 52 | 53 | ######################################################## 54 | # 55 | # 1 First Graph of Panel A 56 | # -------------------- 57 | 58 | sess_one, sess_four = load_reward_data() 59 | 60 | ######################################################## 61 | # 62 | # 1.1 Initiate all Neurons 63 | # ~~~~~~~~~~~~~~~~~ 64 | # 65 | 66 | 67 | def get_spike_time(raw_data, neuron_number): 68 | 69 | spike_times = raw_data['alldays'][0]['PMd_units'][0][:] 70 | spike_times = spike_times[neuron_number - 1][0][1:] 71 | spike_times = [i[0] for i in spike_times] 72 | 73 | return spike_times 74 | 75 | ######################################################## 76 | 77 | 78 | def initiate_neurons(raw_data): 79 | 80 | neuron_list = list() 81 | 82 | for i in range((raw_data['alldays'][0]['PMd_units'][0][:]).shape[0]): 83 | spike_times = get_spike_time(raw_data, i + 1) 84 | 85 | # instantiate neuron 86 | neuron = NeuroVis(spike_times, name='PMd %d' % (i + 1)) 87 | neuron_list.append(neuron) 88 | 89 | return neuron_list 90 | 91 | ######################################################## 92 | 93 | neuron_list = initiate_neurons(sess_four) 94 | 95 | ######################################################## 96 | # 97 | # 1.2 Get Event Times 98 | # ~~~~~~~~~~~~~ 99 | 100 | 101 | def create_data_frame(raw_data): 102 | 103 | data_df = pd.DataFrame() 104 | 105 | uncertainty_conditions = list() 106 | center_target_times = list() 107 | reward_times = list() 108 | reward_outcomes = list() 109 | 110 | for i in range(raw_data['alldays'].shape[0]): 111 | 112 | meta_data = raw_data['alldays'][i]['tt'][0] 113 | 114 | uncertainty_conditions.append(meta_data[:, 2]) 115 | center_target_times.append(meta_data[:, 3]) 116 | reward_times.append(meta_data[:, 6]) 117 | reward_outcomes.append(meta_data[:, 7]) 118 | 119 | data_df['uncertaintyCondition'] = np.concatenate(uncertainty_conditions) 120 | data_df['centerTargetTime'] = np.concatenate(center_target_times) 121 | data_df['rewardTime'] = np.concatenate(reward_times) 122 | data_df['rewardOutcome'] = np.concatenate(reward_outcomes) 123 | 124 | data_df['rewardBool'] = data_df['rewardOutcome'].map(lambda s: s == 32) 125 | 126 | # find time in between previous reward onset and start of current trial 127 | # shouldn't be more than 1500ms 128 | 129 | start_times = data_df['centerTargetTime'] 130 | last_reward_times = np.roll(data_df['rewardTime'], 1) 131 | 132 | diffs = start_times - last_reward_times 133 | diffs[0] = 0 134 | 135 | data_df['consecutiveBool'] = diffs.map(lambda s: s <= 1.5) 136 | 137 | return data_df[((data_df['uncertaintyCondition'] == 5.0) | 138 | (data_df['uncertaintyCondition'] == 50.0)) & 139 | data_df['consecutiveBool']] 140 | 141 | ######################################################## 142 | 143 | data_df = create_data_frame(sess_four) 144 | print(len(data_df)) 145 | data_df.head() 146 | 147 | ######################################################## 148 | # 149 | # 1.3 Match Peak Velocities 150 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 151 | 152 | 153 | def find_velocities_in_range(raw_data, dataframe, min_vel, max_vel, min_time, 154 | max_time): 155 | 156 | all_velocities = raw_data['alldays'][0]['kin'][0]['vel'][0][0] 157 | 158 | max_velocities = np.empty(len(dataframe)) 159 | peak_times = np.empty(len(dataframe)) 160 | 161 | for i in range(len(dataframe)): 162 | 163 | # find time range for potential peak velocity 164 | start_time = dataframe['rewardTime'][i] + .2 165 | end_time = dataframe['rewardTime'][i] + 1.5 166 | 167 | # find velocities in the time range 168 | indices = (all_velocities[:, 0] >= start_time) & ( 169 | all_velocities[:, 0] <= end_time) 170 | in_time = all_velocities[indices] 171 | 172 | # find max velocity in given time range 173 | velocity_norms = np.square(in_time[:, 1]) + np.square(in_time[:, 2]) 174 | 175 | max_velocity_index = np.argmax(velocity_norms) 176 | 177 | max_velocities[i] = velocity_norms[max_velocity_index]**.5 178 | peak_times[i] = in_time[max_velocity_index, 0] 179 | 180 | dataframe['maxVelocity'] = max_velocities 181 | dataframe['peakTimesDiff'] = peak_times - dataframe['rewardTime'] 182 | 183 | return dataframe[((dataframe['maxVelocity'] >= min_vel) & 184 | (dataframe['maxVelocity'] <= max_vel)) & 185 | ((dataframe['peakTimesDiff'] >= min_time) & 186 | (dataframe['peakTimesDiff'] <= max_time))] 187 | 188 | ######################################################## 189 | 190 | trials_df = find_velocities_in_range( 191 | sess_four, data_df.reset_index(), 11, 16, .55, .95) 192 | print(len(trials_df)) 193 | trials_df.head() 194 | 195 | ######################################################## 196 | # 197 | # 1.4 Plot PSTHs 198 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 199 | # 200 | # Before Matching 201 | 202 | neuron_number = 60 203 | neuron = neuron_list[neuron_number - 1] 204 | 205 | plt.figure(figsize=(10, 5)) 206 | psth = neuron.get_psth(event=event, 207 | conditions=condition, 208 | df=data_df, 209 | window=[-500, 1500], 210 | binsize=25, 211 | event_name='Reward Time') 212 | 213 | plt.title('neuron %s: Reward' % neuron.name) 214 | plt.show() 215 | 216 | ######################################################## 217 | # 218 | # After Velocity Matching 219 | 220 | neuron_number = 60 221 | neuron = neuron_list[neuron_number - 1] 222 | 223 | plt.figure(figsize=(10, 5)) 224 | psth = neuron.get_psth(event=event, 225 | conditions=condition, 226 | df=trials_df, 227 | window=[-500, 1500], 228 | binsize=25, 229 | event_name='Reward Time') 230 | 231 | plt.title('neuron %s: Reward' % neuron.name) 232 | plt.show() 233 | 234 | ######################################################## 235 | # 236 | # 2 First Graph of Panel C 237 | # -------------------- 238 | 239 | neuron_list = initiate_neurons(sess_one) 240 | data_df = create_data_frame(sess_one) 241 | 242 | ######################################################## 243 | # 244 | # 2.1 Normalize PSTHs 245 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~ 246 | 247 | 248 | def normalize_psth(neuron, dataframe): 249 | 250 | psth = neuron.get_psth(event=event, 251 | conditions=condition, 252 | df=dataframe, 253 | window=window, 254 | binsize=binsize, 255 | plot=False) 256 | 257 | # find all max rates, and find max of max rates 258 | 259 | max_rates = list() 260 | 261 | for i, cond_id in enumerate(np.sort(psth['data'].keys())): 262 | max_rates.append(np.amax(psth['data'][cond_id]['mean'])) 263 | 264 | max_rate = max(max_rates) 265 | 266 | # divide all means by max to normalize 267 | 268 | for i, cond_id in enumerate(np.sort(psth['data'].keys())): 269 | 270 | psth['data'][cond_id]['mean'] /= max_rate 271 | psth['data'][cond_id]['sem'] = 0 # population SEM calculated later 272 | 273 | return psth 274 | 275 | ######################################################## 276 | 277 | neuron = neuron_list[0] # example 278 | new_psth = normalize_psth(neuron, data_df) 279 | neuron.plot_psth(new_psth, event, condition) 280 | 281 | ######################################################## 282 | # 283 | # 2.2 Find Population Average 284 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 285 | 286 | psth_dict = {} 287 | for cond_id in np.sort(psth['data'].keys()): 288 | psth_dict[cond_id] = list() 289 | 290 | 291 | # add all normalized psth's 292 | for neuron in neuron_list: 293 | 294 | norm_psth = normalize_psth(neuron, data_df) 295 | 296 | for cond_id in np.sort(psth['data'].keys()): 297 | psth_dict[cond_id].append(norm_psth['data'][cond_id]['mean']) 298 | 299 | for key in psth_dict: 300 | psth_dict[key] = np.array(psth_dict[key]) 301 | 302 | # get base psth 303 | 304 | base_neuron = neuron_list[0] 305 | psth = normalize_psth(base_neuron, data_df) 306 | 307 | # update mean and SEM to reflect population 308 | 309 | for cond_id in np.sort(psth['data'].keys()): 310 | 311 | psth['data'][cond_id]['mean'] = np.mean(psth_dict[cond_id], axis=0) 312 | psth['data'][cond_id]['sem'] = ( 313 | np.var(psth_dict[cond_id], axis=0) / len(neuron_list))**.5 314 | 315 | ######################################################## 316 | # 317 | # 2.3 Plot PSTH 318 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 319 | 320 | plt.figure(figsize=(10, 5)) 321 | neuron.plot_psth(psth, event, condition) 322 | plt.title("") 323 | plt.show() 324 | -------------------------------------------------------------------------------- /examples/plot_neuropixels_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | =================== 3 | Neuropixels Example 4 | =================== 5 | 6 | Use spykes to analyze data from UCL's Neuropixels 7 | 8 | """ 9 | # Authors: Mayank Agrawal 10 | # 11 | # License: MIT 12 | 13 | ######################################################## 14 | 15 | import numpy as np 16 | import pandas as pd 17 | from spykes.plot.neurovis import NeuroVis 18 | from spykes.plot.popvis import PopVis 19 | import matplotlib.pyplot as plt 20 | from spykes.io.datasets import load_neuropixels_data 21 | 22 | 23 | plt.style.use('seaborn-ticks') 24 | 25 | ######################################################## 26 | # Neuropixels 27 | # ----------------------------- 28 | # Neuropixels is a new recording technique by UCL's `Cortex Lab 29 | # `__ that is able to measure data from 30 | # hundreds of neurons. Below we show how this data can be worked with in Spykes 31 | # 32 | # 0 Download Data 33 | # ----------------------------- 34 | # 35 | # Download all data `here `__. 36 | # 37 | # 1 Read In Data 38 | # ----------------------------- 39 | 40 | folder_names = ['posterior', 'frontal'] 41 | Fs = 30000.0 42 | 43 | striatum = list() 44 | motor_ctx = list() 45 | thalamus = list() 46 | hippocampus = list() 47 | visual_ctx = list() 48 | 49 | # a lot of this code is adapted from Cortex Lab's MATLAB script 50 | # see here: http://data.cortexlab.net/dualPhase3/data/script_dualPhase3.m 51 | 52 | data_dict = load_neuropixels_data() 53 | 54 | for name in folder_names: 55 | 56 | clusters = np.squeeze(data_dict[name + '/spike_clusters.npy']) 57 | spike_times = (np.squeeze(data_dict[(name + '/spike_times.npy')])) / Fs 58 | spike_templates = (np.squeeze(data_dict[(name + '/spike_templates.npy')])) 59 | temps = (np.squeeze(data_dict[(name + '/templates.npy')])) 60 | winv = (np.squeeze(data_dict[(name + '/whitening_mat_inv.npy')])) 61 | y_coords = (np.squeeze(data_dict[(name + '/channel_positions.npy')]))[:, 1] 62 | 63 | # frontal times need to align with posterior 64 | if (name == 'frontal'): 65 | time_correction = data_dict[('timeCorrection.npy')] 66 | spike_times *= time_correction[0] 67 | spike_times += time_correction[1] 68 | 69 | data = data_dict[(name + '/cluster_groups.csv')] 70 | cids = np.array([x[0] for x in data]) 71 | cfg = np.array([x[1] for x in data]) 72 | 73 | # find good clusters and only use those spikes 74 | good_clusters = cids[cfg == 'good'] 75 | good_indices = (np.in1d(clusters, good_clusters)) 76 | 77 | real_spikes = spike_times[good_indices] 78 | real_clusters = clusters[good_indices] 79 | real_spike_templates = spike_templates[good_indices] 80 | 81 | # find how many spikes per cluster and then order spikes by which cluster 82 | # they are in 83 | 84 | counts_per_cluster = np.bincount(real_clusters) 85 | 86 | sort_idx = np.argsort(real_clusters) 87 | sorted_clusters = real_clusters[sort_idx] 88 | sorted_spikes = real_spikes[sort_idx] 89 | sorted_spike_templates = real_spike_templates[sort_idx] 90 | 91 | # find depth for each spike 92 | # this is translated from Cortex Lab's MATLAB code 93 | # for more details, check out the original code here: 94 | # https://github.com/cortex-lab/spikes/blob/master/analysis/templatePositionsAmplitudes.m 95 | 96 | temps_unw = np.zeros(temps.shape) 97 | for t in range(temps.shape[0]): 98 | temps_unw[t, :, :] = np.dot(temps[t, :, :], winv) 99 | 100 | temp_chan_amps = np.ptp(temps_unw, axis=1) 101 | temps_amps = np.max(temp_chan_amps, axis=1) 102 | thresh_vals = temps_amps * 0.3 103 | 104 | thresh_vals = [thresh_vals for i in range(temp_chan_amps.shape[1])] 105 | thresh_vals = np.stack(thresh_vals, axis=1) 106 | 107 | temp_chan_amps[temp_chan_amps < thresh_vals] = 0 108 | 109 | y_coords = np.reshape(y_coords, (y_coords.shape[0], 1)) 110 | temp_depths = np.sum( 111 | np.dot(temp_chan_amps, y_coords), axis=1) / (np.sum(temp_chan_amps, 112 | axis=1)) 113 | 114 | sorted_spike_depths = temp_depths[sorted_spike_templates] 115 | 116 | # create neurons and find region 117 | 118 | accumulator = 0 119 | 120 | for idx, count in enumerate(counts_per_cluster): 121 | 122 | if count > 0: 123 | 124 | spike_times = sorted_spikes[accumulator:accumulator + count] 125 | neuron = NeuroVis(spiketimes=spike_times, name='%d' % (idx)) 126 | cluster_depth = np.mean( 127 | sorted_spike_depths[accumulator:accumulator + count]) 128 | 129 | if name == 'frontal': 130 | 131 | if (cluster_depth > 0 and cluster_depth < 1550): 132 | striatum.append(neuron) 133 | elif (cluster_depth > 1550 and cluster_depth < 3840): 134 | motor_ctx.append(neuron) 135 | 136 | elif name == 'posterior': 137 | 138 | if (cluster_depth > 0 and cluster_depth < 1634): 139 | thalamus.append(neuron) 140 | elif (cluster_depth > 1634 and cluster_depth < 2797): 141 | hippocampus.append(neuron) 142 | elif (cluster_depth > 2797 and cluster_depth < 3840): 143 | visual_ctx.append(neuron) 144 | 145 | accumulator += count 146 | 147 | 148 | print("Striatum (n = %d)" % len(striatum)) 149 | print("Motor Cortex (n = %d)" % len(motor_ctx)) 150 | print("Thalamus (n = %d)" % len(thalamus)) 151 | print("Hippocampus (n = %d)" % len(hippocampus)) 152 | print("Visual Cortex (n = %d)" % len(visual_ctx)) 153 | 154 | ######################################################## 155 | # 2 Create Data Frame 156 | # ----------------------------- 157 | 158 | df = pd.DataFrame() 159 | 160 | raw_data = data_dict['experiment1stimInfo.mat'] 161 | 162 | df['start'] = np.squeeze(raw_data['stimStarts']) 163 | df['stop'] = np.squeeze(raw_data['stimStops']) 164 | df['stimulus'] = np.squeeze(raw_data['stimIDs']) 165 | 166 | print(df.head()) 167 | 168 | ######################################################## 169 | # 3 Start Plotting 170 | # ----------------------------- 171 | # 3.1 Striatum 172 | # ~~~~~~~~~~~~ 173 | 174 | pop = PopVis(striatum, name='Striatum') 175 | 176 | fig = plt.figure(figsize=(30, 20)) 177 | 178 | all_psth = pop.get_all_psth( 179 | event='start', df=df, conditions='stimulus', plot=False, binsize=100, 180 | window=[-500, 2000]) 181 | 182 | pop.plot_heat_map(all_psth, cond_id=[ 183 | 2, 7, 13], sortorder='descend', neuron_names=False) 184 | 185 | ######################################################## 186 | 187 | pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12]) 188 | 189 | ######################################################## 190 | # 3.2 Frontal 191 | # ~~~~~~~~~~~~ 192 | 193 | pop = PopVis(striatum + motor_ctx, name='Frontal') 194 | 195 | fig = plt.figure(figsize=(30, 20)) 196 | 197 | all_psth = pop.get_all_psth( 198 | event='start', df=df, conditions='stimulus', plot=False, binsize=100, 199 | window=[-500, 2000]) 200 | 201 | pop.plot_heat_map( 202 | all_psth, cond_id=[2, 7, 13], sortorder='descend', neuron_names=False) 203 | 204 | ######################################################## 205 | 206 | pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12]) 207 | 208 | ######################################################## 209 | # 3.3 All Neurons 210 | # ~~~~~~~~~~~~ 211 | 212 | pop = PopVis(striatum + motor_ctx + thalamus + hippocampus + visual_ctx) 213 | 214 | fig = plt.figure(figsize=(30, 20)) 215 | 216 | all_psth = pop.get_all_psth( 217 | event='start', df=df, conditions='stimulus', plot=False, binsize=100, 218 | window=[-500, 2000]) 219 | 220 | pop.plot_heat_map( 221 | all_psth, cond_id=[2, 7, 13], sortorder='descend', neuron_names=False) 222 | 223 | ######################################################## 224 | 225 | pop.plot_population_psth(all_psth=all_psth, cond_id=[1, 7, 12]) 226 | 227 | ######################################################## 228 | # 3.4 Striatum vs. Motor Cortex 229 | # ~~~~~~~~~~~~ 230 | 231 | striatum_pop = PopVis(striatum, name='Striatum') 232 | motor_ctx_pop = PopVis(motor_ctx, name='Motor Cortex') 233 | 234 | striatum_psth = striatum_pop.get_all_psth( 235 | event='start', df=df, conditions='stimulus', plot=False, binsize=100, 236 | window=[-500, 2000]) 237 | motor_ctx_psth = motor_ctx_pop.get_all_psth( 238 | event='start', df=df, conditions='stimulus', plot=False, binsize=100, 239 | window=[-500, 2000]) 240 | 241 | ######################################################## 242 | 243 | striatum_pop.plot_population_psth(all_psth=striatum_psth, cond_id=[1, 7, 12]) 244 | 245 | ######################################################## 246 | 247 | motor_ctx_pop.plot_population_psth(all_psth=motor_ctx_psth, cond_id=[1, 7, 12]) 248 | -------------------------------------------------------------------------------- /examples/plot_neuropop_simul_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================ 3 | Neuropop Example 4 | ================ 5 | 6 | A demonstration of Neuropop using simulated data 7 | 8 | """ 9 | 10 | ######################################################## 11 | 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | 15 | from spykes.ml.neuropop import NeuroPop 16 | from spykes.utils import train_test_split 17 | 18 | ######################################################## 19 | # Create a NeuroPop object 20 | # ----------------------------- 21 | 22 | n_neurons = 10 23 | pop = NeuroPop(n_neurons=n_neurons, tunemodel='glm') 24 | 25 | ######################################################## 26 | # Simulate a population of neurons 27 | # ----------------------------- 28 | 29 | n_samples = 1000 30 | x, Y, mu, k0, k, g, b = pop.simulate(pop.tunemodel, n_samples=n_samples, 31 | winsize=400.0) 32 | 33 | ######################################################## 34 | # Split into training and testing sets 35 | # ----------------------------- 36 | 37 | np.random.seed(42) 38 | (Y_train, Y_test), (x_train, x_test) = train_test_split(Y, x, percent=0.5) 39 | 40 | ######################################################## 41 | # Fit the tuning curves with gradient descent 42 | # ----------------------------- 43 | 44 | pop.fit(x_train, Y_train) 45 | 46 | ######################################################## 47 | # Predict the population activity with the fit tuning curves 48 | # ----------------------------- 49 | 50 | Yhat_test = pop.predict(x_test) 51 | 52 | ######################################################## 53 | # Score the prediction 54 | # ----------------------------- 55 | 56 | Ynull = np.mean(Y_train, axis=0) 57 | pseudo_R2 = pop.score(Y_test, Yhat_test, Ynull, method='pseudo_R2') 58 | print(pseudo_R2) 59 | 60 | ######################################################## 61 | # Plot the simulated and fit tuning curves 62 | # ----------------------------- 63 | 64 | plt.figure(figsize=[15, 15]) 65 | 66 | for neuron in range(pop.n_neurons): 67 | plt.subplot(4, 3, neuron + 1) 68 | pop.display(x_test, Y_test[:, neuron], neuron=neuron, 69 | ylim=[0.8 * np.min(Y_test[:, neuron]), 1.2 * 70 | np.max(Y_test[:, neuron])]) 71 | 72 | plt.show() 73 | 74 | ######################################################## 75 | # Decode feature from the population activity 76 | # ----------------------------- 77 | 78 | xhat_test = pop.decode(Y_test) 79 | 80 | ######################################################## 81 | # Visualize ground truth vs. decoded estimates 82 | # ----------------------------- 83 | 84 | plt.figure(figsize=[6, 5]) 85 | 86 | plt.plot(x_test, xhat_test, 'k.', alpha=0.5) 87 | plt.xlim([-1.2 * np.pi, 1.2 * np.pi]) 88 | plt.ylim([-1.2 * np.pi, 1.2 * np.pi]) 89 | plt.xlabel('Ground truth [radians]') 90 | plt.ylabel('Decoded [radians]') 91 | plt.tick_params(axis='y', right='off') 92 | plt.tick_params(axis='x', top='off') 93 | ax = plt.gca() 94 | ax.spines['top'].set_visible(False) 95 | ax.spines['right'].set_visible(False) 96 | 97 | plt.figure(figsize=[15, 5]) 98 | jitter = 0.2 * np.random.rand(x_test.shape[0]) 99 | plt.subplot(132, polar=True) 100 | plt.plot(x_test, np.ones(x_test.shape[0]) + jitter, 'ko', alpha=0.5) 101 | plt.title('Ground truth') 102 | 103 | plt.subplot(133, polar=True) 104 | plt.plot(xhat_test, np.ones(xhat_test.shape[0]) + jitter, 'co', alpha=0.5) 105 | plt.title('Decoded') 106 | plt.show() 107 | 108 | ######################################################## 109 | # Score decoding performance 110 | # ----------------------------- 111 | 112 | # Circular correlation 113 | circ_corr = pop.score(x_test, xhat_test, method='circ_corr') 114 | print('Circular correlation: %f' % (circ_corr)) 115 | 116 | ######################################################## 117 | 118 | # Cosine distance 119 | cosine_dist = pop.score(x_test, xhat_test, method='cosine_dist') 120 | print('Cosine distance: %f' % (cosine_dist)) 121 | -------------------------------------------------------------------------------- /examples/plot_popvis_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============== 3 | PopVis Example 4 | ============== 5 | 6 | """ 7 | # Authors: Mayank Agrawal 8 | # 9 | # License: MIT 10 | 11 | ######################################################## 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | import pandas as pd 17 | from spykes.plot.neurovis import NeuroVis 18 | from spykes.plot.popvis import PopVis 19 | from spykes.io.datasets import load_reward_data 20 | import random 21 | 22 | ######################################################## 23 | # 0 Initialization 24 | # ----------------------------- 25 | # 26 | # 0.1 Download Data 27 | # ~~~~~~~~~~~~~ 28 | # 29 | # Download all files [`here 30 | # `__] 31 | # However, we'll only be looking at Mihili_08062013.mat (Monkey M, Session 4) 32 | # 33 | # 0.2 Read In Data 34 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 35 | _, mat = load_reward_data() 36 | 37 | ######################################################## 38 | # 39 | # 0.3 Initialize Variables 40 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | event = 'rewardTime' 42 | condition = 'rewardBool' 43 | window = [-500, 1500] 44 | binsize = 10 45 | 46 | ######################################################## 47 | # 1 PopVis 48 | # ----------------------------- 49 | # 50 | # 1.1 Initiate all Neurons 51 | # ~~~~~~~~~~~~~ 52 | 53 | 54 | def get_spike_time(raw_data, neuron_number): 55 | 56 | spike_times = raw_data['alldays'][0][ 57 | 'PMd_units'][0][:][neuron_number - 1][0][1:] 58 | spike_times = [i[0] for i in spike_times] 59 | 60 | return spike_times 61 | 62 | ######################################################## 63 | 64 | 65 | def initiate_neurons(raw_data): 66 | 67 | neuron_list = list() 68 | 69 | for i in range((raw_data['alldays'][0]['PMd_units'][0][:]).shape[0]): 70 | spike_times = get_spike_time(raw_data, i + 1) 71 | 72 | # instantiate neuron 73 | neuron = NeuroVis(spike_times, name='PMd %d' % (i + 1)) 74 | neuron_list.append(neuron) 75 | 76 | return neuron_list 77 | 78 | ######################################################## 79 | 80 | neuron_list = initiate_neurons(mat) 81 | 82 | ######################################################## 83 | # 84 | # 1.2 Get Event Times 85 | # ~~~~~~~~~~~~~ 86 | 87 | 88 | def create_data_frame(raw_data): 89 | 90 | data_df = pd.DataFrame() 91 | 92 | uncertainty_conditions = list() 93 | center_target_times = list() 94 | reward_times = list() 95 | reward_outcomes = list() 96 | 97 | for i in range(raw_data['alldays'].shape[0]): 98 | 99 | meta_data = raw_data['alldays'][i]['tt'][0] 100 | 101 | uncertainty_conditions.append(meta_data[:, 2]) 102 | center_target_times.append(meta_data[:, 3]) 103 | reward_times.append(meta_data[:, 6]) 104 | reward_outcomes.append(meta_data[:, 7]) 105 | 106 | data_df['uncertaintyCondition'] = np.concatenate(uncertainty_conditions) 107 | data_df['centerTargetTime'] = np.concatenate(center_target_times) 108 | data_df['rewardTime'] = np.concatenate(reward_times) 109 | data_df['rewardOutcome'] = np.concatenate(reward_outcomes) 110 | 111 | data_df['rewardBool'] = data_df['rewardOutcome'].map(lambda s: s == 32) 112 | 113 | # find time in between previous reward onset and start of current trial 114 | # shouldn't be more than 1500ms 115 | 116 | start_times = data_df['centerTargetTime'] 117 | last_reward_times = np.roll(data_df['rewardTime'], 1) 118 | 119 | diffs = start_times - last_reward_times 120 | diffs[0] = 0 121 | 122 | data_df['consecutiveBool'] = diffs.map(lambda s: s <= 1.5) 123 | 124 | return data_df[((data_df['uncertaintyCondition'] == 5.0) | 125 | (data_df['uncertaintyCondition'] == 50.0)) & 126 | data_df['consecutiveBool']] 127 | 128 | ######################################################## 129 | 130 | data_df = create_data_frame(mat) 131 | print(len(data_df)) 132 | data_df.head() 133 | 134 | ######################################################## 135 | # 136 | # 1.3 Create PopVis Object 137 | # ~~~~~~~~~~~~~ 138 | 139 | neuron_list = initiate_neurons(mat)[:10] # let's just look at first 10 neurons 140 | pop = PopVis(neuron_list) 141 | 142 | ######################################################## 143 | # 144 | # 1.3.1 Plot Heat Map 145 | # ^^^^^^^^^^^^^^^^^^^ 146 | 147 | fig = plt.figure(figsize=(10, 10)) 148 | fig.subplots_adjust(hspace=.3) 149 | all_psth = pop.get_all_psth( 150 | event=event, df=data_df, conditions=condition, window=window, 151 | binsize=binsize, plot=True) 152 | 153 | ######################################################## 154 | # 155 | # 1.3.2 Plot Heat Map. Sort by Peak Latency 156 | # ^^^^^^^^^^^^^^^^^^^ 157 | 158 | fig = plt.figure(figsize=(10, 10)) 159 | fig.subplots_adjust(hspace=.3) 160 | pop.plot_heat_map(all_psth, sortby='latency') 161 | 162 | ######################################################## 163 | # 164 | # 1.3.3 Plot Heat Map. Sort by Avg Firing Rate in Ascending Order. 165 | # ^^^^^^^^^^^^^^^^^^^ 166 | 167 | fig = plt.figure(figsize=(10, 10)) 168 | fig.subplots_adjust(hspace=.3) 169 | pop.plot_heat_map(all_psth, sortby='rate', sortorder='ascend') 170 | 171 | ######################################################## 172 | # 173 | # 1.3.4 Plot Heat Map. Normalize Each Neuron Individually. 174 | # ^^^^^^^^^^^^^^^^^^^ 175 | 176 | fig = plt.figure(figsize=(10, 10)) 177 | fig.subplots_adjust(hspace=.3) 178 | pop.plot_heat_map(all_psth, normalize='each') 179 | 180 | ######################################################## 181 | # 182 | # 1.3.5 Plot Heat Map. Normalize All Neurons and Sort in Specified Order. 183 | # ^^^^^^^^^^^^^^^^^^^ 184 | 185 | random_list = range(10) 186 | random.shuffle(random_list) 187 | print(random_list) 188 | fig = plt.figure(figsize=(10, 10)) 189 | fig.subplots_adjust(hspace=.3) 190 | pop.plot_heat_map(all_psth, normalize='all', sortby=random_list) 191 | 192 | ######################################################## 193 | # 194 | # 1.3.5. Plot Population PSTH 195 | # ^^^^^^^^^^^^^^^^^^^ 196 | 197 | plt.figure(figsize=(10, 5)) 198 | pop.plot_population_psth(all_psth=all_psth) 199 | -------------------------------------------------------------------------------- /examples/plot_reaching_dataset_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================== 3 | Reaching Dataset Example 4 | ======================== 5 | 6 | A demonstration of Spykes' functionality using the reaching dataset. 7 | 8 | """ 9 | 10 | 11 | ######################################################## 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from spykes.plot.neurovis import NeuroVis 16 | from spykes.ml.neuropop import NeuroPop 17 | from spykes.io.datasets import load_reaching_data 18 | from spykes.utils import train_test_split 19 | import matplotlib.pyplot as plt 20 | 21 | ######################################################## 22 | # Initialization 23 | # ----------------------------- 24 | # 25 | # Download Reaching Dataset 26 | # ~~~~~~~~~~~~~ 27 | # 28 | # [`download link 29 | # `__] 30 | # 31 | # Load Data 32 | 33 | reach_data = load_reaching_data() 34 | 35 | print('dataset keys:', reach_data.keys()) 36 | print('events:', reach_data['events'].keys()) 37 | print('features', reach_data['features'].keys()) 38 | print('number of PMd neurons:', len(reach_data['neurons_PMd'])) 39 | print('number of M1 neurons:', len(reach_data['neurons_M1'])) 40 | 41 | ######################################################## 42 | # Part I: NeuroVis 43 | # ----------------------------- 44 | # 45 | # 46 | # Instantiate Example PMd Neuron 47 | # ~~~~~~~~~~~~~ 48 | 49 | neuron_number = 91 50 | spike_times = reach_data['neurons_PMd'][neuron_number - 1] 51 | neuron_PMd = NeuroVis(spike_times, name='PMd %d' % neuron_number) 52 | 53 | ######################################################## 54 | # 55 | # Raster plot and PSTH aligned to target onset 56 | # ~~~~~~~~~~~~~ 57 | 58 | neuron_PMd.get_raster(event='targetOnTime', df=reach_data['events']) 59 | 60 | ######################################################## 61 | 62 | neuron_PMd.get_psth(event='targetOnTime', df=reach_data['events']) 63 | 64 | ######################################################## 65 | # Let's put the data into a DataFrame 66 | # 67 | # Events 68 | # ~~~~~~~~~~~~~ 69 | 70 | data_df = pd.DataFrame() 71 | events = ['targetOnTime', 'goCueTime', 'rewardTime'] 72 | 73 | for i in events: 74 | data_df[i] = np.squeeze(reach_data['events'][i]) 75 | 76 | 77 | data_df[events].head() 78 | 79 | ######################################################## 80 | # 81 | # Features 82 | # ~~~~~~~~~~~~~ 83 | 84 | data_df['endpointOfReach'] = np.squeeze( 85 | reach_data['features']['endpointOfReach']) 86 | data_df['reward_code'] = reach_data['features']['reward'] 87 | 88 | features = ['endpointOfReach', 'reward_code'] 89 | data_df[features].head() 90 | 91 | ######################################################## 92 | # Let's visualize PSTHs separated by conditions 93 | # 94 | # Example 1: Reward vs No Reward 95 | # ~~~~~~~~~~~~~ 96 | 97 | # We could use 'reward_code' but let's create a boolean feature for reward 98 | data_df['reward'] = data_df['reward_code'].map( 99 | lambda s: {32: True, 34: False, '': np.nan}[s]) 100 | data_df.head() 101 | 102 | ######################################################## 103 | # Plot PSTH 104 | 105 | psth = neuron_PMd.get_psth(event='rewardTime', 106 | conditions='reward', 107 | df=data_df) 108 | 109 | 110 | ######################################################## 111 | # Make it look nicer 112 | 113 | plt.figure(figsize=(10, 5)) 114 | psth = neuron_PMd.get_psth(event='rewardTime', 115 | conditions='reward', 116 | df=data_df, 117 | window=[-200, 600], 118 | binsize=20, 119 | event_name='Reward Time') 120 | 121 | plt.title('neuron %s: Reward' % neuron_PMd.name) 122 | plt.show() 123 | 124 | ######################################################## 125 | # 126 | # Example 2: according to quadrant of reaching direction 127 | # ~~~~~~~~~~~~~ 128 | 129 | data_df['endpointOfReach_quad'] = pd.cut( 130 | data_df['endpointOfReach'], np.linspace(0, 360, 5)) 131 | data_df[['endpointOfReach', 'endpointOfReach_quad']].head() 132 | 133 | ######################################################## 134 | 135 | plt.figure(figsize=(10, 5)) 136 | psth_PMd = neuron_PMd.get_psth(event='targetOnTime', 137 | conditions='endpointOfReach_quad', 138 | df=data_df, 139 | window=[-200, 1000], 140 | binsize=20, 141 | event_name='target onset') 142 | 143 | plt.title('%s: Reach angle quadrant' % neuron_PMd.name) 144 | plt.show() 145 | 146 | ######################################################## 147 | # Raster plots for the same neuron and conditions 148 | 149 | # get rasters 150 | rasters_PMd = neuron_PMd.get_raster(event='targetOnTime', 151 | conditions='endpointOfReach_quad', 152 | df=data_df, 153 | window=[-200, 1000], 154 | binsize=20, 155 | plot=False) 156 | 157 | 158 | # plot rasters 159 | plt.figure(figsize=(15, 6)) 160 | plot_order = np.array([2, 3, 4, 1]) 161 | cmap = ['Oranges', 'Blues', 'Reds', 'Greens'] 162 | for i, cond_id in enumerate(np.sort(rasters_PMd['data'].keys())): 163 | plt.subplot(2, 2, plot_order[i]) 164 | neuron_PMd.plot_raster(rasters_PMd, 165 | cond_id=cond_id, 166 | cmap=cmap[i], 167 | cond_name='reach angle: %s' % cond_id, 168 | sortby='rate', sortorder='ascend') 169 | if plot_order[i] < 3: 170 | plt.xlabel('') 171 | 172 | 173 | ######################################################## 174 | # Example 3: Same as Example 2 but for an M1 neuron and aligned at goCueTime 175 | # ~~~~~~~~~~~~~ 176 | 177 | neuron_number = 100 178 | spike_times = reach_data['neurons_M1'][neuron_number - 1] 179 | neuron_M1 = NeuroVis(spike_times, name='M1 %d' % neuron_number) 180 | 181 | ######################################################## 182 | 183 | plt.figure(figsize=(10, 5)) 184 | psth_M1 = neuron_M1.get_psth(event='goCueTime', 185 | df=data_df, 186 | conditions='endpointOfReach_quad', 187 | window=[-1000, 2000], 188 | binsize=40, 189 | plot=True, 190 | ylim=[0, 70], 191 | event_name='Go Cue') 192 | plt.show() 193 | 194 | ######################################################## 195 | # Example 4: sorted by direction only for the trials with reward 196 | # ~~~~~~~~~~~~~ 197 | 198 | # use standard pandas filtering to isolate trials of interest 199 | trials_df = data_df[data_df['reward'] == True] 200 | trials_df.head() 201 | 202 | ######################################################## 203 | 204 | plt.figure(figsize=(10, 5)) 205 | psth_M1 = neuron_M1.get_psth(event='rewardTime', 206 | conditions='endpointOfReach_quad', 207 | df=trials_df, 208 | window=[-1000, 2000], 209 | binsize=40, 210 | ylim=[0, 70], 211 | event_name='Reward Time' 212 | ) 213 | plt.show() 214 | 215 | ######################################################## 216 | # One last thing before moving on to `spykes.neuropop` 217 | # We can use `get_spikecounts()` to count the number of spikes within a 218 | # certain time window relative to event onset 219 | 220 | spike_counts = neuron_PMd.get_spikecounts( 221 | 'targetOnTime', df=data_df, window=[0, 1200]) 222 | 223 | ######################################################## 224 | 225 | conditions_names = np.unique(data_df['endpointOfReach_quad']) 226 | conditions_names = conditions_names[[0, 3, 1, 2]] 227 | conditions_names 228 | 229 | ######################################################## 230 | # 231 | # Let's visualize the spike counts per trial for each condition 232 | 233 | colors = ['#F5A21E', '#134B64', '#EF3E34', '#02A68E'] 234 | 235 | plt.figure(figsize=(10, 5)) 236 | for i, cond in enumerate(conditions_names): 237 | idx = np.where(data_df['endpointOfReach_quad'] == cond)[0] 238 | x_noise = 0.08 * np.random.randn(np.size(idx)) 239 | plt.plot(i + x_noise + 1, spike_counts[idx], '.', color=colors[i], 240 | alpha=0.3, markersize=20) 241 | 242 | plt.xlabel('condition') 243 | plt.ylabel('number of spikes') 244 | plt.xlim([0, 5]) 245 | plt.xticks(np.arange(np.size(conditions_names)) + 1) 246 | ax = plt.gca() 247 | ax.spines['top'].set_visible(False) 248 | ax.spines['right'].set_visible(False) 249 | plt.tick_params(axis='y', right='off') 250 | plt.tick_params(axis='x', top='off') 251 | plt.legend(conditions_names, frameon=False) 252 | plt.show() 253 | 254 | ######################################################## 255 | # Part II: NeuroPop 256 | # ----------------------------- 257 | # 258 | # Organize data as preferred features and spike counts (x Y) 259 | # 260 | # Extract reach direction x 261 | # ~~~~~~~~~~~~~ 262 | 263 | # Get reach direction, ensure it is between [-pi, pi] 264 | x = np.arctan2(np.sin(reach_data['features']['endpointOfReach'] * 265 | np.pi / 180.0), 266 | np.cos(reach_data['features']['endpointOfReach'] * 267 | np.pi / 180.0)) 268 | 269 | ######################################################## 270 | # Extract M1 spike counts Y 271 | # ~~~~~~~~~~~~~ 272 | # - Select only neurons above a threshold firing rate 273 | # - Align spike counts to the GO cue 274 | # - Use the convenience function ```get_spikecounts()``` from ```NeuroVis``` 275 | 276 | # Select only high firing rate neurons 277 | M1_select = list() 278 | threshold = 10.0 279 | 280 | # Specify timestamps of events to which trials are aligned 281 | align = 'goCueTime' 282 | 283 | # Specify a window of around the go cue for spike counts 284 | window = [0., 500.] # milliseconds 285 | 286 | # Get spike counts 287 | Y = np.zeros([x.shape[0], len(reach_data['neurons_M1'])]) 288 | 289 | for n in range(len(reach_data['neurons_M1'])): 290 | this_neuron = NeuroVis(spiketimes=reach_data['neurons_M1'][n]) 291 | Y[:, n] = np.squeeze( 292 | this_neuron.get_spikecounts(event=align, df=data_df, window=window)) 293 | 294 | # Short list a few high-firing neurons 295 | if this_neuron.firingrate > threshold: 296 | M1_select.append(n) 297 | 298 | # Rescale spike counts to units of spikes/s 299 | Y = Y / float(window[1] - window[0]) * 1e3 300 | 301 | # How many neurons shortlisted? 302 | print('%d M1 neurons had firing rates over %4.1f spks/s' % 303 | (len(M1_select), threshold)) 304 | 305 | ######################################################## 306 | # Split into train and test sets 307 | # ~~~~~~~~~~~~~ 308 | 309 | np.random.seed(42) 310 | (Y_train, Y_test), (x_train, x_test) = train_test_split(Y, x, percent=0.33) 311 | 312 | ######################################################## 313 | # Create an instance of NeuroPop 314 | # ~~~~~~~~~~~~~ 315 | 316 | pop = NeuroPop(n_neurons=len(M1_select), 317 | tunemodel='gvm', 318 | n_repeats=3, 319 | verbose=False) 320 | 321 | # Let's fit tuning curves to the population 322 | # ~~~~~~~~~~~~~ 323 | 324 | pop.fit(np.squeeze(x_train), Y_train[:, M1_select]) 325 | 326 | ######################################################## 327 | # Predict firing rates 328 | # ~~~~~~~~~~~~~ 329 | 330 | Yhat_test = pop.predict(np.squeeze(x_test)) 331 | 332 | ######################################################## 333 | # Score the prediction 334 | # ~~~~~~~~~~~~~ 335 | 336 | # calculate and plot the pseudo R2 337 | Ynull = np.mean(Y_train[:, M1_select], axis=0) 338 | pseudo_R2 = pop.score( 339 | Y_test[:, M1_select], Yhat_test, Ynull, method='pseudo_R2') 340 | 341 | plt.figure(figsize=(10, 5)) 342 | plt.plot(pseudo_R2, 'co', markeredgecolor='c', alpha=0.5, markersize=8) 343 | plt.xlim([-5, len(M1_select) + 5]) 344 | plt.ylim([-0.1, 1]) 345 | plt.xlabel('neurons') 346 | plt.ylabel('pseudo-$R^2$ (test)') 347 | ax = plt.gca() 348 | ax.spines['top'].set_visible(False) 349 | ax.spines['right'].set_visible(False) 350 | plt.tick_params(axis='y', right='off') 351 | plt.tick_params(axis='x', top='off') 352 | 353 | ######################################################## 354 | # Visualize tuning curves 355 | # ~~~~~~~~~~~~~ 356 | 357 | plt.figure(figsize=[15, 70]) 358 | 359 | for neuron in range(len(M1_select)): 360 | plt.subplot(27, 4, neuron + 1) 361 | pop.display(x_test, Y_test[:, M1_select[neuron]], neuron=neuron, 362 | ylim=[0.8 * np.min(Y_test[:, M1_select[neuron]]), 363 | 1.2 * np.max(Y_test[:, M1_select[neuron]])]) 364 | # plt.axis('off') 365 | 366 | plt.show() 367 | 368 | 369 | ######################################################## 370 | # Decode reach direction from population vector 371 | # ~~~~~~~~~~~~~ 372 | 373 | xhat_test = pop.decode(Y_test[:, M1_select]) 374 | 375 | ######################################################## 376 | # Visualize decoded reach direction 377 | # ~~~~~~~~~~~~~ 378 | 379 | plt.figure(figsize=[6, 5]) 380 | 381 | plt.plot(x_test, xhat_test, 'k.', alpha=0.5) 382 | plt.xlim([-1.2 * np.pi, 1.2 * np.pi]) 383 | plt.ylim([-1.2 * np.pi, 1.2 * np.pi]) 384 | plt.xlabel('Ground truth [radians]') 385 | plt.ylabel('Decoded [radians]') 386 | plt.tick_params(axis='y', right='off') 387 | plt.tick_params(axis='x', top='off') 388 | ax = plt.gca() 389 | ax.spines['top'].set_visible(False) 390 | ax.spines['right'].set_visible(False) 391 | 392 | plt.figure(figsize=[15, 5]) 393 | jitter = 0.2 * np.random.rand(x_test.shape[0]) 394 | plt.subplot(121, polar=True) 395 | plt.plot(x_test, np.ones(x_test.shape[0]) + jitter, 'ko', alpha=0.5) 396 | plt.title('Ground truth') 397 | 398 | plt.subplot(122, polar=True) 399 | plt.plot(xhat_test, np.ones(xhat_test.shape[0]) + jitter, 'co', alpha=0.5) 400 | plt.title('Decoded') 401 | plt.show() 402 | 403 | ######################################################## 404 | # Score decoding performance 405 | # ~~~~~~~~~~~~~ 406 | 407 | circ_corr = pop.score(x_test, xhat_test, method='circ_corr') 408 | print('Circular Correlation: %f' % (circ_corr)) 409 | cosine_dist = pop.score(x_test, xhat_test, method='cosine_dist') 410 | print('Cosine Distance: %f' % (cosine_dist)) 411 | -------------------------------------------------------------------------------- /images/journal.pone.0160851.g002.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/images/journal.pone.0160851.g002.PNG -------------------------------------------------------------------------------- /images/nature14178-f2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/images/nature14178-f2.jpg -------------------------------------------------------------------------------- /images/psth_PMd_n91.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/images/psth_PMd_n91.png -------------------------------------------------------------------------------- /images/spykes-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/images/spykes-logo.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | verbosity = 2 3 | detailed-errors = 1 4 | with-coverage = 1 5 | cover-package = spykes 6 | cover-inclusive = 1 7 | nologcapture = 1 8 | 9 | [flake8] 10 | exclude = 11 | .git, 12 | __pycache__, 13 | build, 14 | dist 15 | count = 1 16 | 17 | [metadata] 18 | description-file = README.md 19 | 20 | [aliases] 21 | test = nosetests 22 | flake = flake8 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | DISTNAME = 'spykes' 6 | DESCRIPTION = """Basic tools for neural data analysis and visualization.""" 7 | MAINTAINER = 'Pavan Ramkumar and Hugo Fernandes' 8 | MAINTAINER_EMAIL = 'pavan.ramkumar@gmail.com' 9 | LICENSE = 'MIT' 10 | URL = 'https://github.com/KordingLab/spykes.git' 11 | VERSION = '0.3.dev' 12 | 13 | if __name__ == "__main__": 14 | setup( 15 | name=DISTNAME, 16 | maintainer=MAINTAINER, 17 | maintainer_email=MAINTAINER_EMAIL, 18 | description=DESCRIPTION, 19 | license=LICENSE, 20 | version=VERSION, 21 | url=URL, 22 | long_description=open('README.md').read(), 23 | classifiers=[ 24 | 'Intended Audience :: Science/Research', 25 | 'Intended Audience :: Developers', 26 | 'License :: OSI Approved', 27 | 'Programming Language :: Python', 28 | 'Topic :: Software Development', 29 | 'Topic :: Scientific/Engineering', 30 | 'Operating System :: Microsoft :: Windows', 31 | 'Operating System :: POSIX', 32 | 'Operating System :: Unix', 33 | 'Operating System :: MacOS', 34 | ], 35 | install_requires=[ 36 | 'numpy', 37 | 'scipy', 38 | 'matplotlib', 39 | 'pandas', 40 | 'requests', 41 | ], 42 | extras_require={ 43 | 'deepdish': ['deepdish'], 44 | 'develop': [ 45 | 'nose', 46 | 'coverage', 47 | 'flake8', 48 | 'sphinx', 49 | 'sphinx-gallery', 50 | 'sphinx_rtd_theme', 51 | 'm2r', 52 | 'deepdish', 53 | 'image', 54 | 'tensorflow>=1.4.0', 55 | 'h5py', 56 | ], 57 | 'ml': [ 58 | 'tensorflow>=1.4.0', 59 | ], 60 | }, 61 | platforms='any', 62 | packages=find_packages(exclude=['tests', 'tests.*']), 63 | ) 64 | -------------------------------------------------------------------------------- /spykes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/spykes/__init__.py -------------------------------------------------------------------------------- /spykes/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | # Environment variables are formulated using SPYKES_KEY_%s. 8 | SPYKES_KEY = 'SPYKES' 9 | 10 | # The default data path. 11 | DEFAULT_DATA_DIR = '.spykes' 12 | 13 | # Defines the default colors for the population plot. 14 | DEFAULT_POPULATION_COLORS = [ 15 | '#F5A21E', 16 | '#134B64', 17 | '#EF3E34', 18 | '#02A68E', 19 | '#FF07CD', 20 | ] 21 | 22 | 23 | def get_home_directory(): 24 | '''Returns the home directory, as a string. 25 | 26 | The home directory is either the :data:`HOME` environment variable, or the 27 | directory pointed to by :data:`~`, or the :data:`/tmp` directory if neither 28 | of those have write access. 29 | 30 | Returns: 31 | str: The path to the home directory. 32 | ''' 33 | if 'HOME' in os.environ and os.access(os.environ['HOME'], os.W_OK): 34 | return os.environ['HOME'] 35 | elif os.access(os.path.expanduser('~'), os.W_OK): 36 | return os.path.expanduser('~') 37 | else: 38 | return '/tmp' # Default is to return a temp directory. 39 | 40 | 41 | def get_data_directory(): 42 | '''Returns the home directory for Spykes data. 43 | 44 | By default, this points to :data:`~/.spykes`. This can be overridden by 45 | setting the :data:`SPYKES_DATA` environment variable to point to the 46 | directory of your choice. 47 | 48 | Returns: 49 | str: The path to the data directory. 50 | ''' 51 | data_key = '{prefix}_DATA'.format(prefix=SPYKES_KEY) 52 | 53 | if data_key not in os.environ: 54 | home = get_home_directory() 55 | dir_path = os.path.join(home, DEFAULT_DATA_DIR) 56 | if not os.path.exists(dir_path): 57 | os.makedirs(dir_path) 58 | return dir_path 59 | else: 60 | return os.environ[data_key] 61 | -------------------------------------------------------------------------------- /spykes/io/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .datasets import ( 4 | load_reward_data, 5 | load_neuropixels_data, 6 | load_reaching_data, 7 | ) 8 | 9 | __all__ = [ 10 | 'load_reward_data', 11 | 'load_neuropixels_data', 12 | 'load_reaching_data', 13 | ] 14 | -------------------------------------------------------------------------------- /spykes/io/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import scipy.io 7 | import numpy as np 8 | import requests 9 | import zipfile 10 | 11 | from .. import config 12 | 13 | 14 | def _urlretrieve(url, filename): 15 | '''Defines a convenience method for downloading files with requests. 16 | 17 | Args: 18 | url (str): The URL of the file to download. 19 | filename (str): The path to save the file. 20 | ''' 21 | r = requests.get(url, stream=True) 22 | with open(filename, 'wb') as f: 23 | for chunk in r.iter_content(chunk_size=1024): 24 | if chunk: 25 | f.write(chunk) 26 | 27 | 28 | def _load_file(fpath): 29 | '''Checks whether a file is a .mat or .npy file and loads it. 30 | 31 | This is a convenience method for the other loading functions. 32 | 33 | Args: 34 | fpath (str): The exact path of where data is located. 35 | 36 | Returns: 37 | mat or numpy array: The loaded dataset. 38 | ''' 39 | if fpath[-4:] == '.mat': 40 | data = scipy.io.loadmat(fpath) 41 | elif fpath[-4:] == '.npy': 42 | data = np.load(fpath) 43 | else: 44 | raise ValueError('Invalid file type: {}'.format(fpath)) 45 | return data 46 | 47 | 48 | def _arg_check(name, arg, valid_args): 49 | '''Convenience function for doing argument cleaning and checking.''' 50 | 51 | # Makes sure that the argument is valid. 52 | if arg not in valid_args: 53 | valid_args = list(valid_args) 54 | formatted_args = ', '.join('"{}"'.format(i) for i in valid_args[:-1]) 55 | formatted_args += ' or "{}"'.format(valid_args[-1]) 56 | raise ValueError('Invalid {}: "{}". Expected {}.' 57 | .format(name, arg, formatted_args)) 58 | 59 | return arg 60 | 61 | 62 | def load_spikefinder_data(dir_name='spikefinder'): 63 | '''Downloads and returns a dataset of paired calcium recordings. 64 | 65 | This dataset was used for the Spikefinder competition 66 | (DOI: 10.1101/177956), and consists of datasets of paired calcium traces 67 | and spike trains collected from multiple sources. 68 | 69 | Args: 70 | dir_name (str): Specifies the directory to which the data files should 71 | be downloaded. This is concatenated with the user-set data 72 | directory. 73 | 74 | Returns: 75 | tuple: Paths to the downloaded training and testing datasets. Each 76 | dataset is a CSV which can be loaded using Pandas, 77 | :data:`pd.read_csv(path)`. 78 | 79 | * :data:`train_data`: List of pairs of strings, where each pair 80 | consists of the path to the calcium data (inputs) and the path to 81 | the spike data (ground truth) for that dataset pair. 82 | * :data:`test_data`: List of strings, where each string is the path 83 | to a testing dataset. 84 | ''' 85 | dpath = os.path.join(config.get_data_directory(), dir_name) 86 | if not os.path.exists(dpath): 87 | os.makedirs(dpath) 88 | 89 | url_template = ( 90 | 'https://s3.amazonaws.com/neuro.datasets/' 91 | 'challenges/spikefinder/spikefinder.{version}.zip' 92 | ) 93 | 94 | # Downloads the two datasets. 95 | def _download(version): 96 | zipname = os.path.join(dpath, '{}.zip'.format(version)) 97 | if not os.path.exists(zipname): 98 | url = url_template.format(version=version) 99 | _urlretrieve(url, zipname) 100 | 101 | # Unzips the associated files. 102 | unzip_path = os.path.join(dpath, 'spikefinder.{}'.format(version)) 103 | if not os.path.exists(unzip_path): 104 | zipref = zipfile.ZipFile(zipname, 'r') 105 | zipref.extractall(dpath) 106 | zipref.close() 107 | return unzip_path 108 | 109 | # Downloads the two datasets. 110 | train_path, test_path = _download('train'), _download('test') 111 | train_template = os.path.join(train_path, '{index}.train.{mode}.csv') 112 | test_template = os.path.join(test_path, '{index}.test.calcium.csv') 113 | 114 | # Converts each dataset to a file path. 115 | train_paths = [( 116 | train_template.format(index=i, mode='calcium'), 117 | train_template.format(index=i, mode='spikes'), 118 | ) for i in range(1, 11)] 119 | test_paths = [test_template.format(index=i) for i in range(1, 6)] 120 | 121 | # Checks that all of the files exist. 122 | assert all(os.path.exists(i) and os.path.exists(j) for i, j in train_paths) 123 | assert all(os.path.exists(i) for i in test_paths) 124 | 125 | return train_paths, test_paths 126 | 127 | 128 | def load_reward_data(dir_name='reward'): 129 | '''Downloads and returns the data for the PopVis example. 130 | 131 | Downloads and returns data for Neural Coding Reward Example as well as 132 | PopVis Example. Dataset comes from `Ramkumar et al's` "Premotor and Motor 133 | Cortices Encode Reward" paper. 134 | 135 | Args: 136 | dir_name (str): Specifies the directory to which the data files should 137 | be downloaded. This is concatenated with the user-set data 138 | directory. 139 | 140 | Returns: 141 | tuple: The two downloaded files. 142 | 143 | * :data:`sess_one_mat`: :data:`.mat` file for Monkey M, Session 1. 144 | * :data:`sess_four_mat`: :data:`.mat` file for Monkey M, Session 4. 145 | ''' 146 | dpath = os.path.join(config.get_data_directory(), dir_name) 147 | if not os.path.exists(dpath): 148 | os.makedirs(dpath) 149 | 150 | def download_mat(fname, url): 151 | '''Helper function for downloading the existing MAT files.''' 152 | fpath = os.path.join(dpath, fname) 153 | if not os.path.exists(fname): 154 | _urlretrieve(url, fpath) 155 | return _load_file(fpath) 156 | 157 | # Downloads sess_one_mat. 158 | sess_one_mat = download_mat( 159 | fname='Mihili_07112013.mat', 160 | url='https://ndownloader.figshare.com/files/5652051', 161 | ) 162 | 163 | # Downloads sess_four_mat. 164 | sess_four_mat = download_mat( 165 | fname='Mihili_08062013.mat', 166 | url='https://ndownloader.figshare.com/files/5652060', 167 | ) 168 | 169 | return sess_one_mat, sess_four_mat 170 | 171 | 172 | def load_neuropixels_data(dir_name='neuropixels'): 173 | '''Downloads and returns data for the Neuropixels example. 174 | 175 | The dataset comes from `UCL's Cortex Lab 176 | `_. 177 | 178 | Args: 179 | dir_name (str): Specifies the directory to which the data files 180 | should be downloaded. This is concatenated with the user-set 181 | data directory. 182 | 183 | Returns: 184 | dict: A dictionary where each key corresponds to a needed file. 185 | ''' 186 | dpath = os.path.join(config.get_data_directory(), dir_name) 187 | if not os.path.exists(dpath): 188 | os.makedirs(dpath) 189 | 190 | base_url = 'http://data.cortexlab.net/dualPhase3/data/' 191 | file_dict = dict() 192 | 193 | parent_fnames = [ 194 | 'experiment1stimInfo.mat', 195 | 'experiment2stimInfo.mat', 196 | 'experiment3stimInfo.mat', 197 | 'timeCorrection.mat', 198 | 'timeCorrection.npy', 199 | ] 200 | parent_dir = [ 201 | 'frontal/', 202 | 'posterior/', 203 | ] 204 | subdir_fnames = [ 205 | 'spike_clusters.npy', 206 | 'spike_templates.npy', 207 | 'spike_times.npy', 208 | 'templates.npy', 209 | 'whitening_mat_inv.npy', 210 | 'cluster_groups.csv', 211 | 'channel_positions.npy', 212 | ] 213 | 214 | for name in parent_fnames: 215 | fname = os.path.join(dpath, name) 216 | url = os.path.join(base_url, name) 217 | if not os.path.exists(fname): 218 | _urlretrieve(url, fname) 219 | file_dict[name] = _load_file(fname) 220 | 221 | for directory in parent_dir: 222 | if not os.path.exists(os.path.join(dpath, directory)): 223 | os.makedirs(os.path.join(dpath, directory)) 224 | for subdir in subdir_fnames: 225 | fname = os.path.join(dpath, directory, subdir) 226 | url = os.path.join(base_url, directory, subdir) 227 | if not os.path.exists(fname): 228 | _urlretrieve(url, fname) 229 | key = os.path.join(directory, subdir) 230 | if subdir == 'cluster_groups.csv': 231 | file_dict[key] = np.recfromcsv(fname, delimiter='\t') 232 | else: 233 | file_dict[key] = _load_file(fname) 234 | 235 | return file_dict 236 | 237 | 238 | def load_neuropixels_times(location, cutoff=0.3, dir_name='neuropixels'): 239 | '''Extracts the neuropixel spike deltas. 240 | 241 | This code is adopted from the Cortex Lab's Matlab implementation. This 242 | method provides a simpler interface for loading that data by location. 243 | 244 | Args: 245 | location (str): One of :data:`striatum`, :data:`motor_ctx`, 246 | :data:`thalamus`, :data:`hippocampus`, or :data:`visual_ctx`. 247 | cutoff (double): The cutoff threshold for spike templates. 248 | dir_name (str): Specifies the directory to which the data files 249 | should be downloaded. This is concatenated with the user-set 250 | data directory. 251 | 252 | Returns: 253 | list of arrays: each element contains the spike times for one cluster. 254 | ''' 255 | 256 | # Cleans and validates arguments. 257 | location = location.lower().replace('cortex', 'ctx') 258 | _arg_check('location', location, ('striatum', 'motor_ctx', 'thalamus', 259 | 'hippocampus', 'visual_ctx')) 260 | 261 | fname = 'processed_{}_{}.npy'.format(location, cutoff) 262 | fpath = os.path.join(config.get_data_directory(), dir_name, fname) 263 | 264 | # Loads a cached version, if one exists. 265 | if os.path.exists(fpath): 266 | return np.load(fpath) 267 | 268 | # Parses the mode from the location. 269 | mode = 'frontal' if location in ('striatum', 'motor_ctx') else 'posterior' 270 | 271 | # Initializes the recording frequency. 272 | frequency = 30000.0 273 | 274 | # Loads the data normally. 275 | data_dict = load_neuropixels_data(dir_name=dir_name) 276 | 277 | def _load_key(name, ext='npy', squeeze=True): 278 | key = '{}/{}.{}'.format(mode, name, ext) 279 | return np.squeeze(data_dict[key]) if squeeze else data_dict[key] 280 | 281 | # Loads data that is common to any of the analysis. 282 | clusters = _load_key('spike_clusters') # Number of clusters 283 | spike_times = _load_key('spike_times') / frequency 284 | spike_templates = _load_key('spike_templates') 285 | templates = _load_key('templates') 286 | winv = _load_key('whitening_mat_inv') 287 | y_coords = _load_key('channel_positions')[:, 1] 288 | 289 | # Performs time correction on the spike times if needed. 290 | if mode == 'frontal': 291 | time_correction = data_dict['timeCorrection.npy'] 292 | spike_times = spike_times * time_correction[0] + time_correction[1] 293 | 294 | data = _load_key('cluster_groups', ext='csv', squeeze=False) 295 | 296 | # Gets indices. 297 | cids = np.array([x[0] for x in data]) 298 | cfg = np.array([x[1] for x in data]) 299 | cids, cfgs = (np.asarray(i) for i in zip(*data)) 300 | good_indices = np.in1d(clusters, cids[cfg == b'good']) 301 | 302 | # Orders spikes by how many clusters they are in. 303 | real_clusters = clusters[good_indices] 304 | sort_idx = np.argsort(real_clusters) 305 | sorted_spikes = spike_times[good_indices][sort_idx] 306 | sorted_spike_templates = spike_templates[good_indices][sort_idx] 307 | 308 | # Gets the counts per cluster. 309 | counts_per_cluster = np.bincount(real_clusters) 310 | 311 | # Computes the depth for each spike. 312 | templates_unw = np.array([np.dot(t, winv) for t in templates]) 313 | template_amps = np.ptp(templates_unw, axis=1) 314 | template_thresholds = np.max(template_amps, axis=1, keepdims=True) * cutoff 315 | template_amps[template_amps < template_thresholds] = 0 316 | amp_sums = np.sum(template_amps, axis=1, keepdims=True) 317 | template_depths = ((y_coords * template_amps) / amp_sums).sum(axis=1) 318 | sorted_spike_depths = template_depths[sorted_spike_templates] 319 | 320 | # Splits by cluster and computes the average cluster depth. 321 | split_idxs = np.cumsum(counts_per_cluster[counts_per_cluster != 0])[:-1] 322 | times = np.split(sorted_spikes, split_idxs) 323 | depths = [np.mean(i) for i in np.split(sorted_spike_depths, split_idxs)] 324 | 325 | def _get_range(lo, hi): 326 | return np.array([np.sort(t) for t, d in 327 | zip(times, depths) if lo < d <= hi]) 328 | 329 | if location == 'striatum': 330 | data = _get_range(0, 1550) 331 | elif location == 'motor_ctx': 332 | data = _get_range(1550, 3840) 333 | elif location == 'thalamus': 334 | data = _get_range(0, 1634) 335 | elif location == 'hippocampus': 336 | data = _get_range(1634, 2797) 337 | else: # visual_ctx 338 | data = _get_range(2797, 3840) 339 | 340 | # Caches the data to avoid recomputation. 341 | np.save(fpath, data) 342 | 343 | return data 344 | 345 | 346 | def load_reaching_data(dir_name='reaching'): 347 | '''Downloads and returns data for the Reaching Dataset example. 348 | 349 | The dataset is publicly available `here `_. Because 350 | this is hosted on DropBox, you have to manually visit the link, then 351 | download it to the appropriate location (usually 352 | :data:`~/.spykes/reaching/reaching_dataset.h5`). 353 | 354 | Args: 355 | dir_name (str): Specifies the directory to which the data files 356 | should be downloaded. This is concatenated with the user-set 357 | data directory. 358 | 359 | Returns: 360 | deep dish dataset: The dataset, loaded using :meth:`deepdish.io.load`. 361 | ''' 362 | # Import is performed here so that deepdish is not required for all of 363 | # the "datasets" functions. 364 | import deepdish 365 | 366 | dpath = os.path.join(config.get_data_directory(), dir_name) 367 | if not os.path.exists(dpath): 368 | os.makedirs(dpath) 369 | 370 | # Downloads the file if it doesn't exist already. 371 | fpath = os.path.join(dpath, 'reaching_dataset.h5') 372 | if not os.path.exists(fpath): 373 | url = 'http://goo.gl/eXeUz8' 374 | _urlretrieve(url, fpath) 375 | 376 | data = deepdish.io.load(fpath) 377 | return data 378 | 379 | 380 | def _load_reaching_helper(transformer, identifier, event, feature, neuron, 381 | window_min, window_max, threshold, dir_name): 382 | # Loads the formatted data, if it has already been processed. 383 | fname = '{}.npz'.format('_'.join('{}'.format(i) for i in [ 384 | event, feature, neuron, window_min, window_max, threshold, identifier 385 | ])) 386 | fpath = os.path.join(config.get_data_directory(), dir_name, fname) 387 | if os.path.exists(fpath): 388 | with open(fpath, 'rb') as f: 389 | data = np.load(f) 390 | return data['x'], data['y'] 391 | 392 | # Converts to seconds. 393 | window_max, window_min = window_max * 1e-3, window_min * 1e-3 394 | 395 | # Loads the reaching data normally. 396 | reaching_data = load_reaching_data(dir_name) 397 | 398 | events = list(reaching_data['events'].keys()) 399 | features = list(reaching_data['features'].keys()) 400 | 401 | # Checks the input arguments, throwing helpful error messages if needed. 402 | _arg_check('align event', event, events) 403 | _arg_check('feature', feature, features) 404 | _arg_check('neuron type', neuron, ('M1', 'PMd')) 405 | 406 | neuron_key = 'neurons_{}'.format(neuron) 407 | spike_times = np.asarray([ 408 | np.squeeze(np.sort(s)) for s in reaching_data[neuron_key] 409 | ]) 410 | 411 | # Applies the cutoff threshold. 412 | spike_freqs = np.asarray([len(t) / (t[-1] - t[0]) for t in spike_times]) 413 | thresh_idxs = np.where(spike_freqs > threshold)[0] 414 | spike_times = spike_times[thresh_idxs] 415 | 416 | # Gets the reach angle, in radians. 417 | if feature == 'endpointOfReach': 418 | x = reaching_data['features'][feature] * np.pi / 180.0 419 | x = np.arctan2(np.sin(x), np.cos(x)) 420 | else: 421 | x = reaching_data['features'][feature] 422 | 423 | # Gets the spike responses. 424 | event_data = reaching_data['events'][event].reshape(-1) 425 | 426 | # Gets the on and off times. 427 | on_off = np.sort(np.concatenate([ 428 | event_data + window_min, event_data + window_max 429 | ])) 430 | 431 | # Checks that we haven't violated the order. 432 | window_diff = window_max - window_min 433 | for i in range(1, len(on_off), 2): 434 | d = on_off[i] - on_off[i-1] 435 | if abs(d - window_diff) > 1e-9: 436 | raise ValueError('Samples were found that overlap! Make sure that ' 437 | '`window_max` - `window_min` is small enough. ' 438 | 'Time difference is {:.3f}s, average is {:.3f}s.' 439 | .format(d, window_diff)) 440 | 441 | # Applies the transformation. 442 | y = np.stack([ 443 | transformer(n, on_off) for n in spike_times 444 | ]).transpose(1, 0) 445 | 446 | # Saves the dataset after processing it. 447 | with open(fpath, 'wb') as f: 448 | np.savez(f, x=x, y=y) 449 | 450 | return x, y 451 | 452 | 453 | def load_reaching_rates(event='goCueTime', feature='endpointOfReach', 454 | neuron='M1', window_min=0., window_max=500., 455 | threshold=10., dir_name='reaching'): 456 | '''Extracts the reach direction and spike rates from the reaching dataset. 457 | 458 | Args: 459 | event (str): Event to which to align each trial; :data:`goCueTime`, 460 | :data:`targetOnTime` or :data:`rewardTime`. 461 | feature (str): The feature to get; :data:`endpointOfReach` or 462 | :data:`reward`. 463 | neuron (str): The neuron response to use, either :data:`M1` or 464 | :data:`PMd`. 465 | window_min (double): The lower window value around the align queue to 466 | get spike counts, in milliseconds. 467 | window_max (double): The upper window value around the align queue to 468 | get spike counts, in milliseconds. 469 | threshold (double): The threshold for selecting high-firing neurons, 470 | representing the minimum firing rate in Hz. 471 | dir_name (str): Specifies the directory to which the data files 472 | should be downloaded. This is concatenated with the user-set 473 | data directory. 474 | 475 | Returns: 476 | tuple: The :data:`x` and :data:`y` features of the dataset. 477 | 478 | * :data:`x`: Array with shape :data:`(samples, features)` 479 | * :data:`y`: Array with shape :data:`(samples, neurons)` 480 | ''' 481 | 482 | def _get_spike_rates(n, on_off): 483 | arrs = np.split(n, n.searchsorted(on_off)) 484 | windows = [arrs[i] for i in range(1, len(arrs), 2)] 485 | return np.asarray([len(w) for w in windows]) 486 | 487 | return _load_reaching_helper( 488 | transformer=_get_spike_rates, 489 | identifier='rates', 490 | event=event, 491 | feature=feature, 492 | neuron=neuron, 493 | window_min=window_min, 494 | window_max=window_max, 495 | threshold=threshold, 496 | dir_name=dir_name, 497 | ) 498 | 499 | 500 | def load_reaching_deltas(event='goCueTime', feature='endpointOfReach', 501 | neuron='M1', window_min=0., window_max=500., 502 | threshold=10., dir_name='reaching'): 503 | '''Extracts the reach direction and spike deltas from the reaching dataset. 504 | 505 | The first spike delta is the difference between the event onset minus the 506 | min time and the first spike. The remaining spike deltas are the difference 507 | between the time of the current spike and the time of the previous spike. 508 | The last dimension of the :data:`y` data is a list of variable length, 509 | since there are a variable number of spikes. 510 | 511 | Args: 512 | event (str): Event to which to align each trial; :data:`goCueTime`, 513 | :data:`targetOnTime` or :data:`rewardTime`. 514 | feature (str): The feature to get; :data:`endpointOfReach` or 515 | :data:`reward`. 516 | neuron (str): The neuron response to use, either :data:`M1` or 517 | :data:`PMd`. 518 | window_min (double): The lower window value around the align queue to 519 | get spike counts, in milliseconds. 520 | window_max (double): The upper window value around the align queue to 521 | get spike counts, in milliseconds. 522 | threshold (double): The threshold for selecting high-firing neurons, 523 | representing the minimum firing rate in Hz. 524 | dir_name (str): Specifies the directory to which the data files 525 | should be downloaded. This is concatenated with the user-set 526 | data directory. 527 | 528 | Returns: 529 | tuple: The :data:`x` and :data:`y` features of the dataset. 530 | 531 | * :data:`x`: Array with shape :data:`(samples, features)` 532 | * :data:`y`: Array with shape :data:`(samples, neurons, deltas)` 533 | ''' 534 | 535 | def _get_spike_deltas(n, on_off): 536 | arrs = np.split(n, n.searchsorted(on_off)) 537 | windows = [[on_off[i-1]] + arrs[i] for i in range(1, len(arrs), 2)] 538 | deltas = [w[1:] - w[:-1] for w in windows] 539 | return deltas 540 | 541 | return _load_reaching_helper( 542 | transformer=_get_spike_deltas, 543 | identifier='deltas', 544 | event=event, 545 | feature=feature, 546 | neuron=neuron, 547 | window_min=window_min, 548 | window_max=window_max, 549 | threshold=threshold, 550 | dir_name=dir_name, 551 | ) 552 | -------------------------------------------------------------------------------- /spykes/ml/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .neuropop import NeuroPop 4 | from .strf import STRF 5 | 6 | __all__ = ['NeuroPop', 'STRF'] 7 | -------------------------------------------------------------------------------- /spykes/ml/strf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class STRF(object): 6 | '''Allows the estimation of spatiotemporal receptive fields 7 | 8 | Args: 9 | patch_size (int): Dimension of the square patch spanned by the 10 | spatial basis. 11 | sigma (float): Standard deviation of the Gaussian distribution. 12 | n_spatial_basis (int): Number of spatial basis functions for the 13 | Gaussian basis (has to be a perfect square). 14 | n_temporal_basis (int): Number of temporal basis functions. 15 | ''' 16 | def __init__(self, patch_size=100, sigma=0.5, 17 | n_spatial_basis=25, n_temporal_basis=3): 18 | self.patch_size = patch_size 19 | self.sigma = sigma 20 | self.n_spatial_basis = n_spatial_basis 21 | self.n_temporal_basis = n_temporal_basis 22 | 23 | def make_2d_gaussian(self, center=(0, 0)): 24 | '''Makes a 2D Gaussian filter with arbitary mean and variance. 25 | 26 | Args: 27 | center (tuple): The coordinates of the center of the Gaussian, 28 | specified as :data:`(row, col)`. The center of the image is 29 | :data:`(0, 0)`. 30 | 31 | Returns: 32 | numpy array: The Gaussian mask. 33 | ''' 34 | sigma = self.sigma 35 | n_rows = (self.patch_size - 1.) / 2. 36 | n_cols = (self.patch_size - 1.) / 2. 37 | 38 | y, x = np.ogrid[-n_rows: n_rows + 1, -n_cols: n_cols + 1] 39 | y0, x0 = center[1], center[0] 40 | gaussian_mask = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / 41 | (2. * sigma ** 2)) 42 | gaussian_mask[gaussian_mask < 43 | np.finfo(gaussian_mask.dtype).eps * 44 | gaussian_mask.max()] = 0 45 | gaussian_mask = 1. / gaussian_mask.max() * gaussian_mask 46 | return gaussian_mask 47 | 48 | def make_gaussian_basis(self): 49 | '''Makes a list of Gaussian filters. 50 | 51 | Returns: 52 | list: A list where each entry is a 2D array of size 53 | :data:`(patch_size, patch_size)` specifing the spatial basis. 54 | ''' 55 | spatial_basis = list() 56 | n_tiles = int(np.sqrt(self.n_spatial_basis)) 57 | n_pixels = self.patch_size 58 | centers = np.linspace(start=(-n_pixels / 2. + 59 | n_pixels / (n_tiles + 1.)), 60 | stop=(n_pixels / 2. - 61 | n_pixels / (n_tiles + 1.)), 62 | num=n_tiles) 63 | 64 | for y in range(n_tiles): 65 | for x in range(n_tiles): 66 | gaussian_mask = self.make_2d_gaussian(center=(centers[x], 67 | centers[y])) 68 | spatial_basis.append(gaussian_mask) 69 | return spatial_basis 70 | 71 | def make_cosine_basis(self): 72 | '''Makes a spatial cosine and sine basis. 73 | 74 | Returns: 75 | list: A list where each entry is a 2D array of size 76 | :data:`(patch_size, patch_size)` specifing the spatial basis. 77 | ''' 78 | patch_size = self.patch_size 79 | cosine_mask = np.zeros((patch_size, patch_size)) 80 | sine_mask = np.zeros((patch_size, patch_size)) 81 | for row in np.arange(patch_size): 82 | for col in np.arange(patch_size): 83 | theta = np.arctan2(patch_size / 2 - row, col - patch_size / 2) 84 | cosine_mask[row, col] = np.cos(theta) 85 | sine_mask[row, col] = np.sin(theta) 86 | 87 | spatial_basis = list() 88 | spatial_basis.append(cosine_mask) 89 | spatial_basis.append(sine_mask) 90 | return spatial_basis 91 | 92 | def visualize_gaussian_basis(self, spatial_basis, color='Greys', 93 | show=True): 94 | '''Plots spatial basis functions in a tile of images. 95 | 96 | Args: 97 | spatial_basis (list): A list where each entry is a 2D array of size 98 | :data:`(patch_size, patch_size)` specifing the spatial basis. 99 | color (str): The color for the figure. 100 | show (bool): Whether or not to show the image when it is plotted. 101 | ''' 102 | n_spatial_basis = len(spatial_basis) 103 | n_tiles = np.sqrt(n_spatial_basis) 104 | plt.figure(figsize=(7, 7)) 105 | for i in range(n_spatial_basis): 106 | plt.subplot(np.int(n_tiles), np.int(n_tiles), i + 1) 107 | plt.imshow(spatial_basis[i], cmap=color) 108 | plt.axis('off') 109 | if show: 110 | plt.show() 111 | 112 | def project_to_spatial_basis(self, image, spatial_basis): 113 | '''Projects a given image into a spatial basis. 114 | 115 | Args: 116 | image (numpy array): Image that must be projected into the 117 | spatial basis 2D array of size :data:`(patch_size, 118 | patch_size)`. 119 | spatial_basis (list): A list where each entry is a 2D array of size 120 | :data:`(patch_size, patch_size)` specifing the spatial basis. 121 | 122 | Returns: 123 | numpy array: A 1D array of coefficients for the projection. 124 | ''' 125 | n_spatial_basis = len(spatial_basis) 126 | weights = np.zeros(n_spatial_basis) 127 | for b in range(n_spatial_basis): 128 | weights[b] = np.sum(spatial_basis[b] * image) 129 | return weights 130 | 131 | def make_image_from_spatial_basis(self, basis, weights): 132 | '''Recovers an image from a basis given a set of weights. 133 | 134 | Args: 135 | spatial_basis (list): A list where each entry is a 2D array of size 136 | :data:`(patch_size, patch_size)` specifing the spatial basis. 137 | weights (numpy array): A 1D array of coefficients. 138 | 139 | Returns: 140 | numpy array: 2D array of size :data:`(patch_size, patch_size)`, the 141 | resulting image. 142 | ''' 143 | image = np.zeros(basis[0].shape) 144 | n_basis = len(basis) 145 | for b in range(n_basis): 146 | image += weights[b] * basis[b] 147 | return image 148 | 149 | def make_raised_cosine_temporal_basis(self, time_points, centers, widths): 150 | '''Makes a series of raised cosine temporal basis. 151 | 152 | Args: 153 | time_points (numpy array): List of time points at which the basis 154 | function is computed. 155 | centers (numpy array or list): List of coordinates at which each 156 | basis function is centered 1D array of size 157 | :data:`(n_temporal_basis)`. 158 | widths (numpy array or list): List of widths, one per basis 159 | function; each is a 1D array of size 160 | :data:`(n_temporal_basis)`. 161 | 162 | Returns: 163 | numpy array: 2D array of size :data:`(n_basis, n_timepoints)`. 164 | ''' 165 | temporal_basis = list() 166 | for idx, center in enumerate(centers): 167 | this_basis = np.zeros(len(time_points)) 168 | arg_cos = (time_points - center) * np.pi / widths[idx] / 2. 169 | arg_cos[arg_cos > np.pi] = np.pi 170 | arg_cos[arg_cos < -np.pi] = -np.pi 171 | this_basis = 0.5 * (np.cos(arg_cos) + 1.) 172 | temporal_basis.append(this_basis) 173 | temporal_basis = np.transpose(np.array(temporal_basis)) 174 | return temporal_basis 175 | 176 | def convolve_with_temporal_basis(self, design_matrix, temporal_basis): 177 | '''Convolves a design matrix with a temporal basis function. 178 | 179 | Convolves each column of the design matrix with a series of temporal 180 | basis functions. 181 | 182 | Args: 183 | design_matrix (numpy array): 2D array of size 184 | :data:`(n_samples, n_features)`. 185 | temporal_basis (numpy array): 2D array of size 186 | :data:`(n_basis, n_timepoints)`. 187 | 188 | Returns: 189 | numpy array: 2D array of size 190 | :data:`(n_samples, n_features * n_basis)`. 191 | ''' 192 | n_temporal_basis = temporal_basis.shape[1] 193 | n_features = design_matrix.shape[1] 194 | convolved_design_matrix = list() 195 | for feat in range(n_features): 196 | for b in range(n_temporal_basis): 197 | convolved_design_matrix.append( 198 | np.convolve(design_matrix[:, feat], temporal_basis[:, b], 199 | mode='same')) 200 | convolved_design_matrix = \ 201 | np.transpose(np.array(convolved_design_matrix)) 202 | return convolved_design_matrix 203 | 204 | def design_prior_covariance(self, sigma_temporal=2., sigma_spatial=5.): 205 | '''Design a prior covariance matrix for STRF estimation. 206 | 207 | Args: 208 | sigma_temporal (float): Standard deviation of temporal prior 209 | covariance. 210 | sigma_spatial (float): Standard deviation of spatial prior 211 | covariance. 212 | 213 | Returns: 214 | numpy array: 2-d array of size :data:`(n_spatial_basis * 215 | n_temporal_basis, n_spatial_basis * n_temporal_basis)`, the 216 | ordering of rows and columns is so that all temporal basis are 217 | consecutive for each spatial basis. 218 | ''' 219 | n_spatial_basis = self.n_spatial_basis 220 | n_temporal_basis = self.n_temporal_basis 221 | 222 | n_features = n_temporal_basis * n_spatial_basis 223 | sp_covariance = np.zeros([n_features, n_features]) 224 | te_covariance = np.zeros([n_features, n_features]) 225 | prior_covariance = np.zeros([n_features, n_features]) 226 | for i in np.arange(0, n_features): 227 | # Get spatiotemporal indices 228 | s_i = np.floor(np.float(i) % 229 | (n_temporal_basis * n_spatial_basis) / 230 | n_temporal_basis) 231 | t_i = i % n_temporal_basis 232 | # Convert spatial indices to (x,y) coordinates 233 | x_i = s_i % np.sqrt(n_spatial_basis) 234 | y_i = np.floor(np.float(s_i) / np.sqrt(n_spatial_basis)) 235 | 236 | for j in np.arange(i, n_features): 237 | # Get spatiotemporal indices 238 | s_j = np.floor(np.float(j) % 239 | (n_temporal_basis * n_spatial_basis) / 240 | n_temporal_basis) 241 | t_j = j % n_temporal_basis 242 | # Convert spatial indices to (x,y) coordinates 243 | x_j = s_j % np.sqrt(n_spatial_basis) 244 | y_j = np.floor(np.float(s_j) / np.sqrt(n_spatial_basis)) 245 | 246 | sp_covariance[i, j] = np.exp(-1. / (sigma_spatial ** 2) * 247 | ((x_i - x_j) ** 2 + 248 | (y_i - y_j) ** 2)) 249 | sp_covariance[j, i] = sp_covariance[i, j] 250 | te_covariance[i, j] = np.exp(-1. / (sigma_temporal ** 2) * 251 | (t_i - t_j) ** 2) 252 | te_covariance[j, i] = te_covariance[i, j] 253 | 254 | prior_covariance = sp_covariance * te_covariance 255 | prior_covariance = 1. / np.max(prior_covariance) * prior_covariance 256 | return prior_covariance 257 | -------------------------------------------------------------------------------- /spykes/ml/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .sparse_filtering import SparseFiltering 4 | 5 | # Checks that the correct version of TensorFlow is installed. 6 | MIN_TF_VERSION = '1.4.0' 7 | try: 8 | from distutils.version import LooseVersion 9 | import tensorflow as tf 10 | assert LooseVersion(tf.__version__) >= LooseVersion(MIN_TF_VERSION) 11 | except (ImportError, AssertionError): 12 | raise RuntimeError('To use the `tensorflow` submodule, your Tensorflow ' 13 | 'distribution must be at least version {version}.' 14 | .format(version=MIN_TF_VERSION)) 15 | 16 | __all__ = ['SparseFiltering'] 17 | -------------------------------------------------------------------------------- /spykes/ml/tensorflow/poisson_models.py: -------------------------------------------------------------------------------- 1 | from math import pi as PI 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras as ks 5 | 6 | 7 | class PoissonLayer(ks.layers.Layer): 8 | '''Defines a TensorFlow implementation of the NeuroPop layers. 9 | 10 | Two types of models are available. `The Generalized von Mises model by 11 | Amirikan & Georgopulos (2000) `_ is 12 | defined by 13 | 14 | .. math:: 15 | 16 | f(x) = b + g * exp(k * cos(x - mu)) 17 | 18 | f(x) = b + g * exp(k1 * cos(x) + k2 * sin(x)) 19 | 20 | The Poisson generalized linear model is defined by 21 | 22 | .. math:: 23 | 24 | f(x) = exp(k0 + k * cos(x - mu)) 25 | 26 | f(x) = exp(k0 + k1 * cos(x) + k2 * sin(x)) 27 | 28 | 29 | Args: 30 | model_type (str): Can be either :data:`gvm`, the Generalized von Mises 31 | model, or :data:`glm`, the Poisson generalized linear model. 32 | num_neurons (int): Number of neurons in the population (being inferred 33 | from the input features). 34 | num_features (int): Number of input features. Convenience parameter for 35 | for setting the input shape. 36 | mu_initializer (Keras initializer): The initializer for the :data:`mu`. 37 | k_initializer (Keras initializer): The initializer for the :data:`k`. 38 | g_initializer (Keras initializer): The initializer for the :data:`g`. 39 | If :data:`model_type` is :data:`glm`, this is ignored. 40 | b_initializer (Keras initializer): The initializer for the :data:`b`. 41 | If :data:`model_type` is :data:`glm`, this is ignored. 42 | k0_initializer (Keras initializer): The initializer for the :data:`k0`. 43 | If :data:`model_type` is :data:`gvm`, this is ignored. 44 | ''' 45 | 46 | def __init__(self, 47 | model_type, 48 | num_neurons, 49 | num_features=None, 50 | mu_initializer=ks.initializers.RandomUniform(-PI, PI), 51 | k_initializer=ks.initializers.RandomNormal(stddev=.2), 52 | g_initializer=ks.initializers.RandomNormal(stddev=.05), 53 | b_initializer=ks.initializers.RandomNormal(stddev=.1), 54 | k0_initializer=ks.initializers.RandomNormal(stddev=.01), 55 | **kwargs): 56 | if num_features is not None: 57 | kwargs['input_shape'] = (num_features,) 58 | super(PoissonLayer, self).__init__(**kwargs) 59 | self.model_type = model_type.lower() 60 | if self.model_type not in ('gvm', 'glm'): 61 | raise ValueError('Invalid model type: "{}" Must be either "gvm" ' 62 | '(generalised Von Mises model) or "glm" ' 63 | '(generalized linear model)'.format(model_type)) 64 | 65 | self.num_neurons = num_neurons 66 | self.mu_initializer = ks.initializers.get(mu_initializer) 67 | self.g_initializer = ks.initializers.get(g_initializer) 68 | self.b_initializer = ks.initializers.get(b_initializer) 69 | self.k_initializer = ks.initializers.get(k_initializer) 70 | self.k0_initializer = ks.initializers.get(k0_initializer) 71 | 72 | def build(self, input_shape): 73 | assert len(input_shape) == 2 74 | input_dim = input_shape[-1] 75 | 76 | self.mu = self.add_weight( 77 | shape=(input_dim, self.num_neurons), 78 | initializer=self.mu_initializer, 79 | name='mu', 80 | ) 81 | self.k1 = self.add_weight( 82 | shape=(input_dim, self.num_neurons), 83 | initializer=self.k_initializer, 84 | name='k1', 85 | ) 86 | self.k2 = self.add_weight( 87 | shape=(input_dim, self.num_neurons), 88 | initializer=self.k_initializer, 89 | name='k2', 90 | ) 91 | 92 | # Adds generalized Von Mises parameters. 93 | if self.model_type == 'gvm': 94 | self.g = self.add_weight( 95 | shape=(1, input_dim), 96 | initializer=self.g_initializer, 97 | name='g', 98 | ) 99 | self.b = self.add_weight( 100 | shape=(1, input_dim), 101 | initializer=self.b_initializer, 102 | name='b', 103 | ) 104 | 105 | # Adds generalized linear model parameters. 106 | if self.model_type == 'glm': 107 | self.k0 = self.add_weight( 108 | shape=(1, input_dim), 109 | initializer=self.k_initializer, 110 | name='k0', 111 | ) 112 | 113 | def call(self, inputs): 114 | k1 = tf.matmul(tf.cos(inputs), self.k1 * tf.cos(self.mu)) 115 | k2 = tf.matmul(tf.sin(inputs), self.k2 * tf.sin(self.mu)) 116 | 117 | # Defines the two model formulations: "glm" vs "gvm". 118 | if self.model_type == 'glm': 119 | return tf.exp(k1 + k2 + self.k0) 120 | else: 121 | return tf.nn.softplus(self.b) + self.g * tf.exp(k1 + k2) 122 | 123 | def get_config(self): 124 | config = { 125 | 'model_type': self.model_type, 126 | 'mu_initializer': ks.initializers.serialize(self.mu_initializer), 127 | 'g_initializer': ks.initializers.serialize(self.g_initializer), 128 | 'b_initializer': ks.initializers.serialize(self.b_initializer), 129 | 'k_initializer': ks.initializers.serialize(self.k_initializer), 130 | 'k0_initializer': ks.initializers.serialize(self.k0_initializer), 131 | } 132 | base_config = super(PoissonLayer, self).get_config() 133 | return dict(list(base_config.items()) + list(config.items())) 134 | 135 | def compute_output_shape(self, input_shape): 136 | output_shape = list(input_shape) 137 | output_shape[-1] = self.num_neurons 138 | return tuple(output_shape) 139 | -------------------------------------------------------------------------------- /spykes/ml/tensorflow/sparse_filtering.py: -------------------------------------------------------------------------------- 1 | import six 2 | import collections 3 | 4 | import tensorflow as tf 5 | from tensorflow import keras as ks 6 | 7 | 8 | def sparse_filtering_loss(_, y_pred): 9 | '''Defines the sparse filtering loss function. 10 | 11 | Args: 12 | y_true (tensor): The ground truth tensor (not used, since this is an 13 | unsupervised learning algorithm). 14 | y_pred (tensor): Tensor representing the feature vector at a 15 | particular layer. 16 | 17 | Returns: 18 | scalar tensor: The sparse filtering loss. 19 | ''' 20 | y = tf.reshape(y_pred, tf.stack([-1, tf.reduce_prod(y_pred.shape[1:])])) 21 | l2_normed = tf.nn.l2_normalize(y, axis=1) 22 | l1_norm = tf.norm(l2_normed, ord=1, axis=1) 23 | return tf.reduce_sum(l1_norm) 24 | 25 | 26 | class SparseFiltering(object): 27 | '''Defines a class for performing sparse filtering on a dataset. 28 | 29 | The MATLAB code on which this is based is available `here 30 | `_, from the paper 31 | `Sparse Filtering by Ngiam et. al. 32 | `_. 33 | 34 | Args: 35 | model (Keras model): The trainable model, which takes as inputs the 36 | data you are training on and outputs a feature vector, which is 37 | minimized according to the loss function described above. 38 | layers (str or list of str): An optional name or list of names of 39 | layers in the provided model whose outputs to apply the sparse 40 | filtering loss to. If none are provided, the sparse filtering loss 41 | is applied to each layer in the model. 42 | ''' 43 | 44 | def __init__(self, model, layers=None): 45 | assert isinstance(model, ks.models.Model) 46 | assert len(model.inputs) == 1 and len(model.outputs) == 1 47 | self.model = model 48 | 49 | # Parses the "layers" argument. 50 | if layers is None: 51 | self.layer_names = [layer.name for layer in model.layers] 52 | elif isinstance(layers, six.string_types): 53 | self.layer_names = [layers] 54 | elif isinstance(layers, collections.Iterable): 55 | self.layer_names = list(layers) 56 | else: 57 | raise ValueError('`layers` must be a string (a single layer) or ' 58 | 'a list of strings. Got: "{}"'.format(layers)) 59 | 60 | self._submodels = None 61 | self._submodel_map = None 62 | 63 | @property 64 | def submodels(self): 65 | if self._submodels is None: 66 | raise RuntimeError('This model must be compiled before you can ' 67 | 'access the `submodels` parameter.') 68 | return self._submodels 69 | 70 | @property 71 | def num_layers(self): 72 | return len(self.layer_names) 73 | 74 | def get_submodel(self, model): 75 | if self._submodel_map is None: 76 | raise RuntimeError('This model must be compiled before you can ' 77 | 'get a particular submodel') 78 | 79 | if model not in self._submodel_map: 80 | raise ValueError('Submodel not found: "{}". Must be one of {}' 81 | .format(model, list(self._submodel_map.keys()))) 82 | 83 | return self._submodel_map[model] 84 | 85 | def _clean_maybe_iterable_param(self, it, param): 86 | '''Converts a potential iterable or single value to a list of values. 87 | 88 | After being cleaned, the iterable is guarenteed to be a list of the 89 | same length as the number of layer names. 90 | 91 | Args: 92 | it (single value or iterable): The iterable to clean. 93 | param (str): The name of the parameter being set. 94 | 95 | Returns: 96 | list: a list of values of the same length as the layer names. 97 | ''' 98 | if isinstance(it, six.string_types): 99 | return [it] * self.num_layers 100 | elif isinstance(it, collections.Iterable): 101 | it = list(it) 102 | if len(it) != self.num_layers: 103 | raise ValueError('Provided {} values for `{}`, ' 104 | 'but one parameter is needed for each ' 105 | 'requested layer ({} layers).' 106 | .format(len(it), param, self.num_layers)) 107 | return it 108 | else: 109 | return [it] * self.num_layers 110 | 111 | def compile(self, optimizer, freeze=False, **kwargs): 112 | '''Compiles the model to create all the submodels. 113 | 114 | Args: 115 | optimizer (str or list of str): The optimizer to use. If a list is 116 | provided, it must specify one optimizer for each layer 117 | passed to the constructor. 118 | freeze (bool): If set, for each submodel, all the previous layers 119 | are frozen, so that only the last layer is "learned". 120 | kwargs (dict): Extra arguments to be passed to the :data:`compile` 121 | function of each submodel. 122 | ''' 123 | if self._submodels is not None: 124 | raise RuntimeError('This model has already been compiled!') 125 | optimizer = self._clean_maybe_iterable_param(optimizer, 'optimizer') 126 | 127 | # Creates each submodel. 128 | self._submodels = [] 129 | input_layer = self.model.input 130 | for layer_name, o in zip(self.layer_names, optimizer): 131 | output_layer = self.model.get_layer(layer_name).output 132 | 133 | submodel = ks.models.Model( 134 | inputs=input_layer, 135 | outputs=output_layer, 136 | ) 137 | 138 | # Freezes all but the selected layer. 139 | if freeze: 140 | for layer in submodel.layers: 141 | layer.trainable = layer.name == layer_name 142 | 143 | submodel.compile(loss=sparse_filtering_loss, optimizer=o, **kwargs) 144 | submodel._make_train_function() # Forces submodel compilation. 145 | self._submodels.append(submodel) 146 | 147 | # Maps the layer names ot the submodels. 148 | self._submodel_map = dict(zip(self.layer_names, self._submodels)) 149 | 150 | def fit(self, x, epochs=1, **kwargs): 151 | '''Fits the model to the provided data. 152 | 153 | Args: 154 | x (Numpy array): An array where the first dimension is the batch 155 | size, and the remaining dimensions match the input dimensions 156 | of the provided model. 157 | epochs (int or list of ints): The number of epochs to train. 158 | If a list is provided, there must be one value for each named 159 | layer, specifying the number of epochs to train that layer for. 160 | 161 | Returns: 162 | list: A list of histories, the training history for each submodel. 163 | ''' 164 | histories = [] 165 | nb_epochs = self._clean_maybe_iterable_param(epochs, 'epochs') 166 | for submodel, n in zip(self._submodels, nb_epochs): 167 | histories.append(submodel.fit(x=x, y=x, epochs=n, **kwargs)) 168 | return histories 169 | -------------------------------------------------------------------------------- /spykes/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .neurovis import NeuroVis 4 | from .popvis import PopVis 5 | 6 | __all__ = ['NeuroVis', 'PopVis'] 7 | -------------------------------------------------------------------------------- /spykes/plot/neurovis.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from .. import utils 7 | from ..config import DEFAULT_POPULATION_COLORS 8 | 9 | 10 | class NeuroVis(object): 11 | '''This class is used to visualize firing activity of single neurons. 12 | 13 | This class implements several conveniences for visualizing firing 14 | activity of single neurons. 15 | 16 | Args: 17 | spiketimes (Numpy array): Array of spike times. 18 | name (str): The name of the visualization. 19 | ''' 20 | def __init__(self, spiketimes, name='neuron'): 21 | self.name = name 22 | self.spiketimes = np.squeeze(np.sort(spiketimes)) 23 | n_seconds = (self.spiketimes[-1] - self.spiketimes[0]) 24 | n_spikes = np.size(spiketimes) 25 | self.firingrate = (n_spikes / n_seconds) 26 | 27 | def get_raster(self, event=None, conditions=None, df=None, 28 | window=[-100, 500], binsize=10, plot=True, 29 | sortby=None, sortorder='descend'): 30 | '''Computes the raster and plots it. 31 | 32 | Args: 33 | event (str): Column/key name of DataFrame/dictionary "data" which 34 | contains event times in milliseconds (e.g. 35 | stimulus/trial/fixation onset, etc.) 36 | conditions (str): Column/key name of DataFrame/dictionary 37 | :data:`data` which contains the conditions by which the trials 38 | must be grouped. 39 | df (DataFrame or dictionary): The dataframe containing the data, 40 | or a dictionary with the equivalent structure. 41 | window (list of 2 elements): Time interval to consider, in 42 | milliseconds. 43 | binsize (int): Bin size in milliseconds 44 | plot (bool): If True then plot 45 | sortby (str or list): If :data:`rate`, sort by firing rate. If 46 | :data:`latency`, sort by peak latency. If a list, integers to 47 | be used as sorting indices. 48 | sortorder (str): Direction to sort, either :data:`descend` or 49 | :data:`ascend`. 50 | 51 | Returns: 52 | dict: :data:`rasters` with keys :data:`event`, :data:`conditions`, 53 | :data:`binsize`, :data:`window`, and :data:`data`. 54 | :data:`rasters['data']` is a dictionary where each value is a 55 | raster for each unique entry of :data:`df['conditions']`. 56 | ''' 57 | 58 | if not type(df) is dict: 59 | df = df.reset_index() 60 | 61 | window = [np.floor(window[0] / binsize) * binsize, 62 | np.ceil(window[1] / binsize) * binsize] 63 | 64 | # Get a set of binary indicators for trials of interest 65 | if conditions: 66 | trials = dict() 67 | for cond_id in np.sort(df[conditions].unique()): 68 | trials[cond_id] = \ 69 | np.where((df[conditions] == cond_id).apply( 70 | lambda x: (0, 1)[x]).values)[0] 71 | else: 72 | trials = dict() 73 | trials[0] = np.where(np.ones(np.size(df[event])))[0] 74 | 75 | # Initialize rasters 76 | rasters = { 77 | 'event': event, 78 | 'conditions': conditions, 79 | 'window': window, 80 | 'binsize': binsize, 81 | 'data': {}, 82 | } 83 | 84 | # Loop over each raster 85 | for cond_id in trials: 86 | # Select events relevant to this raster 87 | selected_events = df[event][trials[cond_id]] 88 | 89 | raster = [] 90 | 91 | bin_template = 1e-3 * \ 92 | np.arange(window[0], window[1] + binsize, binsize) 93 | for event_time in selected_events: 94 | bins = event_time + bin_template 95 | 96 | # consider only spikes within window 97 | searchsorted_idx = np.squeeze(np.searchsorted(self.spiketimes, 98 | [event_time + 1e-3 * 99 | window[0], 100 | event_time + 1e-3 * 101 | window[1]])) 102 | 103 | # bin the spikes into time bins 104 | 105 | spike_counts = np.histogram( 106 | self.spiketimes[searchsorted_idx[0]:searchsorted_idx[1]], 107 | bins)[0] 108 | 109 | raster.append(spike_counts) 110 | 111 | rasters['data'][cond_id] = np.array(raster) 112 | 113 | # Show the raster 114 | if plot is True: 115 | self.plot_raster(rasters, cond_id=None, sortby=sortby, 116 | sortorder=sortorder) 117 | 118 | # Return all the rasters 119 | return rasters 120 | 121 | def plot_raster(self, rasters, cond_id=None, cond_name=None, sortby=None, 122 | sortorder='descend', cmap='Greys', has_title=True): 123 | '''Plot a single raster. 124 | 125 | Args: 126 | rasters (dict): Output of get_raster method 127 | cond_id (str): Which raster to plot indicated by the key in 128 | :data:`rasters['data']`. If None then all are plotted. 129 | cond_name (str): Name to appear in the title. 130 | sortby (str or list): If :data:`rate`, sort by firing rate. If 131 | :data:`latency`, sort by peak latency. If a list, integers to 132 | be used as sorting indices. 133 | sortorder (str): Direction to sort in, either :data:`descend` or 134 | :data:`ascend`. 135 | cmap (str): Colormap for raster. 136 | has_title (bool): If True then adds title. 137 | ''' 138 | window = rasters['window'] 139 | binsize = rasters['binsize'] 140 | 141 | xtics = [window[0], 0, window[1]] 142 | xtics = [str(i) for i in xtics] 143 | xtics_loc = [-0.5, (-window[0]) / binsize - 0.5, 144 | (window[1] - window[0]) / binsize - 0.5] 145 | 146 | if cond_id is None: 147 | for cond in list(rasters['data']): 148 | self.plot_raster(rasters, cond_id=cond, cond_name=cond_name, 149 | sortby=sortby, sortorder=sortorder, cmap=cmap, 150 | has_title=has_title) 151 | plt.show() 152 | else: 153 | raster = rasters['data'][cond_id] 154 | 155 | if len(raster) > 0: 156 | sort_idx = utils.get_sort_indices( 157 | data=raster, 158 | by=sortby, 159 | order=sortorder, 160 | ) 161 | raster_sorted = raster[sort_idx] 162 | 163 | plt.imshow(raster_sorted, aspect='auto', 164 | interpolation='none', cmap=plt.get_cmap(cmap)) 165 | 166 | plt.axvline( 167 | (-window[0]) / binsize - 0.5, color='r', linestyle='--') 168 | plt.ylabel('trials') 169 | plt.xlabel('time [ms]') 170 | plt.xticks(xtics_loc, xtics) 171 | 172 | if has_title: 173 | if cond_id: 174 | if cond_name: 175 | plt.title('neuron %s. %s' % 176 | (self.name, cond_name)) 177 | else: 178 | plt.title('neuron %s. %s: %s' % 179 | (self.name, rasters['conditions'], 180 | cond_id)) 181 | else: 182 | plt.title('neuron %s' % self.name) 183 | 184 | ax = plt.gca() 185 | ax.spines['top'].set_visible(False) 186 | ax.spines['right'].set_visible(False) 187 | ax.spines['bottom'].set_visible(False) 188 | ax.spines['left'].set_visible(False) 189 | plt.tick_params(axis='x', which='both', top='off') 190 | plt.tick_params(axis='y', which='both', right='off') 191 | 192 | else: 193 | print('No trials for this condition!') 194 | 195 | def get_psth(self, event=None, df=None, conditions=None, cond_id=None, 196 | window=[-100, 500], binsize=10, plot=True, event_name=None, 197 | conditions_names=None, ylim=None, 198 | colors=DEFAULT_POPULATION_COLORS): 199 | '''Compute the PSTH and plot it. 200 | 201 | Args: 202 | event (str): Column/key name of DataFrame/dictionary :data:`data` 203 | which contains event times in milliseconds (e.g. 204 | stimulus/trial/fixation onset, etc.) 205 | conditions (str): Column/key name of DataFrame/dictionary 206 | :data:`data` which contains the conditions by which the trials 207 | must be grouped. 208 | cond_id (list): Which psth to plot indicated by the key in 209 | :data:`all_psth['data']``. If None then all are plotted. 210 | df (DataFrame or dictionary): The dataframe containing the data. 211 | window (list of 2 elements): Time interval to consider, in 212 | milliseconds. 213 | binsize (int): Bin size in milliseconds. 214 | plot (bool): If True then plot. 215 | event_name (string): Legend name for event. Default is the actual 216 | event name 217 | conditions_names (TODO): Legend names for conditions. Default are 218 | the unique values in :data:`df['conditions']`. 219 | ylim (list): The lower and upper limits for Y. 220 | colors (list): The colors for the plot. 221 | 222 | Returns: 223 | dict: :data:`rasters` with keys :data:`event`, :data:`conditions`, 224 | :data:`binsize`, :data:`window`, and :data:`data`. 225 | :data:`rasters['data']` is a dictionary where each value is a 226 | raster for each unique entry of :data:`df['conditions']`. 227 | ''' 228 | 229 | window = [np.floor(window[0] / binsize) * binsize, 230 | np.ceil(window[1] / binsize) * binsize] 231 | # Get all the rasters first 232 | rasters = self.get_raster(event=event, df=df, 233 | conditions=conditions, 234 | window=window, binsize=binsize, plot=False) 235 | 236 | # Initialize PSTH 237 | psth = dict() 238 | 239 | psth['window'] = window 240 | psth['binsize'] = binsize 241 | psth['event'] = event 242 | psth['conditions'] = conditions 243 | psth['data'] = dict() 244 | 245 | # Compute the PSTH 246 | for cond_id in np.sort(list(rasters['data'])): 247 | psth['data'][cond_id] = dict() 248 | raster = rasters['data'][cond_id] 249 | mean_psth = np.mean(raster, axis=0) / (1e-3 * binsize) 250 | std_psth = np.sqrt(np.var(raster, axis=0)) / (1e-3 * binsize) 251 | 252 | sem_psth = std_psth / np.sqrt(float(np.shape(raster)[0])) 253 | 254 | psth['data'][cond_id]['mean'] = mean_psth 255 | psth['data'][cond_id]['sem'] = sem_psth 256 | 257 | if plot is True: 258 | if not event_name: 259 | event_name = event 260 | conditions_names = list(psth['data']) 261 | self.plot_psth(psth, ylim=ylim, event_name=event_name, 262 | conditions_names=conditions_names, 263 | colors=colors) 264 | 265 | return psth 266 | 267 | def plot_psth(self, psth, event_name='event_onset', conditions_names=None, 268 | cond_id=None, ylim=None, colors=DEFAULT_POPULATION_COLORS): 269 | '''Plots PSTH. 270 | 271 | Args: 272 | psth (dict): Output of :meth:`get_psth`. 273 | event_name (string): Legend name for event. Default is the actual 274 | event name. 275 | conditions_names (list of str): Legend names for conditions. 276 | Default are the keys in :data:`psth['data']`. 277 | cond_id (list): Which psth to plot indicated by the key in 278 | :data:`all_psth['data']``. If None then all are plotted. 279 | ylim (list): The lower and upper limits for Y. 280 | colors (list): The colors for the plot. 281 | ''' 282 | 283 | window = psth['window'] 284 | binsize = psth['binsize'] 285 | conditions = psth['conditions'] 286 | 287 | if cond_id is None: 288 | keys = np.sort(list(psth['data'].keys())) 289 | else: 290 | keys = cond_id 291 | 292 | if conditions_names is None: 293 | conditions_names = keys 294 | 295 | scale = 0.1 296 | y_min = (1.0 - scale) * np.nanmin([np.min( 297 | psth['data'][psth_idx]['mean']) 298 | for psth_idx in psth['data']]) 299 | y_max = (1.0 + scale) * np.nanmax([np.max( 300 | psth['data'][psth_idx]['mean']) 301 | for psth_idx in psth['data']]) 302 | 303 | legend = [event_name] 304 | 305 | time_bins = np.arange(window[0], window[1], binsize) + binsize / 2.0 306 | 307 | if ylim: 308 | plt.plot([0, 0], ylim, color='k', ls='--') 309 | else: 310 | plt.plot([0, 0], [y_min, y_max], color='k', ls='--') 311 | 312 | for i, cond_id in enumerate(keys): 313 | if np.all(np.isnan(psth['data'][cond_id]['mean'])): 314 | plt.plot(0, 0, alpha=1.0, color=colors[i % len(colors)]) 315 | else: 316 | plt.plot(time_bins, psth['data'][cond_id]['mean'], 317 | color=colors[i % len(colors)], lw=1.5) 318 | 319 | for i, cond_id in enumerate(keys): 320 | if conditions is not None: 321 | if conditions_names is not None: 322 | legend.append('%s' % conditions_names[i]) 323 | else: 324 | legend.append('%s' % str(cond_id)) 325 | else: 326 | legend.append('all') 327 | 328 | if not np.all(np.isnan(psth['data'][cond_id]['mean'])): 329 | plt.fill_between(time_bins, psth['data'][cond_id]['mean'] - 330 | psth['data'][cond_id]['sem'], 331 | psth['data'][cond_id]['mean'] + 332 | psth['data'][cond_id]['sem'], 333 | color=colors[i % len(colors)], 334 | alpha=0.2) 335 | 336 | if conditions: 337 | plt.title('neuron %s: %s' % (self.name, conditions)) 338 | else: 339 | plt.title('neuron %s' % self.name) 340 | 341 | plt.xlabel('time [ms]') 342 | plt.ylabel('spikes per second [spks/s]') 343 | 344 | if ylim: 345 | plt.ylim(ylim) 346 | else: 347 | plt.ylim([y_min, y_max]) 348 | 349 | ax = plt.gca() 350 | ax.spines['top'].set_visible(False) 351 | ax.spines['right'].set_visible(False) 352 | 353 | plt.tick_params(axis='y', right='off') 354 | plt.tick_params(axis='x', top='off') 355 | 356 | plt.legend(legend, frameon=False) 357 | 358 | def get_spikecounts(self, event=None, df=None, 359 | window=np.array([50.0, 100.0])): 360 | '''Counts spikes in the dataframe. 361 | 362 | Args: 363 | event (str): Column/key name of DataFrame/dictionary :data:`data` 364 | which contains event times in milliseconds (e.g. 365 | stimulus/trial/fixation onset, etc.) 366 | window (list of 2 elements): Time interval to consider, in 367 | milliseconds. 368 | 369 | Return: 370 | array: An :data:`n x 1` array of spike counts. 371 | ''' 372 | events = df[event].values 373 | spiketimes = self.spiketimes 374 | spikecounts = np.asarray([ 375 | np.sum(np.all(( 376 | spiketimes >= e + 1e-3 * window[0], 377 | spiketimes <= e + 1e-3 * window[1], 378 | ), axis=0)) 379 | for e in events 380 | ]) 381 | return spikecounts 382 | -------------------------------------------------------------------------------- /spykes/plot/popvis.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import copy 8 | from collections import defaultdict 9 | 10 | from fractions import gcd 11 | 12 | from .neurovis import NeuroVis 13 | from .. import utils 14 | from ..config import DEFAULT_POPULATION_COLORS 15 | 16 | # Defines the default colors for a PSTH plot. 17 | DEFAULT_PSTH_COLORS = ['Blues', 'Reds', 'Greens'] 18 | 19 | 20 | class PopVis(object): 21 | '''Facilitates visualization of neuron population firing activity. 22 | 23 | Args: 24 | neuron_list (list of NeuroVis objects): The list of neurons to 25 | visualize (see the NeuroVis class in 26 | :class:`spykes.plot.neurovis`). 27 | 28 | Attributes: 29 | n_neurons (int): The number of neurons in the visualization. 30 | name (str): The name of this visualization. 31 | ''' 32 | 33 | def __init__(self, neuron_list, name='PopVis'): 34 | self.neuron_list = neuron_list 35 | self._name = name 36 | 37 | @property 38 | def name(self): 39 | return self._name 40 | 41 | @property 42 | def n_neurons(self): 43 | return len(self.neuron_list) 44 | 45 | def get_all_psth(self, event=None, df=None, conditions=None, 46 | window=[-100, 500], binsize=10, conditions_names=None, 47 | plot=True, colors=DEFAULT_PSTH_COLORS): 48 | '''Iterates through all neurons and computes their PSTH's. 49 | 50 | Args: 51 | event (str): Column/key name of DataFrame/dictionary "data" which 52 | contains event times in milliseconds (e.g. 53 | stimulus/trial/ fixation onset, etc.). 54 | df (DataFrame or dictionary): The data to use. 55 | conditions (str): Column/key name of DataFrame/dictionary 56 | :data:`df` which contains the conditions by which the trials 57 | must be grouped 58 | window (list of 2 elements): Time interval to consider, in 59 | milliseconds. 60 | binsize (int): Bin size, in milliseconds. 61 | conditions_names (list of str): Legend names for conditions. 62 | Default are the unique values in :data:`df['conditions']`. 63 | plot (bool): If set, automatically plot; otherwise, don't. 64 | colors (list): List of colors for heatmap (only if plot is True). 65 | 66 | Returns: 67 | dict: With keys :data:`event`, :data:`conditions`, :data:`binsize`, 68 | :data:`window`, and :data:`data`. Each entry in 69 | :data:`psth['data']` is itself a dictionary with keys of 70 | each :data:`cond_id` that correspond to the means for that 71 | condition. 72 | ''' 73 | all_psth = { 74 | 'window': window, 75 | 'binsize': binsize, 76 | 'event': event, 77 | 'conditions': conditions, 78 | 'data': defaultdict(list), 79 | } 80 | 81 | for i, neuron in enumerate(self.neuron_list): 82 | psth = neuron.get_psth( 83 | event=event, 84 | df=df, 85 | conditions=conditions, 86 | window=window, 87 | binsize=binsize, 88 | plot=False, 89 | ) 90 | for cond_id in np.sort(list(psth['data'].keys())): 91 | all_psth['data'][cond_id].append(psth['data'][cond_id]['mean']) 92 | 93 | for cond_id in np.sort(list(all_psth['data'].keys())): 94 | all_psth['data'][cond_id] = np.stack(all_psth['data'][cond_id]) 95 | 96 | if plot is True: 97 | self.plot_heat_map( 98 | all_psth, 99 | conditions_names=conditions_names, 100 | colors=colors, 101 | ) 102 | 103 | return all_psth 104 | 105 | def plot_heat_map(self, psth_dict, cond_id=None, conditions_names=None, 106 | sortby=None, sortorder='descend', normalize=None, 107 | neuron_names=True, colors=None, show=False): 108 | '''Plots heat map for neuron population 109 | 110 | Args: 111 | psth_dict (dict): With keys :data:`event`, :data:`conditions`, 112 | :data:`binsize`, :data:`window`, and :data:`data`. Each entry 113 | in :data:`psth['data']` is itself a dictionary with keys of 114 | each :data:`cond_id` that correspond to the means for that 115 | condition. 116 | cond_id (str): Which psth to plot indicated by the key in 117 | :data:`all_psth['data']`. If None then all are plotted. 118 | conditions_names (str or list of str): Name(s) to appear in the 119 | title. 120 | sortby (str or list): If :data:`rate`, sort by firing rate. If 121 | :data:`latency`, sort by peak latency. If a list or array is 122 | provided, it must correspond to integer indices to be used as 123 | sorting indices. If no sort order is provided, the data 124 | is returned as-is. 125 | sortorder (str): The direction to sort in, either :data:`descend` 126 | or :data:`ascend`. 127 | normalize (str): If :data:`all`, divide all PSTHs by highest peak 128 | firing rate in all neurons. If :data:`each`, divide each PSTH 129 | by its own peak firing rate. If None, do not normalize. 130 | neuron_names (bool): Whether or not to list the names of neurons on 131 | the side. 132 | colors (list of str): List of colors for the heatmap (as strings). 133 | show (bool): If set, show the plot once finished. 134 | ''' 135 | if colors is None: 136 | colors = ['Blues', 'Reds', 'Greens'] 137 | 138 | window = psth_dict['window'] 139 | binsize = psth_dict['binsize'] 140 | conditions = psth_dict['conditions'] 141 | 142 | if cond_id is None: 143 | keys = np.sort(list(psth_dict['data'].keys())) 144 | else: 145 | keys = cond_id 146 | 147 | if conditions_names is None: 148 | conditions_names = keys 149 | 150 | for i, cond_id in enumerate(keys): 151 | # Sorts and norms the data. 152 | orig_data = psth_dict['data'][cond_id] 153 | normed_data = self._get_normed_data(orig_data, normalize=normalize) 154 | sort_idx = utils.get_sort_indices( 155 | normed_data, 156 | by=sortby, 157 | order=sortorder, 158 | ) 159 | 160 | data = normed_data[sort_idx, :] 161 | 162 | plt.subplot(len(keys), 1, i + 1) 163 | plt.pcolormesh(data, cmap=colors[i % len(colors)]) 164 | 165 | # Makes it visually appealing. 166 | xtic_len = gcd(abs(window[0]), window[1]) 167 | xtic_labels = range(window[0], window[1] + xtic_len, xtic_len) 168 | xtic_locs = [(j - window[0]) / binsize for j in xtic_labels] 169 | 170 | if 0 not in xtic_labels: 171 | xtic_labels.append(0) 172 | xtic_locs.append(-window[0] / binsize) 173 | 174 | plt.xticks(xtic_locs, xtic_labels) 175 | plt.axvline((-window[0]) / binsize, color='r', 176 | linestyle='--') 177 | 178 | if neuron_names: 179 | unsorted_ylabels = [neuron.name for neuron in self.neuron_list] 180 | ylabels = [unsorted_ylabels[j] for j in sort_idx] 181 | else: 182 | ylabels = ["" for neuron in self.neuron_list] 183 | 184 | plt.yticks(np.arange(data.shape[0]) + 0.5, ylabels) 185 | 186 | ax = plt.gca() 187 | ax.invert_yaxis() 188 | ax.set_frame_on(False) 189 | 190 | plt.tick_params(axis='x', which='both', top='off') 191 | plt.tick_params(axis='y', which='both', left='off', right='off') 192 | 193 | plt.xlabel('time [ms]') 194 | plt.ylabel('Neuron') 195 | plt.title("%s: %s" % 196 | (conditions, conditions_names[i])) 197 | plt.colorbar() 198 | 199 | if show: 200 | plt.show() 201 | 202 | def plot_population_psth(self, all_psth=None, event=None, df=None, 203 | conditions=None, cond_id=None, window=[-100, 500], 204 | binsize=10, conditions_names=None, 205 | event_name='event_onset', ylim=None, 206 | colors=DEFAULT_POPULATION_COLORS, show=False): 207 | '''Plots population PSTH's. 208 | 209 | This involves two steps. First, it normalizes each neuron's PSTH across 210 | the conditions. Second, it averages out and plots population PSTH. 211 | 212 | Args: 213 | all_psth (dict): With keys :data:`event`, :data:`conditions`, 214 | :data:`binsize`, :data:`window`, and :data:`data`. Each entry 215 | in :data:`psth['data']` is itself a dictionary with keys of 216 | each :data:`cond_id` that correspond to the means for that 217 | condition. 218 | event (str): Column/key name of the :data:`df` DataFrame/dictionary 219 | which contains event times in milliseconds 220 | (stimulus/trial/fixation onset, etc.). 221 | df (DataFrame or dictionary): DataFrame containing the data to 222 | plot, or a dictionary corresponding to such a DataFrame. 223 | conditions (str): Column/key name of DataFrame/dictionary 224 | :data:`df` which contains the conditions by which the trials 225 | must be grouped. 226 | cond_id (list): Which psth to plot indicated by the key in 227 | :data:`all_psth['data']``. If :data:`None` then all are 228 | plotted. 229 | window (list of 2 elements): Time interval to consider, in 230 | milliseconds. 231 | binsize (int): Bin size, in milliseconds. 232 | conditions_names (list): Legend names for conditions. Default are 233 | the unique values in :data:`df['conditions']`. 234 | event_name (string): Legend name for event. Default is the actual 235 | event name 236 | ylim (list): The minimum and maximum values for Y. 237 | colors (list): The colors to use for the plot (as strings). 238 | show (bool): If set, show the plot once it has been created. 239 | ''' 240 | 241 | # placeholder in order to use NeuroVis functionality 242 | base_neuron = NeuroVis(spiketimes=range(10), name=self.name) 243 | 244 | if all_psth is None: 245 | psth = self.get_all_psth( 246 | event=event, 247 | df=df, 248 | conditions=conditions, 249 | window=window, 250 | binsize=binsize, 251 | plot=False, 252 | ) 253 | else: 254 | psth = copy.deepcopy(all_psth) 255 | 256 | keys = np.sort(list(psth['data'].keys())) 257 | 258 | # Normalizes each neuron across all conditions. 259 | for i in range(self.n_neurons): 260 | max_rates = list() 261 | for key in keys: 262 | max_rates.append(np.amax(psth['data'][key][i, :])) 263 | norm_factor = max(max_rates) 264 | for key in keys: 265 | psth['data'][key][i, :] /= norm_factor 266 | 267 | # Averages out all the neurons and plots. 268 | for i, key in enumerate(keys): 269 | normed_data = psth['data'][key] 270 | psth['data'][key] = dict() 271 | psth['data'][key]['mean'] = np.nanmean(normed_data, axis=0) 272 | psth['data'][key]['sem'] = \ 273 | np.nanvar(normed_data, axis=0) / (len(self.neuron_list)**.5) 274 | 275 | # Plots the PSTH. 276 | base_neuron.plot_psth( 277 | psth=psth, 278 | event_name=event_name, 279 | cond_id=cond_id, 280 | conditions_names=conditions_names, 281 | ylim=ylim, 282 | colors=colors, 283 | ) 284 | 285 | # Adds the appropriate title. 286 | plt.title("%s Population PSTH: %s" % (self.name, psth['conditions'])) 287 | 288 | if show: 289 | plt.show() 290 | 291 | def _get_normed_data(self, data, normalize): 292 | '''Normalizes all PSTH data 293 | 294 | Args: 295 | data (2-D numpy array): Array with shape 296 | :data:`(n_neurons, n_bins)` 297 | normalize (str): If :data:`all`, divide all PSTHs by highest peak 298 | firing rate in all neurons. If :data:`each`, divide each PSTH 299 | by its own peak firing rate. If None, do not normalize. 300 | 301 | Returns: 302 | array: The original data array, divided such 303 | that all values fall between 0 and 1. 304 | ''' 305 | max_rates = np.amax(data, axis=1) 306 | 307 | # Computes the normalization factors. 308 | if normalize == 'all': 309 | norm_factors = np.ones([data.shape[0], 1]) * np.amax(max_rates) 310 | elif normalize == 'each': 311 | norm_factors = ( 312 | np.reshape(max_rates, (max_rates.shape[0], 1)) * 313 | np.ones((1, data.shape[1])) 314 | ) 315 | elif normalize is None: 316 | norm_factors = np.ones([data.shape[0], 1]) 317 | else: 318 | raise ValueError('Invalid norm factors: {}'.format(norm_factors)) 319 | 320 | return data / norm_factors 321 | -------------------------------------------------------------------------------- /spykes/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from scipy import stats 7 | import matplotlib.pyplot 8 | 9 | 10 | def train_test_split(*datasets, **split): 11 | '''Splits test data into training and testing data. 12 | 13 | This is a replacement for the Scikit Learn version of the function (which 14 | is being deprecated). 15 | 16 | Args: 17 | datasets (list of Numpy arrays): The datasets as Numpy arrays, where 18 | the first dimension is the batch dimension. 19 | n (int): Number of test samples to split off (only `n` or `percent` 20 | may be specified). 21 | percent (int): Percentange of test samples to split off. 22 | 23 | Returns: 24 | tuple of train / test data, or list of tuples: If only one dataset is 25 | provided, this method returns a tuple of training and testing data; 26 | otherwise, it returns a list of such tuples. 27 | ''' 28 | if not datasets: 29 | return [] # Guarentee there's at least one dataset. 30 | num_batches = int(datasets[0].shape[0]) 31 | 32 | # Checks the input shapes. 33 | if not all(d.shape[0] == num_batches for d in datasets): 34 | raise ValueError('Not all of the datasets have the same batch size. ' 35 | 'Received batch sizes: {batch_sizes}' 36 | .format(batch_sizes=[d.shape[0] for d in datasets])) 37 | 38 | # Gets the split num or split percent. 39 | split_num = split.get('n', None) 40 | split_prct = split.get('percent', None) 41 | 42 | # Checks the splits 43 | if (split_num and split_prct) or not (split_num or split_prct): 44 | raise ValueError('Must specify either `split_num` or `split_prct`') 45 | 46 | # Splits all of the datasets. 47 | if split_prct is None: 48 | num_test = split_num 49 | else: 50 | num_test = int(num_batches * split_prct) 51 | 52 | # Checks that the test number is less than the number of batches. 53 | if num_test >= num_batches: 54 | raise ValueError('Invalid split number: {num_test} There are only ' 55 | '{num_batches} samples.' 56 | .format(num_test=num_test, num_batches=num_batches)) 57 | 58 | # Splits each of the datasets. 59 | idxs = np.arange(num_batches) 60 | np.random.shuffle(idxs) 61 | train_idxs, test_idxs = idxs[num_test:], idxs[:num_test] 62 | datasets = [(d[train_idxs], d[test_idxs]) for d in datasets] 63 | return datasets if len(datasets) > 1 else datasets[0] 64 | 65 | 66 | def slow_exp(z, eta): 67 | '''Applies a slowly rising exponential function to some data. 68 | 69 | This function defines a slowly rising exponential that linearizes above 70 | the threshold parameter :data:`eta`. Mathematically, this is defined as: 71 | 72 | .. math:: 73 | 74 | q = \\begin{cases} 75 | (z + 1 - eta) * \\exp(eta) & \\text{if } z > eta \\\\ 76 | \\exp(eta) & \\text{if } z \\leq eta 77 | \\end{cases} 78 | 79 | The gradient of this function is defined in :meth:`grad_slow_exp`. 80 | 81 | Args: 82 | z (array): The data to apply the :func:`slow_exp` function to. 83 | eta (float): The threshold parameter. 84 | 85 | Returns: 86 | array: The resulting slow exponential, with the same shape as 87 | :data:`z`. 88 | ''' 89 | qu = np.zeros(z.shape) 90 | slope = np.exp(eta) 91 | intercept = (1 - eta) * slope 92 | qu[z > eta] = z[z > eta] * slope + intercept 93 | qu[z <= eta] = np.exp(z[z <= eta]) 94 | return qu 95 | 96 | 97 | def grad_slow_exp(z, eta): 98 | '''Computes the gradient of a slowly rising exponential function. 99 | 100 | This is defined as: 101 | 102 | .. math:: 103 | 104 | \\nabla q = \\begin{cases} 105 | \\exp(eta) & \\text{if } z > eta \\\\ 106 | \\exp(z) & \\text{if } z \\leq eta 107 | \\end{cases} 108 | 109 | Args: 110 | z (array): The dependent variable, before calling the :func:`slow_exp` 111 | function. 112 | eta (float): The threshold parameter used in the original 113 | :func:`slow_exp` call. 114 | 115 | Returns: 116 | array: The gradient with respect to :data:`z` of the output of 117 | :func:`slow_exp`. 118 | ''' 119 | dqu_dz = np.zeros(z.shape) 120 | slope = np.exp(eta) 121 | dqu_dz[z > eta] = slope 122 | dqu_dz[z <= eta] = np.exp(z[z <= eta]) 123 | return dqu_dz 124 | 125 | 126 | def log_likelihood(y, yhat): 127 | '''Helper function to compute the log likelihood.''' 128 | eps = np.spacing(1) 129 | return np.nansum(y * np.log(eps + yhat) - yhat) 130 | 131 | 132 | def circ_corr(alpha1, alpha2): 133 | '''Helper function to compute the circular correlation.''' 134 | alpha1_bar = stats.circmean(alpha1) 135 | alpha2_bar = stats.circmean(alpha2) 136 | num = np.sum(np.sin(alpha1 - alpha1_bar) * np.sin(alpha2 - alpha2_bar)) 137 | den = np.sqrt(np.sum(np.sin(alpha1 - alpha1_bar) ** 2) * 138 | np.sum(np.sin(alpha2 - alpha2_bar) ** 2)) 139 | rho = num / den 140 | return rho 141 | 142 | 143 | def get_sort_indices(data, by=None, order='descend'): 144 | '''Helper function to calculate sorting indices given sorting condition. 145 | 146 | Args: 147 | data (2-D numpy array): Array with shape :data:`(n_neurons, n_bins)`. 148 | by (str or list): If :data:`rate`, sort by firing rate. If 149 | :data:`latency`, sort by peak latency. If a list or array is 150 | provided, it must correspond to integer indices to be used as 151 | sorting indices. If no sort order is provided, the data is 152 | returned as-is. 153 | order (str): Direction to sort in (either :data:`descend` or 154 | :data:`ascend`). 155 | 156 | Returns: 157 | list: The sort indices as a Numpy array, with one index per element in 158 | :data:`data` (i.e. :data:`data[sort_idxs]` gives the sorted data). 159 | ''' 160 | # Checks if the by indices are a list or array. 161 | if isinstance(by, list): 162 | by = np.array(by) 163 | if isinstance(by, np.ndarray): 164 | if np.array_equal(np.sort(by), list(range(data.shape[0]))): 165 | return by # Returns if it is a proper permutation. 166 | else: 167 | raise ValueError('The sorting indices not a proper permutation: {}' 168 | .format(by)) 169 | 170 | # Converts the by array to 171 | if by == 'rate': 172 | sort_idx = np.sum(data, axis=1).argsort() 173 | elif by == 'latency': 174 | sort_idx = np.argmax(data, axis=1).argsort() 175 | elif by is None: 176 | sort_idx = np.arange(data.shape[0]) 177 | else: 178 | raise ValueError('Invalid sort preference: "{}". Must be "rate", ' 179 | '"latency" or None.'.format(by)) 180 | 181 | # Checks the sorting order. 182 | if order == 'ascend': 183 | return sort_idx 184 | elif order == 'descend': 185 | return sort_idx[::-1] 186 | else: 187 | raise ValueError('Invalid sort order: {}'.format(order)) 188 | 189 | 190 | def set_matplotlib_defaults(plt=None): 191 | '''Sets publication quality defaults for matplotlib. 192 | 193 | Args: 194 | plt (matplotlib.pyplot instance): The plt instance. 195 | ''' 196 | if plt is None: 197 | plt = matplotlib.pyplot 198 | plt.rcParams.update({ 199 | 'font.family': 'sans-serif', 200 | 'font.sans-serif': 'Bitsream Vera Sans', 201 | 'font.size': 13, 202 | 'axes.titlesize': 12, 203 | 'xtick.labelsize': 10, 204 | 'ytick.labelsize': 10, 205 | 'xtick.direction': 'out', 206 | 'ytick.direction': 'out', 207 | 'xtick.major.size': 6, 208 | 'ytick.major.size': 6, 209 | 'legend.fontsize': 11, 210 | }) 211 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/tests/__init__.py -------------------------------------------------------------------------------- /tests/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/tests/io/__init__.py -------------------------------------------------------------------------------- /tests/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/tests/ml/__init__.py -------------------------------------------------------------------------------- /tests/ml/tensorflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/tests/ml/tensorflow/__init__.py -------------------------------------------------------------------------------- /tests/ml/tensorflow/test_poisson_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import uuid 5 | 6 | import numpy as np 7 | from nose.tools import ( 8 | assert_raises, 9 | assert_false, 10 | ) 11 | 12 | from tempfile import TemporaryFile 13 | 14 | from tensorflow import keras as ks 15 | 16 | from spykes.ml.tensorflow.poisson_models import PoissonLayer 17 | from spykes.io.datasets import load_reaching_rates 18 | 19 | 20 | def _build_model(model_type, num_features, num_neurons): 21 | i = ks.layers.Input(shape=(num_features,)) 22 | x = PoissonLayer(model_type, num_neurons)(i) 23 | model = ks.models.Model(inputs=i, outputs=x) 24 | model.compile(optimizer='sgd', loss='poisson') 25 | return model 26 | 27 | 28 | def test_poisson_layer(): 29 | with assert_raises(ValueError): 30 | PoissonLayer('invalid_type', 1) 31 | 32 | with assert_raises(AssertionError): 33 | i = ks.layers.Input(shape=(1, 2)) 34 | x = PoissonLayer('glm', 3, num_features=2)(i) 35 | 36 | # Loads the reaching dataset (with default parameters). 37 | x, y = load_reaching_rates() 38 | num_features, num_neurons = x.shape[1], y.shape[1] 39 | 40 | for model_type in ('gvm', 'glm'): 41 | model = _build_model(model_type, num_features, num_neurons) 42 | p = model.predict(x) 43 | h = model.fit(x, y, epochs=1) # , verbose=0) 44 | assert_false(np.any(np.isnan(h.history['loss']))) 45 | -------------------------------------------------------------------------------- /tests/ml/tensorflow/test_sparse_filtering.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from nose.tools import ( 5 | assert_raises, 6 | assert_equal, 7 | ) 8 | 9 | from tensorflow import keras as ks 10 | from spykes.ml.tensorflow.sparse_filtering import SparseFiltering 11 | 12 | # Keeps the number of training images small to reduce testing time. 13 | NUM_TRAIN = 100 14 | 15 | 16 | def test_sparse_filtering(): 17 | train_images = np.random.rand(NUM_TRAIN, 28 * 28) 18 | 19 | # Creates a simple model. 20 | model = ks.models.Sequential([ 21 | ks.layers.Dense(20, input_shape=(28 * 28,), name='a'), 22 | ks.layers.Dense(20, name='b'), 23 | ]) 24 | 25 | # Checks the four ways to pass layers. 26 | sf_model = SparseFiltering(model=model) 27 | assert_equal(len(sf_model.layer_names), len(model.layers)) 28 | with assert_raises(ValueError): 29 | sf_model = SparseFiltering(model=model, layers=1) 30 | sf_model = SparseFiltering(model=model, layers='a') 31 | assert_equal(sf_model.layer_names, ['a']) 32 | 33 | # Checks model compilation. 34 | sf_model.compile('sgd') 35 | assert_raises(RuntimeError, sf_model.compile, 'sgd') 36 | 37 | sf_model = SparseFiltering(model=model, layers=['a', 'b']) 38 | assert_equal(sf_model.layer_names, ['a', 'b']) 39 | 40 | # Checks that the submodels attribute is not available yet. 41 | with assert_raises(RuntimeError): 42 | print(sf_model.submodels) 43 | 44 | # Checks getting a submodel. 45 | with assert_raises(RuntimeError): 46 | sf_model.get_submodel('a') 47 | 48 | # Checks model freezing. 49 | sf_model.compile('sgd', freeze=True) 50 | assert_equal(len(sf_model.submodels), 2) 51 | 52 | # Checks getting an invalid submodel. 53 | with assert_raises(ValueError): 54 | sf_model.get_submodel('c') 55 | 56 | # Checks model fitting. 57 | h = sf_model.fit(x=train_images, epochs=1) 58 | assert_equal(len(h), 2) # One history for each layer name. 59 | 60 | # Checks the iterable cleaning part. 61 | assert_raises(ValueError, sf_model._clean_maybe_iterable_param, ['a'], 'a') 62 | 63 | def _check_works(p): 64 | cleaned_v = sf_model._clean_maybe_iterable_param(p, '1337') 65 | assert_equal(len(cleaned_v), 2) 66 | 67 | _check_works('a') 68 | _check_works(1) 69 | _check_works(['a', 'b']) 70 | -------------------------------------------------------------------------------- /tests/ml/test_neuropop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as p 5 | from nose.tools import ( 6 | assert_true, 7 | assert_equal, 8 | assert_raises, 9 | ) 10 | 11 | from spykes.ml.neuropop import NeuroPop 12 | from spykes.utils import train_test_split 13 | 14 | np.random.seed(42) 15 | p.switch_backend('Agg') 16 | 17 | 18 | def test_neuropop(): 19 | 20 | np.random.seed(1738) 21 | 22 | num_neurons = 10 23 | 24 | for num_neurons in [1, 10]: 25 | for tunemodel in ['glm', 'gvm']: 26 | for i in range(2): 27 | 28 | pop = NeuroPop(tunemodel=tunemodel, n_neurons=num_neurons, 29 | verbose=True) 30 | 31 | if i == 0: 32 | pop.set_params() 33 | else: 34 | pop.set_params(mu=np.random.randn(), 35 | k0=np.random.randn(), 36 | k=np.random.randn(), 37 | g=np.random.randn(), 38 | b=np.random.randn()) 39 | 40 | x, Y, mu, k0, k, g, b = pop.simulate(tunemodel) 41 | 42 | _helper_test_neuropop(pop, num_neurons, x, Y) 43 | 44 | 45 | def _helper_test_neuropop(pop, num_neurons, x, Y): 46 | 47 | # Splits into training and testing parts. 48 | x_split, Y_split = train_test_split(x, Y, percent=0.5) 49 | (x_train, x_test), (Y_train, Y_test) = x_split, Y_split 50 | 51 | pop.fit(x_train, Y_train) 52 | 53 | Yhat_test = pop.predict(x_test) 54 | 55 | assert_equal(Yhat_test.shape[0], x_test.shape[0]) 56 | assert_equal(Yhat_test.shape[1], num_neurons) 57 | 58 | Ynull = np.mean(Y_train, axis=0) 59 | 60 | score = pop.score(Y_test, Yhat_test, Ynull, method='pseudo_R2') 61 | assert_equal(len(score), num_neurons) 62 | 63 | xhat_test = pop.decode(Y_test) 64 | assert_equal(xhat_test.shape[0], Y_test.shape[0]) 65 | 66 | for method in ['circ_corr', 'cosine_dist']: 67 | score = pop.score(x_test, xhat_test, method=method) 68 | 69 | pop.display(x, Y, 0) 70 | 71 | pop.display(x, Y, 0, xjitter=True, yjitter=True) 72 | -------------------------------------------------------------------------------- /tests/ml/test_strf.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as p 5 | from nose.tools import assert_equal 6 | 7 | from spykes.ml.strf import STRF 8 | p.switch_backend('Agg') 9 | 10 | 11 | def test_strf(): 12 | 13 | n_spatial_basis = 36 14 | n_temporal_basis = 7 15 | patch_size = 50 16 | 17 | # Instantiate strf object 18 | strf_ = STRF(patch_size=patch_size, sigma=5, 19 | n_spatial_basis=n_spatial_basis, 20 | n_temporal_basis=n_temporal_basis) 21 | 22 | # Design a spatial basis 23 | spatial_basis = strf_.make_cosine_basis() 24 | assert_equal(len(spatial_basis), 2) 25 | for basis in spatial_basis: 26 | assert_equal(basis.shape[0], patch_size) 27 | assert_equal(basis.shape[1], patch_size) 28 | 29 | spatial_basis = strf_.make_gaussian_basis() 30 | assert_equal(len(spatial_basis), n_spatial_basis) 31 | 32 | # Visualize spatial basis 33 | strf_.visualize_gaussian_basis(spatial_basis) 34 | 35 | # Design temporal basis 36 | time_points = np.linspace(-100., 100., 10.) 37 | centers = [-75., -50., -25., 0, 25., 50., 75.] 38 | width = 10. 39 | temporal_basis = strf_.make_raised_cosine_temporal_basis( 40 | time_points=time_points, 41 | centers=centers, 42 | widths=width * np.ones(7)) 43 | assert_equal(temporal_basis.shape[0], len(time_points)) 44 | assert_equal(temporal_basis.shape[1], n_temporal_basis) 45 | 46 | # Project to spatial basis 47 | I = np.zeros(shape=(patch_size, patch_size)) 48 | row = 5 49 | col = 10 50 | I[row, col] = 1 51 | basis_projection = strf_.project_to_spatial_basis(I, spatial_basis) 52 | assert_equal(len(basis_projection), n_spatial_basis) 53 | 54 | # Recover image from basis projection 55 | weights = np.random.normal(size=n_spatial_basis) 56 | RF = strf_.make_image_from_spatial_basis(spatial_basis, weights) 57 | assert_equal(RF.shape[0], patch_size) 58 | assert_equal(RF.shape[1], patch_size) 59 | 60 | # Convolve with temporal basis 61 | n_samples = 100 62 | n_features = n_spatial_basis 63 | design_matrix = np.random.normal(size=(n_samples, n_features)) 64 | features = strf_.convolve_with_temporal_basis( 65 | design_matrix, temporal_basis) 66 | assert_equal(features.shape[0], n_samples) 67 | assert_equal(features.shape[1], n_features * n_temporal_basis) 68 | 69 | # Design prior covariance 70 | PriorCov = strf_.design_prior_covariance( 71 | sigma_temporal=3., sigma_spatial=5.) 72 | assert_equal(PriorCov.shape[0], PriorCov.shape[1]) 73 | assert_equal(PriorCov.shape[0], n_spatial_basis * n_temporal_basis) 74 | -------------------------------------------------------------------------------- /tests/plot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KordingLab/spykes/66db722f54c842630a3c7538aa6a955a17d340cb/tests/plot/__init__.py -------------------------------------------------------------------------------- /tests/plot/test_neurovis.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as p 6 | from nose.tools import ( 7 | assert_true, 8 | assert_equal, 9 | ) 10 | 11 | from spykes.plot.neurovis import NeuroVis 12 | p.switch_backend('Agg') 13 | 14 | 15 | def test_neurovis(): 16 | 17 | np.random.seed(1738) 18 | 19 | num_spikes = 500 20 | num_trials = 10 21 | 22 | binsize = 100 23 | window = [-500, 1500] 24 | 25 | rand_spiketimes = np.sort(num_trials * np.random.rand(num_spikes)) 26 | 27 | neuron = NeuroVis(spiketimes=rand_spiketimes) 28 | 29 | df = pd.DataFrame() 30 | 31 | event = 'realCueTime' 32 | condition_num = 'responseNum' 33 | condition_bool = 'responseBool' 34 | 35 | start_times = rand_spiketimes[0::int(num_spikes/num_trials)] 36 | df['trialStart'] = start_times 37 | 38 | df[event] = df['trialStart'] + np.random.rand(num_trials) 39 | 40 | event_times = ((start_times[:-1] + start_times[1:]) / 2).tolist() 41 | event_times.append(start_times[-1] + np.random.rand()) 42 | 43 | df[event] = event_times 44 | 45 | df[condition_num] = np.random.rand(num_trials) 46 | df[condition_bool] = df[condition_num] < 0.5 47 | 48 | for cond in [None, condition_bool]: 49 | 50 | raster = neuron.get_raster(event=event, conditions=cond, 51 | df=df, plot=True, binsize=binsize, 52 | window=window) 53 | 54 | neuron.plot_raster(raster, cond_name=raster['conditions']) 55 | 56 | assert_equal(raster['event'], event) 57 | assert_equal(raster['conditions'], cond) 58 | assert_equal(raster['binsize'], binsize) 59 | assert_equal(raster['window'], window) 60 | 61 | total_trials = 0 62 | 63 | for cond_id in raster['data'].keys(): 64 | 65 | assert_true(cond_id in df[condition_bool]) 66 | assert_equal(raster['data'][cond_id].shape[1], 67 | (window[1] - window[0]) / binsize) 68 | total_trials += raster['data'][cond_id].shape[0] 69 | 70 | assert_equal(total_trials, num_trials) 71 | 72 | psth = neuron.get_psth(event=event, conditions=condition_bool, df=df, 73 | plot=True, binsize=binsize, window=window) 74 | 75 | neuron.plot_psth(psth=psth, ylim=np.random.randn(2).tolist()) 76 | 77 | assert_equal(psth['window'], window) 78 | assert_equal(psth['binsize'], binsize) 79 | assert_equal(psth['event'], event) 80 | assert_equal(psth['conditions'], condition_bool) 81 | 82 | for cond_id in psth['data'].keys(): 83 | 84 | assert_true(cond_id in df[condition_bool]) 85 | assert_equal(psth['data'][cond_id]['mean'].shape[0], 86 | (window[1] - window[0]) / binsize) 87 | assert_equal(psth['data'][cond_id]['sem'].shape[0], 88 | (window[1] - window[0]) / binsize) 89 | 90 | neuron.get_spikecounts(event=event, df=df, window=[0, num_trials]) 91 | -------------------------------------------------------------------------------- /tests/plot/test_popvis.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as p 6 | from nose.tools import ( 7 | assert_true, 8 | assert_equal, 9 | assert_raises, 10 | ) 11 | 12 | from spykes.plot.popvis import PopVis 13 | from spykes.plot.neurovis import NeuroVis 14 | p.switch_backend('Agg') 15 | 16 | 17 | def test_popvis(): 18 | 19 | np.random.seed() 20 | 21 | num_spikes = 500 22 | num_trials = 10 23 | 24 | binsize = 100 25 | window = [-500, 1500] 26 | 27 | num_neurons = 10 28 | neuron_list = list() 29 | 30 | for i in range(num_neurons): 31 | rand_spiketimes = num_trials * np.random.rand(num_spikes) 32 | neuron_list.append(NeuroVis(rand_spiketimes)) 33 | 34 | pop = PopVis(neuron_list) 35 | 36 | df = pd.DataFrame() 37 | 38 | event = 'realCueTime' 39 | condition_num = 'responseNum' 40 | condition_bool = 'responseBool' 41 | 42 | start_times = rand_spiketimes[0::int(num_spikes/num_trials)] 43 | 44 | df['trialStart'] = start_times 45 | 46 | df[event] = df['trialStart'] + np.random.rand(num_trials) 47 | 48 | event_times = ((start_times[:-1] + start_times[1:]) / 2).tolist() 49 | event_times.append(start_times[-1] + np.random.rand()) 50 | 51 | df[event] = event_times 52 | 53 | df[condition_num] = np.random.rand(num_trials) 54 | df[condition_bool] = df[condition_num] < 0.5 55 | 56 | all_psth = pop.get_all_psth(event=event, conditions=condition_bool, df=df, 57 | plot=True, binsize=binsize, window=window) 58 | 59 | assert_equal(all_psth['window'], window) 60 | assert_equal(all_psth['binsize'], binsize) 61 | assert_equal(all_psth['event'], event) 62 | assert_equal(all_psth['conditions'], condition_bool) 63 | 64 | for cond_id in all_psth['data'].keys(): 65 | 66 | assert_true(cond_id in df[condition_bool]) 67 | assert_equal(all_psth['data'][cond_id].shape[0], 68 | num_neurons) 69 | assert_equal(all_psth['data'][cond_id].shape[1], 70 | (window[1] - window[0]) / binsize) 71 | 72 | assert_raises(ValueError, pop.plot_heat_map, all_psth, 73 | sortby=list(range(num_trials-1))) 74 | 75 | pop.plot_heat_map(all_psth, sortby=list(range(num_trials))) 76 | pop.plot_heat_map(all_psth, sortby='rate') 77 | pop.plot_heat_map(all_psth, sortby='latency') 78 | pop.plot_heat_map(all_psth, sortorder='ascend') 79 | 80 | pop.plot_population_psth(all_psth=all_psth) 81 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | from nose.tools import assert_equal 5 | 6 | from spykes import config 7 | 8 | 9 | def test_get_home_directory(): 10 | # Tests using os.environ. 11 | assert_equal(config.get_home_directory(), os.environ['HOME']) 12 | 13 | # Tests using os.path.expanduser. 14 | home_tmp = os.environ['HOME'] 15 | os.environ.pop('HOME') 16 | assert_equal(config.get_home_directory(), os.path.expanduser('~')) 17 | os.environ['HOME'] = home_tmp 18 | 19 | 20 | def test_get_data_directory(): 21 | # Tests the data directory as inferred from the home directory. 22 | assert_equal( 23 | config.get_data_directory(), 24 | os.path.join(config.get_home_directory(), config.DEFAULT_DATA_DIR), 25 | ) 26 | 27 | # Tests the data directory after adding 28 | os.environ['SPYKES_DATA'] = '~/data' 29 | assert_equal(config.get_data_directory(), '~/data') 30 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | from nose.tools import assert_equal 5 | import numpy as np 6 | 7 | from spykes import utils 8 | 9 | 10 | def test_train_test_split(): 11 | x, y, z = np.zeros((10,)), np.zeros((10, 10)), np.zeros((10, 10, 10)) 12 | 13 | def _check(xyz, xyzs): 14 | for s, (s_train, s_test) in zip(xyz, xyzs): 15 | assert_equal(s_train.shape[0], 7) 16 | assert_equal(s_test.shape[0], 3) 17 | assert_equal(np.ndim(s), np.ndim(s_train)) 18 | assert_equal(np.ndim(s), np.ndim(s_test)) 19 | 20 | # Checks number-wise and percent-wise. 21 | _check([x, y, z], utils.train_test_split(x, y, z, n=3)) 22 | _check([x, y, z], utils.train_test_split(x, y, z, percent=0.3)) 23 | --------------------------------------------------------------------------------