├── ignore_words.txt ├── mne_incubator ├── externals │ ├── __init__.py │ └── pacpy │ │ ├── __init__.py │ │ ├── util.py │ │ ├── filt.py │ │ └── pac.py ├── __init__.py ├── preprocessing │ ├── __init__.py │ ├── tests │ │ ├── test_dss.py │ │ ├── test_sns.py │ │ └── test_eog.py │ ├── _dss.py │ ├── eog.py │ └── _sns.py └── connectivity │ ├── __init__.py │ ├── simulation.py │ ├── viz.py │ ├── tests │ └── test_cfc.py │ └── cfc.py ├── .coveragerc ├── tools └── github_actions_test.sh ├── .gitignore ├── setup.cfg ├── README.md ├── setup.py ├── examples ├── preprocessing │ ├── plot_dss_array.py │ ├── plot_dss_epochs.py │ └── plot_eog_regression.py └── connectivity │ ├── plot_pac_viz.py │ └── plot_pac.py ├── .github └── workflows │ ├── codespell_and_flake.yml │ └── linux_conda.yml └── Makefile /ignore_words.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mne_incubator/externals/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pacpy 2 | -------------------------------------------------------------------------------- /mne_incubator/__init__.py: -------------------------------------------------------------------------------- 1 | from . import preprocessing 2 | from . import connectivity 3 | from . import externals 4 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .eog import eog_regression 2 | from ._dss import dss 3 | from ._sns import SensorNoiseSuppression 4 | -------------------------------------------------------------------------------- /mne_incubator/externals/pacpy/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import pac 4 | from . import filt 5 | from . import util 6 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = mne_incubator 4 | include = */mne_incubator/* 5 | omit = 6 | */mne_incubator/externals/* 7 | */setup.py 8 | -------------------------------------------------------------------------------- /tools/github_actions_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ef 2 | 3 | USE_DIRS="mne_incubator/" 4 | echo 'pytest --tb=short --cov=mne_incubator --cov-report xml -vv ${USE_DIRS}' 5 | pytest --tb=short --cov=mne_incubator --cov-report xml -vv ${USE_DIRS} 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.sh 4 | *.so 5 | *.fif 6 | *.tar.gz 7 | *.log 8 | *.stc 9 | *~ 10 | .#* 11 | *.swp 12 | *.lprof 13 | *.npy 14 | *.zip 15 | *.fif.gz 16 | *.nii.gz 17 | *.tar.* 18 | *.egg* 19 | *.tmproj 20 | *.png 21 | .DS_Store 22 | build 23 | *.orig 24 | junit-results.xml 25 | -------------------------------------------------------------------------------- /mne_incubator/connectivity/__init__.py: -------------------------------------------------------------------------------- 1 | from .cfc import (phase_amplitude_coupling, 2 | phase_locked_amplitude, 3 | phase_binned_amplitude) 4 | from .viz import (plot_phase_locked_amplitude, 5 | plot_phase_binned_amplitude) 6 | from .simulation import simulate_pac_signal 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | release = egg_info -RDb '' 3 | 4 | [nosetests] 5 | # with-coverage = 1 6 | # cover-html = 1 7 | # cover-html-dir = coverage 8 | cover-package = mne_incubator 9 | exclude = externals 10 | 11 | detailed-errors = 1 12 | with-doctest = 1 13 | doctest-tests = 1 14 | doctest-extension = rst 15 | doctest-fixtures = _fixture 16 | #doctest-options = +ELLIPSIS,+NORMALIZE_WHITESPACE 17 | 18 | [tool:pytest] 19 | addopts = 20 | --showlocals --durations=20 -ra --cov-report= 21 | --doctest-ignore-import-errors --junit-xml=junit-results.xml 22 | 23 | [flake8] 24 | exclude = __init__.py,*externals* 25 | ignore = E241,W504 26 | -------------------------------------------------------------------------------- /mne_incubator/externals/pacpy/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import math 4 | 5 | 6 | def fasthilbert(x, axis=-1): 7 | """ 8 | Redefinition of scipy.signal.hilbert, which is very slow for some lengths 9 | of the signal x. This version zero-pads the signal to the next power of 2 10 | for speed. 11 | """ 12 | x = np.array(x) 13 | N = x.shape[axis] 14 | N2 = 2**(int(math.log(len(x), 2)) + 1) 15 | Xf = np.fft.fft(x, N2, axis=axis) 16 | h = np.zeros(N2) 17 | h[0] = 1 18 | h[1:(N2 + 1) // 2] = 2 19 | 20 | x = np.fft.ifft(Xf * h, axis=axis) 21 | return x[:N] 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNE-Python incubator 2 | 3 | [Questions? Use the MNE Forum!](https://mne.discourse.group) 4 | 5 | This is a repository for experimental code for new techniques and ideas that may or may not make it into the official MNE-Python package. All of this is considered as work-in-progress. 6 | 7 | ## How this works 8 | Contributions are welcome in the form of pull requests. Once the implementation of a piece of functionality is considered to be bug free and properly documented (both API docs and an example script), it can be incorporated into the master branch. Once it is in the master branch, it can be used by users of MNE-Python while the functionality awaits verification in a scientific manner (for new techniques, this means a paper). After the functionality has been verified, it can be integrated into MNE-Python. 9 | 10 | ## Code organization 11 | 12 | The directory structure of this repository mirrors the one of MNE-Python. When you add new functionality, place it in the location where you would expect it to end up in the MNE-Python repository. Your code may depend on the development version of MNE-Python and other submodules of MNE-incubator. At least one example script should be placed in the `mne_incubator/examples` folder. 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from setuptools import setup 3 | 4 | descr = """Experimental code for MEG and EEG data analysis.""" 5 | 6 | DISTNAME = 'mne-incubator' 7 | DESCRIPTION = descr 8 | MAINTAINER = 'Alexandre Gramfort' 9 | MAINTAINER_EMAIL = 'alexandre.gramfort@telecom-paristech.fr' 10 | URL = 'https://mne.tools/' 11 | LICENSE = 'BSD (3-clause)' 12 | DOWNLOAD_URL = 'http://github.com/mne-tools/mne-incubator' 13 | VERSION = 'unstable' 14 | 15 | if __name__ == "__main__": 16 | setup(name=DISTNAME, 17 | maintainer=MAINTAINER, 18 | maintainer_email=MAINTAINER_EMAIL, 19 | description=DESCRIPTION, 20 | license=LICENSE, 21 | url=URL, 22 | version=VERSION, 23 | download_url=DOWNLOAD_URL, 24 | long_description=open('README.md').read(), 25 | classifiers=[ 26 | 'Intended Audience :: Science/Research', 27 | 'Intended Audience :: Developers', 28 | 'License :: OSI Approved', 29 | 'Programming Language :: Python', 30 | 'Topic :: Software Development', 31 | 'Topic :: Scientific/Engineering', 32 | 'Operating System :: Microsoft :: Windows', 33 | 'Operating System :: POSIX', 34 | 'Operating System :: Unix', 35 | 'Operating System :: MacOS', 36 | ], 37 | platforms='any', 38 | packages=[ 39 | 'mne_incubator', 40 | 'mne_incubator.preprocessing', 41 | 'mne_incubator.connectivity', 42 | 'mne_incubator.externals', 43 | 'mne_incubator.externals.pacpy', 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /examples/preprocessing/plot_dss_array.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Denoising source separation applied to a NumPy array""" 4 | 5 | # Authors: Daniel McCloy 6 | # 7 | # License: BSD (3-clause) 8 | 9 | import numpy as np 10 | from mne_incubator.preprocessing import dss 11 | from matplotlib import pyplot as plt 12 | 13 | 14 | def rms(data): 15 | return np.sqrt(np.mean(data ** 2, axis=-1, keepdims=True)) 16 | 17 | snr = 0.1 18 | noise_dims = 20 19 | rand = np.random.RandomState(123) 20 | 21 | # create synthetic data 22 | n_trials = 200 23 | n_times = 1000 24 | n_channels = 32 25 | pad = np.zeros(n_times // 3) 26 | signal_nsamps = n_times - 2 * pad.size 27 | sine = np.sin(2 * np.pi * np.arange(signal_nsamps) / float(signal_nsamps)) 28 | signal = rand.randn(n_channels, 1) * np.r_[pad, sine, pad][np.newaxis, :] 29 | channel_noise = rand.randn(n_channels, noise_dims) 30 | trial_noise = rand.randn(n_trials, n_times, noise_dims) 31 | noise = np.einsum('ijk,lk->ilj', trial_noise, channel_noise) 32 | data = 4e-6 * (noise / rms(noise) + snr * signal / rms(signal)) 33 | 34 | # perform DSS 35 | dss_mat, dss_data = dss(data, data_thresh=1e-3, bias_thresh=1e-3) 36 | 37 | # plot 38 | fig, axs = plt.subplots(3, 1, figsize=(7, 12), sharex=True) 39 | plotdata = [signal.T, data[0].T, dss_data[:, 0].T] 40 | linewidths = (1, 0.3, 0.4) 41 | titles = ('synthetic signal with random weights for each channel', 42 | 'one trial, all channels, after noise addition (SNR=0.1)', 43 | 'First DSS component from each trial') 44 | for ax, dat, lw, ti in zip(axs, plotdata, linewidths, titles): 45 | ax.xaxis.set_ticks_position('bottom') 46 | ax.yaxis.set_ticks_position('left') 47 | ax.spines['top'].set_visible(False) 48 | ax.spines['right'].set_visible(False) 49 | ax.plot(dat, linewidth=lw) 50 | ax.set_title(ti) 51 | ax.set_xlabel('samples') 52 | plt.tight_layout() 53 | -------------------------------------------------------------------------------- /.github/workflows/codespell_and_flake.yml: -------------------------------------------------------------------------------- 1 | name: 'codespell_and_flake' 2 | # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#concurrency 3 | # https://docs.github.com/en/developers/webhooks-and-events/events/github-event-types#pullrequestevent 4 | # workflow name, PR number (empty on push), push ref (empty on PR) 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} 7 | cancel-in-progress: true 8 | on: 9 | push: 10 | branches: 11 | - '*' 12 | pull_request: 13 | branches: 14 | - '*' 15 | 16 | jobs: 17 | style: 18 | name: 'codespell and flake' 19 | runs-on: ubuntu-20.04 20 | env: 21 | CODESPELL_DIRS: 'mne_incubator/ examples/' 22 | CODESPELL_SKIPS: '*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle,*.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume,*.defect_borders,*.mgh,lh.*,rh.*,COR-*,FreeSurferColorLUT.txt,*.examples,.xdebug_mris_calc,bad.segments,BadChannels,*.hist,empty_file,*.orig,*.js,*.map,*.ipynb,searchindex.dat,install_mne_c.rst,plot_*.rst,*.rst.txt,c_EULA.rst*,*.html,gdf_encodes.txt,*.svg,references.bib,*.css,*.edf,*.bdf,*.vhdr' 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: '3.10' 29 | architecture: 'x64' 30 | - name: 'Install dependencies' 31 | run: | 32 | python -m pip install --upgrade pip setuptools wheel 33 | python -m pip install flake8 34 | - name: 'Setup flake8 annotations' 35 | uses: rbialon/flake8-annotations@v1 36 | - name: 'Run flake8' 37 | run: make flake 38 | - name: 'Run codespell' 39 | uses: codespell-project/actions-codespell@v1.0 40 | with: 41 | path: ${{ env.CODESPELL_DIRS }} 42 | skip: ${{ env.CODESPELL_SKIPS }} 43 | builtin: 'clear,rare,informal,names' 44 | ignore_words_file: 'ignore_words.txt' 45 | uri_ignore_words_list: 'bu' 46 | -------------------------------------------------------------------------------- /examples/preprocessing/plot_dss_epochs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Denoising source separation applied to an Epochs object""" 4 | 5 | # Authors: Daniel McCloy 6 | # 7 | # License: BSD (3-clause) 8 | 9 | from os import path as op 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | import mne 14 | from mne.datasets import sample 15 | 16 | from mne_incubator.preprocessing import dss 17 | 18 | 19 | # file paths 20 | data_path = sample.data_path() 21 | raw_fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw.fif') 22 | events_fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis_raw-eve.fif') 23 | # import sample data 24 | raw = mne.io.read_raw_fif(raw_fname, preload=True) 25 | events = mne.read_events(events_fname) 26 | # pick channels, filter, epoch 27 | picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=True) 28 | # reject = dict(eeg=180e-6, eog=150e-6) 29 | reject = None 30 | raw.filter(0.3, 30, method='iir', picks=picks) 31 | epochs = mne.Epochs(raw, events, event_id=1, preload=True, picks=picks, 32 | reject=reject) 33 | epochs.pick_types(eeg=True) 34 | evoked = epochs.average() 35 | 36 | # perform DSS 37 | dss_mat, dss_data = dss(epochs, data_thresh=1e-9, bias_thresh=1e-9) 38 | 39 | evoked_data_clean = np.dot(dss_mat, evoked.data) 40 | evoked_data_clean[4:] = 0. # select 4 components 41 | evoked_data_clean = np.dot(np.linalg.pinv(dss_mat), evoked_data_clean) 42 | 43 | # plot 44 | fig, axs = plt.subplots(3, 1, figsize=(7, 10), sharex=True) 45 | plotdata = [evoked.data.T, evoked_data_clean.T, dss_data[:, 0].T] 46 | linewidths = (1, 1, 0.6) 47 | titles = ('evoked data (EEG only)', 48 | 'evoked data after DSS (EEG only)', 49 | 'first DSS component from each epoch (EEG only)') 50 | for ax, dat, lw, ti in zip(axs, plotdata, linewidths, titles): 51 | ax.xaxis.set_ticks_position('bottom') 52 | ax.yaxis.set_ticks_position('left') 53 | ax.spines['top'].set_visible(False) 54 | ax.spines['right'].set_visible(False) 55 | ax.plot(1e3 * evoked.times, dat, linewidth=lw) 56 | ax.set_title(ti) 57 | ax.set_xlabel('Time (ms)') 58 | plt.tight_layout() 59 | plt.show() 60 | -------------------------------------------------------------------------------- /examples/preprocessing/plot_eog_regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================== 3 | EOG regression 4 | ======================== 5 | 6 | Reduce EOG artifacts by regressing the EOG channels onto the rest of the 7 | signal. 8 | 9 | References 10 | ---------- 11 | [1] Croft, R. J., & Barry, R. J. (2000). Removal of ocular artifact from 12 | the EEG: a review. Clinical Neurophysiology, 30(1), 5-19. 13 | http://doi.org/10.1016/S0987-7053(00)00055-1 14 | 15 | Authors: Marijn van Vliet 16 | 17 | License: BSD (3-clause) 18 | """ 19 | 20 | import mne 21 | from mne_incubator.preprocessing import eog_regression 22 | from mne.datasets import sample 23 | from matplotlib import pyplot as plt 24 | 25 | print(__doc__) 26 | 27 | data_path = sample.data_path() 28 | 29 | raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' 30 | 31 | # Read raw data 32 | raw = mne.io.Raw(raw_fname, preload=True) 33 | events = mne.find_events(raw, 'STI 014') 34 | 35 | # Bandpass filter the EEG and EOG 36 | picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=True) 37 | raw.filter(0.3, 30, method='iir', picks=picks) 38 | 39 | # Create evokeds before EOG correction 40 | picks = mne.pick_types(raw.info, meg=False, eeg=True) 41 | tmin, tmax = -0.2, 0.5 42 | event_ids = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4} 43 | epochs_before = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks, 44 | preload=True) 45 | evoked_before = epochs_before.average() 46 | 47 | 48 | # Estimate blink onsets and create blink epochs 49 | eog_event_id = 512 50 | eog_events = mne.preprocessing.find_eog_events(raw, eog_event_id) 51 | picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=True) 52 | blink_epochs = mne.Epochs(raw, eog_events, eog_event_id, tmin=-0.5, tmax=0.5, 53 | picks=picks, baseline=(-0.5, -0.3), preload=True) 54 | 55 | # Perform regression and remove EOG 56 | raw_clean, weights = eog_regression(raw, blink_epochs) 57 | 58 | # Create epochs after EOG correction 59 | picks = mne.pick_types(raw_clean.info, meg=False, eeg=True) 60 | tmin, tmax = -0.2, 0.5 61 | event_ids = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4} 62 | epochs_after = mne.Epochs(raw_clean, events, event_ids, tmin, tmax, 63 | picks=picks, preload=True) 64 | evoked_after = epochs_after.average() 65 | 66 | # Show the filter weights in a topomap 67 | l = mne.channels.make_eeg_layout(raw.info) 68 | mne.viz.plot_topomap(weights[0], l.pos[:, :2]) 69 | plt.title('Regression weights') 70 | 71 | # Plot the evoked before and after EOG regression 72 | evoked_before.plot() 73 | plt.ylim(-6, 6) 74 | plt.title('Before EOG regression') 75 | 76 | evoked_after.plot() 77 | plt.ylim(-6, 6) 78 | plt.title('After EOG regression') 79 | -------------------------------------------------------------------------------- /.github/workflows/linux_conda.yml: -------------------------------------------------------------------------------- 1 | name: 'linux / conda' 2 | concurrency: 3 | group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} 4 | cancel-in-progress: true 5 | on: 6 | push: 7 | branches: 8 | - '*' 9 | pull_request: 10 | branches: 11 | - '*' 12 | 13 | jobs: 14 | py310: 15 | runs-on: ubuntu-20.04 16 | name: 'linux conda 3.10' 17 | defaults: 18 | run: 19 | shell: bash 20 | env: 21 | CONDA_ENV: 'environment.yml' 22 | DISPLAY: ':99.0' 23 | MNE_LOGGING_LEVEL: 'warning' 24 | MKL_NUM_THREADS: '1' 25 | PYTHONUNBUFFERED: '1' 26 | PYTHON_VERSION: '3.10' 27 | steps: 28 | - name: 'checkout MNE-Python repo' 29 | uses: actions/checkout@v3 30 | with: 31 | name: mne-tools/mne-python 32 | ref: refs/heads/main 33 | - name: 'Setup xvfb' 34 | run: ./tools/setup_xvfb.sh 35 | - name: 'Setup conda' 36 | uses: conda-incubator/setup-miniconda@v2 37 | with: 38 | activate-environment: 'mne' 39 | python-version: ${{ env.PYTHON_VERSION }} 40 | environment-file: ${{ env.CONDA_ENV }} 41 | - name: 'Install MNE-Python dependencies' 42 | shell: bash -el {0} 43 | run: | 44 | ./tools/github_actions_dependencies.sh 45 | source tools/get_minimal_commands.sh 46 | - name: 'Check minimal commands' 47 | shell: bash -el {0} 48 | run: mne_surf2bem --version 49 | - name: 'Install MNE-Python' 50 | shell: bash -el {0} 51 | run: ./tools/github_actions_install.sh 52 | - name: 'Check Qt GL' 53 | shell: bash -el {0} 54 | run: | 55 | QT_QPA_PLATFORM=xcb LIBGL_DEBUG=verbose LD_DEBUG=libs python -c "import pyvistaqt; pyvistaqt.BackgroundPlotter(show=True)" 56 | - name: 'Show infos' 57 | shell: bash -el {0} 58 | run: ./tools/github_actions_infos.sh 59 | - name: 'Get testing version' 60 | shell: bash -el {0} 61 | run: ./tools/get_testing_version.sh 62 | - name: 'Cache testing data' 63 | uses: actions/cache@v3 64 | with: 65 | key: ${{ env.TESTING_VERSION }} 66 | path: ~/mne_data 67 | - name: 'Download testing data' 68 | shell: bash -el {0} 69 | run: ./tools/github_actions_download.sh 70 | - name: 'Print locale' 71 | shell: bash -el {0} 72 | run: ./tools/github_actions_locale.sh 73 | - name: 'Clone MNE-Incubator' 74 | uses: actions/checkout@v3 75 | with: 76 | name: mne-tools/mne-incubator 77 | ref: refs/heads/main 78 | - name: 'Run tests' 79 | shell: bash -el {0} 80 | run: ./tools/github_actions_test.sh 81 | - uses: codecov/codecov-action@v3 82 | if: success() 83 | name: 'Upload coverage to CodeCov' 84 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # simple makefile to simplify repetetive build env management tasks under posix 2 | 3 | # caution: testing won't work on windows, see README 4 | 5 | PYTHON ?= python 6 | NOSETESTS ?= nosetests 7 | CTAGS ?= ctags 8 | 9 | all: clean inplace test test-doc 10 | 11 | clean-pyc: 12 | find . -name "*.pyc" | xargs rm -f 13 | 14 | clean-so: 15 | find . -name "*.so" | xargs rm -f 16 | find . -name "*.pyd" | xargs rm -f 17 | 18 | clean-build: 19 | rm -rf _build 20 | 21 | clean-ctags: 22 | rm -f tags 23 | 24 | clean-cache: 25 | find . -name "__pycache__" | xargs rm -rf 26 | 27 | clean: clean-build clean-pyc clean-so clean-ctags clean-cache 28 | 29 | in: inplace # just a shortcut 30 | inplace: 31 | $(PYTHON) setup.py build_ext -i 32 | 33 | sample_data: 34 | @python -c "import mne; mne.datasets.sample.data_path(verbose=True);" 35 | 36 | testing_data: 37 | @python -c "import mne; mne.datasets.testing.data_path(verbose=True);" 38 | 39 | test: in 40 | rm -f .coverage 41 | $(NOSETESTS) -a '!ultra_slow_test' mne_incubator 42 | 43 | test-fast: in 44 | rm -f .coverage 45 | $(NOSETESTS) -a '!slow_test' mne_incubator 46 | 47 | test-full: in 48 | rm -f .coverage 49 | $(NOSETESTS) mne_incubator 50 | 51 | test-no-network: in 52 | sudo unshare -n -- sh -c 'MNE_SKIP_NETWORK_TESTS=1 nosetests mne_incubator' 53 | 54 | test-no-testing-data: in 55 | @MNE_SKIP_TESTING_DATASET_TESTS=true \ 56 | $(NOSETESTS) mne_incubator 57 | 58 | test-no-sample-with-coverage: in testing_data 59 | rm -rf coverage .coverage 60 | $(NOSETESTS) --with-coverage --cover-package=mne_incubator --cover-html --cover-html-dir=coverage 61 | 62 | test-doc: sample_data testing_data 63 | $(NOSETESTS) --with-doctest --doctest-tests --doctest-extension=rst doc/ 64 | 65 | test-coverage: testing_data 66 | rm -rf coverage .coverage 67 | $(NOSETESTS) --with-coverage --cover-package=mne_incubator --cover-html --cover-html-dir=coverage 68 | 69 | test-profile: testing_data 70 | $(NOSETESTS) --with-profile --profile-stats-file stats.pf mne_incubator 71 | hotshot2dot stats.pf | dot -Tpng -o profile.png 72 | 73 | test-mem: in testing_data 74 | ulimit -v 1097152 && $(NOSETESTS) 75 | 76 | trailing-spaces: 77 | find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' 78 | 79 | ctags: 80 | # make tags for symbol based navigation in emacs and vim 81 | # Install with: sudo apt-get install exuberant-ctags 82 | $(CTAGS) -R * 83 | 84 | flake: 85 | @if command -v flake8 > /dev/null; then \ 86 | echo "Running flake8"; \ 87 | flake8 --count mne_incubator; \ 88 | else \ 89 | echo "flake8 not found, please install it!"; \ 90 | exit 1; \ 91 | fi; 92 | @echo "flake8 passed" 93 | 94 | codespell: 95 | # The *.fif had to be there twice to be properly ignored (!) 96 | codespell.py -w -i 3 -S="*.fif,*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc,*.label,*.w,*.bz2,*.coverage,*.annot,*.sulc,*.log,*.local-copy,*.orig_avg,*.inflated_avg,*.gii" ./dictionary.txt -r . 97 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/tests/test_dss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from numpy.testing import assert_allclose 5 | import pytest 6 | 7 | from mne import create_info, EpochsArray 8 | from mne_incubator.preprocessing import dss 9 | 10 | 11 | def test_dss_args(): 12 | """Test DSS error handling.""" 13 | data1 = list() 14 | data2 = np.arange(6).reshape(2, 3) 15 | data3 = np.arange(2 * 3 * 5).reshape(2, 3, 5) 16 | data4 = np.arange(2 * 3 * 5 * 7).reshape(2, 3, 5, 7) 17 | pytest.raises(TypeError, dss, data1) 18 | pytest.raises(ValueError, dss, data2) # not enough dimensions 19 | pytest.raises(ValueError, dss, data4) # too many dimensions 20 | pytest.raises(ValueError, dss, data3, data_thresh=2) # invalid threshold 21 | 22 | 23 | def test_dss(): 24 | """Test DSS computations.""" 25 | 26 | def rms(data): 27 | return np.sqrt(np.mean(data ** 2, axis=-1, keepdims=True)) 28 | 29 | rand = np.random.RandomState(123) 30 | # parameters 31 | n_trials, n_times, n_channels, noise_dims, snr = [200, 1000, 16, 10, 0.1] 32 | # 1 Hz sine wave with silence before & after 33 | pad = np.zeros(n_times // 3) 34 | signal_nsamps = n_times - 2 * pad.size 35 | sine = np.sin(2 * np.pi * np.arange(signal_nsamps) / float(signal_nsamps)) 36 | sine = np.r_[pad, sine, pad] 37 | signal = rand.randn(n_channels, 1) * sine[np.newaxis, :] 38 | # noise 39 | noise = np.einsum('hjk,ik->hij', 40 | rand.randn(n_trials, n_times, noise_dims), 41 | rand.randn(n_channels, noise_dims)) 42 | # signal plus noise, in a reasonable range for EEG 43 | data = 4e-6 * (noise / rms(noise) + snr * signal / rms(signal)) 44 | # perform DSS 45 | dss_mat, dss_data = dss(data, data_thresh=1e-3, bias_thresh=1e-3, 46 | bias_max_components=n_channels - 1) 47 | # handle scaling and possible 180 degree phase difference 48 | dss_trial1_comp1 = dss_data[0, 0] / np.abs(dss_data[0, 0]).max() 49 | dss_trial1_comp1 *= np.sign(np.dot(dss_trial1_comp1, sine)) 50 | assert_allclose(dss_trial1_comp1, sine, rtol=0, atol=0.2) 51 | # test handling of epochs objects 52 | sfreq = 1000 53 | first_samp = 150 54 | samps = np.arange(first_samp, first_samp + n_trials * n_times * 2, 55 | n_times * 2)[:, np.newaxis] 56 | events = np.c_[samps, np.zeros_like(samps), np.ones_like(samps)] 57 | ch_names = ['EEG{0:03}'.format(n + 1) for n in range(n_channels)] 58 | ch_types = ['eeg'] * n_channels 59 | info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) 60 | epochs = EpochsArray(data, info=info, events=events, event_id={'fake': 1}) 61 | dss_mat_epochs = dss(epochs, data_thresh=1e-3, bias_thresh=1e-3, 62 | bias_max_components=n_channels - 1, return_data=False) 63 | # make sure we get the same answer when data is an epochs object 64 | dss_mat = dss_mat / dss_mat.max() 65 | dss_mat_epochs = dss_mat_epochs / dss_mat_epochs.max() 66 | assert_allclose(dss_mat, dss_mat_epochs) 67 | -------------------------------------------------------------------------------- /mne_incubator/connectivity/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def simulate_pac_signal(time, freq_phase, freq_amp, max_amp_lo=2., 5 | max_amp_hi=.5, frac_pac=.1, snr_lo=4., 6 | snr_hi=4., mask_pac_times=None): 7 | """Simulate a signal with phase-amplitude coupling according to [1]. 8 | 9 | Parameters 10 | ---------- 11 | time : array, shape (n_times,) 12 | The times for the signal (which implicitly defines the sampling 13 | frequency). 14 | freq_phase : float 15 | The frequency of the low-frequency phase that modulates amplitude. 16 | freq_amp : float 17 | The frequency of the high-frequency amplitude that is modulated by 18 | phase. 19 | max_amp_lo : float 20 | The maximum amplitude for the low-frequency phase signal. 21 | max_amp_hi : float 22 | The maximum amplitude for the high-frequency amplitude signal. 23 | frac_pac : float, between (0., 1.) 24 | The fraction of the high-frequency amplitude that is modulated by 25 | low-frequency phase. 26 | snr_lo : float | None 27 | The ratio of signal to noise in the low-frequency signal. 28 | Defaults to 4. 29 | snr_hi : float | None 30 | The ratio of signal to noise in the high-frequency signal. 31 | Defaults to 4. 32 | mask_pac_times : array, dtype bool, shape (n_times,) | None 33 | Whether to mask specific times to induce PAC. Values where 34 | `mask_pac_times` is False will have `frac_pac` set to 0. 35 | If None, all times have PAC. 36 | 37 | Returns 38 | ------- 39 | signal : array, shape (n_times,) 40 | The simulated PAC signal w/ both low and high frequency components. 41 | phase_signal : array, shape (n_times,) 42 | The low-frequency phase. 43 | amp_signal : array, shape (n_times,) 44 | The high-frequency amplitude. 45 | 46 | References 47 | ---------- 48 | .. [1] Tort, et al. "Measuring Phase-Amplitude Coupling Between Neuronal 49 | Oscillations of Different Frequencies." Journal of Neurophysiology, 50 | vol. 104, issue 2, 2010. 51 | """ 52 | if frac_pac < 0. or frac_pac > 1.: 53 | raise ValueError('frac_pac must be between 0. and 1.') 54 | if not all([isinstance(i, (float, int)) for i in [freq_amp, freq_phase]]): 55 | raise ValueError('freq_amp and freq_phase must be a single float') 56 | if mask_pac_times is None: 57 | mask_pac_times = np.ones_like(time).astype(bool) 58 | 59 | # Simulate noise 60 | noise = np.random.randn(2, time.shape[0]) 61 | noise = noise * np.array([max_amp_lo / snr_lo, 62 | max_amp_hi / snr_hi])[:, np.newaxis] 63 | 64 | # Low-freq phase 65 | phase_signal = max_amp_lo * np.sin(2 * np.pi * freq_phase * time) 66 | phase_signal = phase_signal + noise[0] 67 | 68 | # High-freq amp 69 | frac_non_pac = 1. - frac_pac 70 | frac_non_pac = np.where(mask_pac_times, frac_non_pac, 1) 71 | amp_signal = (1 - frac_non_pac) * np.sin(2 * np.pi * freq_phase * time) 72 | amp_signal += 1 + frac_non_pac 73 | amp_signal = max_amp_hi * amp_signal / 2. 74 | amp_signal = amp_signal * np.sin(2 * np.pi * freq_amp * time) 75 | amp_signal = amp_signal + noise[1] 76 | 77 | # Combine them 78 | signal = amp_signal + phase_signal 79 | 80 | return signal, phase_signal, amp_signal 81 | -------------------------------------------------------------------------------- /examples/connectivity/plot_pac_viz.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================================== 3 | Visualize phase-amplitude coupling measures between signals 4 | ============================================================== 5 | Computes the normalized amplitude traces for a cross frequency coupled 6 | signal across a given range of frequencies and displays it along with 7 | the event related average response. 8 | References 9 | ---------- 10 | [1] Canolty RT, Edwards E, Dalal SS, Soltani M, Nagarajan SS, Kirsch HE, 11 | Berger MS, Barbaro NM, Knight RT. "High gamma power is phase-locked to 12 | theta oscillations in human neocortex." Science. 2006. 13 | [2] Tort ABL, Komorowski R, Eichenbaum H, Kopell N. Measuring phase-amplitude 14 | coupling between neuronal oscillations of different frequencies. Journal of 15 | Neurophysiology. 2010. 16 | """ 17 | # Author: Chris Holdgraf 18 | # Praveen Sripad 19 | # Alexandre Gramfort 20 | # 21 | # License: BSD (3-clause) 22 | 23 | import numpy as np 24 | 25 | import mne 26 | from mne import io 27 | from mne.datasets import sample 28 | from mne_incubator.connectivity import (phase_amplitude_coupling, 29 | plot_phase_locked_amplitude, 30 | plot_phase_binned_amplitude) 31 | import matplotlib.pyplot as plt 32 | 33 | print(__doc__) 34 | 35 | ############################################################################### 36 | # Set parameters 37 | data_path = sample.data_path() 38 | raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' 39 | event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' 40 | 41 | # Setup for reading the raw data 42 | raw = io.Raw(raw_fname, preload=True) 43 | events = mne.read_events(event_fname) 44 | ev_ixs = events[:, 0].astype(int) 45 | 46 | # Add a bad channel 47 | raw.info['bads'] += ['MEG 2443'] 48 | 49 | # Pick MEG gradiometers 50 | raw = raw.pick_types(meg='grad', eeg=False, stim=False, eog=True, 51 | exclude='bads') 52 | 53 | # Define a pair of indices 54 | ixs = [(4, 10)] 55 | ix_ph, ix_amp = ixs[0] 56 | 57 | # First we can simply calculate a PAC statistic for these signals 58 | f_range_phase = (6, 8) 59 | f_range_amp = (40, 60) 60 | 61 | # Create some artifical PAC data to show the effect 62 | # Calculate the low-frequency phase 63 | raw_phase = raw.copy() 64 | raw_phase = raw_phase.filter(*f_range_phase) 65 | raw_phase.apply_hilbert([ix_ph]) 66 | angles = np.angle(raw_phase._data[1]) 67 | msk_angles = angles > (.5 * np.pi) 68 | 69 | # Take the high-frequency component of the signal, and modulate it w/ the phase 70 | raw_band = raw.copy() 71 | raw_band = raw_band.filter(*f_range_amp) 72 | raw_band._data[ix_amp][~msk_angles] = 0 73 | raw_band._data[ix_amp][msk_angles] *= 10. 74 | 75 | # Now add the high-freq signal back into the raw data 76 | raw_artificial = raw.copy() 77 | raw_artificial._data[ix_amp] += raw_band._data[ix_amp] 78 | 79 | for i_data in [raw, raw_artificial]: 80 | pac, freqs_pac = phase_amplitude_coupling( 81 | i_data, f_range_phase, f_range_amp, ixs, pac_func='glm', 82 | events=ev_ixs, tmin=0, tmax=.5) 83 | pac = pac.mean() # Average across events 84 | 85 | # We can also visualize these relationships 86 | # Create epochs for left-visual condition 87 | event_id, tmin, tmax = 3, -1, 4 88 | epochs = mne.Epochs(i_data, events, event_id, tmin, tmax, 89 | baseline=(None, 0.), 90 | reject=dict(grad=4000e-13, eog=150e-6), preload=True) 91 | ph_range = np.linspace(*f_range_phase, num=6) 92 | amp_range = np.linspace(*f_range_amp, num=20) 93 | 94 | # Show the amp for a range of frequencies, phase-locked to a low-freq 95 | ax = plot_phase_locked_amplitude(epochs, ph_range, amp_range, ixs[0][0], 96 | ixs[0][1], normalize=True) 97 | ax.set_title('Phase Locked Amplitude, PAC = {0}'.format(pac)) 98 | 99 | # Show the avg amp of the high freqs for bins of phase in the low freq 100 | ax = plot_phase_binned_amplitude(epochs, ph_range, amp_range, 101 | ixs[0][0], ixs[0][1], normalize=True, 102 | n_bins=20) 103 | ax.set_title('Phase Binned Amplitude, PAC = {0}'.format(pac)) 104 | 105 | plt.tight_layout() 106 | plt.show(block=True) 107 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/tests/test_sns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os.path as op 4 | 5 | import numpy as np 6 | import pytest 7 | from numpy.testing import assert_allclose, assert_equal 8 | 9 | from mne import create_info, io, pick_types, pick_channels 10 | from mne.io import RawArray 11 | from mne.utils import run_tests_if_main 12 | from mne.datasets import testing 13 | 14 | from mne_incubator.preprocessing import SensorNoiseSuppression 15 | 16 | data_path = testing.data_path() 17 | raw_fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis_trunc_raw.fif') 18 | 19 | 20 | @testing.requires_testing_data 21 | def test_sns(): 22 | """Test sensor noise suppression""" 23 | # artificial (IID) data 24 | data = np.random.RandomState(0).randn(102, 5000) 25 | info = create_info(len(data), 1000., 'mag') 26 | raw = io.RawArray(data, info) 27 | pytest.raises(ValueError, SensorNoiseSuppression, 'foo') 28 | pytest.raises(TypeError, SensorNoiseSuppression(10).fit, 'foo') 29 | pytest.raises(ValueError, SensorNoiseSuppression, -1) 30 | raw.info['bads'] = [raw.ch_names[1]] 31 | pytest.raises(ValueError, SensorNoiseSuppression(101).fit, raw) 32 | for n_neighbors, bounds in ((2, (17, 20)), 33 | (5, (11, 15),), 34 | (10, (9, 12)), 35 | (20, (7, 10)), 36 | (50, (5, 9)), 37 | (100, (5, 8)), 38 | ): 39 | sns = SensorNoiseSuppression(n_neighbors) 40 | sns.fit(raw) 41 | raw_sns = sns.apply(raw.copy()) 42 | operator = sns.operator 43 | # bad channels are modified but not used 44 | assert_allclose(operator[:, 1], 0.) 45 | assert (np.sum(np.abs(operator)) > 0) 46 | assert_equal(operator[0].astype(bool).sum(), n_neighbors) 47 | assert_allclose(np.diag(operator), 0.) 48 | picks = pick_types(raw.info) 49 | orig_power = np.linalg.norm(raw[picks][0]) 50 | # Test the suppression factor 51 | factor = orig_power / np.linalg.norm(raw_sns[picks][0]) 52 | assert bounds[0] < factor < bounds[1] 53 | # degenerate conditions 54 | pytest.raises(TypeError, sns.apply, 'foo') 55 | sub_raw = raw.copy().pick_channels(raw.ch_names[:-1]) 56 | pytest.raises(RuntimeError, sns.apply, sub_raw) # not all orig chs 57 | sub_sns = SensorNoiseSuppression(8) 58 | sub_sns.fit(sub_raw) 59 | pytest.raises(RuntimeError, sub_sns.apply, raw) # not all new chs 60 | # sample data 61 | raw = io.read_raw_fif(raw_fname) 62 | n_neighbors = 8 63 | sns = SensorNoiseSuppression(n_neighbors=n_neighbors) 64 | sns.fit(raw) 65 | raw_sns = sns.apply(raw.copy().load_data()) 66 | operator = sns.operator 67 | # bad channels are modified 68 | assert_equal(len(raw.info['bads']), 2) 69 | for pick in pick_channels(raw.ch_names, raw.info['bads']): 70 | expected = np.zeros(operator.shape[0]) 71 | sub_pick = sns._used_chs.index(raw.ch_names[pick]) 72 | expected[sub_pick] = 1. 73 | pytest.raises(AssertionError, assert_allclose, operator[sub_pick], 74 | expected) 75 | pytest.raises(AssertionError, assert_allclose, raw[pick][0], 76 | raw_sns[pick][0]) 77 | assert_equal(operator[0].astype(bool).sum(), n_neighbors) 78 | assert_equal(operator[0, 0], 0.) 79 | picks = pick_types(raw.info) 80 | orig_power = np.linalg.norm(raw[picks][0]) 81 | # Test the suppression factor 82 | factor = orig_power / np.linalg.norm(raw_sns[picks][0]) 83 | bounds = (1.3, 1.7) 84 | assert bounds[0] < factor < bounds[1] 85 | # degenerate conditions 86 | pytest.raises(RuntimeError, sns.apply, raw) # not preloaded 87 | 88 | # Test against NoiseTools 89 | rng = np.random.RandomState(0) 90 | n_channels = 9 91 | x = rng.randn(n_channels, 10000) 92 | # make some correlations 93 | x = np.dot(rng.randn(n_channels, n_channels), x) 94 | x -= x.mean(axis=-1, keepdims=True) 95 | # savemat('test.mat', dict(x=x)) # --> run through nt_sns in MATLAB 96 | nt_op = np.array([ # Obtained from NoiseTools 18-Nov-2016 97 | [0, 0, -0.3528, 0, 0.6152, 0, 0, -0.3299, 0.1914], 98 | [0, 0, 0, 0.0336, 0, 0, -0.4284, 0, 0], 99 | [-0.2928, 0.2463, 0, 0, -0.0891, 0, 0, 0.2200, 0], 100 | [0, 0.0191, 0, 0, -0.3756, -0.3253, 0.4047, -0.4608, 0], 101 | [0.3184, 0, -0.0878, 0, 0, 0, 0, 0, 0], 102 | [0, 0, 0, 0, 0, 0, 0.5865, 0, -0.2137], 103 | [0, -0.5190, 0, 0.5059, 0, 0.8271, 0, 0, -0.0771], 104 | [-0.3953, 0, 0.3092, -0.5018, 0, 0, 0, 0, 0], 105 | [0, 0, 0, 0, 0, -0.2050, 0, 0, 0]]).T 106 | 107 | raw = RawArray(x, create_info(n_channels, 1000., 108 | ['grad', 'grad', 'mag'] * 3)) 109 | sns = SensorNoiseSuppression(3) 110 | sns.fit(raw) 111 | # this isn't perfect, but it's close 112 | np.testing.assert_allclose(sns.operator, nt_op, rtol=5e-2, atol=1e-3) 113 | 114 | 115 | run_tests_if_main() 116 | -------------------------------------------------------------------------------- /mne_incubator/connectivity/viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def plot_phase_locked_amplitude(epochs, freqs_phase, freqs_amp, 5 | ix_ph, ix_amp, mask_times=None, 6 | normalize=True, 7 | tmin=-.5, tmax=.5, return_data=False, 8 | amp_kwargs=None, ph_kwargs=None): 9 | """Make a phase-locked amplitude plot. 10 | 11 | Parameters 12 | ---------- 13 | epochs : mne.Epochs 14 | The epochs to be used in phase locking computation 15 | freqs_phase : np.array 16 | The frequencies to use in phase calculation. The phase of each 17 | frequency will be averaged together. 18 | freqs_amp : np.array 19 | The frequencies to use in amplitude calculation. 20 | ix_ph : int 21 | The index of the signal to be used for phase calculation 22 | ix_amp : int 23 | The index of the signal to be used for amplitude calculation 24 | normalize : bool 25 | Whether amplitudes are normalized before averaging together. Helps 26 | if some frequencies have a larger mean amplitude than others. 27 | tmin : float 28 | The time to include before each phase peak 29 | tmax : float 30 | The time to include after each phase peak 31 | return_data : bool 32 | If True, the amplitude/frequency data will be returned 33 | amp_kwargs : dict 34 | kwargs to be passed to pcolormesh for amplitudes 35 | ph_kwargs : dict 36 | kwargs to be passed to the line plot for phase 37 | 38 | Returns 39 | ------- 40 | axs : array of matplotlib axes 41 | The axes used for plotting. 42 | """ 43 | import matplotlib.pyplot as plt 44 | from .cfc import phase_locked_amplitude 45 | from sklearn.preprocessing import scale 46 | amp_kwargs = dict() if amp_kwargs is None else amp_kwargs 47 | ph_kwargs = dict() if ph_kwargs is None else ph_kwargs 48 | 49 | # Handle kwargs defaults 50 | if 'cmap' not in amp_kwargs.keys(): 51 | amp_kwargs['cmap'] = plt.cm.RdBu_r 52 | 53 | data_am, data_ph, times = phase_locked_amplitude( 54 | epochs, freqs_phase, freqs_amp, 55 | ix_ph, ix_amp, tmin=tmin, tmax=tmax, mask_times=mask_times) 56 | 57 | if normalize is True: 58 | # Scale within freqs across time 59 | data_am = scale(data_am, axis=-1) 60 | 61 | # Plotting 62 | f, axs = plt.subplots(2, 1) 63 | ax = axs[0] 64 | ax.pcolormesh(times, freqs_amp, data_am, **amp_kwargs) 65 | 66 | ax = axs[1] 67 | ax.plot(times, data_ph, **ph_kwargs) 68 | 69 | plt.setp(axs, xlim=[times[0], times[-1]]) 70 | ylim = np.max(np.abs(ax.get_ylim())) 71 | plt.setp(ax, ylim=[-ylim, ylim]) 72 | if return_data is True: 73 | return ax, data_am, data_ph 74 | else: 75 | return ax 76 | 77 | 78 | def plot_phase_binned_amplitude(epochs, freqs_phase, freqs_amp, 79 | ix_ph, ix_amp, normalize=True, 80 | n_bins=20, return_data=False, 81 | mask_times=None, ax=None, 82 | **kwargs): 83 | """Make a circular phase-binned amplitude plot. 84 | 85 | Parameters 86 | ---------- 87 | epochs : mne.Epochs 88 | The epochs to be used in phase locking computation 89 | freqs_phase : np.array 90 | The frequencies to use in phase calculation. The phase of each 91 | frequency will be averaged together. 92 | freqs_amp : np.array 93 | The frequencies to use in amplitude calculation. The amplitude 94 | of each frequency will be averaged together. 95 | ix_ph : int 96 | The index of the signal to be used for phase calculation 97 | ix_amp : int 98 | The index of the signal to be used for amplitude calculation 99 | normalize : bool 100 | Whether amplitudes are normalized before averaging together. Helps 101 | if some frequencies have a larger mean amplitude than others. 102 | n_bins : int 103 | The number of bins to use when grouping amplitudes. Each bin will 104 | have size (2 * np.pi) / n_bins. 105 | return_data : bool 106 | If True, the amplitude/frequency data will be returned 107 | ax : matplotlib axis | None 108 | If not None, plotting functions will be called on this object 109 | kwargs : dict 110 | kwargs to be passed to plt.bar 111 | 112 | Returns 113 | ------- 114 | ax : matplotlib axis 115 | The axis used for plotting. 116 | """ 117 | import matplotlib.pyplot as plt 118 | from .cfc import phase_binned_amplitude 119 | from sklearn.preprocessing import RobustScaler 120 | amps, bins = phase_binned_amplitude(epochs, freqs_phase, freqs_amp, 121 | ix_ph, ix_amp, n_bins=n_bins, 122 | mask_times=mask_times) 123 | if normalize is True: 124 | amps = RobustScaler().fit_transform(amps[:, np.newaxis]) 125 | if ax is None: 126 | plt.figure() 127 | ax = plt.subplot(111, polar=True) 128 | bins_plt = bins[:-1] # Because there is 1 more bins than amps 129 | width = 2 * np.pi / len(bins_plt) 130 | ax.bar(bins_plt + np.pi, amps, color='r', width=width) 131 | if return_data is True: 132 | return ax, amps, bins 133 | else: 134 | return ax 135 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/tests/test_eog.py: -------------------------------------------------------------------------------- 1 | import mne 2 | import numpy as np 3 | import scipy.stats 4 | from numpy.testing import assert_allclose, assert_equal 5 | 6 | from mne_incubator.preprocessing.eog import eog_regression 7 | 8 | 9 | def test_eog_regression(): 10 | """Test EOG artifact removal using RAAA""" 11 | # Clean EEG signal: weak 10 Hz sine 12 | sine = 0.01 * np.sin(10 * 2 * np.pi * np.arange(410) / 100.) 13 | clean = mne.io.RawArray( 14 | data=np.vstack([ 15 | np.tile(sine[np.newaxis, :], (2, 1)), 16 | np.zeros((3, 410)), # EOG 17 | ]), 18 | info=mne.create_info( 19 | ['EEG1', 'EEG2', 'HEOG', 'VEOG', 'REOG'], 20 | sfreq=100., 21 | ch_types=['eeg', 'eeg', 'eog', 'eog', 'eog'], 22 | ), 23 | ) 24 | HEOG_ind = clean.ch_names.index('HEOG') 25 | VEOG_ind = clean.ch_names.index('VEOG') 26 | REOG_ind = clean.ch_names.index('REOG') 27 | events = np.array([ 28 | [000, 0, 1], # Blink 29 | [100, 0, 1], # Blink 30 | [200, 0, 2], # Horizontal saccade 31 | [300, 0, 3], # Vertical saccade 32 | ]) 33 | 34 | # Scenario 1: Some blinks captured by the VEOG channel 35 | blinks = clean.copy() 36 | blink_shape = scipy.stats.norm(0.5, 0.05).pdf(np.arange(0, 1, 0.01)) 37 | blink_shape /= blink_shape.max() 38 | blink_shape = np.tile(blink_shape, 2) 39 | blinks._data[[0, 1, VEOG_ind], :200] += np.dot( 40 | [[1.1], [2.1], [1.0]], 41 | blink_shape[np.newaxis, :] 42 | ) 43 | cleaned, weights = eog_regression( 44 | raw=blinks, 45 | blink_epochs=mne.Epochs( 46 | blinks, events, {'blink': 1}, 47 | tmin=0, tmax=1, preload=True), 48 | eog_channels='VEOG', 49 | copy=True, 50 | ) 51 | assert_allclose(cleaned._data[:2], clean._data[:2], atol=1e-12) 52 | assert_allclose(weights, [[1.1, 2.1]]) 53 | 54 | # Scenario 2: Some blinks captured by VEOG and horizontal saccades captured 55 | # by HEOG 56 | blink_sacc = clean.copy() 57 | # Add blinks 58 | blink_sacc._data[[0, 1, VEOG_ind], :200] += np.dot( 59 | [[1.1], [2.1], [1.0]], 60 | blink_shape[np.newaxis, :] 61 | ) 62 | # Add saccades 63 | sacc_shape = np.hstack(( 64 | scipy.stats.norm(0.2, 0.05).pdf(np.arange(0, 0.2, 0.01)), 65 | scipy.stats.norm(0.2, 0.20).pdf(np.arange(0.2, 1, 0.01)) 66 | )) 67 | sacc_shape[:20] /= sacc_shape[:20].max() 68 | sacc_shape[20:] /= sacc_shape[20:].max() 69 | sacc_shape = np.tile(sacc_shape, 2) 70 | blink_sacc._data[[0, 1, HEOG_ind], 200:400] += np.dot( 71 | [[1.2], [2.2], [1.0]], 72 | sacc_shape[np.newaxis, :] 73 | ) 74 | cleaned, weights = eog_regression( 75 | raw=blink_sacc, 76 | blink_epochs=mne.Epochs( 77 | blink_sacc, events, {'blink': 1}, 78 | tmin=0, tmax=1, preload=True), 79 | saccade_epochs=mne.Epochs( 80 | blink_sacc, events, {'horiz-sacc': 2}, 81 | tmin=0, tmax=1, preload=True), 82 | eog_channels=['VEOG', 'HEOG'], 83 | copy=True, 84 | ) 85 | assert_allclose(cleaned._data[:2], clean._data[:2], atol=1e-3) 86 | assert_allclose(weights, [[1.1, 2.1], [1.2, 2.2]], atol=1e-3) 87 | 88 | # Scenario 3: Some blinks captured by REOG, some saccades captured 89 | # by HEOG. EOG also cross-contaminated. 90 | full_eog = clean.copy() 91 | # Add saccades (also contaminate VEOG and REOG) 92 | full_eog._data[[0, 1, VEOG_ind, HEOG_ind, REOG_ind], 200:400] += np.dot( 93 | [[1.2], [2.2], [0], [1.0], [0]], 94 | sacc_shape[np.newaxis, :] 95 | ) 96 | # Add blinks (also contaminate HEOG) 97 | full_eog._data[[0, 1, VEOG_ind, HEOG_ind, REOG_ind], :200] += np.dot( 98 | [[1.1], [2.1], [1.0], [0], [1.0]], 99 | blink_shape[np.newaxis, :] 100 | ) 101 | cleaned, weights = eog_regression( 102 | raw=full_eog, 103 | blink_epochs=mne.Epochs( 104 | full_eog, events, {'blink': 1}, 105 | tmin=0, tmax=1, preload=True), 106 | saccade_epochs=mne.Epochs( 107 | full_eog, events, {'horiz-sacc': 2, 'vert-sacc': 3}, 108 | tmin=0, tmax=1, preload=True), 109 | reog='REOG', 110 | eog_channels=['VEOG', 'HEOG', 'REOG'], 111 | copy=True, 112 | ) 113 | assert_allclose(cleaned._data[:2], clean._data[:2], atol=1e-3) 114 | assert_allclose(weights, [[0, 0], [1.2, 2.2], [1.1, 2.1]], atol=1e-3) 115 | 116 | # We use the last scenario to run a few more tests 117 | raw = full_eog 118 | blink_epochs = mne.Epochs( 119 | full_eog, events, {'blink': 1}, 120 | tmin=0, tmax=1, preload=True) 121 | saccade_epochs = mne.Epochs( 122 | full_eog, events, {'horiz-sacc': 2, 'vert-sacc': 3}, 123 | tmin=0, tmax=1, preload=True) 124 | 125 | # Default parameters 126 | raw2 = raw.copy() 127 | raw3, weights = eog_regression(raw2, blink_epochs) 128 | assert_equal(raw2, raw3) 129 | assert_equal(weights.shape, (3, 2)) 130 | 131 | # Picks parameter 132 | _, weights = eog_regression(raw, blink_epochs, saccade_epochs, copy=True, 133 | picks=[0]) 134 | assert_equal(weights.shape, (3, 1)) 135 | 136 | # REOG parameter 137 | _, weights = eog_regression(raw, blink_epochs, saccade_epochs, copy=True, 138 | reog='REOG', eog_channels=['VEOG', 'HEOG']) 139 | assert_equal(weights.shape, (3, 2)) 140 | 141 | # EOG channels parameter 142 | _, weights = eog_regression(raw, blink_epochs, saccade_epochs, copy=True, 143 | eog_channels='VEOG') 144 | assert_equal(weights.shape, (1, 2)) 145 | 146 | # Order of the EOG channels should not matter 147 | cleaned, weights = eog_regression(raw, blink_epochs, saccade_epochs, 148 | copy=True, reog='REOG', 149 | eog_channels=['VEOG', 'REOG', 'HEOG']) 150 | assert_allclose(cleaned._data[:2], clean._data[:2], atol=1e-3) 151 | assert_allclose(weights, [[0, 0], [1.1, 2.1], [1.2, 2.2]], atol=1e-3) 152 | -------------------------------------------------------------------------------- /mne_incubator/externals/pacpy/filt.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | from scipy.signal import filtfilt 5 | from scipy.signal import firwin2, firwin 6 | from scipy.signal import morlet 7 | 8 | 9 | def firf(x, f_range, fs=1000, w=3): 10 | """ 11 | Filter signal with an FIR filter 12 | *Like fir1 in MATLAB 13 | 14 | x : array-like, 1d 15 | Time series to filter 16 | f_range : (low, high), Hz 17 | Cutoff frequencies of bandpass filter 18 | fs : float, Hz 19 | Sampling rate 20 | w : float 21 | Length of the filter in terms of the number of cycles 22 | of the oscillation whose frequency is the low cutoff of the 23 | bandpass filter 24 | 25 | Returns 26 | ------- 27 | x_filt : array-like, 1d 28 | Filtered time series 29 | """ 30 | 31 | if w <= 0: 32 | raise ValueError( 33 | 'Number of cycles in a filter must be a positive number.') 34 | 35 | nyq = np.float(fs / 2) 36 | if np.any(np.array(f_range) > nyq): 37 | raise ValueError('Filter frequencies must be below nyquist rate.') 38 | 39 | if np.any(np.array(f_range) < 0): 40 | raise ValueError('Filter frequencies must be positive.') 41 | 42 | Ntaps = np.floor(w * fs / f_range[0]) 43 | if len(x) < Ntaps: 44 | raise RuntimeError( 45 | 'Length of filter is loger than data. ' 46 | 'Provide more data or a shorter filter.') 47 | 48 | # Perform filtering 49 | taps = firwin(Ntaps, np.array(f_range) / nyq, pass_zero=False) 50 | x_filt = filtfilt(taps, [1], x) 51 | 52 | if any(np.isnan(x_filt)): 53 | raise RuntimeError( 54 | 'Filtered signal contains nans. Adjust filter parameters.') 55 | 56 | # Remove edge artifacts 57 | return _remove_edge(x_filt, Ntaps) 58 | 59 | 60 | def firfls(x, f_range, fs=1000, w=3, tw=.15): 61 | """ 62 | Filter signal with an FIR filter 63 | *Like firls in MATLAB 64 | 65 | x : array-like, 1d 66 | Time series to filter 67 | f_range : (low, high), Hz 68 | Cutoff frequencies of bandpass filter 69 | fs : float, Hz 70 | Sampling rate 71 | w : float 72 | Length of the filter in terms of the number of cycles 73 | of the oscillation whose frequency is the low cutoff of the 74 | bandpass filter 75 | tw : float 76 | Transition width of the filter in normalized frequency space 77 | 78 | Returns 79 | ------- 80 | x_filt : array-like, 1d 81 | Filtered time series 82 | """ 83 | 84 | if w <= 0: 85 | raise ValueError( 86 | 'Number of cycles in a filter must be a positive number.') 87 | 88 | if np.logical_or(tw < 0, tw > 1): 89 | raise ValueError('Transition width must be between 0 and 1.') 90 | 91 | nyq = fs / 2 92 | if np.any(np.array(f_range) > nyq): 93 | raise ValueError('Filter frequencies must be below nyquist rate.') 94 | 95 | if np.any(np.array(f_range) < 0): 96 | raise ValueError('Filter frequencies must be positive.') 97 | 98 | Ntaps = np.floor(w * fs / f_range[0]) 99 | if len(x) < Ntaps: 100 | raise RuntimeError( 101 | 'Length of filter is loger than data. ' 102 | 'Provide more data or a shorter filter.') 103 | 104 | # Characterize desired filter 105 | f = [0, (1 - tw) * f_range[0] / nyq, f_range[0] / nyq, 106 | f_range[1] / nyq, (1 + tw) * f_range[1] / nyq, 1] 107 | m = [0, 0, 1, 1, 0, 0] 108 | if any(np.diff(f) < 0): 109 | raise RuntimeError( 110 | 'Invalid FIR filter parameters.' 111 | 'Please decrease the transition width parameter.') 112 | 113 | # Perform filtering 114 | taps = firwin2(Ntaps, f, m) 115 | x_filt = filtfilt(taps, [1], x) 116 | 117 | if any(np.isnan(x_filt)): 118 | raise RuntimeError( 119 | 'Filtered signal contains nans. Adjust filter parameters.') 120 | 121 | # Remove edge artifacts 122 | return _remove_edge(x_filt, Ntaps) 123 | 124 | 125 | def morletf(x, f0, fs=1000, w=3, s=1, M=None, norm='sss'): 126 | """ 127 | NOTE: This function is not currently ready to be interfaced with pacpy 128 | This is because the frequency input is not a range, which is a big 129 | assumption in how the api is currently designed 130 | 131 | Convolve a signal with a complex wavelet 132 | The real part is the filtered signal 133 | Taking np.abs() of output gives the analytic amplitude 134 | Taking np.angle() of output gives the analytic phase 135 | 136 | x : array 137 | Time series to filter 138 | f0 : float 139 | Center frequency of bandpass filter 140 | Fs : float 141 | Sampling rate 142 | w : float 143 | Length of the filter in terms of the number of 144 | cycles of the oscillation with frequency f0 145 | s : float 146 | Scaling factor for the morlet wavelet 147 | M : integer 148 | Length of the filter. Overrides the f0 and w inputs 149 | norm : string 150 | Normalization method 151 | 'sss' - divide by the sqrt of the sum of squares of points 152 | 'amp' - divide by the sum of amplitudes divided by 2 153 | 154 | Returns 155 | ------- 156 | x_trans : array 157 | Complex time series 158 | """ 159 | 160 | if w <= 0: 161 | raise ValueError( 162 | 'Number of cycles in a filter must be a positive number.') 163 | 164 | if M == None: 165 | M = 2 * s * w * fs / f0 166 | 167 | morlet_f = morlet(M, w=w, s=s) 168 | 169 | if norm == 'sss': 170 | morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f)**2)) 171 | elif norm == 'abs': 172 | morlet_f = morlet_f / np.sum(np.abs(morlet_f)) * 2 173 | else: 174 | raise ValueError('Not a valid wavelet normalization method.') 175 | 176 | x_filtR = np.convolve(x, np.real(morlet_f), mode='same') 177 | x_filtI = np.convolve(x, np.imag(morlet_f), mode='same') 178 | 179 | # Remove edge artifacts 180 | #x_filtR = _remove_edge(x_filtR, M/2.) 181 | #x_filtI = _remove_edge(x_filtI, M/2.) 182 | 183 | return x_filtR + 1j * x_filtI 184 | 185 | 186 | def _remove_edge(x, N): 187 | """ 188 | Calculate the number of points to remove for edge artifacts 189 | 190 | x : array 191 | time series to remove edge artifacts from 192 | N : int 193 | length of filter 194 | """ 195 | N = int(N) 196 | return x[N:-N] 197 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/_dss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Denoising source separation""" 4 | 5 | # Authors: Daniel McCloy 6 | # 7 | # License: BSD (3-clause) 8 | 9 | import numpy as np 10 | from mne import Epochs, EpochsArray, compute_covariance 11 | 12 | 13 | def dss(data, data_max_components=None, data_thresh=0, 14 | bias_max_components=None, bias_thresh=0, return_data=True): 15 | """Process physiological data with denoising source separation (DSS) 16 | 17 | Implementation follows the procedure described in Särelä & Valpola [1]_ 18 | and de Cheveigné & Simon [2]_. 19 | 20 | Parameters 21 | ---------- 22 | data : instance of Epochs | array of shape (n_trials, n_channels, n_times) 23 | Data to be denoised. 24 | data_max_components : int | None 25 | Maximum number of components to keep during PCA decomposition of the 26 | data. ``None`` (the default) keeps all suprathreshold components. 27 | data_thresh : float | None 28 | Threshold (relative to the largest component) above which components 29 | will be kept during decomposition of the data. The default keeps all 30 | non-zero values; to keep all values, specify ``thresh=None``. 31 | bias_max_components : int | None 32 | Maximum number of components to keep during PCA decomposition of the 33 | bias function. ``None`` (the default) keeps all suprathreshold 34 | components. 35 | bias_thresh : float | None 36 | Threshold (relative to the largest component) below which components 37 | will be discarded during decomposition of the bias function. ``None`` 38 | (the default) keeps all non-zero values; to keep all values, pass 39 | ``thresh=None`` and ``max_components=None``. 40 | return_data : bool 41 | Whether to return the denoised data along with the denoising matrix. 42 | 43 | Returns 44 | ------- 45 | dss_mat : array of shape (n_dss_components, n_channels) 46 | The denoising matrix. Apply to data via ``np.dot(dss_mat, ep)``, where 47 | ``ep`` is an epoch of shape (n_channels, n_samples). 48 | dss_data : array of shape (n_trials, n_dss_components, n_samples) 49 | The denoised data. Note that the DSS components are orthogonal virtual 50 | channels and may be fewer in number than the number of channels in the 51 | input Epochs object. Returned only if ``return_data`` is ``True``. 52 | 53 | References 54 | ---------- 55 | .. [1] Särelä, Jaakko, and Valpola, Harri (2005). Denoising source 56 | separation. Journal of Machine Learning Research 6: 233–72. 57 | 58 | .. [2] de Cheveigné, Alain, and Simon, Jonathan Z. (2008). Denoising based 59 | on spatial filtering. Journal of Neuroscience Methods, 171(2): 331-339. 60 | """ 61 | if isinstance(data, (Epochs, EpochsArray)): 62 | data_cov = compute_covariance(data).data 63 | bias_cov = np.cov(data.average().pick_types(eeg=True, ref_meg=False). 64 | data) 65 | if return_data: 66 | data = data.get_data() 67 | elif isinstance(data, np.ndarray): 68 | if data.ndim != 3: 69 | raise ValueError('Data to denoise must have shape ' 70 | '(n_trials, n_channels, n_times).') 71 | data_cov = np.sum([np.dot(trial, trial.T) for trial in data], axis=0) 72 | bias_cov = np.cov(data.mean(axis=0)) 73 | else: 74 | raise TypeError('Data to denoise must be an instance of mne.Epochs or ' 75 | 'a numpy array.') 76 | dss_mat = _dss(data_cov, bias_cov, data_max_components, data_thresh, 77 | bias_max_components, bias_thresh) 78 | if return_data: 79 | # next line equiv. to: np.array([np.dot(dss_mat, ep) for ep in data]) 80 | dss_data = np.einsum('ij,hjk->hik', dss_mat, data) 81 | return dss_mat, dss_data 82 | else: 83 | return dss_mat 84 | 85 | 86 | def _dss(data_cov, bias_cov, data_max_components=None, data_thresh=None, 87 | bias_max_components=None, bias_thresh=None): 88 | """Process physiological data with denoising source separation (DSS) 89 | 90 | Acts on covariance matrices; allows specification of arbitrary bias 91 | functions (as compared to the public ``dss`` function, which forces the 92 | bias to be the evoked response). 93 | """ 94 | data_eigval, data_eigvec = _pca(data_cov, data_max_components, data_thresh) 95 | W = np.sqrt(1 / data_eigval) # diagonal of whitening matrix 96 | # bias covariance projected into whitened PCA space of data channels 97 | bias_cov_white = (W * data_eigvec).T.dot(bias_cov).dot(data_eigvec) * W 98 | # proj. matrix from whitened data space to a space maximizing bias fxn 99 | bias_eigval, bias_eigvec = _pca(bias_cov_white, bias_max_components, 100 | bias_thresh) 101 | # proj. matrix from data to bias-maximizing space (DSS space) 102 | dss_mat = (W[np.newaxis, :] * data_eigvec).dot(bias_eigvec) 103 | # normalize DSS dimensions 104 | N = np.sqrt(1 / np.diag(dss_mat.T.dot(data_cov).dot(dss_mat))) 105 | return (N * dss_mat).T 106 | 107 | 108 | def _pca(cov, max_components=None, thresh=0): 109 | """Perform PCA decomposition 110 | 111 | Parameters 112 | ---------- 113 | cov : array-like 114 | Covariance matrix 115 | max_components : int | None 116 | Maximum number of components to retain after decomposition. ``None`` 117 | (the default) keeps all suprathreshold components (see ``thresh``). 118 | thresh : float | None 119 | Threshold (relative to the largest component) above which components 120 | will be kept. The default keeps all non-zero values; to keep all 121 | values, specify ``thresh=None`` and ``max_components=None``. 122 | 123 | Returns 124 | ------- 125 | eigval : array 126 | 1-dimensional array of eigenvalues. 127 | eigvec : array 128 | 2-dimensional array of eigenvectors. 129 | """ 130 | 131 | if thresh is not None and (thresh > 1 or thresh < 0): 132 | raise ValueError('Threshold must be between 0 and 1 (or None).') 133 | eigval, eigvec = np.linalg.eigh(cov) 134 | eigval = np.abs(eigval) 135 | sort_ix = np.argsort(eigval)[::-1] 136 | eigvec = eigvec[:, sort_ix] 137 | eigval = eigval[sort_ix] 138 | if max_components is not None: 139 | eigval = eigval[:max_components] 140 | eigvec = eigvec[:, :max_components] 141 | if thresh is not None: 142 | suprathresh = np.where(eigval / eigval.max() > thresh)[0] 143 | eigval = eigval[suprathresh] 144 | eigvec = eigvec[:, suprathresh] 145 | return eigval, eigvec 146 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/eog.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import lstsq 3 | from mne import pick_types 4 | 5 | 6 | def eog_regression(raw, blink_epochs, saccade_epochs=None, reog=None, 7 | eog_channels=None, picks=None, copy=False): 8 | """Remove EOG signals from the EEG channels by regression. 9 | 10 | It employes the RAAA (recommended aligned-artifact average) procedure 11 | described by Croft & Barry [1]. 12 | 13 | Parameters 14 | ---------- 15 | raw : Instance of Raw 16 | The raw data on which the EOG correction produce should be performed. 17 | blink_epochs : Instance of Epochs 18 | Epochs cut around blink events. We recommend cutting a window from -0.5 19 | to 0.5 seconds relative to the onset of the blink. 20 | saccade_epochs : Instance of Epochs | None 21 | Epochs cut around saccade events. We recommend cutting a window from -1 22 | to 1.5 seconds relative to the onset of the saccades, and providing 23 | separate events for "up", "down", "left" and "right" saccades. 24 | By default, no saccade information is taken into account. 25 | reog : str | None 26 | The name of the rEOG channel, if present. If an rEOG channel is 27 | available as well as saccade data, the accuracy of the estimation of 28 | the weights can be improved. By default, no rEOG channel is assumed to 29 | be present. 30 | eog_channels : str | list of str | None 31 | The names of the EOG channels to use. By default, all EOG channels are 32 | used. 33 | picks : list of int | None 34 | Indices of the channels in the Raw instance for which to apply the EOG 35 | correction procedure. By default, the correction is applied to EEG 36 | channels only. 37 | copy : bool 38 | If True, a copy of the Raw instance will be made before applying the 39 | EOG correction procedure. Defaults to False, which will perform the 40 | operation in-place. 41 | 42 | 43 | References 44 | ---------- 45 | [1] Croft, R. J., & Barry, R. J. (2000). Removal of ocular artifact from 46 | the EEG: a review. Clinical Neurophysiology, 30(1), 5-19. 47 | http://doi.org/10.1016/S0987-7053(00)00055-1 48 | """ 49 | # Handle defaults for EOG channels parameter 50 | if eog_channels is None: 51 | eog_picks = pick_types(raw.info, meg=False, ref_meg=False, eog=True) 52 | eog_channels = [raw.ch_names[ch] for ch in eog_picks] 53 | elif isinstance(eog_channels, str): 54 | eog_channels = [eog_channels] 55 | 56 | # Make sure the REOG channel is part of the EOG channel list 57 | if reog is not None: 58 | if reog not in eog_channels: 59 | eog_channels += [reog] 60 | 61 | # Default picks 62 | if picks is None: 63 | picks = pick_types(raw.info, meg=False, ref_meg=False, eeg=True) 64 | 65 | if copy: 66 | raw = raw.copy() 67 | 68 | # Compute channel indices for the EOG channels 69 | raw_eog_ind = [raw.ch_names.index(ch) for ch in eog_channels] 70 | ev_eog_ind = [blink_epochs.ch_names.index(ch) for ch in eog_channels] 71 | 72 | blink_evoked = [ 73 | blink_epochs[cl].average(range(blink_epochs.info['nchan'])) 74 | for cl in blink_epochs.event_id.keys() 75 | ] 76 | blink_data = np.hstack([ev.data for ev in blink_evoked]) 77 | 78 | if saccade_epochs is None: 79 | # Calculate EOG weights 80 | v = np.vstack(( 81 | np.ones(blink_data.shape[1]), 82 | blink_data[ev_eog_ind] 83 | )).T 84 | weights = lstsq(v, blink_data.T)[0][1:] 85 | else: 86 | saccade_evoked = [ 87 | saccade_epochs[cl].average(range(saccade_epochs.info['nchan'])) 88 | for cl in saccade_epochs.event_id.keys() 89 | ] 90 | saccade_data = np.hstack([ev.data for ev in saccade_evoked]) 91 | 92 | if reog is None: 93 | # If no rEOG data is present, just concatenate the saccade data 94 | # to the blink data and treat it as one 95 | blink_sac_data = np.c_[blink_data, saccade_data] 96 | v = np.vstack(( 97 | np.ones(blink_sac_data.shape[1]), 98 | blink_sac_data[np.r_[ev_eog_ind]] 99 | )).T 100 | weights = lstsq(v, blink_sac_data.T)[0][1:] 101 | else: 102 | # If rEOG data is present, use the saccade data to compute the 103 | # weights for all non-rEOG channels. The blink data will be used 104 | # for the rEOG channel weight. 105 | 106 | # Isolate the rEOG channel from the other EOG channels 107 | raw_reog_ind = raw.ch_names.index(reog) 108 | raw_non_reog_ind = list(raw_eog_ind) 109 | raw_non_reog_ind.remove(raw_reog_ind) 110 | ev_reog_ind = blink_epochs.ch_names.index(reog) 111 | ev_non_reog_ind = list(ev_eog_ind) 112 | ev_non_reog_ind.remove(ev_reog_ind) 113 | 114 | # Compute non-rEOG weights on the saccade data 115 | v1 = np.vstack(( 116 | np.ones(saccade_data.shape[1]), 117 | saccade_data[ev_non_reog_ind, :], 118 | )).T 119 | weights_sac = lstsq(v1, saccade_data.T)[0][1:] 120 | 121 | # Remove saccades from blink data 122 | blink_data -= weights_sac.T.dot(blink_data[ev_non_reog_ind, :]) 123 | 124 | # Compute rEOG weights on the blink data 125 | v2 = np.vstack(( 126 | np.ones(blink_data.shape[1]), 127 | blink_data[ev_reog_ind, :] 128 | )).T 129 | weights_blink = lstsq(v2, blink_data.T)[0][[1]] 130 | 131 | # Remove saccades from rEOG channel in raw data 132 | raw._data[raw_reog_ind, :] -= np.dot( 133 | weights_sac[:, ev_reog_ind].T, raw._data[raw_non_reog_ind, :]) 134 | 135 | # Compile the EOG weights and make sure to put them in the right 136 | # order. 137 | ind = list(range(len(eog_channels))) 138 | REOG_ind = eog_channels.index('REOG') 139 | del ind[REOG_ind] 140 | ind.append(REOG_ind) 141 | weights = np.vstack((weights_sac, weights_blink))[ind] 142 | 143 | # Create a mapping between the picked channels of the raw instance and the 144 | # EOG weights 145 | weight_names = blink_epochs.ch_names 146 | weight_ch_ind = [weight_names.index(raw.ch_names[ch]) for ch in picks] 147 | 148 | # Remove EOG from raw channels 149 | raw._data[picks, :] -= np.dot(weights[:, weight_ch_ind].T, 150 | raw._data[raw_eog_ind, :]) 151 | 152 | return raw, weights[:, weight_ch_ind] 153 | -------------------------------------------------------------------------------- /examples/connectivity/plot_pac.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================================== 3 | Compute phase-amplitude coupling measures between signals 4 | ============================================================== 5 | Simulate phase-amplitude coupling between two signals, and computes 6 | several PAC metrics between them. Calculates PAC for all timepoints, as well 7 | as for time-locked PAC responses. 8 | References 9 | ---------- 10 | [1] Canolty RT, Edwards E, Dalal SS, Soltani M, Nagarajan SS, Kirsch HE, 11 | Berger MS, Barbaro NM, Knight RT. "High gamma power is phase-locked to 12 | theta oscillations in human neocortex." Science. 2006. 13 | [2] Tort ABL, Komorowski R, Eichenbaum H, Kopell N. Measuring phase-amplitude 14 | coupling between neuronal oscillations of different frequencies. Journal of 15 | Neurophysiology. 2010. 16 | """ 17 | # Author: Chris Holdgraf 18 | # 19 | # License: BSD (3-clause) 20 | import mne 21 | import numpy as np 22 | from matplotlib import pyplot as plt 23 | from mne_incubator.connectivity import (simulate_pac_signal, 24 | phase_amplitude_coupling) 25 | import logging 26 | 27 | print(__doc__) 28 | np.random.seed(1337) 29 | logger = logging.getLogger('mne') 30 | logger.setLevel(50) 31 | 32 | ############################################################################### 33 | # Phase-amplitude coupling (PAC) is a technique to determine if the 34 | # amplitude of a high-frequency signal is locked to the phase 35 | # of a low-frequency signal. The phase_amplitude_coupling function 36 | # calculates PAC between pairs of signals for one or multiple 37 | # time windows. In this example, we'll simulate two signals. One 38 | # of the signals has an amplitude that is locked to the phase of 39 | # the other. We'll calculate PAC for a number of time points, and 40 | # in both directions to show how PAC responds. 41 | 42 | # Define parameters for our simulated signal 43 | sfreq = 1000. 44 | f_phase = 5 45 | f_amp = 40 46 | frac_pac = .99 # This is the fraction of PAC to use 47 | mag_ph = 4 48 | mag_am = 1 49 | 50 | # These are the times where PAC is active in our simulated signal 51 | n_secs = 20. 52 | time = np.arange(0, n_secs, 1. / sfreq) 53 | event_times = np.arange(1, 18, 4) 54 | event_dur = 2. 55 | 56 | # Create a time mask that defines when PAC is active 57 | msk_pac_times = np.zeros_like(time).astype(bool) 58 | for i_time in event_times: 59 | msk_pac_times += mne.utils._time_mask(time, i_time, i_time + event_dur) 60 | 61 | # Now simulate two signals. First, a low-frequency phase 62 | # that modulates high-frequency amplitude 63 | _, lo_pac, hi_pac = simulate_pac_signal(time, f_phase, f_amp, mag_ph, mag_am, 64 | frac_pac=frac_pac, 65 | mask_pac_times=msk_pac_times) 66 | 67 | # Now two signals with no relationship between them 68 | _, lo_none, hi_none = simulate_pac_signal(time, f_phase, f_amp, mag_ph, 69 | mag_am, frac_pac=0, 70 | mask_pac_times=msk_pac_times) 71 | 72 | # Finally we'll mix them up. 73 | # The low-frequency phase of signal A... 74 | signal_a = lo_pac + hi_none 75 | # Modulates the high-frequency amplitude of signal B. But not the reverse. 76 | signal_b = lo_none + hi_pac 77 | 78 | 79 | # To standardize the scales of each PAC metric 80 | def normalize_pac(pac): 81 | return pac / np.mean(pac) 82 | 83 | 84 | # We'll visualize these signals. A on the left, B on the right 85 | # The top row is a combination of the middle and bottom row 86 | labels = ['Combined Signal', 'Lo-Freq signal', 'Hi-freq signal'] 87 | data = [[signal_a, lo_none, hi_pac], 88 | [signal_b, lo_pac, hi_none]] 89 | fig, axs = plt.subplots(3, 2, figsize=(10, 5)) 90 | for axcol, i_data in zip(axs.T, data): 91 | for ax, i_sig, i_label in zip(axcol, i_data, labels): 92 | ax.plot(time, i_sig) 93 | ax.set_title(i_label, fontsize=20) 94 | _ = plt.setp(axs, xlim=[8, 12]) 95 | plt.tight_layout() 96 | 97 | # Create a raw array from the simulated data 98 | info = mne.create_info(['pac_hi', 'pac_lo'], sfreq, 'eeg') 99 | raw = mne.io.RawArray([signal_a, signal_b], info) 100 | 101 | # The PAC function needs a lower and upper bound for each frequency 102 | f_phase_bound = (f_phase-.1, f_phase+.1) 103 | f_amp_bound = (f_amp-2, f_amp+2) 104 | 105 | # First we'll calculate PAC for the entire timeseries. 106 | # We'll use a few PAC metrics to compare. 107 | iter_pac_funcs = [['mi_tort', 'ozkurt'], ['plv']] 108 | win_size = 1. # In seconds 109 | step_size = .1 110 | pac_times = np.array( 111 | [(i, i + win_size) 112 | for i in np.arange(0, np.max(time) - win_size, step_size)]) 113 | 114 | # Here we specify indices to calculate PAC in both directions 115 | ixs = np.array([[0, 1], 116 | [1, 0]]) 117 | fig, axs = plt.subplots(2, 1, figsize=(10, 5)) 118 | ylim = 0 119 | all_pac = [] 120 | for pac_funcs in iter_pac_funcs: 121 | pac, pac_freqs = phase_amplitude_coupling( 122 | raw, (f_phase-.1, f_phase+.1), (f_amp-.5, f_amp+.5), ixs, 123 | pac_func=pac_funcs, tmin=pac_times[:, 0], tmax=pac_times[:, 1], 124 | n_cycles_ph=3, n_cycles_am=3) 125 | pac = pac.squeeze() 126 | if len(pac_funcs) == 1: 127 | pac = pac[np.newaxis, ...] 128 | for i in range(pac.shape[0]): 129 | pac[i] = normalize_pac(pac[i]) 130 | # Now plot timeseries of each 131 | for i_pac, i_name in zip(pac, pac_funcs): 132 | for i_pac_ix, ax in zip(i_pac.squeeze(), axs): 133 | ax.plot(pac_times.mean(-1), i_pac_ix, label=i_name) 134 | ylim = np.max(pac) if np.max(pac) > ylim else ylim # For proper scale 135 | axs[0].legend() 136 | axs[0].set_title('PAC: low-freq A to high-freq B', fontsize=20) 137 | axs[1].set_title('PAC: low-freq B to high-freq A', fontsize=20) 138 | _ = plt.setp(axs, ylim=[0, ylim + .1]) 139 | plt.tight_layout() 140 | 141 | 142 | # We can also calculate event-locked PAC 143 | # by supplying a list of event indices 144 | ev = np.array(event_times) * sfreq 145 | ev = ev.astype(int) 146 | 147 | pac_funcs = ['mi_tort', 'ozkurt'] 148 | colors = ['b', 'r'] 149 | win_size = 1. 150 | ev_tmin = -2. 151 | ev_tmax = 3. 152 | pac_times = np.array([(i, i + win_size) 153 | for i in np.arange(ev_tmin, ev_tmax, step_size)]) 154 | pac, pac_freqs = phase_amplitude_coupling( 155 | raw, (f_phase-.1, f_phase+.1), (f_amp-.5, f_amp+.5), ixs, 156 | pac_func=pac_funcs, tmin=pac_times[:, 0], tmax=pac_times[:, 1], 157 | events=ev, concat_epochs=False) 158 | pac = pac.squeeze() 159 | for i in range(pac.shape[0]): 160 | pac[i] = normalize_pac(pac[i]) 161 | 162 | # This allows us to calculate the stability of PAC across epochs 163 | fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 5)) 164 | for ii, (i_pac, i_pac_name, color) in enumerate(zip(pac, pac_funcs, colors)): 165 | for i_pac_mn, ax in zip(i_pac.mean(0), axs): 166 | ax.plot(pac_times[:, 0], i_pac_mn, color=color, label=i_pac_name) 167 | ax.axvline(0, ls='--', color='k', lw=1) 168 | 169 | axs[0].legend() 170 | axs[0].set_title('Time-locked PAC: Signal A to Signal B') 171 | axs[1].set_title('Time-locked PAC: Signal B to Signal A') 172 | 173 | plt.tight_layout() 174 | 175 | # Finally, we can also calculate PAC with multiple frequencies 176 | freqs_phase = np.array([(i-.1, i+.1) 177 | for i in np.arange(3, 12, 1)]) 178 | freqs_amp = np.array([(i-.1, i+.1) 179 | for i in np.arange(f_amp-20, f_amp+20, 5)]) 180 | # We can set a variable number of cycles per freq to control bandwidth 181 | cycles_phase = 5. 182 | cycles_amp = freqs_amp.mean(-1) / 10. 183 | pac, pac_freqs = phase_amplitude_coupling( 184 | raw, freqs_phase, freqs_amp, ixs, 185 | pac_func=['ozkurt'], tmin=.2, tmax=event_dur - .2, 186 | events=ev, concat_epochs=True, n_cycles_ph=cycles_phase, 187 | n_cycles_am=cycles_amp) 188 | pac = pac.squeeze() 189 | pac = normalize_pac(pac) 190 | 191 | f, axs = plt.subplots(1, 2, figsize=(10, 5)) 192 | for ax, i_pac in zip(axs, pac): 193 | comod = i_pac.reshape([-1, len(freqs_amp)]).T 194 | ax.pcolormesh(freqs_phase[:, 0], freqs_amp[:, 0], comod, 195 | vmin=0, vmax=np.max(pac) + .1) 196 | print(pac.max()) 197 | _ = plt.setp(axs, xlim=[freqs_phase[:, 0].min(), freqs_phase[:, 0].max()], 198 | ylim=[freqs_amp[:, 0].min(), freqs_amp[:, 0].max()], 199 | xlabel='Frequency Phase (Hz)', ylabel='Frequency Amplitude (Hz)') 200 | axs[0].set_title('Comodulogram: Signal A to Signal B') 201 | axs[1].set_title('Comodulogram: Signal B to Signal A') 202 | plt.tight_layout() 203 | plt.show(block=True) 204 | -------------------------------------------------------------------------------- /mne_incubator/preprocessing/_sns.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Sensor noise suppression""" 3 | 4 | import numpy as np 5 | 6 | from mne import compute_raw_covariance, pick_info 7 | from mne.cov import _check_scalings_user, _picks_by_type 8 | try: 9 | from mne.utils import (_apply_scaling_cov, _apply_scaling_array, 10 | _undo_scaling_array) 11 | except ImportError: # old MNE 12 | from mne.cov import (_apply_scaling_cov, _apply_scaling_array, 13 | _undo_scaling_array) 14 | from mne.io.pick import _pick_data_channels, pick_channels 15 | from mne.io import BaseRaw 16 | from mne.utils import logger, verbose 17 | 18 | from ._dss import _pca 19 | 20 | 21 | class SensorNoiseSuppression(object): 22 | """Apply the sensor noise suppression (SNS) algorithm 23 | 24 | This algorithm (from [1]_) will replace the data from each channel by 25 | its regression on the subspace formed by the other channels. 26 | 27 | .. note:: Bad channels are not modified or reset by this class. 28 | 29 | Parameters 30 | ---------- 31 | n_neighbors : int 32 | Number of neighbors (based on correlation) to include in the 33 | projection. 34 | reject : dict | str | None 35 | Rejection parameters based on peak-to-peak amplitude. 36 | See :class:`mne.Epochs` for details. 37 | This is only used during the covariance-fitting phase. 38 | flat : dict | None 39 | Rejection parameters based on flatness of signal. 40 | See :class:`mne.Epochs` for details. 41 | This is only used during the covariance-fitting phase. 42 | scalings : dict | None (default None) 43 | Defaults to ``dict(mag=1e15, grad=1e13, eeg=1e6)``. 44 | These defaults will scale data channels to the same unit 45 | (which should be roughly unity). 46 | verbose : bool, str, int, or None 47 | If not None, override default verbose level (see mne.verbose). 48 | 49 | References 50 | ---------- 51 | .. [1] De Cheveigné A, Simon JZ. Sensor noise suppression. Journal of 52 | Neuroscience Methods 168: 195–202, 2008. 53 | """ 54 | @verbose 55 | def __init__(self, n_neighbors, reject=None, flat=None, scalings=None, 56 | verbose=None): 57 | self._n_neighbors = int(n_neighbors) 58 | if self._n_neighbors < 1: 59 | raise ValueError('n_neighbors must be positive') 60 | self._reject = reject 61 | self._flat = flat 62 | self._scalings = _check_scalings_user(scalings) 63 | self.verbose = verbose 64 | 65 | @verbose 66 | def fit(self, raw, verbose=None): 67 | """Fit the SNS operator 68 | 69 | Parameters 70 | ---------- 71 | raw : Instance of Raw 72 | The raw data to fit. 73 | verbose : bool, str, int, or None 74 | If not None, override default verbose level (see mne.verbose). 75 | 76 | Returns 77 | ------- 78 | sns : Instance of SensorNoiseSuppression 79 | The modified instance. 80 | 81 | Notes 82 | ----- 83 | In the resulting operator, bad channels will be reconstructed by 84 | using the good channels. 85 | """ 86 | logger.info('Processing data with sensor noise suppression algorithm') 87 | logger.info(' Loading raw data') 88 | if not isinstance(raw, BaseRaw): 89 | raise TypeError('raw must be an instance of Raw, got %s' 90 | % type(raw)) 91 | good_picks = _pick_data_channels(raw.info, exclude='bads') 92 | if self._n_neighbors > len(good_picks) - 1: 93 | raise ValueError('n_neighbors must be at most len(good_picks) ' 94 | '- 1 (%s)' % (len(good_picks) - 1,)) 95 | logger.info(' Loading data') 96 | picks = _pick_data_channels(raw.info, exclude=()) 97 | # The following lines are equivalent to this, but require less mem use: 98 | # data_cov = np.cov(orig_data) 99 | # data_corrs = np.corrcoef(orig_data) ** 2 100 | logger.info(' Computing covariance for %s good channels' 101 | % len(good_picks)) 102 | data_cov = compute_raw_covariance( 103 | raw, picks=picks, reject=self._reject, flat=self._flat, 104 | verbose=False if verbose is None else verbose)['data'] 105 | good_subpicks = np.searchsorted(picks, good_picks) 106 | del good_picks 107 | # scale the norms so everything is close enough to unity for our checks 108 | picks_list = _picks_by_type(pick_info(raw.info, picks), exclude=()) 109 | _apply_scaling_cov(data_cov, picks_list, self._scalings) 110 | data_norm = np.diag(data_cov).copy() 111 | eps = np.finfo(np.float32).eps 112 | pos_mask = data_norm >= eps 113 | data_norm[pos_mask] = 1. / data_norm[pos_mask] 114 | data_norm[~pos_mask] = 0 115 | # normalize 116 | data_corrs = data_cov * data_cov 117 | data_corrs *= data_norm 118 | data_corrs *= data_norm[:, np.newaxis] 119 | del data_norm 120 | operator = np.zeros((len(picks), len(picks))) 121 | logger.info(' Assembling spatial operator') 122 | for ii in range(len(picks)): 123 | # For each channel, the set of other signals is orthogonalized by 124 | # applying PCA to obtain an orthogonal basis of the subspace 125 | # spanned by the other channels. 126 | idx = np.argsort(data_corrs[ii][good_subpicks])[::-1] 127 | if ii in good_subpicks: 128 | idx = good_subpicks[idx[:self._n_neighbors + 1]].tolist() 129 | idx.pop(idx.index(ii)) # should be in there iff it is good 130 | else: 131 | idx = good_subpicks[idx[:self._n_neighbors]].tolist() 132 | # We have already effectively thresholded by zeroing out components 133 | eigval, eigvec = _pca(data_cov[np.ix_(idx, idx)], thresh=None) 134 | # Some of the eigenvalues could be zero, don't let it blow up 135 | norm = np.zeros(len(eigval)) 136 | use_mask = eigval > eps 137 | norm[use_mask] = 1. / np.sqrt(eigval[use_mask]) 138 | eigvec *= norm 139 | del eigval 140 | # The channel is projected on this basis and replaced by its 141 | # projection 142 | operator[ii, idx] = np.dot(eigvec, 143 | np.dot(data_cov[ii][idx], eigvec)) 144 | # Equivalently (and less efficiently): 145 | # eigvec = linalg.block_diag([1.], eigvec) 146 | # idx = np.concatenate(([ii], idx)) 147 | # corr = np.dot(np.dot(eigvec.T, data_cov[np.ix_(idx, idx)]), 148 | # eigvec) 149 | # operator[ii, idx[1:]] = np.dot(corr[0, 1:], eigvec[1:, 1:].T) 150 | if operator[ii, ii] != 0: 151 | raise RuntimeError 152 | # scale our results back (the ratio of channel scales is what matters) 153 | _apply_scaling_array(operator.T, picks_list, self._scalings) 154 | _undo_scaling_array(operator, picks_list, self._scalings) 155 | logger.info('Done') 156 | self._operator = operator 157 | self._used_chs = [raw.ch_names[pick] for pick in picks] 158 | return self 159 | 160 | @property 161 | def operator(self): 162 | """The operator matrix 163 | 164 | Returns 165 | ------- 166 | operator : ndarray, shape (n_meg_ch, n_meg_ch) 167 | The spatial operator that was applied to the MEG channels. 168 | """ 169 | return self._operator.copy() 170 | 171 | def apply(self, inst): 172 | """Apply the operator 173 | 174 | Parameters 175 | ---------- 176 | inst : instance of Raw 177 | The data on which to apply the operator. 178 | 179 | Returns 180 | ------- 181 | inst : instance of Raw 182 | The input instance with cleaned data (operates inplace). 183 | """ 184 | if isinstance(inst, BaseRaw): 185 | if not inst.preload: 186 | raise RuntimeError('raw data must be loaded, use ' 187 | 'raw.load_data() or preload=True') 188 | offsets = np.concatenate([np.arange(0, len(inst.times), 10000), 189 | [len(inst.times)]]) 190 | info = inst.info 191 | picks = pick_channels(info['ch_names'], self._used_chs) 192 | data_chs = [info['ch_names'][pick] 193 | for pick in _pick_data_channels(info, exclude=())] 194 | missing = set(data_chs) - set(self._used_chs) 195 | if len(missing) > 0: 196 | raise RuntimeError('Not all data channels of inst were used ' 197 | 'to construct the operator: %s' 198 | % sorted(missing)) 199 | missing = set(self._used_chs) - set(info['ch_names'][pick] 200 | for pick in picks) 201 | if len(missing) > 0: 202 | raise RuntimeError('Not all channels originally used to ' 203 | 'construct the operator are present: %s' 204 | % sorted(missing)) 205 | for start, stop in zip(offsets[:-1], offsets[1:]): 206 | time_sl = slice(start, stop) 207 | inst._data[picks, time_sl] = np.dot(self._operator, 208 | inst._data[picks, time_sl]) 209 | else: 210 | # XXX Eventually this could support Evoked and Epochs, too 211 | raise TypeError('Only Raw instances are currently supported, got ' 212 | '%s' % type(inst)) 213 | return inst 214 | -------------------------------------------------------------------------------- /mne_incubator/connectivity/tests/test_cfc.py: -------------------------------------------------------------------------------- 1 | # Authors: Chris Holdgraf 2 | # Praveen Sripad 3 | # Alexandre Gramfort 4 | # 5 | # License: BSD (3-clause) 6 | 7 | import numpy as np 8 | from numpy.testing import assert_allclose 9 | import pytest 10 | 11 | import mne 12 | from mne_incubator.connectivity import (phase_amplitude_coupling, 13 | phase_locked_amplitude, 14 | phase_binned_amplitude, 15 | simulate_pac_signal) 16 | from sklearn.preprocessing import scale 17 | 18 | np.random.seed(1337) 19 | pac_func = 'ozkurt' 20 | f_phase = 4 21 | f_amp = 40 22 | eptmin, eptmax = 1, 5 23 | min_pac = .05 24 | max_pac = .3 25 | 26 | # First create PAC data 27 | 28 | # Define parameters for our simulated signal 29 | sfreq = 1000. 30 | frac_pac = 1. # This is the fraction of PAC to use 31 | mag_ph = 4 32 | mag_am = 1 33 | 34 | # These are the times where PAC is active in our simulated signal 35 | n_secs = 20. 36 | time = np.arange(0, n_secs, 1. / sfreq) 37 | event_times = np.arange(1, 18, 4) 38 | events = (event_times * sfreq).astype(int) 39 | event_dur = 2. 40 | 41 | # Create a time mask that defines when PAC is active 42 | msk_pac_times = np.zeros_like(time).astype(bool) 43 | for i_time in event_times: 44 | msk_pac_times += mne.utils._time_mask(time, i_time, i_time + event_dur) 45 | kws_sim = dict(mask_pac_times=msk_pac_times, snr_lo=10, snr_hi=10) 46 | _, lo_pac, hi_pac = simulate_pac_signal(time, f_phase, f_amp, mag_ph, 47 | mag_am, frac_pac=frac_pac, 48 | **kws_sim) 49 | _, lo_none, hi_none = simulate_pac_signal(time, f_phase, f_amp, mag_ph, 50 | mag_am, frac_pac=0, **kws_sim) 51 | 52 | signal_a = lo_pac + hi_none 53 | signal_b = lo_none + hi_pac 54 | 55 | info = mne.create_info(['pac_hi', 'pac_lo'], sfreq, 'eeg') 56 | raw = mne.io.RawArray([signal_a, signal_b], info) 57 | events = np.vstack([events, np.zeros_like(events), np.ones_like(events)]).T 58 | epochs = mne.Epochs(raw, events, tmin=eptmin, tmax=eptmax, baseline=None) 59 | 60 | 61 | def test_phase_amplitude_coupling(): 62 | """ Test phase amplitude coupling.""" 63 | f_band_lo = [f_phase - 1, f_phase + 1] 64 | f_band_hi = [f_amp - 1, f_amp + 1] 65 | ixs_pac = [0, 1] 66 | ixs_no_pac = [1, 0] 67 | 68 | pytest.raises(ValueError, 69 | phase_amplitude_coupling, epochs, f_band_lo, 70 | f_band_hi, ixs_pac) 71 | 72 | # Testing Raw 73 | conn, _ = phase_amplitude_coupling( 74 | raw, f_band_lo, f_band_hi, ixs_no_pac, pac_func=pac_func) 75 | assert conn.mean() < min_pac 76 | assert conn.shape == (1, 1, 1, 1) 77 | 78 | # Testing Raw + multiple times 79 | conn, _ = phase_amplitude_coupling( 80 | raw, f_band_lo, f_band_hi, ixs_pac, pac_func=pac_func, 81 | tmin=event_times, tmax=event_times + event_dur) 82 | assert conn.mean() > max_pac 83 | assert conn.shape == (1, 1, 1, event_times.shape[0]) 84 | # Difference in number of tmin / tmax 85 | pytest.raises(ValueError, phase_amplitude_coupling, 86 | raw, f_band_lo, f_band_hi, ixs_pac, pac_func=pac_func, 87 | tmin=event_times[1:], tmax=event_times + event_dur) 88 | 89 | # Testing Raw + multiple frequency pairs 90 | # Increasing n_cycles for better freq resolution 91 | f_band_lo_mult = [f_band_lo, (10, 12)] 92 | f_band_hi_mult = [f_band_hi, (14, 16)] 93 | conn, fbands = phase_amplitude_coupling( 94 | raw, f_band_lo_mult, f_band_hi_mult, ixs_pac, pac_func=pac_func, 95 | tmin=event_times, tmax=event_times + event_dur, n_cycles_ph=4, 96 | n_cycles_am=4) 97 | # Make sure shapes are right 98 | assert conn.shape[2] == 4 99 | assert len(fbands) == 4 100 | assert (conn[:, 0, 0, :].mean() > max_pac) # pac freqs 101 | # Loosening the min value for frequency because of freq spillage 102 | assert (conn[:, 0, 1, :].mean() < .1) # Non-pac freqs 103 | 104 | # Testing multiple n_cycles 105 | pytest.raises(ValueError, phase_amplitude_coupling, raw, f_band_lo_mult, 106 | f_band_hi_mult, ixs_pac, pac_func=pac_func, 107 | n_cycles_ph=[3, 4, 5], n_cycles_am=4) 108 | pytest.raises(ValueError, phase_amplitude_coupling, raw, f_band_lo_mult, 109 | f_band_hi_mult, ixs_pac, pac_func=pac_func, 110 | tmin=event_times, tmax=event_times + event_dur, 111 | n_cycles_ph=[[3, 4, 5]], n_cycles_am=4) 112 | 113 | # Testing Raw + multiple PAC 114 | conn, _ = phase_amplitude_coupling( 115 | raw, f_band_lo, f_band_hi, ixs_no_pac, pac_func=['ozkurt', 'glm']) 116 | assert conn.shape[0] == 2 117 | 118 | # Mixing hi-freq phase and hi-freq amplitude metrics 119 | pytest.raises(ValueError, phase_amplitude_coupling, 120 | raw, f_band_lo, f_band_hi, ixs_no_pac, 121 | pac_func=['ozkurt', 'plv']) 122 | 123 | # Testing Raw + Epochs 124 | conn, _ = phase_amplitude_coupling( 125 | raw, f_band_lo, f_band_hi, ixs_pac, pac_func=pac_func, events=events, 126 | tmin=0, tmax=event_dur) 127 | assert (conn.mean() > max_pac) 128 | assert conn.shape == (events.shape[0], 1, 1, 1) 129 | 130 | conn, _ = phase_amplitude_coupling( 131 | raw, f_band_lo, f_band_hi, ixs_no_pac, pac_func=pac_func, 132 | events=events, tmin=0, tmax=event_dur) 133 | assert (conn.mean() < min_pac) 134 | 135 | # Testing Raw + Epochs + concatenating epochs 136 | conn, _ = phase_amplitude_coupling( 137 | raw, f_band_lo, f_band_hi, ixs_pac, pac_func=pac_func, events=events, 138 | tmin=0, tmax=event_dur, concat_epochs=True) 139 | assert (conn.mean() > max_pac) 140 | assert conn.shape == (1, 1, 1, 1) 141 | 142 | # Testing Raw + Epochs + multiple times 143 | # First time window should have PAC, second window doesn't 144 | # Testing hi end at .3 because ozkurt seems to peak here 145 | conn, _ = phase_amplitude_coupling( 146 | raw, f_band_lo, f_band_hi, ixs_pac, pac_func=pac_func, events=events, 147 | tmin=[0, -1], tmax=[event_dur, -.5]) 148 | assert (conn[..., 0].mean() > max_pac) 149 | assert (conn[..., 1].mean() < min_pac) 150 | assert conn.shape == (events.shape[0], 1, 1, 2) 151 | 152 | # Same times but non-pac ixs 153 | conn, _ = phase_amplitude_coupling( 154 | raw, f_band_lo, f_band_hi, ixs_no_pac, pac_func=pac_func, 155 | events=events, tmin=[0, -1], tmax=[event_dur, -.5]) 156 | assert (conn[..., 0].mean() < min_pac) 157 | assert (conn[..., 1].mean() < min_pac) 158 | assert conn.shape == (events.shape[0], 1, 1, 2) 159 | 160 | # Check return data and scale func 161 | conn, _, data_phase, data_amp = phase_amplitude_coupling( 162 | raw, f_band_lo, f_band_hi, ixs_no_pac, pac_func=pac_func, 163 | return_data=True, scale_amp_func=scale) 164 | # Make sure amp has been scaled 165 | assert (np.abs(data_amp.mean()) < 1e-7) 166 | assert (np.abs(data_amp.std() - 1) < 1e-7) 167 | # Make sure we have phases 168 | assert_allclose(data_phase.max(), np.pi, rtol=1e-2) 169 | assert_allclose(data_phase.min(), -np.pi, rtol=1e-2) 170 | 171 | # Check that arrays don't work 172 | pytest.raises( 173 | ValueError, phase_amplitude_coupling, raw._data, f_band_lo, f_band_hi, 174 | [0, 1], pac_func=pac_func) 175 | # Make sure ixs at least length 2 176 | pytest.raises( 177 | ValueError, phase_amplitude_coupling, raw, f_band_lo, f_band_hi, 178 | [0], pac_func=pac_func) 179 | # f-band only has 1 value 180 | pytest.raises( 181 | ValueError, phase_amplitude_coupling, raw, f_band_lo, 182 | [1], [0, 1], pac_func=pac_func) 183 | # Wrong pac func 184 | pytest.raises( 185 | ValueError, phase_amplitude_coupling, raw, f_band_lo, f_band_hi, 186 | [0, 1], pac_func='blah') 187 | 188 | 189 | def test_phase_amplitude_viz_funcs(): 190 | """Test helper functions for visualization""" 191 | freqs_ph = np.linspace(8, 12, 2) 192 | freqs_amp = np.linspace(40, 60, 5) 193 | ix_ph = 0 194 | ix_amp = 1 195 | 196 | # Phase locked viz 197 | amp, phase, times = phase_locked_amplitude( 198 | epochs, freqs_ph, freqs_amp, ix_ph, ix_amp) 199 | assert amp.shape[-1] == phase.shape[-1] == times.shape[-1] 200 | 201 | amp, phase, times = phase_locked_amplitude( 202 | raw, freqs_ph, freqs_amp, ix_ph, ix_amp) 203 | assert amp.shape[-1] == phase.shape[-1] == times.shape[-1] 204 | 205 | use_times = raw.times < 3 206 | amp, phase, times = phase_locked_amplitude( 207 | raw, freqs_ph, freqs_amp, ix_ph, ix_amp, mask_times=use_times, 208 | tmin=-.5, tmax=.5) 209 | assert amp.shape[-1] == phase.shape[-1] == times.shape[-1] 210 | 211 | # Phase binning 212 | amp_binned, bins = phase_binned_amplitude(epochs, freqs_ph, freqs_amp, 213 | ix_ph, ix_amp, n_bins=20) 214 | assert (amp_binned.shape[0] == bins.shape[0] - 1) 215 | 216 | amp_binned, bins = phase_binned_amplitude(raw, freqs_ph, freqs_amp, 217 | ix_ph, ix_amp, n_bins=20) 218 | assert (amp_binned.shape[0] == bins.shape[0] - 1) 219 | 220 | amp_binned, bins = phase_binned_amplitude(raw, freqs_ph, freqs_amp, 221 | ix_ph, ix_amp, n_bins=20, 222 | mask_times=use_times) 223 | assert (amp_binned.shape[0] == bins.shape[0] - 1) 224 | 225 | 226 | def test_phase_amplitude_coupling_simulation(): 227 | both, lo_none, hi_none = simulate_pac_signal(time, f_phase, f_amp, mag_ph, 228 | mag_am, frac_pac=1., 229 | **kws_sim) 230 | # Shapes are correct 231 | assert both.shape == lo_none.shape == hi_none.shape 232 | assert time.shape[-1] == both.shape[-1] 233 | 234 | # Fracs outside of 0 to 1 235 | pytest.raises(ValueError, simulate_pac_signal, time, f_phase, f_amp, 236 | mag_ph, mag_am, frac_pac=-.5, **kws_sim) 237 | pytest.raises(ValueError, simulate_pac_signal, time, f_phase, f_amp, 238 | mag_ph, mag_am, frac_pac=1.5, **kws_sim) 239 | # Giving a band for frequencies 240 | pytest.raises(ValueError, simulate_pac_signal, time, [1, 2], f_amp, mag_ph, 241 | mag_am, frac_pac=-.5, **kws_sim) 242 | -------------------------------------------------------------------------------- /mne_incubator/connectivity/cfc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mne import BaseEpochs 3 | from mne.io import BaseRaw 4 | from mne.time_frequency.tfr import _compute_tfr 5 | from mne.preprocessing import peak_finder 6 | from mne.utils import ProgressBar, logger 7 | from scipy.signal import hilbert 8 | from itertools import product 9 | import mne 10 | import warnings 11 | 12 | if not callable(peak_finder): 13 | peak_finder = peak_finder.peak_finder # old MNE 14 | 15 | 16 | # Supported PAC functions 17 | _pac_funcs = ['plv', 'glm', 'mi_tort', 'mi_canolty', 'ozkurt', 'otc'] 18 | # Calculate the phase of the amplitude signal for these PAC funcs 19 | _hi_phase_funcs = ['plv'] 20 | 21 | 22 | def phase_amplitude_coupling(inst, f_phase, f_amp, ixs, pac_func='ozkurt', 23 | events=None, tmin=None, tmax=None, 24 | n_cycles_ph=3, n_cycles_am=3, 25 | scale_amp_func=None, return_data=False, 26 | concat_epochs=False, n_jobs=1, verbose=None): 27 | """ Compute phase-amplitude coupling between pairs of signals using pacpy. 28 | 29 | Parameters 30 | ---------- 31 | inst : an instance of Raw or Epochs 32 | The data used to calculate PAC. 33 | f_phase : array, dtype float, shape (n_bands_phase, 2,) 34 | The frequency ranges to use for the phase carrier. PAC will be 35 | calculated between n_bands_phase * n_bands_amp frequencies. 36 | f_amp : array, dtype float, shape (n_bands_amp, 2,) 37 | The frequency ranges to use for the phase-modulated amplitude. 38 | PAC will be calculated between n_bands_phase * n_bands_amp frequencies. 39 | ixs : array-like, shape (n_ch_pairs x 2) 40 | The indices for low/high frequency channels. PAC will be estimated 41 | between n_ch_pairs of channels. Indices correspond to rows of `data`. 42 | pac_func : {'plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt'} | 43 | list of strings 44 | The function for estimating PAC. Corresponds to functions in 45 | `pacpy.pac`. Defaults to 'ozkurt'. If multiple frequency bands are used 46 | then `plv` cannot be calculated. 47 | events : array, shape (n_events, 3) | array, shape (n_events,) | None 48 | MNE events array. To be supplied if data is 2D and output should be 49 | split by events. In this case, `tmin` and `tmax` must be provided. If 50 | `ndim == 1`, it is assumed to be event indices, and all events will be 51 | grouped together. 52 | tmin : float | list of floats, shape (n_pac_windows,) | None 53 | If `events` is not provided, it is the start time to use in `inst`. 54 | If `events` is provided, it is the time (in seconds) to include before 55 | each event index. If a list of floats is given, then PAC is calculated 56 | for each pair of `tmin` and `tmax`. Defaults to `min(inst.times)`. 57 | tmax : float | list of floats, shape (n_pac_windows,) | None 58 | If `events` is not provided, it is the stop time to use in `inst`. 59 | If `events` is provided, it is the time (in seconds) to include after 60 | each event index. If a list of floats is given, then PAC is calculated 61 | for each pair of `tmin` and `tmax`. Defaults to `max(inst.n_times)`. 62 | n_cycles_ph : float, int | array of floats, shape (n_bands_phase,) 63 | The number of cycles to be included in the window for each band-pass 64 | filter for phase. Defaults to 3. 65 | n_cycles_am : float, int | array of floats, shape (n_bands_amp,) 66 | The number of cycles to be included in the window for each band-pass 67 | filter for amplitude. Defaults to 3. 68 | scale_amp_func : None | function 69 | If not None, will be called on each amplitude signal in order to scale 70 | the values. Function must accept an N-D input and will operate on the 71 | last dimension. E.g., `sklearn.preprocessing.scale`. 72 | Defaults to no scaling. 73 | return_data : bool 74 | If False, output will be `[pac_out]`. If True, output will be, 75 | `[pac_out, phase_signal, amp_signal]`. 76 | concat_epochs : bool 77 | If True, epochs will be concatenated before calculating PAC values. If 78 | epochs are relatively short, this is a good idea in order to improve 79 | stability of the PAC metric. 80 | n_jobs : int 81 | Number of jobs to run in parallel. Defaults to 1. 82 | verbose : bool, str, int, or None 83 | If not None, override default verbose level (see `mne.verbose`). 84 | 85 | Returns 86 | ------- 87 | pac_out : array, list of arrays, dtype float, 88 | shape([n_pac_funcs], n_epochs, n_channel_pairs, 89 | n_freq_pairs, n_pac_windows). 90 | The computed phase-amplitude coupling between each pair of data sources 91 | given in ixs. If multiple pac metrics are specified, there will be one 92 | array per metric in the output list. If n_pac_funcs is 1, then the 93 | first dimension will be dropped. 94 | [phase_signal] : array, shape (n_phase_signals, n_times,) 95 | Only returned if `return_data` is True. The phase timeseries of the 96 | phase signals (first column of `ixs`). 97 | [amp_signal] : array, shape (n_amp_signals, n_times,) 98 | Only returned if `return_data` is True. The amplitude timeseries of the 99 | amplitude signals (second column of `ixs`). 100 | 101 | References 102 | ---------- 103 | [1] This function uses the PacPy module developed by the Voytek lab. 104 | https://github.com/voytekresearch/pacpy 105 | """ 106 | if not isinstance(inst, BaseRaw): 107 | raise ValueError('Must supply Raw as input') 108 | sfreq = inst.info['sfreq'] 109 | data = inst[:][0] 110 | pac = _phase_amplitude_coupling(data, sfreq, f_phase, f_amp, ixs, 111 | pac_func=pac_func, events=events, 112 | tmin=tmin, tmax=tmax, 113 | n_cycles_ph=n_cycles_ph, 114 | n_cycles_am=n_cycles_am, 115 | scale_amp_func=scale_amp_func, 116 | return_data=return_data, 117 | concat_epochs=concat_epochs, 118 | n_jobs=n_jobs, verbose=verbose) 119 | # Collect the data properly 120 | if return_data is True: 121 | pac, freq_pac, data_ph, data_am = pac 122 | return pac, freq_pac, data_ph, data_am 123 | else: 124 | pac, freq_pac = pac 125 | return pac, freq_pac 126 | 127 | 128 | def _phase_amplitude_coupling(data, sfreq, f_phase, f_amp, ixs, 129 | pac_func='ozkurt', events=None, 130 | tmin=None, tmax=None, n_cycles_ph=3, 131 | n_cycles_am=3, scale_amp_func=None, 132 | return_data=False, concat_epochs=False, 133 | n_jobs=1, verbose=None): 134 | """ Compute phase-amplitude coupling using pacpy. 135 | 136 | Parameters 137 | ---------- 138 | data : array, shape ([n_epochs], n_channels, n_times) 139 | The data used to calculate PAC 140 | sfreq : float 141 | The sampling frequency of the data. 142 | f_phase : array, dtype float, shape (n_bands_phase, 2,) 143 | The frequency ranges to use for the phase carrier. PAC will be 144 | calculated between n_bands_phase * n_bands_amp frequencies. 145 | f_amp : array, dtype float, shape (n_bands_amp, 2,) 146 | The frequency ranges to use for the phase-modulated amplitude. 147 | PAC will be calculated between n_bands_phase * n_bands_amp frequencies. 148 | ixs : array-like, shape (n_ch_pairs x 2) 149 | The indices for low/high frequency channels. PAC will be estimated 150 | between n_ch_pairs of channels. Indices correspond to rows of `data`. 151 | pac_func : {'plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt'} | 152 | list of strings 153 | The function for estimating PAC. Corresponds to functions in 154 | `pacpy.pac`. Defaults to 'ozkurt'. If multiple frequency bands are used 155 | then `plv` cannot be calculated. 156 | events : array, shape (n_events, 3) | array, shape (n_events,) | None 157 | MNE events array. To be supplied if data is 2D and output should be 158 | split by events. In this case, `tmin` and `tmax` must be provided. If 159 | `ndim == 1`, it is assumed to be event indices, and all events will be 160 | grouped together. 161 | tmin : float | list of floats, shape (n_pac_windows,) | None 162 | If `events` is not provided, it is the start time to use in `inst`. 163 | If `events` is provided, it is the time (in seconds) to include before 164 | each event index. If a list of floats is given, then PAC is calculated 165 | for each pair of `tmin` and `tmax`. Defaults to `min(inst.times)`. 166 | tmax : float | list of floats, shape (n_pac_windows,) | None 167 | If `events` is not provided, it is the stop time to use in `inst`. 168 | If `events` is provided, it is the time (in seconds) to include after 169 | each event index. If a list of floats is given, then PAC is calculated 170 | for each pair of `tmin` and `tmax`. Defaults to `max(inst.n_times)`. 171 | n_cycles_ph : float, int | array of floats, shape (n_bands_phase,) 172 | The number of cycles to be included in the window for each band-pass 173 | filter for phase. Defaults to 3. 174 | n_cycles_am : float, int | array of floats, shape (n_bands_amp,) 175 | The number of cycles to be included in the window for each band-pass 176 | filter for amplitude. Defaults to 3. 177 | scale_amp_func : None | function 178 | If not None, will be called on each amplitude signal in order to scale 179 | the values. Function must accept an N-D input and will operate on the 180 | last dimension. E.g., `sklearn.preprocessing.scale`. 181 | Defaults to no scaling. 182 | return_data : bool 183 | If False, output will be `[pac_out]`. If True, output will be, 184 | `[pac_out, phase_signal, amp_signal]`. 185 | concat_epochs : bool 186 | If True, epochs will be concatenated before calculating PAC values. If 187 | epochs are relatively short, this is a good idea in order to improve 188 | stability of the PAC metric. 189 | n_jobs : int 190 | Number of jobs to run in parallel. Defaults to 1. 191 | verbose : bool, str, int, or None 192 | If not None, override default verbose level (see `mne.verbose`). 193 | 194 | Returns 195 | ------- 196 | pac_out : array, list of arrays, dtype float, 197 | shape([n_pac_funcs], n_epochs, n_channel_pairs, 198 | n_freq_pairs, n_pac_windows). 199 | The computed phase-amplitude coupling between each pair of data sources 200 | given in ixs. If multiple pac metrics are specified, there will be one 201 | array per metric in the output list. If n_pac_funcs is 1, then the 202 | first dimension will be dropped. 203 | [phase_signal] : array, shape (n_phase_signals, n_times,) 204 | Only returned if `return_data` is True. The phase timeseries of the 205 | phase signals (first column of `ixs`). 206 | [amp_signal] : array, shape (n_amp_signals, n_times,) 207 | Only returned if `return_data` is True. The amplitude timeseries of the 208 | amplitude signals (second column of `ixs`). 209 | """ 210 | from ..externals.pacpy import pac as ppac 211 | pac_func = np.atleast_1d(pac_func) 212 | for i_func in pac_func: 213 | if i_func not in _pac_funcs: 214 | raise ValueError("PAC function %s is not supported" % i_func) 215 | n_pac_funcs = pac_func.shape[0] 216 | ixs = np.array(ixs, ndmin=2) 217 | n_ch_pairs = ixs.shape[0] 218 | tmin = 0 if tmin is None else tmin 219 | tmin = np.atleast_1d(tmin) 220 | n_pac_windows = len(tmin) 221 | tmax = (data.shape[-1] - 1) / float(sfreq) if tmax is None else tmax 222 | tmax = np.atleast_1d(tmax) 223 | f_phase = np.atleast_2d(f_phase) 224 | f_amp = np.atleast_2d(f_amp) 225 | n_cycles_ph = np.atleast_1d(n_cycles_ph) 226 | n_cycles_am = np.atleast_1d(n_cycles_am) 227 | if n_cycles_ph.shape[0] == 1: 228 | n_cycles_ph = np.repeat(n_cycles_ph, f_phase.shape[0]) 229 | if n_cycles_am.shape[0] == 1: 230 | n_cycles_am = np.repeat(n_cycles_am, f_amp.shape[0]) 231 | 232 | if data.ndim != 2: 233 | raise ValueError('Data must be shape (n_channels, n_times)') 234 | if ixs.shape[1] != 2: 235 | raise ValueError('Indices must have have a 2nd dimension of length 2') 236 | if f_phase.shape[-1] != 2 or f_amp.shape[-1] != 2: 237 | raise ValueError('Frequencies must be specified w/ a low/hi tuple') 238 | if len(tmin) != len(tmax): 239 | raise ValueError('tmin and tmax have differing lengths') 240 | if any(i_f.shape[0] > 1 and 'plv' in pac_func for i_f in (f_amp, f_phase)): 241 | raise ValueError('If calculating PLV, must use a single pair of freqs') 242 | for icyc, i_f in zip([n_cycles_ph, n_cycles_am], [f_phase, f_amp]): 243 | if icyc.shape[0] != i_f.shape[0]: 244 | raise ValueError("n_cycles must match n_freq_bands") 245 | if icyc.ndim > 1: 246 | raise ValueError("n_cycles must be 1-d, not {}d".format(icyc.ndim)) 247 | 248 | logger.info('Pre-filtering data and extracting phase/amplitude...') 249 | hi_phase = np.unique([i_func in _hi_phase_funcs for i_func in pac_func]) 250 | if len(hi_phase) != 1: 251 | raise ValueError("Can't mix pac funcs that use both hi-freq phase/amp") 252 | hi_phase = bool(hi_phase[0]) 253 | data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am( 254 | data, sfreq, ixs, f_phase, f_amp, hi_phase=hi_phase, 255 | scale_amp_func=scale_amp_func, n_cycles_ph=n_cycles_ph, 256 | n_cycles_am=n_cycles_am) 257 | 258 | # So we know how big the PAC output will be 259 | if events is None: 260 | n_epochs = 1 261 | elif concat_epochs is True: 262 | if events.ndim == 1: 263 | n_epochs = 1 264 | else: 265 | n_epochs = np.unique(events[:, -1]).shape[0] 266 | else: 267 | n_epochs = events.shape[0] 268 | 269 | # Iterate through each pair of frequencies 270 | ixs_freqs = product(range(data_ph.shape[1]), range(data_am.shape[1])) 271 | ixs_freqs = np.atleast_2d(list(ixs_freqs)) 272 | 273 | freq_pac = np.array([[f_phase[ii], f_amp[jj]] for ii, jj in ixs_freqs]) 274 | n_f_pairs = len(ixs_freqs) 275 | pac = np.zeros([n_pac_funcs, n_epochs, n_ch_pairs, 276 | n_f_pairs, n_pac_windows]) 277 | for i_f_pair, (ix_f_ph, ix_f_am) in enumerate(ixs_freqs): 278 | # Second dimension is frequency 279 | i_f_data_ph = data_ph[:, ix_f_ph, ...] 280 | i_f_data_am = data_am[:, ix_f_am, ...] 281 | 282 | # Redefine indices to match the new data arrays 283 | ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs] 284 | i_f_data_ph = mne.io.RawArray( 285 | i_f_data_ph, mne.create_info(i_f_data_ph.shape[0], sfreq)) 286 | i_f_data_am = mne.io.RawArray( 287 | i_f_data_am, mne.create_info(i_f_data_am.shape[0], sfreq)) 288 | 289 | # Turn into Epochs if we have defined events 290 | if events is not None: 291 | i_f_data_ph = _raw_to_epochs_mne(i_f_data_ph, events, tmin, tmax) 292 | i_f_data_am = _raw_to_epochs_mne(i_f_data_am, events, tmin, tmax) 293 | 294 | # Data is either Raw or Epochs 295 | pbar = ProgressBar(n_epochs) 296 | for itime, (i_tmin, i_tmax) in enumerate(zip(tmin, tmax)): 297 | # Pull times of interest 298 | with warnings.catch_warnings(): # To suppress a depracation 299 | warnings.simplefilter("ignore") 300 | # Not sure how to do this w/o copying 301 | i_t_data_am = i_f_data_am.copy().crop(i_tmin, i_tmax) 302 | i_t_data_ph = i_f_data_ph.copy().crop(i_tmin, i_tmax) 303 | 304 | if concat_epochs is True: 305 | # Iterate through each event type and hstack 306 | con_data_ph = [] 307 | con_data_am = [] 308 | for i_ev in i_t_data_am.event_id.keys(): 309 | con_data_ph.append(np.hstack(i_t_data_ph[i_ev]._data)) 310 | con_data_am.append(np.hstack(i_t_data_am[i_ev]._data)) 311 | i_t_data_ph = np.vstack(con_data_ph) 312 | i_t_data_am = np.vstack(con_data_am) 313 | else: 314 | # Just pull all epochs separately 315 | i_t_data_ph = i_t_data_ph._data 316 | i_t_data_am = i_t_data_am._data 317 | # Now make sure that inputs to the loop are ep x chan x time 318 | if i_t_data_am.ndim == 2: 319 | i_t_data_ph = i_t_data_ph[np.newaxis, ...] 320 | i_t_data_am = i_t_data_am[np.newaxis, ...] 321 | # Loop through epochs (or epoch grps), each index pair, and funcs 322 | data_iter = zip(i_t_data_ph, i_t_data_am) 323 | for iep, (ep_ph, ep_am) in enumerate(data_iter): 324 | for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new): 325 | for ix_func, i_pac_func in enumerate(pac_func): 326 | func = getattr(ppac, i_pac_func) 327 | pac[ix_func, iep, iix, i_f_pair, itime] = func( 328 | ep_ph[i_ix_ph], ep_am[i_ix_am], 329 | f_phase, f_amp, filterfn=False) 330 | pbar.update_with_increment_value(1) 331 | if pac.shape[0] == 1: 332 | pac = pac[0] 333 | if return_data: 334 | return pac, freq_pac, data_ph, data_am 335 | else: 336 | return pac, freq_pac 337 | 338 | 339 | def _raw_to_epochs_mne(raw, events, tmin, tmax): 340 | """Convert Raw data to Epochs w/ some time checks.""" 341 | events = np.atleast_1d(events) 342 | if events.ndim == 1: 343 | events = np.vstack([events, np.zeros_like(events), 344 | np.ones_like(events)]).T 345 | if events.ndim != 2: 346 | raise ValueError('events have incorrect number of dimensions') 347 | if events.shape[-1] != 3: 348 | raise ValueError('events have incorrect number of columns') 349 | # Convert to Epochs using the event times 350 | tmin_all = np.min(tmin) 351 | tmax_all = np.max(tmax) + (1. / raw.info['sfreq']) 352 | return mne.Epochs(raw, events, tmin=tmin_all, tmax=tmax_all, preload=True, 353 | baseline=None) 354 | 355 | 356 | def _pre_filter_ph_am(data, sfreq, ixs, f_ph, f_am, n_cycles_ph=3, 357 | n_cycles_am=3, hi_phase=False, scale_amp_func=None, 358 | kws_filt=None): 359 | """Filter for phase/amp only once for each channel.""" 360 | from ..externals.pacpy.pac import _range_sanity 361 | 362 | kws_filt = dict() if kws_filt is None else kws_filt 363 | ix_ph = np.atleast_1d(np.unique(ixs[:, 0])) 364 | ix_am = np.atleast_1d(np.unique(ixs[:, 1])) 365 | n_unique_am = ix_am.shape[0] 366 | 367 | # Filter for lo-freq phase 368 | for i_f_ph in f_ph: 369 | _range_sanity(i_f_ph, f_am[0]) 370 | for i_f_am in f_am: 371 | _range_sanity(f_ph[0], i_f_am) 372 | 373 | # Output will be (n_chan, n_freqs, n_times). Operates in place. 374 | data_ph = data[ix_ph, :] 375 | out_ph = _filter_and_hilbert(data_ph, sfreq, f_ph, n_cycles_ph) 376 | out_ph = np.angle(out_ph) 377 | 378 | data_am = data[ix_am, :] 379 | out_am = _filter_and_hilbert(data_am, sfreq, f_am, n_cycles_am) 380 | out_am = np.abs(out_am) 381 | if hi_phase is True: 382 | # We assume f_ph has len(1), multiple freqs not supported w/ plv 383 | out_am = _filter_and_hilbert(out_am, sfreq, f_ph, n_cycles_ph, 384 | inplace=True) 385 | out_am = np.angle(out_am) 386 | 387 | # New index mapping for unique channels 388 | ix_map_ph = dict((ix, i) for i, ix in enumerate(ix_ph)) 389 | ix_map_am = dict((ix, i) for i, ix in enumerate(ix_am)) 390 | 391 | if scale_amp_func is not None: 392 | for ii in range(n_unique_am): 393 | out_am[ii] = scale_amp_func(out_am[ii], axis=-1) 394 | return out_ph, out_am, ix_map_ph, ix_map_am 395 | 396 | 397 | def _raw_to_epochs_array(x, sfreq, events, tmin, tmax): 398 | """Aux function to create epochs from a 2D array""" 399 | if events.ndim != 1: 400 | raise ValueError('events must be 1D') 401 | if events.dtype != int: 402 | raise ValueError('events must be of dtype int') 403 | 404 | # Check that events won't be cut off 405 | n_times = x.shape[-1] 406 | min_ix = 0 - sfreq * tmin 407 | max_ix = n_times - sfreq * tmax 408 | msk_keep = np.logical_and(events > min_ix, events < max_ix) 409 | 410 | if not all(msk_keep): 411 | logger.info('Some event windows extend beyond data limits,' 412 | ' and will be cut off...') 413 | events = events[msk_keep] 414 | 415 | # Pull events from the raw data 416 | epochs = [] 417 | for ix in events: 418 | ix_min, ix_max = [ix + int(i_tlim * sfreq) 419 | for i_tlim in [tmin, tmax]] 420 | epochs.append(x[np.newaxis, :, ix_min:ix_max]) 421 | epochs = np.concatenate(epochs, axis=0) 422 | times = np.arange(epochs.shape[-1]) / float(sfreq) + tmin 423 | return epochs, times, msk_keep 424 | 425 | 426 | def _filter_and_hilbert(data, sfreq, frequencies, n_cycles, inplace=False): 427 | if inplace is True: 428 | # Assume data is (n_chan, n_freqs, n_times) 429 | n_channels, n_freqs, n_times = data.shape 430 | out = data 431 | else: 432 | n_freqs = frequencies.shape[0] 433 | n_channels, n_times = data.shape 434 | data = data[:, np.newaxis, :] # To ensure shapes are always consistent 435 | out = np.zeros([n_channels, n_freqs, n_times]) 436 | # Filtering in place 437 | for ii in range(n_channels): 438 | for jj, n_cyc in zip(range(n_freqs), n_cycles): 439 | out[ii, jj] = _band_pass_pac(data[ii, 0], frequencies[jj], sfreq, 440 | n_cycles=n_cyc) 441 | n_hil = int(2 ** np.ceil(np.log2(n_times))) 442 | out = hilbert(out, N=n_hil)[..., :n_times] 443 | return out 444 | 445 | 446 | def phase_locked_amplitude(inst, freqs_phase, freqs_amp, ix_ph, ix_amp, 447 | tmin=-.5, tmax=.5, mask_times=None): 448 | """Calculate the average amplitude of a signal at a phase of another. 449 | 450 | Parameters 451 | ---------- 452 | inst : instance of mne.Epochs | mne.io.Raw 453 | The data to be used in phase locking computation. 454 | freqs_phase : np.array 455 | The frequencies to use in phase calculation. The phase of each 456 | frequency will be averaged together. 457 | freqs_amp : np.array 458 | The frequencies to use in amplitude calculation. 459 | ix_ph : int 460 | The index of the signal to be used for phase calculation. 461 | ix_amp : int 462 | The index of the signal to be used for amplitude calculation. 463 | tmin : float 464 | The time to include before each phase peak. 465 | tmax : float 466 | The time to include after each phase peak. 467 | mask_times : np.array, dtype bool, shape (inst.n_times,) 468 | If `inst` is an instance of Raw, this will only include times contained 469 | in `mask_times`. Defaults to using all times. 470 | 471 | Returns 472 | ------- 473 | data_am : np.array 474 | The mean amplitude values for the frequencies specified in `freqs_amp`, 475 | time-locked to peaks of the low-frequency phase. 476 | data_ph : np.array 477 | The mean low-frequency signal, phase-locked to low-frequency phase 478 | peaks. 479 | times : np.array 480 | The times before / after each phase peak. 481 | """ 482 | sfreq = inst.info['sfreq'] 483 | # Pull the amplitudes/phases using Morlet 484 | data_ph, data_am = _pull_data(inst, ix_ph, ix_amp) 485 | angle_ph, band_ph, amp = _extract_phase_and_amp( 486 | data_ph, data_am, sfreq, freqs_phase, freqs_amp) 487 | 488 | angle_ph = angle_ph.mean(0) # Mean across freq bands 489 | band_ph = band_ph.mean(0) 490 | 491 | # Find peaks in the phase for time-locking 492 | phase_peaks, vals = peak_finder(angle_ph) 493 | ixmin, ixmax = [t * sfreq for t in [tmin, tmax]] 494 | # Remove peaks w/o buffer 495 | phase_peaks = phase_peaks[(phase_peaks > np.abs(ixmin)) * 496 | (phase_peaks < len(angle_ph) - ixmax)] 497 | 498 | if mask_times is not None: 499 | # Set datapoints outside out times to nan so we can drop later 500 | if len(mask_times) != angle_ph.shape[-1]: 501 | raise ValueError('mask_times must be == in length to data') 502 | band_ph[..., ~mask_times] = np.nan 503 | 504 | data_ph, times, msk_window = _raw_to_epochs_array( 505 | band_ph[np.newaxis, :], sfreq, phase_peaks, tmin, tmax) 506 | data_am, times, msk_window = _raw_to_epochs_array( 507 | amp, sfreq, phase_peaks, tmin, tmax) 508 | data_ph = data_ph.squeeze() 509 | data_am = data_am.squeeze() 510 | 511 | # Drop any peak events where there was a nan 512 | keep_rows = np.where(~np.isnan(data_ph).any(-1))[0] 513 | data_ph = data_ph[keep_rows, ...] 514 | data_am = data_am[keep_rows, ...] 515 | 516 | # Average across phase peak events 517 | data_am = data_am.mean(0) 518 | data_ph = data_ph.mean(0) 519 | return data_am, data_ph, times 520 | 521 | 522 | def phase_binned_amplitude(inst, freqs_phase, freqs_amp, 523 | ix_ph, ix_amp, n_bins=20, mask_times=None): 524 | """Calculate amplitude of one signal in sub-ranges of phase for another. 525 | 526 | Parameters 527 | ---------- 528 | inst : instance of mne.Epochs | mne.io.Raw 529 | The data to be used in phase locking computation. 530 | freqs_phase : np.array, shape (n_freqs_phase,) 531 | The frequencies to use in phase calculation. The phase of each 532 | frequency will be averaged together. 533 | freqs_amp : np.array, shape (n_freqs_amp,) 534 | The frequencies to use in amplitude calculation. The amplitude 535 | of each frequency will be averaged together. 536 | ix_ph : int 537 | The index of the signal to be used for phase calculation. 538 | ix_amp : int 539 | The index of the signal to be used for amplitude calculation. 540 | n_bins : int 541 | The number of bins to use when grouping amplitudes. Each bin will 542 | have size `(2 * np.pi) / n_bins`. 543 | mask_times : np.array, dtype bool, shape (inst.n_times,) 544 | Remove timepoints where `mask_times` is False. 545 | Defaults to using all times. 546 | 547 | Returns 548 | ------- 549 | amp_binned : np.array, shape (n_bins,) 550 | The mean amplitude of freqs_amp at each phase bin 551 | bins_phase : np.array, shape (n_bins + 1,) 552 | The bins used in the calculation. There is one extra bin because 553 | bins represent the left/right edges of each bin, not the center value. 554 | """ 555 | sfreq = inst.info['sfreq'] 556 | # Pull the amplitudes/phases using Morlet 557 | data_ph, data_am = _pull_data(inst, ix_ph, ix_amp) 558 | angle_ph, band_ph, amp = _extract_phase_and_amp( 559 | data_ph, data_am, sfreq, freqs_phase, freqs_amp) 560 | angle_ph = angle_ph.mean(0) # Mean across freq bands 561 | if mask_times is not None: 562 | # Only keep times we want 563 | if len(mask_times) != amp.shape[-1]: 564 | raise ValueError('mask_times must be equal in length to data') 565 | angle_ph, band_ph, amp = [i[..., mask_times] 566 | for i in [angle_ph, band_ph, amp]] 567 | 568 | # Bin our phases and extract amplitudes based on bins 569 | bins_phase = np.linspace(-np.pi, np.pi, n_bins) 570 | bins_phase_ixs = np.digitize(angle_ph, bins_phase) 571 | unique_bins = np.unique(bins_phase_ixs) 572 | amp_binned = [np.mean(amp[:, bins_phase_ixs == i], axis=1) 573 | for i in unique_bins] 574 | amp_binned = np.vstack(amp_binned).mean(1) 575 | 576 | return amp_binned, bins_phase 577 | 578 | 579 | # For the viz functions 580 | def _extract_phase_and_amp(data_ph, data_am, sfreq, freqs_phase, 581 | freqs_amp, scale=True): 582 | """Extract the phase and amplitude of two signals for PAC viz. 583 | data should be shape (n_epochs, n_times)""" 584 | from sklearn.preprocessing import scale 585 | 586 | # Morlet transform to get complex representation 587 | band_ph = _compute_tfr([data_ph], freqs_phase, sfreq, method='morlet')[0] 588 | band_amp = _compute_tfr([data_ph], freqs_amp, sfreq, method='morlet')[0] 589 | 590 | # Calculate the phase/amplitude of relevant signals across epochs 591 | band_ph_stacked = np.hstack(np.real(band_ph)) 592 | angle_ph = np.hstack(np.angle(band_ph)) 593 | amp = np.hstack(np.abs(band_amp) ** 2) 594 | 595 | # Scale the amplitude for viz so low freqs don't dominate highs 596 | if scale is True: 597 | amp = scale(amp, axis=1) 598 | return angle_ph, band_ph_stacked, amp 599 | 600 | 601 | def _pull_data(inst, ix_ph, ix_amp, events=None, tmin=None, tmax=None): 602 | """Pull data from either Base or Epochs instances""" 603 | if isinstance(inst, BaseEpochs): 604 | data_ph = inst.get_data()[:, ix_ph, :] 605 | data_am = inst.get_data()[:, ix_amp, :] 606 | elif isinstance(inst, BaseRaw): 607 | data = inst[[ix_ph, ix_amp], :][0].squeeze() 608 | data_ph, data_am = [i[np.newaxis, ...] for i in data] 609 | return data_ph, data_am 610 | 611 | 612 | def _band_pass_pac(x, f_range, sfreq=1000, n_cycles=3): 613 | """ 614 | Band-pass filter a signal using PacPy for PAC coupling. 615 | 616 | This is a version of the firf function in PacPy, minux edge removal. 617 | It's docstring is below 618 | ---- 619 | Filter signal with an FIR filter 620 | *Like fir1 in MATLAB 621 | 622 | x : array-like, 1d 623 | Time series to filter. 624 | f_range : (low, high), Hz 625 | Cutoff frequencies of bandpass filter. 626 | sfreq : float, Hz 627 | Sampling rate. 628 | n_cycles : float 629 | Length of the filter in terms of the number of cycles 630 | of the oscillation whose frequency is the low cutoff of the 631 | bandpass filter. 632 | 633 | Returns 634 | ------- 635 | x_filt : array-like, 1d 636 | Filtered time series. 637 | """ 638 | from ..externals.pacpy.filt import firwin, filtfilt 639 | 640 | if n_cycles <= 0: 641 | raise ValueError( 642 | 'Number of cycles in a filter must be a positive number.') 643 | 644 | nyq = np.float(sfreq / 2) 645 | if np.any(np.array(f_range) > nyq): 646 | raise ValueError('Filter frequencies must be below nyquist rate.') 647 | 648 | if np.any(np.array(f_range) < 0): 649 | raise ValueError('Filter frequencies must be positive.') 650 | 651 | n_taps = int(np.floor(n_cycles * sfreq / f_range[0])) 652 | if len(x) < n_taps: 653 | raise RuntimeError( 654 | 'Length of filter is longer than data. ' 655 | 'Provide more data or a shorter filter.') 656 | 657 | # Perform filtering 658 | taps = firwin(n_taps, np.array(f_range) / nyq, pass_zero=False) 659 | x_filt = filtfilt(taps, [1], x) 660 | 661 | if any(np.isnan(x_filt)): 662 | raise RuntimeError( 663 | 'Filtered signal contains nans. Adjust filter parameters.') 664 | 665 | # Remove edge artifacts 666 | return x_filt 667 | -------------------------------------------------------------------------------- /mne_incubator/externals/pacpy/pac.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Functions to calculate phase-amplitude coupling. 4 | """ 5 | from __future__ import division 6 | import numpy as np 7 | from scipy.signal import hilbert 8 | from scipy.stats.mstats import zscore 9 | from .filt import firf, morletf 10 | 11 | 12 | def _x_sanity(lo=None, hi=None): 13 | if lo is not None: 14 | if np.any(np.isnan(lo)): 15 | raise ValueError("lo contains NaNs") 16 | 17 | if hi is not None: 18 | if np.any(np.isnan(hi)): 19 | raise ValueError("hi contains NaNs") 20 | 21 | if (hi is not None) and (lo is not None): 22 | if lo.size != hi.size: 23 | raise ValueError("lo and hi must be the same length") 24 | 25 | 26 | def _range_sanity(f_lo=None, f_hi=None): 27 | if f_lo is not None: 28 | if len(f_lo) != 2: 29 | raise ValueError("f_lo must contain two elements") 30 | 31 | if f_lo[0] < 0: 32 | raise ValueError("Elements in f_lo must be > 0") 33 | 34 | if f_hi is not None: 35 | if len(f_hi) != 2: 36 | raise ValueError("f_hi must contain two elements") 37 | if f_hi[0] < 0: 38 | raise ValueError("Elements in f_hi must be > 0") 39 | 40 | 41 | def plv(lo, hi, f_lo, f_hi, fs=1000, filterfn=None, filter_kwargs=None): 42 | """ 43 | Calculate PAC using the phase-locking value (PLV) method from prefiltered 44 | signals 45 | 46 | Parameters 47 | ---------- 48 | lo : array-like, 1d 49 | The low frequency time-series to use as the phase component 50 | hi : array-like, 1d 51 | The high frequency time-series to use as the amplitude component 52 | f_lo : (low, high), Hz 53 | The low frequency filtering range 54 | f_hi : (low, high), Hz 55 | The low frequency filtering range 56 | fs : float 57 | The sampling rate (default = 1000Hz) 58 | filterfn : function, False 59 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 60 | 61 | False activates 'EXPERT MODE'. 62 | - DO NOT USE THIS FLAG UNLESS YOU KNOW WHAT YOU ARE DOING! 63 | - In expert mode the user needs to filter the data AND apply the 64 | hilbert transform. 65 | - This requires that 'lo' be the phase time series of the low-bandpass 66 | filtered signal, and 'hi' be the phase time series of the low-bandpass 67 | of the amplitude of the high-bandpass of the original signal. 68 | filter_kwargs : dict 69 | Keyword parameters to pass to `filterfn(.)` 70 | 71 | Returns 72 | ------- 73 | pac : scalar 74 | PAC value 75 | 76 | Usage 77 | ----- 78 | >>> import numpy as np 79 | >>> from scipy.signal import hilbert 80 | >>> from pacpy.pac import plv 81 | >>> t = np.arange(0, 10, .001) # Define time array 82 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 83 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 84 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 85 | >>> plv(lo, hi, (4,8), (80,150)) # Calculate PAC 86 | 0.99863308613553081 87 | """ 88 | 89 | # Arg check 90 | _x_sanity(lo, hi) 91 | 92 | # Filter setup 93 | if filterfn is None: 94 | filterfn = firf 95 | 96 | if filter_kwargs is None: 97 | filter_kwargs = {} 98 | 99 | # Filter then hilbert 100 | if filterfn is not False: 101 | _range_sanity(f_lo, f_hi) 102 | lo = filterfn(lo, f_lo, fs, **filter_kwargs) 103 | hi = filterfn(hi, f_hi, fs, **filter_kwargs) 104 | amp = np.abs(hilbert(hi)) 105 | hi = filterfn(amp, f_lo, fs, **filter_kwargs) 106 | 107 | lo = np.angle(hilbert(lo)) 108 | hi = np.angle(hilbert(hi)) 109 | 110 | # Make arrays the same size 111 | lo, hi = _trim_edges(lo, hi) 112 | 113 | # Calculate PLV 114 | pac = np.abs(np.mean(np.exp(1j * (lo - hi)))) 115 | 116 | return pac 117 | 118 | 119 | def _trim_edges(lo, hi): 120 | """ 121 | Remove extra edge artifact from the signal with the shorter filter 122 | so that its time series is identical to that of the filtered signal 123 | with a longer filter. 124 | """ 125 | 126 | if len(lo) == len(hi): 127 | return lo, hi # Die early if there's nothing to do. 128 | elif len(lo) < len(hi): 129 | Ndiff = len(hi) - len(lo) 130 | if Ndiff % 2 != 0: 131 | raise ValueError( 132 | 'Difference in filtered signal lengths should be even') 133 | hi = hi[np.int(Ndiff / 2):np.int(-Ndiff / 2)] 134 | else: 135 | Ndiff = len(lo) - len(hi) 136 | if Ndiff % 2 != 0: 137 | raise ValueError( 138 | 'Difference in filtered signal lengths should be even') 139 | lo = lo[np.int(Ndiff / 2):np.int(-Ndiff / 2)] 140 | 141 | return lo, hi 142 | 143 | 144 | def mi_tort(lo, hi, f_lo, f_hi, fs=1000, Nbins=20, filterfn=None, 145 | filter_kwargs=None): 146 | """ 147 | Calculate PAC using the modulation index method from prefiltered 148 | signals 149 | 150 | Parameters 151 | ---------- 152 | lo : array-like, 1d 153 | The low frequency time-series to use as the phase component 154 | hi : array-like, 1d 155 | The high frequency time-series to ue as the amplitude component 156 | f_lo : (low, high), Hz 157 | The low frequency filtering ranges 158 | f_hi : (low, high), Hz 159 | The low frequency filtering range 160 | fs : float 161 | The sampling rate (default = 1000Hz) 162 | filterfn : functional 163 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 164 | 165 | False activates 'EXPERT MODE'. 166 | - DO NOT USE THIS FLAG UNLESS YOU KNOW WHAT YOU ARE DOING! 167 | - In expert mode the user needs to filter the data AND apply the 168 | hilbert transform. 169 | - This requires that 'lo' be the phase time series of the low-bandpass 170 | filtered signal, and 'hi' be the amplitude time series of the high- 171 | bandpass of the original signal. 172 | filter_kwargs : dict 173 | Keyword parameters to pass to `filterfn(.)` 174 | Nbins : int 175 | Number of bins to split up the low frequency oscillation cycle 176 | 177 | Returns 178 | ------- 179 | pac : scalar 180 | PAC value 181 | 182 | Usage 183 | ----- 184 | >>> import numpy as np 185 | >>> from scipy.signal import hilbert 186 | >>> from pacpy.pac import mi_tort 187 | >>> t = np.arange(0, 10, .001) # Define time array 188 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 189 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 190 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 191 | >>> mi_tort(lo, hi, (4,8), (80,150)) # Calculate PAC 192 | 0.34898478944110811 193 | """ 194 | 195 | # Arg check 196 | _x_sanity(lo, hi) 197 | if np.logical_or(Nbins < 2, Nbins != int(Nbins)): 198 | raise ValueError( 199 | 'Number of bins in the low frequency oscillation cycle' 200 | 'must be an integer >1.') 201 | 202 | # Filter setup 203 | if filterfn is None: 204 | filterfn = firf 205 | 206 | if filter_kwargs is None: 207 | filter_kwargs = {} 208 | 209 | # Filter then hilbert 210 | if filterfn is not False: 211 | _range_sanity(f_lo, f_hi) 212 | lo = filterfn(lo, f_lo, fs, **filter_kwargs) 213 | hi = filterfn(hi, f_hi, fs, **filter_kwargs) 214 | 215 | hi = np.abs(hilbert(hi)) 216 | lo = np.angle(hilbert(lo)) 217 | 218 | # Make arrays the same size 219 | lo, hi = _trim_edges(lo, hi) 220 | 221 | # Convert the phase time series from radians to degrees 222 | phadeg = np.degrees(lo) 223 | 224 | # Calculate PAC 225 | binsize = 360 / Nbins 226 | phase_lo = np.arange(-180, 180, binsize) 227 | mean_amp = np.zeros(len(phase_lo)) 228 | for b in range(len(phase_lo)): 229 | phaserange = np.logical_and(phadeg >= phase_lo[b], 230 | phadeg < (phase_lo[b] + binsize)) 231 | mean_amp[b] = np.mean(hi[phaserange]) 232 | 233 | p_j = np.zeros(len(phase_lo)) 234 | for b in range(len(phase_lo)): 235 | p_j[b] = mean_amp[b] / sum(mean_amp) 236 | 237 | h = -np.sum(p_j * np.log10(p_j)) 238 | h_max = np.log10(Nbins) 239 | pac = (h_max - h) / h_max 240 | 241 | return pac 242 | 243 | 244 | def _ols(y, X): 245 | """Custom OLS (to minimize outside dependecies)""" 246 | 247 | dummy = np.repeat(1.0, X.shape[0]) 248 | X = np.hstack([X, dummy[:, np.newaxis]]) 249 | 250 | beta_hat, resid, _, _ = np.linalg.lstsq(X, y) 251 | y_hat = np.dot(X, beta_hat) 252 | 253 | return y_hat, beta_hat 254 | 255 | 256 | def glm(lo, hi, f_lo, f_hi, fs=1000, filterfn=None, filter_kwargs=None): 257 | """ 258 | Calculate PAC using the generalized linear model (GLM) method 259 | 260 | Parameters 261 | ---------- 262 | lo : array-like, 1d 263 | The low frequency time-series to use as the phase component 264 | hi : array-like, 1d 265 | The high frequency time-series to use as the amplitude component 266 | f_lo : (low, high), Hz 267 | The low frequency filtering range 268 | f_high : (low, high), Hz 269 | The low frequency filtering range 270 | fs : float 271 | The sampling rate (default = 1000Hz) 272 | filterfn : functional 273 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 274 | 275 | False activates 'EXPERT MODE'. 276 | - DO NOT USE THIS FLAG UNLESS YOU KNOW WHAT YOU ARE DOING! 277 | - In expert mode the user needs to filter the data AND apply the 278 | hilbert transform. 279 | - This requires that 'lo' be the phase time series of the low-bandpass 280 | filtered signal, and 'hi' be the amplitude time series of the high- 281 | bandpass of the original signal. 282 | filter_kwargs : dict 283 | Keyword parameters to pass to `filterfn(.)` 284 | 285 | Returns 286 | ------- 287 | pac : scalar 288 | PAC value 289 | 290 | Usage 291 | ----- 292 | >>> import numpy as np 293 | >>> from scipy.signal import hilbert 294 | >>> from pacpy.pac import glm 295 | >>> t = np.arange(0, 10, .001) # Define time array 296 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 297 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 298 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 299 | >>> glm(lo, hi, (4,8), (80,150)) # Calculate PAC 300 | 0.69090396896138917 301 | """ 302 | 303 | # Arg check 304 | _x_sanity(lo, hi) 305 | 306 | # Filter series 307 | if filterfn is None: 308 | filterfn = firf 309 | 310 | if filter_kwargs is None: 311 | filter_kwargs = {} 312 | 313 | # Filter then hilbert 314 | if filterfn is not False: 315 | _range_sanity(f_lo, f_hi) 316 | lo = filterfn(lo, f_lo, fs, **filter_kwargs) 317 | hi = filterfn(hi, f_hi, fs, **filter_kwargs) 318 | 319 | hi = np.abs(hilbert(hi)) 320 | lo = np.angle(hilbert(lo)) 321 | 322 | # Make arrays the same size 323 | lo, hi = _trim_edges(lo, hi) 324 | 325 | # First prepare GLM 326 | y = hi 327 | X_pre = np.vstack((np.cos(lo), np.sin(lo))) 328 | X = X_pre.T 329 | y_hat, beta_hat = _ols(y, X) 330 | resid = y - y_hat 331 | 332 | # Calculate PAC from GLM residuals 333 | pac = 1 - np.sum(resid ** 2) / np.sum( 334 | (hi - np.mean(hi)) ** 2) 335 | 336 | return pac 337 | 338 | 339 | def mi_canolty(lo, hi, f_lo, f_hi, fs=1000, filterfn=None, filter_kwargs=None, 340 | n_surr=100): 341 | """ 342 | Calculate PAC using the modulation index (MI) method defined in Canolty, 343 | 2006 344 | 345 | Parameters 346 | ---------- 347 | lo : array-like, 1d 348 | The low frequency time-series to use as the phase component 349 | hi : array-like, 1d 350 | The high frequency time-series to use as the amplitude component 351 | f_lo : (low, high), Hz 352 | The low frequency filtering range 353 | f_hi : (low, high), Hz 354 | The low frequency filtering range 355 | fs : float 356 | The sampling rate (default = 1000Hz) 357 | filterfn : functional 358 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 359 | 360 | False activates 'EXPERT MODE'. 361 | - DO NOT USE THIS FLAG UNLESS YOU KNOW WHAT YOU ARE DOING! 362 | - In expert mode the user needs to filter the data AND apply the 363 | hilbert transform. 364 | - This requires that 'lo' be the phase time series of the low-bandpass 365 | filtered signal, and 'hi' be the amplitude time series of the high- 366 | bandpass of the original signal. 367 | filter_kwargs : dict 368 | Keyword parameters to pass to `filterfn(.)` 369 | n_surr : int 370 | Number of surrogate tests to run to calculate normalized MI 371 | 372 | Returns 373 | ------- 374 | pac : scalar 375 | PAC value 376 | 377 | Usage 378 | ----- 379 | >>> import numpy as np 380 | >>> from scipy.signal import hilbert 381 | >>> from pacpy.pac import mi_canolty 382 | >>> t = np.arange(0, 10, .001) # Define time array 383 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 384 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 385 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 386 | >>> mi_canolty(lo, hi, (4,8), (80,150)) # Calculate PAC 387 | 1.1605177063713188 388 | """ 389 | 390 | # Arg check 391 | _x_sanity(lo, hi) 392 | 393 | # Filter series 394 | if filterfn is None: 395 | filterfn = firf 396 | 397 | if filter_kwargs is None: 398 | filter_kwargs = {} 399 | 400 | # Filter then hilbert 401 | if filterfn is not False: 402 | _range_sanity(f_lo, f_hi) 403 | lo = filterfn(lo, f_lo, fs, **filter_kwargs) 404 | hi = filterfn(hi, f_hi, fs, **filter_kwargs) 405 | 406 | hi = np.abs(hilbert(hi)) 407 | lo = np.angle(hilbert(lo)) 408 | 409 | # Make arrays the same size 410 | lo, hi = _trim_edges(lo, hi) 411 | 412 | # Calculate modulation index 413 | pac = np.abs(np.mean(hi * np.exp(1j * lo))) 414 | 415 | # Calculate surrogate MIs 416 | pacS = np.zeros(n_surr) 417 | np.random.seed(0) 418 | for s in range(n_surr): 419 | loS = np.roll(lo, np.random.randint(len(lo))) 420 | pacS[s] = np.abs(np.mean(hi * np.exp(1j * loS))) 421 | 422 | # Return z-score of observed PAC compared to null distribution 423 | return (pac - np.mean(pacS)) / np.std(pacS) 424 | 425 | 426 | def ozkurt(lo, hi, f_lo, f_hi, fs=1000, filterfn=None, filter_kwargs=None): 427 | """ 428 | Calculate PAC using the method defined in Ozkurt & Schnitzler, 2011 429 | 430 | Parameters 431 | ---------- 432 | lo : array-like, 1d 433 | The low frequency time-series to use as the phase component 434 | hi : array-like, 1d 435 | The high frequency time-series to use as the amplitude component 436 | f_lo : (low, high), Hz 437 | The low frequency filtering range 438 | f_hi : (low, high), Hz 439 | The low frequency filtering range 440 | fs : float 441 | The sampling rate (default = 1000Hz) 442 | filterfn : functional 443 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 444 | 445 | False activates 'EXPERT MODE'. 446 | - DO NOT USE THIS FLAG UNLESS YOU KNOW WHAT YOU ARE DOING! 447 | - In expert mode the user needs to filter the data AND apply the 448 | hilbert transform. 449 | - This requires that 'lo' be the phase time series of the low-bandpass 450 | filtered signal, and 'hi' be the amplitude time series of the high- 451 | bandpass of the original signal. 452 | filter_kwargs : dict 453 | Keyword parameters to pass to `filterfn(.)` 454 | 455 | Returns 456 | ------- 457 | pac : scalar 458 | PAC value 459 | 460 | Usage 461 | ----- 462 | >>> import numpy as np 463 | >>> from scipy.signal import hilbert 464 | >>> from pacpy.pac import ozkurt 465 | >>> t = np.arange(0, 10, .001) # Define time array 466 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 467 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 468 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 469 | >>> ozkurt(lo, hi, (4,8), (80,150)) # Calculate PAC 470 | 0.48564417921240238 471 | """ 472 | 473 | # Arg check 474 | _x_sanity(lo, hi) 475 | 476 | # Filter series 477 | if filterfn is None: 478 | filterfn = firf 479 | 480 | if filter_kwargs is None: 481 | filter_kwargs = {} 482 | 483 | # Filter then hilbert 484 | if filterfn is not False: 485 | _range_sanity(f_lo, f_hi) 486 | lo = filterfn(lo, f_lo, fs, **filter_kwargs) 487 | hi = filterfn(hi, f_hi, fs, **filter_kwargs) 488 | 489 | hi = np.abs(hilbert(hi)) 490 | lo = np.angle(hilbert(lo)) 491 | 492 | # Make arrays the same size 493 | lo, hi = _trim_edges(lo, hi) 494 | 495 | # Calculate PAC 496 | pac = np.abs(np.sum(hi * np.exp(1j * lo))) / \ 497 | (np.sqrt(len(lo)) * np.sqrt(np.sum(hi**2))) 498 | return pac 499 | 500 | 501 | def otc(x, f_hi, f_step, fs=1000, 502 | w=3, event_prc=95, t_modsig=None, t_buffer=.01): 503 | """ 504 | Calculate the oscillation-triggered coupling measure of phase-amplitude 505 | coupling from Dvorak, 2014. 506 | 507 | Parameters 508 | ---------- 509 | x : array-like, 1d 510 | The time series 511 | f_hi : (low, high), Hz 512 | The low frequency filtering range 513 | f_step : float, Hz 514 | The width of each frequency bin in the time-frequency representation 515 | fs : float 516 | Sampling rate 517 | w : float 518 | Length of the filter in terms of the number of cycles of the 519 | oscillation whose frequency is the center of the bandpass filter 520 | event_prc : float (in range 0-100) 521 | The percentile threshold of the power signal of an oscillation 522 | for an event to be declared 523 | t_modsig : (min, max) 524 | Time (seconds) around an event to extract to define the modulation 525 | signal 526 | t_buffer : float 527 | Minimum time (seconds) in between high frequency events 528 | 529 | Returns 530 | ------- 531 | pac : float 532 | phase-amplitude coupling value 533 | tf : 2-dimensional array 534 | time-frequency representation of input signal 535 | a_events : array 536 | samples at which a high frequency event occurs 537 | mod_sig : array 538 | modulation signal (see Dvorak, 2014) 539 | 540 | Usage 541 | ----- 542 | >>> import numpy as np 543 | >>> from scipy.signal import hilbert 544 | >>> from pacpy.pac import otc 545 | >>> t = np.arange(0, 10, .001) # Define time array 546 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 547 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 548 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 549 | >>> pac, _, _, _ = otc(lo + hi, (80,150), 4) # Calculate PAC 550 | >>> print pac 551 | 2.1324570402314196 552 | """ 553 | 554 | # Arg check 555 | _x_sanity(x, None) 556 | _range_sanity(None, f_hi) 557 | # Set default time range for modulatory signal 558 | if t_modsig is None: 559 | t_modsig = (-1, 1) 560 | if f_step <= 0: 561 | raise ValueError('Frequency band width must be a positive number.') 562 | if t_modsig[0] > t_modsig[1]: 563 | raise ValueError('Invalid time range for modulation signal.') 564 | 565 | # Calculate the time-frequency representation 566 | f0s = np.arange(f_hi[0], f_hi[1], f_step) 567 | tf = _morletT(x, f0s, w=w, fs=fs) 568 | 569 | # Find the high frequency activity event times 570 | F = len(f0s) 571 | a_events = np.zeros(F, dtype=object) 572 | for f in range(F): 573 | a_events[f] = _peaktimes( 574 | zscore(np.abs(tf[f])), prc=event_prc, t_buffer=t_buffer) 575 | 576 | # Calculate the modulation signal 577 | samp_modsig = np.arange(t_modsig[0] * fs, t_modsig[1] * fs) 578 | samp_modsig = samp_modsig.astype(int) 579 | S = len(samp_modsig) 580 | mod_sig = np.zeros([F, S]) 581 | 582 | # For each frequency in the time-frequency representation, calculate a 583 | # modulation signal 584 | for f in range(F): 585 | # Exclude high frequency events that are too close to the signal 586 | # boundaries to extract an entire modulation signal 587 | mask = np.ones(len(a_events[f]), dtype=bool) 588 | mask[a_events[f] <= samp_modsig[-1]] = False 589 | mask[a_events[f] >= (len(x) - samp_modsig[-1])] = False 590 | a_events[f] = a_events[f][mask] 591 | 592 | # Calculate the average LFP around each high frequency event 593 | E = len(a_events[f]) 594 | for e in range(E): 595 | cur_ecog = x[a_events[f][e] + samp_modsig] 596 | mod_sig[f] = mod_sig[f] + cur_ecog / E 597 | 598 | # Calculate modulation strength, the range of the modulation signal 599 | mod_strength = np.zeros(F) 600 | for f in range(F): 601 | mod_strength = np.max(mod_sig[f]) - np.min(mod_sig[f]) 602 | 603 | # Calculate PAC 604 | pac = np.max(mod_strength) 605 | 606 | return pac, tf, a_events, mod_sig 607 | 608 | 609 | def _peaktimes(x, prc=95, t_buffer=.01, fs=1000): 610 | """ 611 | Calculate event times for which the power signal x peaks 612 | 613 | Parameters 614 | ---------- 615 | x : array 616 | Time series of power 617 | prc : float (in range 0-100) 618 | The percentile threshold of x for an event to be declares 619 | t_buffer : float 620 | Minimum time (seconds) in between events 621 | fs : float 622 | Sampling rate 623 | """ 624 | if np.logical_or(prc < 0, prc >= 100): 625 | raise ValueError('Percentile threshold must be between 0 and 100.') 626 | 627 | samp_buffer = np.int(np.round(t_buffer * fs)) 628 | hi = x > np.percentile(x, prc) 629 | event_intervals = _chunk_time(hi, samp_buffer=samp_buffer) 630 | E = np.int(np.size(event_intervals) / 2) 631 | events = np.zeros(E, dtype=object) 632 | 633 | for e in range(E): 634 | temp = x[np.arange(event_intervals[e][0], event_intervals[e][1] + 1)] 635 | events[e] = event_intervals[e][0] + np.argmax(temp) 636 | 637 | return events 638 | 639 | 640 | def _chunk_time(x, samp_buffer=0): 641 | """ 642 | Define continuous chunks of integers 643 | 644 | Parameters 645 | ---------- 646 | x : array 647 | Array of integers 648 | samp_buffer : int 649 | Minimum number of samples between chunks 650 | 651 | Returns 652 | ------- 653 | chunks : array (#chunks x 2) 654 | List of the sample bounds for each chunk 655 | """ 656 | if samp_buffer < 0: 657 | raise ValueError( 658 | 'Buffer between signal peaks must be a positive number') 659 | if samp_buffer != int(samp_buffer): 660 | raise ValueError('Number of samples must be an integer') 661 | 662 | if type(x[0]) == np.bool_: 663 | Xs = np.arange(len(x)) 664 | x = Xs[x] 665 | X = len(x) 666 | 667 | cur_start = x[0] 668 | cur_samp = x[0] 669 | Nchunk = 0 670 | chunks = [] 671 | for i in range(1, X): 672 | if x[i] > (cur_samp + samp_buffer + 1): 673 | if Nchunk == 0: 674 | chunks = [cur_start, cur_samp] 675 | else: 676 | chunks = np.vstack([chunks, [cur_start, cur_samp]]) 677 | 678 | Nchunk = Nchunk + 1 679 | cur_start = x[i] 680 | 681 | cur_samp = x[i] 682 | 683 | # Add final row to chunk 684 | if Nchunk == 0: 685 | chunks = [[cur_start, cur_samp]] 686 | else: 687 | chunks = np.vstack([chunks, [cur_start, cur_samp]]) 688 | 689 | return chunks 690 | 691 | 692 | def _morletT(x, f0s, w=3, fs=1000, s=1): 693 | """ 694 | Calculate the time-frequency representation of the signal 'x' over the 695 | frequencies in 'f0s' using morlet wavelets 696 | 697 | Parameters 698 | ---------- 699 | x : array 700 | time series 701 | f0s : array 702 | frequency axis 703 | w : float 704 | Length of the filter in terms of the number of cycles 705 | of the oscillation whose frequency is the center of the 706 | bandpass filter 707 | Fs : float 708 | Sampling rate 709 | s : float 710 | Scaling factor 711 | 712 | Returns 713 | ------- 714 | mwt : 2-D array 715 | time-frequency representation of signal x 716 | """ 717 | if w <= 0: 718 | raise ValueError( 719 | 'Number of cycles in a filter must be a positive number.') 720 | 721 | T = len(x) 722 | F = len(f0s) 723 | mwt = np.zeros([F, T], dtype=complex) 724 | for f in range(F): 725 | mwt[f] = morletf(x, f0s[f], fs=fs, w=w, s=s) 726 | 727 | return mwt 728 | 729 | 730 | def comodulogram(lo, hi, p_range, a_range, dp, da, fs=1000, 731 | pac_method='mi_tort', 732 | filterfn=None, filter_kwargs=None): 733 | """ 734 | Calculate PAC for many small frequency bands 735 | 736 | Parameters 737 | ---------- 738 | lo : array-like, 1d 739 | The low frequency time-series to use as the phase component 740 | hi : array-like, 1d 741 | The high frequency time-series to use as the amplitude component 742 | p_range : (low, high), Hz 743 | The low frequency filtering range 744 | a_range : (low, high), Hz 745 | The high frequency filtering range 746 | dp : float, Hz 747 | Width of the low frequency filtering range for each PAC calculation 748 | da : float, Hz 749 | Width of the high frequency filtering range for each PAC calculation 750 | fs : float 751 | The sampling rate (default = 1000Hz) 752 | pac_method : string 753 | Method to calculate PAC. 754 | 'mi_tort' - See Tort, 2008 755 | 'plv' - See Penny, 2008 756 | 'glm' - See Penny, 2008 757 | 'mi_canolty' - See Canolty, 2006 758 | 'ozkurt' - See Ozkurt & Schnitzler, 2011 759 | filterfn : function 760 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 761 | filter_kwargs : dict 762 | Keyword parameters to pass to `filterfn(.)` 763 | 764 | Returns 765 | ------- 766 | comod : array-like, 2d 767 | Matrix of phase-amplitude coupling values for each combination of the 768 | phase frequency bin and the amplitude frequency bin 769 | 770 | Usage 771 | ----- 772 | >>> import numpy as np 773 | >>> from scipy.signal import hilbert 774 | >>> from pacpy.pac import comodulogram 775 | >>> t = np.arange(0, 10, .001) # Define time array 776 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 777 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 778 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 779 | >>> comod = comodulogram(lo, hi, (5,25), (75,175), 10, 50) # Calculate PAC 780 | >>> print comod 781 | [[ 0.32708628 0.32188585] 782 | [ 0.3295994 0.32439953]] 783 | """ 784 | 785 | # Arg check 786 | _x_sanity(lo, hi) 787 | _range_sanity(p_range, a_range) 788 | if dp <= 0: 789 | raise ValueError('Width of lo frequqnecy range must be positive') 790 | if da <= 0: 791 | raise ValueError('Width of hi frequqnecy range must be positive') 792 | 793 | # method check 794 | method2fun = {'plv': plv, 'mi_tort': mi_tort, 'mi_canolty': mi_canolty, 795 | 'ozkurt': ozkurt, 'glm': glm} 796 | pac_fun = method2fun.get(pac_method, None) 797 | if pac_fun == None: 798 | raise ValueError('PAC method given is invalid.') 799 | 800 | # Calculate palette frequency parameters 801 | f_phases = np.arange(p_range[0], p_range[1], dp) 802 | f_amps = np.arange(a_range[0], a_range[1], da) 803 | P = len(f_phases) 804 | A = len(f_amps) 805 | 806 | # Calculate PAC for every combination of P and A 807 | comod = np.zeros((P, A)) 808 | for p in range(P): 809 | f_lo = (f_phases[p], f_phases[p] + dp) 810 | 811 | for a in range(A): 812 | f_hi = (f_amps[a], f_amps[a] + da) 813 | 814 | comod[p, a] = pac_fun(lo, hi, f_lo, f_hi, fs=fs, 815 | filterfn=filterfn, filter_kwargs=filter_kwargs) 816 | 817 | return comod 818 | 819 | 820 | def pa_series(lo, hi, f_lo, f_hi, fs=1000, filterfn=None, filter_kwargs=None): 821 | """ 822 | Calculate the phase and amplitude time series 823 | 824 | Parameters 825 | ---------- 826 | lo : array-like, 1d 827 | The low frequency time-series to use as the phase component 828 | hi : array-like, 1d 829 | The high frequency time-series to use as the amplitude component 830 | f_lo : (low, high), Hz 831 | The low frequency filtering range 832 | f_hi : (low, high), Hz 833 | The low frequency filtering range 834 | fs : float 835 | The sampling rate (default = 1000Hz) 836 | filterfn : function 837 | The filtering function, `filterfn(x, f_range, filter_kwargs)` 838 | filter_kwargs : dict 839 | Keyword parameters to pass to `filterfn(.)` 840 | 841 | Returns 842 | ------- 843 | pha : array-like, 1d 844 | Time series of phase 845 | amp : array-like, 1d 846 | Time series of amplitude 847 | 848 | Usage 849 | ----- 850 | >>> import numpy as np 851 | >>> from scipy.signal import hilbert 852 | >>> from pacpy.pac import pa_series 853 | >>> t = np.arange(0, 10, .001) # Define time array 854 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 855 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 856 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 857 | >>> pha, amp = pa_series(lo, hi, (4,8), (80,150)) 858 | >>> print pha 859 | [ 1.57079633 1.60849544 1.64619455 ..., 1.45769899 1.4953981 1.53309721] 860 | """ 861 | 862 | # Arg check 863 | _x_sanity(lo, hi) 864 | _range_sanity(f_lo, f_hi) 865 | 866 | # Filter setup 867 | if filterfn is None: 868 | filterfn = firf 869 | filter_kwargs = {} 870 | 871 | # Filter 872 | xlo = filterfn(lo, f_lo, fs, **filter_kwargs) 873 | xhi = filterfn(hi, f_hi, fs, **filter_kwargs) 874 | 875 | # Calculate phase time series and amplitude time series 876 | pha = np.angle(hilbert(xlo)) 877 | amp = np.abs(hilbert(xhi)) 878 | 879 | # Make arrays the same size 880 | pha, amp = _trim_edges(pha, amp) 881 | 882 | return pha, amp 883 | 884 | 885 | def pa_dist(pha, amp, Nbins=10): 886 | """ 887 | Calculate distribution of amplitude over a cycle of phases 888 | 889 | Parameters 890 | ---------- 891 | pha : array 892 | Phase time series 893 | amp : array 894 | Amplitude time series 895 | Nbins : int 896 | Number of phase bins in the distribution, 897 | uniformly distributed between -pi and pi. 898 | 899 | Returns 900 | ------- 901 | dist : array 902 | Average amplitude in each phase bins 903 | phase_bins : array 904 | The boundaries to each phase bin. Note the length is 1 + len(dist) 905 | 906 | Usage 907 | ----- 908 | >>> import numpy as np 909 | >>> from scipy.signal import hilbert 910 | >>> from pacpy.pac import pa_series, pa_dist 911 | >>> t = np.arange(0, 10, .001) # Define time array 912 | >>> lo = np.sin(t * 2 * np.pi * 6) # Create low frequency carrier 913 | >>> hi = np.sin(t * 2 * np.pi * 100) # Create modulated oscillation 914 | >>> hi[np.angle(hilbert(lo)) > -np.pi*.5] = 0 # Clip to 1/4 of cycle 915 | >>> pha, amp = pa_series(lo, hi, (4,8), (80,150)) 916 | >>> phase_bins, dist = pa_dist(pha, amp) 917 | >>> print dist 918 | [ 7.21154110e-01 8.04347122e-01 4.49207087e-01 2.08747058e-02 919 | 8.03854240e-05 3.45166617e-05 3.45607343e-05 3.51091029e-05 920 | 7.73644631e-04 1.63514941e-01] 921 | """ 922 | if np.logical_or(Nbins < 2, Nbins != int(Nbins)): 923 | raise ValueError( 924 | 'Number of bins in the low frequency oscillation cycle must be an integer >1.') 925 | if len(pha) != len(amp): 926 | raise ValueError( 927 | 'Phase and amplitude time series must be of same length.') 928 | 929 | phase_bins = np.linspace(-np.pi, np.pi, int(Nbins + 1)) 930 | dist = np.zeros(int(Nbins)) 931 | 932 | for b in range(int(Nbins)): 933 | t_phase = np.logical_and(pha >= phase_bins[b], 934 | pha < phase_bins[b + 1]) 935 | dist[b] = np.mean(amp[t_phase]) 936 | 937 | return phase_bins[:-1], dist 938 | --------------------------------------------------------------------------------