├── .gitignore ├── README.md ├── matlab ├── extract_bursts.m ├── extract_bursts_single_trial.m ├── fwhm_burst_norm.m ├── gaus2d.m └── overlap.m └── python ├── burst_detection.py ├── pca_analysis_tutorial.ipynb ├── superlet_burst_detection_example.ipynb └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Burst Detection Algorithm - usage notes 2 | 3 | library accompanying the paper Szul et al. (2023) "Diverse beta burst waveform motifs characterize movement-related cortical dynamics" Progress in Neurobiology 4 | 5 | 6 | 7 | ## PYTHON 8 | 9 | ### DEPENDENCIES 10 | 11 | MNE-Python (only for filtering) 12 | 13 | scipy 14 | 15 | numpy 16 | 17 | ### INSTALLATION 18 | 19 | 1. install dependencies with `pip` or `conda install` 20 | 21 | 2. add folder to the PYTHONPATH or put it in the same directory as the script 22 | 23 | 24 | ### USAGE 25 | 26 | ```python 27 | from burst_detection import extract_bursts 28 | ``` 29 | 30 | 31 | 32 | 33 | ```python 34 | bursts = extract_bursts( 35 | raw_trials, tf, times, search_freqs, 36 | band_lims, aperiodic_spectrum, sfreq, w_size=.26 37 | ) 38 | ``` 39 | 40 | 41 | 42 | 43 | ```python 44 | bursts_single_trial = extract_bursts_single_trial( 45 | raw_trial, tf, times, search_freqs, 46 | band_lims, aperiodic_spectrum, sfreq, w_size=.26 47 | ) 48 | ``` 49 | 50 | >A single trial burst extraction. Single trial time-course, and single trial 51 | time-frequency. No regressing out the average ERF. Potentially 7-10 Hz high-pass 52 | of the time domain trial can get rid of evoked response related burst shape distortions. 53 | 54 | 55 | 56 | #### Arguments: 57 | 58 | `raw_trials: raw data for each trial (trial x time)` 59 | 60 | >Array of trials from a single channel in time domain. 61 | 62 | 63 | \ 64 | `tf: time-frequency decomposition for each trial (trial x freq x time)` 65 | 66 | >Array of time-frequency data from a single channel. If the targeted range 67 | of frequencies is 13-30 Hz, it is recommended to select the TF with at least 3 68 | Hz buffer (10-33 Hz). Recommended use of superlet (Moca et al., 2021) 69 | transformed data. 70 | 71 | \ 72 | `times: time steps` 73 | 74 | >Time points of the `raw_trials` and `tf` should be equivalent 75 | 76 | 77 | \ 78 | `search_freqs: frequency limits to search within for bursts (should be wider than band_lims)` 79 | 80 | >List of frequencies corresponding to the selected ones in `tf`. 81 | 82 | 83 | \ 84 | `band_lims: keep bursts whose peak frequency falls within these limits` 85 | 86 | >The actual target frequency range. Bursts with peak frequency beyond this range 87 | will be discarded. 88 | 89 | 90 | \ 91 | `aperiodic_spectrum: aperiodic spectrum` 92 | 93 | >Assuming the PSD was calculated based on the TF data averaged over time, the 94 | frequency resolution of the PSD and TF should be the same. 95 | 96 | 97 | \ 98 | `sfreq: sampling rate` 99 | 100 | >Assuming the sampling rate is the same for `raw_trials` and `TF` 101 | 102 | 103 | \ 104 | `w_size=.2: size of the burst window in time domain` 105 | 106 | >Default argument. Window size based on lagged coherence of the MEG data. 107 | 108 | 109 | 110 | #### Output: 111 | 112 | The function returns a dictionary with time-frequency burst features, and time 113 | domain waveforms. 114 | 115 | ```python 116 | { 117 | 'trial': [], 118 | 'waveform': [], 119 | 'peak_freq': [], 120 | 'peak_amp_iter': [], 121 | 'peak_amp_base': [], 122 | 'peak_time': [], 123 | 'peak_adjustment': [], 124 | 'fwhm_freq': [], 125 | 'fwhm_time': [], 126 | 'polarity': [], 127 | 'waveform_times': [] 128 | } 129 | ``` 130 | 131 | 132 | `trial` index of a trial where the burst was detected 133 | 134 | `waveform` array of burst waveforms [burst x time] 135 | 136 | `peak_freq` peak frequency of the burst 137 | 138 | `peak_amp_iter` relative TF amplitude of the burst during the iterations 139 | 140 | `peak_amp_base` absolute TF amplitude of the burst (with the aperiodic 141 | spectrum subtracted) 142 | 143 | `peak_time` TF peak time of the burst 144 | 145 | `peak_adjustment` adjustment of the peak in ms 146 | 147 | `fwhm_freq` frequency span of the burst 148 | 149 | `fwhm_time` duration of the burst 150 | 151 | `polarity` 0 - the polarity was not flipped, 1 - polarity was flipped 152 | 153 | `waveform_times` 1d array containing the timepoints for a waveform 154 | 155 | 156 | ## MATLAB 157 | 158 | ### DEPENDENCIES 159 | 160 | Fieldtrip 161 | 162 | 163 | ### USAGE 164 | ``` 165 | bursts = extract_bursts(raw_trials, tf, times, search_freqs, band_lims, aperiodic_spectrum, sfreq) 166 | ``` 167 | 168 | ``` 169 | bursts = extract_bursts_single_trial(raw_trial, tf, times, search_freqs, band_lims, aperiodic_spectrum, sfreq) 170 | ``` 171 | 172 | #### Arguments: 173 | 174 | AS ABOVE 175 | 176 | #### Output: 177 | 178 | AS ABOVE 179 | -------------------------------------------------------------------------------- /matlab/extract_bursts.m: -------------------------------------------------------------------------------- 1 | function bursts=extract_bursts(raw_trials, tf, times, search_freqs,... 2 | band_lims, aperiodic_spectrum, sfreq, varargin) 3 | % EXTRACT_BURSTS Extract bursts from epoched data 4 | % raw_trials: raw data for each trial (trial x time) 5 | % tf: time-frequency decomposition for each trial (trial x freq x time) 6 | % times: time steps 7 | % search_freqs: frequency limits to search within for bursts (should be 8 | % wider than band_lims) 9 | % band_lims: keep bursts whose peak frequency falls within these limits 10 | % aperiodic_spectrum: aperiodic spectrum 11 | % sfreq: sampling rate 12 | % sfreq: sampling rate 13 | % w_size: (optional) window size to extract burst waveforms (default=0.2) 14 | % returns: disctionary with trial, waveform, peak frequency, relative peak 15 | % amplitude, absolute peak amplitude, peak time, peak adjustment, FWHM in 16 | % frequency, FWHM in time, and polarity for each detected burst 17 | % Optional parameters are used as follows: 18 | % extract_bursts(...,'win_size',0.25) 19 | 20 | defaults = struct('win_size', 0.2); 21 | params = struct(varargin{:}); 22 | for f = fieldnames(defaults)' 23 | if ~isfield(params, f{1}) 24 | params.(f{1}) = defaults.(f{1}); 25 | end 26 | end 27 | 28 | bursts=[]; 29 | bursts.trial=[]; 30 | bursts.waveform=[]; 31 | bursts.peak_freq=[]; 32 | bursts.peak_amp_iter=[]; 33 | bursts.peak_amp_base=[]; 34 | bursts.peak_time=[]; 35 | bursts.peak_adjustment=[]; 36 | bursts.fwhm_freq=[]; 37 | bursts.fwhm_time=[]; 38 | bursts.polarity=[]; 39 | bursts.waveform_times=[]; 40 | 41 | % Compute event-related signal 42 | erf = mean(raw_trials); 43 | 44 | % Iterate through trials 45 | for t_idx=1:size(tf,1) 46 | tr_tf=squeeze(tf(t_idx,:,:)); 47 | 48 | % Regress out ERF 49 | lm=fitlm(erf, raw_trials(t_idx,:)); 50 | raw_trial=lm.Residuals.Raw'; 51 | 52 | % Extract bursts for this trial 53 | trial_bursts=extract_bursts_single_trial(raw_trial, tr_tf, times,... 54 | search_freqs, band_lims, aperiodic_spectrum, sfreq, 'win_size',... 55 | params.win_size); 56 | 57 | n_trial_bursts=length(trial_bursts.peak_time); 58 | bursts.trial(end+1:end+n_trial_bursts)=t_idx; 59 | bursts.waveform(end+1:end+n_trial_bursts,:)=trial_bursts.waveform; 60 | bursts.peak_freq(end+1:end+n_trial_bursts)=trial_bursts.peak_freq; 61 | bursts.peak_amp_iter(end+1:end+n_trial_bursts)=trial_bursts.peak_amp_iter; 62 | bursts.peak_amp_base(end+1:end+n_trial_bursts)=trial_bursts.peak_amp_base; 63 | bursts.peak_time(end+1:end+n_trial_bursts)=trial_bursts.peak_time; 64 | bursts.peak_adjustment(end+1:end+n_trial_bursts)=trial_bursts.peak_adjustment; 65 | bursts.fwhm_freq(end+1:end+n_trial_bursts)=trial_bursts.fwhm_freq; 66 | bursts.fwhm_time(end+1:end+n_trial_bursts)=trial_bursts.fwhm_time; 67 | bursts.polarity(end+1:end+n_trial_bursts)=trial_bursts.polarity; 68 | if ~isempty(trial_bursts.waveform_times) 69 | bursts.waveform_times = trial_bursts.waveform_times; 70 | end 71 | end 72 | 73 | end 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /matlab/extract_bursts_single_trial.m: -------------------------------------------------------------------------------- 1 | function bursts=extract_bursts_single_trial(raw_trial, tf, times,... 2 | search_freqs, band_lims, aperiodic_spectrum, sfreq, varargin) 3 | % EXTRACT_BURSTS_SINGLE_TRIAL Extract bursts from epoched data 4 | % raw_trials: raw data for trial (time) 5 | % tf: time-frequency decomposition for trial (freq x time) 6 | % times: time steps 7 | % search_freqs: frequency limits to search within for bursts (should be 8 | % wider than band_lims) 9 | % band_lims: keep bursts whose peak frequency falls within these limits 10 | % aperiodic_spectrum: aperiodic spectrum 11 | % sfreq: sampling rate 12 | % sfreq: sampling rate 13 | % w_size: (optional) window size to extract burst waveforms (default=0.2) 14 | % returns: disctionary with waveform, peak frequency, relative peak 15 | % amplitude, absolute peak amplitude, peak time, peak adjustment, FWHM in 16 | % frequency, FWHM in time, and polarity for each detected burst 17 | % Optional parameters are used as follows: 18 | % extract_bursts(...,'win_size',0.25) 19 | 20 | defaults = struct('win_size', .2); 21 | params = struct(varargin{:}); 22 | for f = fieldnames(defaults)' 23 | if ~isfield(params, f{1}) 24 | params.(f{1}) = defaults.(f{1}); 25 | end 26 | end 27 | 28 | bursts=[]; 29 | bursts.waveform=[]; 30 | bursts.peak_freq=[]; 31 | bursts.peak_amp_iter=[]; 32 | bursts.peak_amp_base=[]; 33 | bursts.peak_time=[]; 34 | bursts.peak_adjustment=[]; 35 | bursts.fwhm_freq=[]; 36 | bursts.fwhm_time=[]; 37 | bursts.polarity=[]; 38 | bursts.waveform_times = []; 39 | 40 | % Grid for computing 2D Gaussians 41 | [x_idx, y_idx] = meshgrid(1:length(times), 1:length(search_freqs)); 42 | 43 | % Window size in points 44 | wlen = round(params.win_size * sfreq); 45 | half_wlen = round(wlen * .5); 46 | 47 | % Subtract 1/f threshold 48 | trial_tf = tf - repmat(aperiodic_spectrum,1,size(tf,2)); 49 | trial_tf(trial_tf < 0) = 0; 50 | 51 | % skip the thing if: see the 52 | if all(trial_tf(:)==0) 53 | disp('All values equal 0 after aperiodic subtraction'); 54 | return; 55 | end 56 | 57 | % TF for iterating 58 | trial_tf_iter = trial_tf; 59 | 60 | while true 61 | % Compute noise floor 62 | thresh = 2 * std(trial_tf_iter(:)); 63 | 64 | % Find peak 65 | [~,I] = max(trial_tf_iter(:)); 66 | [peak_freq_idx, peak_time_idx] = ind2sub(size(trial_tf_iter),I); 67 | peak_freq = search_freqs(peak_freq_idx); 68 | peak_amp_iter = trial_tf_iter(peak_freq_idx, peak_time_idx); 69 | peak_amp_base = trial_tf(peak_freq_idx, peak_time_idx); 70 | if peak_amp_iter < thresh 71 | break 72 | end 73 | 74 | % Fit 2D Gaussian and subtract from TF 75 | [rloc, llec, uloc, dloc] = fwhm_burst_norm(trial_tf_iter,... 76 | [peak_freq_idx, peak_time_idx]); 77 | 78 | % REMOVE DEGENERATE GAUSSIAN 79 | vert_isnan = any(isnan([uloc, dloc])); 80 | horiz_isnan = any(isnan([rloc, llec])); 81 | if vert_isnan 82 | v_sh = round((length(search_freqs) - peak_freq_idx) / 2); 83 | if v_sh <= 0 84 | v_sh = 1; 85 | end 86 | uloc = v_sh; 87 | dloc = v_sh; 88 | 89 | elseif horiz_isnan 90 | h_sh = round((length(times) - peak_time_idx) / 2); 91 | if h_sh <= 0 92 | h_sh = 1; 93 | end 94 | rloc = h_sh; 95 | llec = h_sh; 96 | end 97 | 98 | hv_isnan = any([vert_isnan, horiz_isnan]); 99 | 100 | fwhm_f_idx = uloc + dloc; 101 | fwhm_f = (search_freqs(2)-search_freqs(1))*fwhm_f_idx; 102 | fwhm_t_idx = llec + rloc; 103 | fwhm_t = (times(2) - times(1))*fwhm_t_idx; 104 | sigma_t = (fwhm_t_idx) / 2.355; 105 | sigma_f = (fwhm_f_idx) / 2.355; 106 | z = peak_amp_iter * gaus2d(x_idx, y_idx, peak_time_idx,... 107 | peak_freq_idx, sigma_t, sigma_f); 108 | new_trial_TF_iter = trial_tf_iter - z; 109 | 110 | if peak_freq>=band_lims(1) && peak_freq<=band_lims(2) && ~hv_isnan 111 | % Bandpass filter 112 | freq_range = [max([1, peak_freq_idx - dloc]),... 113 | min([length(search_freqs) , peak_freq_idx + uloc])]; 114 | 115 | filtered = ft_preproc_bandpassfilter(raw_trial, sfreq,... 116 | search_freqs(freq_range), 6, 'but', 'twopass', 'reduce'); 117 | 118 | % Hilbert transform 119 | analytic_signal = hilbert(filtered); 120 | % Get phase 121 | instantaneous_phase = mod(unwrap(angle(analytic_signal)), pi); 122 | 123 | % Find local phase minima with negative deflection closest to TF peak 124 | % If no minimum is found, the error is caught and no burst is added 125 | [~,min_phase_pts]= findpeaks(-1*instantaneous_phase); 126 | if isempty(min_phase_pts) 127 | adjustment=inf; 128 | else 129 | [~,min_idx]=min(abs(peak_time_idx - min_phase_pts)); 130 | new_peak_time_idx = min_phase_pts(min_idx); 131 | adjustment = (new_peak_time_idx - peak_time_idx) * 1 / sfreq; 132 | end 133 | 134 | % Keep if adjustment less than 30ms 135 | if abs(adjustment) < .03 136 | 137 | % If burst won't be cutoff 138 | if new_peak_time_idx > half_wlen && new_peak_time_idx + half_wlen < length(raw_trial) 139 | peak_time = times(new_peak_time_idx); 140 | 141 | overlapped=false; 142 | % Check for overlap 143 | for b_idx=1:length(bursts.peak_time) 144 | o_t=bursts.peak_time(b_idx); 145 | o_fwhm_t=bursts.fwhm_time(b_idx); 146 | if overlap([peak_time-.5*fwhm_t, peak_time+.5*fwhm_t],... 147 | [o_t-.5*o_fwhm_t, o_t+.5*o_fwhm_t]) 148 | overlapped=true; 149 | break 150 | end 151 | end 152 | 153 | if ~overlapped 154 | % Get burst 155 | burst = raw_trial(new_peak_time_idx - half_wlen:new_peak_time_idx + half_wlen); 156 | % Remove DC offset 157 | burst = burst - mean(burst); 158 | bursts.waveform_times = times(new_peak_time_idx - half_wlen:new_peak_time_idx + half_wlen) - times(new_peak_time_idx); 159 | 160 | % Flip if positive deflection 161 | [~,peak_idxs]= findpeaks(filtered); 162 | peak_dists = abs(peak_idxs - new_peak_time_idx); 163 | [~,trough_idxs]= findpeaks(-1*filtered); 164 | trough_dists = abs(trough_idxs - new_peak_time_idx); 165 | 166 | polarity=0; 167 | if isempty(trough_dists) || (~isempty(peak_dists) && min(peak_dists) < min(trough_dists)) 168 | burst = burst*-1.0; 169 | polarity=1; 170 | end 171 | 172 | bursts.waveform(end+1,:)=burst; 173 | bursts.peak_freq(end+1)=peak_freq; 174 | bursts.peak_amp_iter(end+1)=peak_amp_iter; 175 | bursts.peak_amp_base(end+1)=peak_amp_base; 176 | bursts.peak_time(end+1)=peak_time; 177 | bursts.peak_adjustment(end+1)=adjustment; 178 | bursts.fwhm_freq(end+1)=fwhm_f; 179 | bursts.fwhm_time(end+1)=fwhm_t; 180 | bursts.polarity(end+1)=polarity; 181 | end 182 | end 183 | end 184 | end 185 | 186 | trial_tf_iter = new_trial_TF_iter; 187 | end 188 | end 189 | -------------------------------------------------------------------------------- /matlab/fwhm_burst_norm.m: -------------------------------------------------------------------------------- 1 | function [right_loc, left_loc, up_loc, down_loc]=fwhm_burst_norm(tf, peak) 2 | % FWHM_BURST_NORM Find two-dimensional FWHM 3 | % tf: TF spectrum 4 | % peak: peak of activity [freq, time] 5 | % returns: right, left, up, down limits for FWM 6 | 7 | half=tf(peak(1),peak(2))/2; 8 | 9 | right_loc = NaN; 10 | % Find right limit (values to right of peak less than half value at 11 | % peak) 12 | cand=find(tf(peak(1),peak(2):end)<=half); 13 | % If any found, take the first one 14 | if ~isempty(cand) 15 | right_loc=cand(1); 16 | end 17 | 18 | up_loc = NaN; 19 | % Find up limit (values above peak less than half value at peak) 20 | cand=find(tf(peak(1):end, peak(2)) <= half); 21 | % If any found, take the first one 22 | if ~isempty(cand) 23 | up_loc=cand(1); 24 | end 25 | 26 | left_loc = NaN; 27 | % Find left limit (values below peak less than half value at peak) 28 | cand=find(tf(peak(1),1:peak(2)-1)<=half); 29 | % If any found, take the last one 30 | if ~isempty(cand) 31 | left_loc = peak(2)-cand(end); 32 | end 33 | 34 | down_loc = NaN; 35 | % Find down limit (values below peak less than half value at peak) 36 | cand=find(tf(1:peak(1)-1,peak(2))<=half); 37 | % If any found, take the last one 38 | if ~isempty(cand) 39 | down_loc = peak(1)-cand(end); 40 | end 41 | 42 | % Set arms equal if only one found 43 | if isnan(down_loc) 44 | down_loc = up_loc; 45 | end 46 | if isnan(up_loc) 47 | up_loc = down_loc; 48 | end 49 | if isnan(left_loc) 50 | left_loc = right_loc; 51 | end 52 | if isnan(right_loc) 53 | right_loc = left_loc; 54 | end 55 | 56 | % Use the minimum arm in each direction (forces Gaussian to be 57 | % symmetric in each dimension) 58 | horiz = min([left_loc, right_loc]); 59 | vert = min([up_loc, down_loc]); 60 | right_loc = horiz; 61 | left_loc = horiz; 62 | up_loc = vert; 63 | down_loc = vert; 64 | end -------------------------------------------------------------------------------- /matlab/gaus2d.m: -------------------------------------------------------------------------------- 1 | function z=gaus2d(x, y, mx, my, sx, sy) 2 | % GAUS2D Two-dimensional gaussian function 3 | % x: x grid 4 | % y: y grid 5 | % mx: mean in x dimension 6 | % my: mean in y dimension 7 | % sx: standard deviation in x dimension 8 | % sy: standard deviation in y dimension 9 | % returns: Two-dimensional Gaussian distribution 10 | z=exp(-((x - mx).^2. / (2. * sx^2.) + (y - my).^2. / (2. * sy^2.))); 11 | end 12 | -------------------------------------------------------------------------------- /matlab/overlap.m: -------------------------------------------------------------------------------- 1 | function o=overlap(a,b) 2 | % Find if two ranges overlap 3 | % a: first range [low, high] 4 | % b: second range [low, high] 5 | % returns: True if ranges overlap, false otherwise 6 | o=(a(1)<=b(1) && b(1)<=a(2)) || (b(1)<=a(1) && a(1)<=b(2)); 7 | end -------------------------------------------------------------------------------- /python/burst_detection.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | from mne.filter import filter_data 5 | from scipy.signal import hilbert, argrelextrema 6 | from scipy.stats import linregress 7 | 8 | 9 | def gaus2d(x=0, y=0, mx=0, my=0, sx=1, sy=1): 10 | """ 11 | Two-dimensional gaussian function 12 | :param x: x grid 13 | :param y: y grid 14 | :param mx: mean in x dimension 15 | :param my: mean in y dimension 16 | :param sx: standard deviation in x dimension 17 | :param sy: standard deviation in y dimension 18 | :return: Two-dimensional Gaussian distribution 19 | """ 20 | return np.exp(-((x - mx) ** 2. / (2. * sx ** 2.) + (y - my) ** 2. / (2. * sy ** 2.))) 21 | 22 | 23 | def overlap(a, b): 24 | """ 25 | Find if two ranges overlap 26 | :param a: first range [low, high] 27 | :param b: second range [low, high] 28 | :return: True if ranges overlap, false otherwise 29 | """ 30 | return a[0] <= b[0] <= a[1] or b[0] <= a[0] <= b[1] 31 | 32 | 33 | def fwhm_burst_norm(tf, peak): 34 | """ 35 | Find two-dimensional FWHM 36 | :param tf: TF spectrum 37 | :param peak: peak of activity [freq, time] 38 | :return: right, left, up, down limits for FWM 39 | """ 40 | right_loc = np.nan 41 | # Find right limit (values to right of peak less than half value at peak) 42 | cand = np.where(tf[peak[0], peak[1]:] <= tf[peak] / 2)[0] 43 | # If any found, take the first one 44 | if len(cand): 45 | right_loc = cand[0] 46 | 47 | up_loc = np.nan 48 | # Find up limit (values above peak less than half value at peak) 49 | cand = np.where(tf[peak[0]:, peak[1]] <= tf[peak] / 2)[0] 50 | # If any found, take the first one 51 | if len(cand): 52 | up_loc = cand[0] 53 | 54 | left_loc = np.nan 55 | # Find left limit (values below peak less than half value at peak) 56 | cand = np.where(tf[peak[0], :peak[1]] <= tf[peak] / 2)[0] 57 | # If any found, take the last one 58 | if len(cand): 59 | left_loc = peak[1] - cand[-1] 60 | 61 | down_loc = np.nan 62 | # Find down limit (values below peak less than half value at peak) 63 | cand = np.where(tf[:peak[0], peak[1]] <= tf[peak] / 2)[0] 64 | # If any found, take the last one 65 | if len(cand): 66 | down_loc = peak[0] - cand[-1] 67 | 68 | # Set arms equal if only one found 69 | if down_loc is np.nan: 70 | down_loc = up_loc 71 | if up_loc is np.nan: 72 | up_loc = down_loc 73 | if left_loc is np.nan: 74 | left_loc = right_loc 75 | if right_loc is np.nan: 76 | right_loc = left_loc 77 | 78 | # Use the minimum arm in each direction (forces Gaussian to be symmetric in each dimension) 79 | horiz = np.nanmin([left_loc, right_loc]) 80 | vert = np.nanmin([up_loc, down_loc]) 81 | right_loc = horiz 82 | left_loc = horiz 83 | up_loc = vert 84 | down_loc = vert 85 | return right_loc, left_loc, up_loc, down_loc 86 | 87 | 88 | def extract_bursts_single_trial(raw_trial, tf, times, search_freqs, band_lims, aperiodic_spectrum, sfreq, w_size=.26): 89 | """ 90 | Extract bursts from epoched data 91 | :param raw_trial: raw data for trial (time) 92 | :param tf: time-frequency decomposition for trial (freq x time) 93 | :param times: time steps 94 | :param search_freqs: frequency limits to search within for bursts (should be wider than band_lims) 95 | :param band_lims: keep bursts whose peak frequency falls within these limits 96 | :param aperiodic_spectrum: aperiodic spectrum 97 | :param sfreq: sampling rate 98 | :param w_size: window size to extract burst waveforms 99 | :return: disctionary with waveform, peak frequency, relative peak amplitude, absolute peak amplitude, peak 100 | time, peak adjustment, FWHM in frequency, FWHM in time, and polarity for each detected burst 101 | """ 102 | bursts = { 103 | 'waveform': [], 104 | 'peak_freq': [], 105 | 'peak_amp_iter': [], 106 | 'peak_amp_base': [], 107 | 'peak_time': [], 108 | 'peak_adjustment': [], 109 | 'fwhm_freq': [], 110 | 'fwhm_time': [], 111 | 'polarity': [], 112 | 'waveform_times': [] 113 | } 114 | 115 | # Grid for computing 2D Gaussians 116 | x_idx, y_idx = np.meshgrid(range(len(times)), range(len(search_freqs))) 117 | 118 | # Window size in points 119 | wlen = int(w_size * sfreq) 120 | half_wlen = int(wlen * .5) 121 | 122 | # Subtract 1/f 123 | trial_tf = tf - aperiodic_spectrum 124 | trial_tf[trial_tf < 0] = 0 125 | 126 | # Skip trial if no peaks above aperiodic 127 | if (trial_tf == 0).all(): 128 | print("All values equal 0 after aperiodic subtraction") 129 | return bursts 130 | 131 | # TF for iterating 132 | trial_tf_iter = copy.copy(trial_tf) 133 | 134 | while True: 135 | # Compute noise floor 136 | thresh = 2 * np.std(trial_tf_iter) 137 | 138 | # Find peak 139 | [peak_freq_idx, peak_time_idx] = np.unravel_index(np.argmax(trial_tf_iter), trial_tf.shape) 140 | peak_freq = search_freqs[peak_freq_idx] 141 | peak_amp_iter = trial_tf_iter[peak_freq_idx, peak_time_idx] 142 | peak_amp_base = trial_tf[peak_freq_idx, peak_time_idx] 143 | # Stop if no peak above threshold 144 | if peak_amp_iter < thresh: 145 | break 146 | 147 | # Fit 2D Gaussian and subtract from TF 148 | rloc, lloc, uloc, dloc = fwhm_burst_norm(trial_tf_iter, (peak_freq_idx, peak_time_idx)) 149 | 150 | # Detect degenerate Gaussian (limits not found) 151 | vert_isnan = any(np.isnan([uloc, dloc])) 152 | horiz_isnan = any(np.isnan([rloc, lloc])) 153 | if vert_isnan: 154 | v_sh = int((search_freqs.shape[0] - peak_freq_idx) / 2) 155 | if v_sh <= 0: 156 | v_sh = 1 157 | uloc = v_sh 158 | dloc = v_sh 159 | elif horiz_isnan: 160 | h_sh = int((times.shape[0] - peak_time_idx) / 2) 161 | if h_sh <= 0: 162 | h_sh = 1 163 | rloc = h_sh 164 | lloc = h_sh 165 | hv_isnan = any([vert_isnan, horiz_isnan]) 166 | 167 | # Compute FWHM and convert to SD 168 | fwhm_f_idx = uloc + dloc 169 | fwhm_f = (search_freqs[1] - search_freqs[0]) * fwhm_f_idx 170 | fwhm_t_idx = lloc + rloc 171 | fwhm_t = (times[1] - times[0]) * fwhm_t_idx 172 | sigma_t = fwhm_t_idx / 2.355 173 | sigma_f = fwhm_f_idx / 2.355 174 | # Fitted Gaussian 175 | z = peak_amp_iter * gaus2d(x_idx, y_idx, mx=peak_time_idx, my=peak_freq_idx, sx=sigma_t, sy=sigma_f) 176 | # Subtract fitted Gaussian for next iteration 177 | new_trial_tf_iter = trial_tf_iter - z 178 | 179 | # If detected peak is within band limits and not degenerate 180 | if all([peak_freq >= band_lims[0], peak_freq <= band_lims[1], not hv_isnan]): 181 | # Bandpass filter within frequency range of burst 182 | freq_range = [ 183 | np.max([0, peak_freq_idx - dloc]), 184 | np.min([len(search_freqs) - 1, peak_freq_idx + uloc]) 185 | ] 186 | filtered = filter_data(raw_trial.reshape(1, -1), sfreq, search_freqs[freq_range[0]], search_freqs[freq_range[1]], 187 | verbose=False) 188 | 189 | # Hilbert transform 190 | analytic_signal = hilbert(filtered) 191 | # Get phase 192 | instantaneous_phase = np.unwrap(np.angle(analytic_signal)) % math.pi 193 | 194 | # Find local phase minima with negative deflection closest to TF peak 195 | # If no minimum is found, the error is caught and no burst is added 196 | min_phase_pts = argrelextrema(instantaneous_phase.T, np.less)[0] 197 | new_peak_time_idx = peak_time_idx 198 | try: 199 | new_peak_time_idx = min_phase_pts[np.argmin(np.abs(peak_time_idx - min_phase_pts))] 200 | adjustment = (new_peak_time_idx - peak_time_idx) * 1 / sfreq 201 | except: 202 | adjustment = 1 203 | 204 | # Keep if adjustment less than 30ms 205 | if np.abs(adjustment) < .03: 206 | 207 | # If burst won't be cutoff 208 | if new_peak_time_idx >= half_wlen and new_peak_time_idx + half_wlen <= len(times): 209 | peak_time = times[new_peak_time_idx] 210 | 211 | overlapped = False 212 | # Check for overlap 213 | for b_idx in range(len(bursts['peak_time'])): 214 | o_t = bursts['peak_time'][b_idx] 215 | o_fwhm_t = bursts['fwhm_time'][b_idx] 216 | if overlap([peak_time - .5 * fwhm_t, peak_time + .5 * fwhm_t], 217 | [o_t - .5 * o_fwhm_t, o_t + .5 * o_fwhm_t]): 218 | overlapped = True 219 | break 220 | 221 | if not overlapped: 222 | # Get burst 223 | burst = raw_trial[new_peak_time_idx - half_wlen:new_peak_time_idx + half_wlen] 224 | # Remove DC offset 225 | burst = burst - np.mean(burst) 226 | bursts['waveform_times'] = times[new_peak_time_idx - half_wlen:new_peak_time_idx + half_wlen] - \ 227 | times[new_peak_time_idx] 228 | 229 | # Flip if positive deflection 230 | peak_dists = np.abs(argrelextrema(filtered.T, np.greater)[0] - new_peak_time_idx) 231 | trough_dists = np.abs(argrelextrema(filtered.T, np.less)[0] - new_peak_time_idx) 232 | 233 | polarity = 0 234 | if len(trough_dists) == 0 or ( 235 | len(peak_dists) > 0 and np.min(peak_dists) < np.min(trough_dists)): 236 | burst *= -1.0 237 | polarity = 1 238 | 239 | bursts['waveform'].append(burst) 240 | bursts['peak_freq'].append(peak_freq) 241 | bursts['peak_amp_iter'].append(peak_amp_iter) 242 | bursts['peak_amp_base'].append(peak_amp_base) 243 | bursts['peak_time'].append(peak_time) 244 | bursts['peak_adjustment'].append(adjustment) 245 | bursts['fwhm_freq'].append(fwhm_f) 246 | bursts['fwhm_time'].append(fwhm_t) 247 | bursts['polarity'].append(polarity) 248 | 249 | trial_tf_iter = new_trial_tf_iter 250 | 251 | bursts['waveform'] = np.array(bursts['waveform']) 252 | bursts['peak_freq'] = np.array(bursts['peak_freq']) 253 | bursts['peak_amp_iter'] = np.array(bursts['peak_amp_iter']) 254 | bursts['peak_amp_base'] = np.array(bursts['peak_amp_base']) 255 | bursts['peak_time'] = np.array(bursts['peak_time']) 256 | bursts['peak_adjustment'] = np.array(bursts['peak_adjustment']) 257 | bursts['fwhm_freq'] = np.array(bursts['fwhm_freq']) 258 | bursts['fwhm_time'] = np.array(bursts['fwhm_time']) 259 | bursts['polarity'] = np.array(bursts['polarity']) 260 | 261 | return bursts 262 | 263 | 264 | def extract_bursts(raw_trials, tf, times, search_freqs, band_lims, aperiodic_spectrum, sfreq, w_size=.26): 265 | """ 266 | Extract bursts from epoched data 267 | :param raw_trials: raw data for each trial (trial x time) 268 | :param tf: time-frequency decomposition for each trial (trial x freq x time) 269 | :param times: time steps 270 | :param search_freqs: frequency limits to search within for bursts (should be wider than band_lims) 271 | :param band_lims: keep bursts whose peak frequency falls within these limits 272 | :param aperiodic_spectrum: aperiodic spectrum 273 | :param sfreq: sampling rate 274 | :param w_size: window size to extract burst waveforms 275 | :return: disctionary with trial, waveform, peak frequency, relative peak amplitude, absolute peak amplitude, peak 276 | time, peak adjustment, FWHM in frequency, FWHM in time, and polarity for each detected burst 277 | """ 278 | bursts = { 279 | 'trial': [], 280 | 'waveform': [], 281 | 'peak_freq': [], 282 | 'peak_amp_iter': [], 283 | 'peak_amp_base': [], 284 | 'peak_time': [], 285 | 'peak_adjustment': [], 286 | 'fwhm_freq': [], 287 | 'fwhm_time': [], 288 | 'polarity': [], 289 | 'waveform_times': [] 290 | } 291 | 292 | # Compute event-related signal 293 | erf = np.mean(raw_trials, axis=0) 294 | 295 | # Iterate through trials 296 | for t_idx, tr_tf in enumerate(tf): 297 | 298 | # Regress out ERF 299 | slope, intercept, r, p, se = linregress(erf, raw_trials[t_idx, :]) 300 | raw_trial = raw_trials[t_idx, :] - (intercept + slope * erf) 301 | 302 | trial_bursts=extract_bursts_single_trial(raw_trial, tr_tf, times, search_freqs, band_lims, aperiodic_spectrum, 303 | sfreq, w_size=w_size) 304 | 305 | n_trial_bursts=len(trial_bursts['peak_time']) 306 | bursts['trial'].extend([int(t_idx) for i in range(n_trial_bursts)]) 307 | bursts['waveform'].extend(trial_bursts['waveform']) 308 | bursts['peak_freq'].extend(trial_bursts['peak_freq']) 309 | bursts['peak_amp_iter'].extend(trial_bursts['peak_amp_iter']) 310 | bursts['peak_amp_base'].extend(trial_bursts['peak_amp_base']) 311 | bursts['peak_time'].extend(trial_bursts['peak_time']) 312 | bursts['peak_adjustment'].extend(trial_bursts['peak_adjustment']) 313 | bursts['fwhm_freq'].extend(trial_bursts['fwhm_freq']) 314 | bursts['fwhm_time'].extend(trial_bursts['fwhm_time']) 315 | bursts['polarity'].extend(trial_bursts['polarity']) 316 | if len(trial_bursts['waveform_times']): 317 | bursts['waveform_times'] = trial_bursts['waveform_times'] 318 | 319 | bursts['trial'] = np.array(bursts['trial']) 320 | bursts['waveform'] = np.array(bursts['waveform']) 321 | bursts['peak_freq'] = np.array(bursts['peak_freq']) 322 | bursts['peak_amp_iter'] = np.array(bursts['peak_amp_iter']) 323 | bursts['peak_amp_base'] = np.array(bursts['peak_amp_base']) 324 | bursts['peak_time'] = np.array(bursts['peak_time']) 325 | bursts['peak_adjustment'] = np.array(bursts['peak_adjustment']) 326 | bursts['fwhm_freq'] = np.array(bursts['fwhm_freq']) 327 | bursts['fwhm_time'] = np.array(bursts['fwhm_time']) 328 | bursts['polarity'] = np.array(bursts['polarity']) 329 | 330 | return bursts 331 | -------------------------------------------------------------------------------- /python/tests.py: -------------------------------------------------------------------------------- 1 | from burst_detection import fwhm_burst_norm 2 | import numpy as np 3 | import unittest 4 | 5 | 6 | class FWHM_testing(unittest.TestCase): 7 | def setUp(self): 8 | self.none = np.nan 9 | self.peak_value = 10 10 | self.peak_loc = (50,50) 11 | self.edge1_loc = (50,99) 12 | self.corner_loc = (99,99) 13 | self.square_loc = (40,60) 14 | self.empty = np.zeros((100,100)) 15 | self.single = np.zeros((100,100)) 16 | self.single[self.peak_loc] = self.peak_value 17 | self.square = np.zeros((100,100)) 18 | self.square[ 19 | self.square_loc[0]:self.square_loc[1], 20 | self.square_loc[0]:self.square_loc[1] 21 | ] = self.peak_value 22 | self.edge1 = np.zeros((100,100)) 23 | self.edge1[self.edge1_loc] = self.peak_value 24 | self.corner = np.zeros((100,100)) 25 | self.corner[self.corner_loc] = self.peak_value 26 | self.strip = np.zeros((100,100)) 27 | self.strip[50,:] = 10 28 | 29 | 30 | def test_empty(self): 31 | self.assertEqual( 32 | fwhm_burst_norm(self.empty, (self.peak_loc[0], self.peak_loc[1])), 33 | (0, 0, 0, 0) 34 | ) 35 | 36 | def test_single_peak(self): 37 | self.assertEqual( 38 | fwhm_burst_norm(self.single, (self.peak_loc[0], self.peak_loc[1])), 39 | (1, 1, 1, 1) 40 | ) 41 | 42 | def test_square_peak(self): 43 | self.assertEqual( 44 | fwhm_burst_norm(self.square, (self.peak_loc[0], self.peak_loc[1])), 45 | (10, 10, 10, 10) 46 | ) 47 | 48 | def test_edge_peak(self): 49 | self.assertEqual( 50 | fwhm_burst_norm(self.edge1, (self.edge1_loc[0], self.edge1_loc[1])), 51 | (1, 1, 1, 1) 52 | ) 53 | 54 | def test_corner_peak(self): 55 | self.assertEqual( 56 | fwhm_burst_norm(self.corner, (self.corner_loc[0], self.corner_loc[1])), 57 | (1, 1, 1, 1) 58 | ) 59 | 60 | def test_strip_peak(self): 61 | self.assertEqual( 62 | fwhm_burst_norm(self.strip, (self.peak_loc[0], self.peak_loc[1])), 63 | (self.none, self.none, 1, 1) 64 | ) 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | --------------------------------------------------------------------------------