├── sort-tools ├── loadRawElecDatWin.m ├── compVpredictionSprse.m ├── compWprojW.m ├── samefilt.m ├── circxcorr.m ├── estimWaveforms.m ├── loadWhiteElecDatWin.m ├── estimSps_BinaryPursuit.m ├── compWhitening.m └── runBinaryPursuit.m ├── LICENSE ├── README.md ├── RUNME_bpspikesorting.m ├── step5_analyzePerformance_simData.m ├── step3_reestimWaveforms.m ├── step2_WhitenData.m ├── step1_estimWaveforms.m ├── step4_BPspikesort.m ├── setSpikeSortParams.m └── script0_simulateDataForTesting.m /sort-tools/loadRawElecDatWin.m: -------------------------------------------------------------------------------- 1 | function Ydat = loadRawElecDatWin(twin,filenamestring) 2 | % Ydat = loadRawElecDatWin(twin,filenamestring) 3 | % 4 | % Loads electrode data for all electrodes in given time window 5 | % 6 | % Input: 7 | % twin = [t0,t1]. Loads all samples from index (t0+1) to (t1); 8 | % filenamestring = filename to load 9 | % 10 | % Note: this simple version loads a single file, but for longer experiments, users should 11 | % write their own function so that the user passes in the desired time index range, and the function is 12 | % clever about loading the relevant data files and stitching them together (if needed). 13 | 14 | Ydat = struct2array(load(filenamestring)); 15 | Ydat = Ydat(twin(1)+1:twin(2),:); 16 | -------------------------------------------------------------------------------- /sort-tools/compVpredictionSprse.m: -------------------------------------------------------------------------------- 1 | function Vpred = compVpredictionSprse(xsp, W) 2 | % Vpred = compVpredictionSprse(xsp, ww) 3 | % 4 | % Computes sparse binary xsp convolved with ww spike waveforms ww 5 | % 6 | % INPUT: 7 | % xsp [nt x nneur] - sparse binary matrix, each column is a spike train 8 | % ww [nw x nelec x nneur] - tensor of spike waveforms 9 | % 10 | % OUTPUT: 11 | % Vpred [nt x nelec] - convolution of xsp with ww 12 | % 13 | % jw pillow 8/18/2014 14 | 15 | 16 | nc = size(xsp,2); % number of cells 17 | wwid = size(W,1)/2; % 1/2 length of spike waveform 18 | iirel = (0:wwid*2-1)'; % relative time indices 19 | slen = size(xsp,1); % number of time samples 20 | 21 | Vpred = zeros(slen+wwid*2,size(W,2)); % allocate memory (with padding at beginning and end) 22 | 23 | for j = 1:nc % loop over neurons 24 | isp = find(xsp(:,j)); % find the spike times 25 | for i = 1:length(isp); 26 | ii = isp(i)+iirel; 27 | Vpred(ii,:) = Vpred(ii,:)+W(:,:,j); 28 | end 29 | end 30 | Vpred = Vpred(wwid+1:end-wwid,:); % remove padding at the end 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Jonathan Pillow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sort-tools/compWprojW.m: -------------------------------------------------------------------------------- 1 | function [wproj,wnorm] = compWprojW(W) 2 | % [wproj,wnorm] = compWprojW(W) 3 | % 4 | % Computes the full time-convolution of each waveform with every other waveform ( 5 | % (This convolution is 'valid' in the rows, and 'full' in the columns) 6 | % 7 | % INPUT: 8 | % W [ntime x nelectrodes x ncells ] - tensor of spike waveforms 9 | % 10 | % OUTPUT: 11 | % wproj [ 2*ntime-1 x ncells x ncells] - tensory of waveforms convolved w each other 12 | % wnorm [ ncells x 1 ] - squared L2 norm of each waveform 13 | % 14 | % jw pillow 8/18/2014 15 | 16 | 17 | [nw,~,nc] = size(W); 18 | 19 | wproj = zeros(2*nw-1,nc,nc); 20 | for jcell = 1:nc 21 | for icell = jcell:nc % compute conv of w(icell) convolved with w(jcell) 22 | zz = W(:,:,jcell)*W(:,:,icell)'; 23 | for j = 1:nw 24 | wproj(j:j+nw-1,icell,jcell) = wproj(j:j+nw-1,icell,jcell) + zz(:,nw-j+1); 25 | end 26 | % for all cell pairs (not including auto-correlation) 27 | if icell > jcell 28 | wproj(:,jcell,icell) = flipud(wproj(:,icell,jcell)); 29 | end 30 | end 31 | end 32 | wnorm = diag(squeeze(wproj(nw,:,:))); % dot prod of each waveform with itself 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | BinaryPursuitSpikeSorting 2 | ========================= 3 | 4 | Detect synchronous and overlapped spikes in extracellular recordings 5 | using the Binary Pursuit algorithm, as described in 6 | [Pillow et al 2013](http://pillowlab.princeton.edu/pubs/abs_Pillow_PLOSONE13.html). 7 | 8 | You can find a blog post describing the basic geometry of the problem 9 | and our algorithm for solving it 10 | [here](https://pillowlab.wordpress.com/tag/spike-sorting/). 11 | 12 | Downloading the repository 13 | ------------ 14 | 15 | - **From command line:** 16 | 17 | ```git clone https://github.com/jpillow/BinaryPursuitSpikeSorting.git``` 18 | 19 | - **In browser:** click to 20 | [Download ZIP](https://github.com/jpillow/BinaryPursuitSpikeSorting/archive/master.zip) 21 | and then unzip archive 22 | 23 | 24 | Example Script 25 | - 26 | Open ``RUNME_bpspikesorting.m`` to see it in action using a simulated 27 | dataset 28 | 29 | ##Reference 30 | 31 | - J. W. Pillow, J. Shlens, E. J. Chichilnisky, & E. P. Simoncelli 32 | (2013). 33 | [A model-based spike sorting algorithm for removing correlation artifacts in multi-neuron recordings](http://pillowlab.princeton.edu/pubs/abs_Pillow_PLOSONE13.html) PLoS ONE 8(5) 1-14. 34 | -------------------------------------------------------------------------------- /sort-tools/samefilt.m: -------------------------------------------------------------------------------- 1 | function [G,G1,G2] = samefilt(A, B, str) 2 | % [G,G1,G2] = samefilt(A, B, flag) 3 | % 4 | % Filters matrix A with each (same-width) filters contained in B. 5 | % 6 | % Input: A = tall matrix 7 | % B = 3-tensor with filters B(:,:,j) the same width as A 8 | % flag = set to 'conv' for B to be vertically flipped; 9 | % default is 'filt' (optional) 10 | % 11 | % Output: 12 | % G = same height as A and width is size(B,3). (Vertical 13 | % dimension clipped as with 'same' flag under conv2.m) 14 | % G1,G2 = pieces before and after if 'full' temporal conv2 desired 15 | % 16 | % Notes: 17 | % - Convolution performed using matrix multiplication: 18 | % best when B is short but fat 19 | % 20 | % jw pillow 8/19/2014 21 | 22 | if (nargin <= 2) 23 | str = 'filt'; 24 | end 25 | 26 | if strcmp(str, 'conv') 27 | B = flipdim(B,1); 28 | elseif ~strcmp(str, 'filt') 29 | error('Unrecognized string: samefilt.m'); 30 | end 31 | 32 | [am, an] = size(A); 33 | [bm, bn,nflt] = size(B); 34 | nn = am+bm-1; 35 | npre = ceil((bm-1)/2); 36 | npost = floor((bm-1)/2); 37 | 38 | % Do convolution 39 | G = zeros(nn,nflt); 40 | for i = 1:nflt 41 | yy = A*B(:,:,i)'; 42 | for j = 1:bm 43 | G(j:j+am-1,i) = G(j:j+am-1,i) + yy(:,bm-j+1); 44 | end 45 | end 46 | 47 | if nargout > 1 48 | G1 = G(1:npre,:); 49 | end 50 | if nargout > 2 51 | G2 = G(nn-npost+1:nn,:); 52 | end 53 | 54 | G = G(npre+1:nn-npost,:); 55 | -------------------------------------------------------------------------------- /RUNME_bpspikesorting.m: -------------------------------------------------------------------------------- 1 | % Example script to generate a simulated dataset and then illustrate the steps in our 2 | % spike sorting algorithm. 3 | 4 | % 0. Generate a simulated dataset 5 | script0_simulateDataForTesting; 6 | 7 | %% 1. Estimate spike waveforms for each neuron using spike times provided 8 | step1_estimWaveforms 9 | 10 | %% 2. Estimate the temporal and spatial noise covariance and whiten the raw data 11 | step2_WhitenData; 12 | 13 | %% 3. Re-estimate spike waveforms using whitened data 14 | step3_reestimWaveforms; 15 | 16 | %% 4. Run binary pursuit: identify the spike times given waveforms and whitened data 17 | step4_BPspikesort; 18 | 19 | %% 5. Compare simulated and estimated spike trains (ONLY RELEVANT FOR SIMULATED DATA) 20 | step5_analyzePerformance_simData; 21 | 22 | 23 | % =============================================================================== 24 | % NOTE: To run this on your own data, you'll need to: 25 | % 26 | % (1) specify the number of neurons and identify 27 | % enough spike times from each of them to get a reasonable first-pass estimate of the 28 | % spike waveforms (e.g., using clustering). 29 | % 30 | % (2) Specify a few dataset and processing dependent parameters (e.g., number of 31 | % electrodes, number of neurons, sample rate, number of time bins in spike waveforms, 32 | % path to the raw data, a function for loading time windowed chunks of data) in the 33 | % general script 'setSpikeSortParams.m' 34 | % 35 | % I recommend opening up each of the step_x scripts above to get a sense of what's 36 | % happening in each step. Please feel free to email me with comments, questions, 37 | % suggestions and bug reports (pillow@princeton.edu). 38 | % =============================================================================== 39 | -------------------------------------------------------------------------------- /step5_analyzePerformance_simData.m: -------------------------------------------------------------------------------- 1 | % step5_analyzePerformance_simData.m 2 | % -------------------- 3 | % Script to examine performance on simulated data 4 | % 5 | % Note that you can vary performance on simulated data by changing the 'nsesig' 6 | % parameter in the script 'script0_simulateDataForTesting.m', which controls SNR 7 | 8 | % Set path and loads relevant data structures: 'sdat', 'dirlist', 'filelist' 9 | setSpikeSortParams; 10 | 11 | % ---- Load initial estimate of spike times (sparse nsamps x ncell array) ------------ 12 | Xtrue = struct2array(load('dat/simdata/Xsp_true.mat')); % true spikes 13 | Xhat = struct2array(load(filelist.Xhat,'Xhat')); % estimated spikes 14 | 15 | % Set tolerance in spike time for counting a hit or misses 16 | tol = 3; % in bins 17 | 18 | % Compute Hits for each time shift 19 | slen = sdat.nsamps; 20 | Xhit = zeros(2*tol+1,sdat.ncell); 21 | XFA = zeros(1,sdat.ncell); 22 | for j = 1:(2*tol+1) 23 | ii1 = max(1,j-tol):min(slen,slen-tol+j-1); 24 | ii2 = max(1,tol-j+2):min(slen,slen-j+tol+1); 25 | Xhit(j,:) = sum(Xtrue(ii1,:).*Xhat(ii2,:)); 26 | end 27 | 28 | %% Report Quantitative Performace 29 | 30 | % Make fig showing p(Hit) vs. time shift 31 | nHits = sum(Xhit); 32 | FracHits = Xhit./repmat(sum(Xtrue),2*tol+1,1); 33 | clf; 34 | plot(-tol:tol, FracHits); 35 | title('P(Hit) vs. time shift'); 36 | xlabel('time shift (bins)'); ylabel('P(hit)'); 37 | 38 | % Report Misses 39 | nMiss = sum(Xtrue)-nHits; 40 | fracMiss = nMiss./sum(Xtrue); 41 | fprintf('\n--- Misses ----\n'); 42 | for j = 1:sdat.ncell 43 | fprintf('cell %d: %d (frac = %.3f)\n', j,nMiss(j), fracMiss(j)); 44 | end 45 | 46 | % Report FAs 47 | nFA = sum(Xhat)-sum(Xhit); 48 | fracFA = nFA./sum(Xhat); 49 | fprintf('\n--- False Alarms ----\n'); 50 | for j = 1:sdat.ncell 51 | fprintf('cell %d: %d (frac = %.3f)\n', j,nFA(j), fracFA(j)); 52 | end 53 | -------------------------------------------------------------------------------- /sort-tools/circxcorr.m: -------------------------------------------------------------------------------- 1 | function xc = circxcorr(x1, x2, maxlag, str) 2 | % xc = circxcorr(x1, x2, maxlag, str); %% cross-correlation of x1,x2 3 | % or 4 | % xc = circxcorr(x1,maxlag,str); %% auto-correlation 5 | % 6 | % Performs convolution of x1 and x2 with circular boundary conditions (in Fourier domain) 7 | % 8 | % jw pillow 8/19/2014 9 | 10 | 11 | % parse inputs 12 | slen = length(x1); 13 | switch nargin 14 | case 1, % Full auto-correlation 15 | nargs=1; LAGflag=0; str=[]; 16 | case 2, 17 | if isnumeric(x2) & (length(x2)==1), % 1 arg, with maxlag 18 | nargs=1; LAGflag=1; maxlag=x2; str=[]; 19 | elseif ischar(x2) % 1 arg, no lag, scaling 20 | nargs=1; LAGflag=0; str=x2; 21 | elseif length(x2) == slen % 2 args 22 | nargs=2; LAGflag=0; str=[]; 23 | else 24 | error('unrecognized inputs'); 25 | end 26 | case 3, 27 | if isnumeric(x2) & (length(x2)==1), % 1 arg, with maxlag 28 | nargs = 1; LAGflag=1; str=maxlag; maxlag=x2; 29 | elseif isnumeric(maxlag) % 1 arg, no lag, scaling 30 | nargs=2; LAGflag=1; str=[]; 31 | elseif ischar(maxlag) 32 | nargs=2; LAGflag=0; str=maxlag; 33 | else 34 | error('unrecognized inputs'); 35 | end 36 | case 4, 37 | nargs = 2; LAGflag=1; 38 | end 39 | 40 | if (nargs == 1) 41 | xc = ifft(abs(fft(x1)).^2); 42 | else 43 | xc = ifft(conj(fft(x1)).*fft(x2)); 44 | end 45 | 46 | if (LAGflag) 47 | ii = [slen-maxlag+1:slen, 1:maxlag+1]; 48 | xc = xc(ii); 49 | end 50 | 51 | if ~isempty(str) 52 | if strcmp(str,'unbiased'); 53 | xc = xc./(slen-1); 54 | elseif strcmp(str,'biased') 55 | xc = xc./slen; 56 | elseif strcmp(str,'none'); 57 | else 58 | str 59 | error('unrecognized SCALE OPTION (str) for circxcorr.m'); 60 | end 61 | end 62 | 63 | -------------------------------------------------------------------------------- /step3_reestimWaveforms.m: -------------------------------------------------------------------------------- 1 | % step3_reestimWaveforms.m 2 | % ------------------------- 3 | % Estimate spike waveforms from initial subset of spikes (after whitening) 4 | 5 | % Set path and loads relevant data structures: 'sdat', 'dirlist', 'filelist' 6 | setSpikeSortParams; 7 | 8 | % ---- Load initial estimate of spike times (sparse nsamps x ncell array) ------------ 9 | Xsp = struct2array(load(filelist.initspikes)); % loads variable 'Xsp_init' 10 | 11 | % --- Set some params governing block size for estimating waveforms ------ 12 | nsamps = sdat.nsamps; % Number of total samples 13 | nsampsPerW = sdat.nsampsPerW; % number of samples to use for each estimate 14 | nWblocks = ceil(nsamps/nsampsPerW); % number of blocks for estimating waveforms 15 | 16 | 17 | %% Estimate waveform independently on each chunk of data (here 30s worth / chunk) 18 | for blocknum = 1:nWblocks 19 | fprintf('Step 3: estimating whitened waveforms (block %d/%d)\n', blocknum,nWblocks); 20 | 21 | % --- Set time window and load relevant data ---- 22 | twin = [(blocknum-1)*nsampsPerW, min(blocknum*nsampsPerW,nsamps)]; % time window 23 | Y = loadwhitenedY(twin); 24 | 25 | % --- Estimate spike waveform ---- 26 | [W,wsigs] = estimWaveforms(Xsp(twin(1)+1:twin(2),:),Y,sdat.nw); 27 | 28 | % --- Prune waveforms --- 29 | % Estimate the electrode support for each waveform 30 | % (which electrodes carry non-zero signal) 31 | % Not implemented for now. 32 | 33 | % --- Save out waveforms --- 34 | savename = sprintf(filelist.Wwht, blocknum); 35 | save(savename, 'W', 'twin'); 36 | 37 | end 38 | 39 | %% Make plot of estimated waveforms 40 | 41 | % NOTE: if support of waveform extends outside plotted window, consider shifting spike 42 | % times or increase sdat.nsampsPerW 43 | for j = (1:sdat.ncell) 44 | subplot(sdat.ncell,1,j); 45 | plot(1:sdat.nw, W(:,:,j),'b'); axis tight; 46 | ylabel(sprintf('cell %d',j)); 47 | end 48 | xlabel('time (bins)'); 49 | -------------------------------------------------------------------------------- /sort-tools/estimWaveforms.m: -------------------------------------------------------------------------------- 1 | function [What,wwsigs]=estimWaveforms(X, Y, nw) 2 | % [What,wwsigs]=estimWaveforms(X, Y, nw) 3 | % 4 | % Computes estimate of spike waveform for cells from spike trains and 5 | % electrode data, using (correct) least-squares regression 6 | % 7 | % Input: 8 | % ------ 9 | % X [nsamps x ncells] - each column holds spike train of single neuron 10 | % Y [nsamps x nelec] - raw electrode data 11 | % nw [1 x 1] - number of time bins in the spike waveform 12 | % 13 | % Output: 14 | % ------- 15 | % What [nw x ne x ncells] - estimated waveforms for each cell 16 | % sigs [ncells x 1] - posterior stdev of each neuron's waveform coefficients 17 | % 18 | % jw pillow 8/18/2014 19 | 20 | 21 | [nt,nc] = size(X); % number of time bins and number of cells 22 | ne = size(Y,2); % number of electrodes 23 | nw2 = nw/2; 24 | 25 | % Compute blocks for covariance matrix XX and cross-covariance XY 26 | XXblocks = zeros(nc*nw,nc); 27 | XY = zeros(nc*nw,ne); 28 | for jj = 1:nw 29 | inds = ((jj-1)*nc+1):(jj*nc); 30 | XXblocks(inds,:) = X(1:end-jj+1,:)'*X(jj:end,:); % spike train covariance 31 | XY(inds,:) = X(max(1,nw2-jj+2):min(nt,nt-jj+nw2+1),:)'*... 32 | Y(max(1,jj-nw2):min(nt,nt+jj-nw2-1),:); % cross-covariance 33 | end 34 | 35 | % Insert blocks into covariance matrix XX 36 | XX = zeros(nc*nw,nc*nw); 37 | for jj = 1:nw 38 | inds1 = ((jj-1)*nc+1):(nc*nw); 39 | inds2 = ((jj-1)*nc+1):(jj*nc); 40 | XX(inds1,inds2) = XXblocks(1:(nw-jj+1)*nc,:); % below diagonal blocks 41 | XX(inds2,inds1) = XXblocks(1:(nw-jj+1)*nc,:)'; % above diagonal blocks 42 | end 43 | What = XX\XY; % do regression 44 | What = permute(reshape(What,[nc,nw,ne]),[2,3,1]); % reshape into tensor 45 | 46 | % 4. If desired, compute posterior variance for each waveform (function of # spikes) 47 | if nargout > 1 48 | wwsigs = sqrt(1./diag(XX(1:nc,1:nc))); 49 | end 50 | % Note: the "correct" formula should be diag(inv(Xcov)), but this is 51 | % close, ignoring edge effects, and much faster; -------------------------------------------------------------------------------- /step2_WhitenData.m: -------------------------------------------------------------------------------- 1 | % step2_whitenData.m 2 | % 3 | % Compute whitening filters with current residuals and apply to raw electrode data 4 | 5 | % Set path and loads relevant data structures: 'sdat','dirlist', 'filelist' 6 | setSpikeSortParams; 7 | 8 | % --- Load initial estimate of spike times (sparse nsamps x ncell array) ---------- 9 | Xsp = struct2array(load(filelist.initspikes)); % loads variable 'Xsp_init' 10 | 11 | % --- Set params governing block size ------ 12 | nsamps = sdat.nsamps; % Number of total samples 13 | nsampsPerW = sdat.nsampsPerW; % number of samples to use for each estimate 14 | nWblocks = ceil(nsamps/nsampsPerW); % number of blocks for estimating waveforms 15 | 16 | % --- Set params governing the whitening ------ 17 | nxc_t = 16; % # time bins for temporal whitening filter 18 | nxc_x = 5; % # time bins to use while spatial whitening (not too big, <= 9) 19 | 20 | % --- Determine number of "chunks" to divide whitened data into (for BP sort) ----- 21 | nsPerChunk = sdat.nsampsPerBPchunk; % number of samples per chunk 22 | nchunksPerBlock = nsampsPerW/sdat.nsampsPerBPchunk; % number of chunks per "waveform block" 23 | 24 | 25 | % Loop over blocks 26 | for blocknum = 1:nWblocks 27 | fprintf('Step 2: whitening residuals (block %d/%d)\n',blocknum,nWblocks); 28 | 29 | % --- Set time window and load relevant data ---- 30 | twin = [(blocknum-1)*nsampsPerW, min(blocknum*nsampsPerW,nsamps)]; % time window 31 | Ydat = loadrawY(twin); % load the relevant block of electrode data 32 | load(sprintf(filelist.Wraw, blocknum),'W'); % load spike waveform 33 | 34 | %--- Compute whitening filters & whiten electrode data ------------------ 35 | [Ywht,tfilts,xfilts] = ... 36 | compWhitening(Xsp(twin(1)+1:twin(2),:),Ydat,W,nxc_t,nxc_x); 37 | 38 | % Save out whitened data in chunks 39 | for ichunknum = 1:nchunksPerBlock; 40 | chunknum = (blocknum-1)*nchunksPerBlock+ichunknum; 41 | savename = sprintf(filelist.Ywht, chunknum); 42 | Y = Ywht(((ichunknum-1)*nsPerChunk+1):(ichunknum*nsPerChunk),:); 43 | save(savename, 'Y'); 44 | end 45 | 46 | end 47 | -------------------------------------------------------------------------------- /step1_estimWaveforms.m: -------------------------------------------------------------------------------- 1 | % step1_estimWaveforms.m 2 | % ------------------------- 3 | % Estimate spike waveforms from initial subset of spikes found (before whitening) 4 | 5 | % Set path and loads relevant data structures: 'sdat', 'dirlist', 'filelist' 6 | setSpikeSortParams; 7 | 8 | % Is simulation? (for comparisons to ground truth) 9 | isSIM = 1; % Set to true only when running demo code with simulation data 10 | 11 | % ---- Load initial estimate of spike times (sparse nsamps x ncell array) ----- 12 | Xsp = struct2array(load(filelist.initspikes)); % loads variable 'Xsp_init' 13 | 14 | % --- Set some params governing block size for estimating waveforms ------ 15 | nsamps = sdat.nsamps; % Number of total samples 16 | nsampsPerW = sdat.nsampsPerW; % number of samples to use for each estimate 17 | nWblocks = ceil(nsamps/nsampsPerW); % number of blocks for estimating waveforms 18 | 19 | 20 | %% Estimate waveform independently on each chunk of data (here 30s worth / chunk) 21 | for blocknum = 1:nWblocks 22 | fprintf('Step 1: estimating pre-whitened waveforms (block %d/%d)\n', blocknum,nWblocks); 23 | 24 | % --- Set time window and load electrode data --- 25 | twin = [(blocknum-1)*nsampsPerW, min(blocknum*nsampsPerW,nsamps)]; % time window 26 | Ydat = loadrawY(twin); % load the relevant block of electrode data 27 | 28 | % --- Estimate spike waveform ---- 29 | W = estimWaveforms(Xsp(twin(1)+1:twin(2),:),Ydat,sdat.nw); 30 | 31 | % --- Save out waveforms --- 32 | savename = sprintf(filelist.Wraw, blocknum); 33 | save(savename, 'W', 'twin'); 34 | 35 | end 36 | 37 | %% Make plot of estimated waveforms 38 | 39 | % NOTE: if support of waveform extends outside plotted window, consider shifting spike 40 | % times or increase sdat.nsampsPerW 41 | for j = (1:sdat.ncell) 42 | subplot(sdat.ncell,1,j); 43 | plot(1:sdat.nw, W(:,:,j),'b'); 44 | ylabel(sprintf('cell %d',j)); 45 | end 46 | xlabel('time (bins)'); 47 | 48 | % === Compare esimated and true waveform (SIMULATION ONLY) === 49 | if isSIM 50 | ww = load('dat/simdata/W_true.mat'); % load true waveforms 51 | for j = (1:sdat.ncell) 52 | subplot(sdat.ncell,1,j); hold on; 53 | plot(1:size(ww.W,1),ww.W(:,:,j),'r--'); hold off; 54 | end 55 | end 56 | -------------------------------------------------------------------------------- /sort-tools/loadWhiteElecDatWin.m: -------------------------------------------------------------------------------- 1 | function Ydat = loadWhiteElecDatWin(twin,filenamestring,nsampsperchunk) 2 | % Ydat = loadRawElecDatWin(twin,filenamestring,nsampsperchunk) 3 | % 4 | % Loads whitened electrode data for all electrodes in given time window 5 | % 6 | % Input: 7 | % twin = [t0,t1]. Loads all samples from index (t0+1) to (t1); 8 | % filenamestring = filename to load (contains '%d' for chunk number) 9 | % nsampsperchunk = number of samples per file 10 | % 11 | % Note: this simple version loads a single file, but for longer experiments, users should 12 | % write their own function so that the user passes in the desired time index range, and the function is 13 | % clever about loading the relevant data files and stitching them together (if needed). 14 | % 15 | % jw pillow 8/18/2014 16 | 17 | slen = diff(twin); % number of time bins (rows in matrix) 18 | ichunk1 = floor(twin(1)/nsampsperchunk)+1; % index for first chunk 19 | ichunkN = ceil(twin(2)/nsampsperchunk); % index for last chunk 20 | i1 = twin(1)-((ichunk1-1)*nsampsperchunk)+1; % index of first sample in 1st chunk 21 | iN = twin(2)-((ichunkN-1)*nsampsperchunk); % index of last sample in last chunk 22 | 23 | % Load first chunk 24 | Ydat = struct2array(load(sprintf(filenamestring,ichunk1))); 25 | 26 | % --- Determine if we need multiple chunks ------- 27 | if ichunk1 == ichunkN 28 | 29 | % Remove rows if necessary 30 | if (i1>1) || (iN < nsampsperchunk) 31 | Ydat = Ydat(i1:iN,:); 32 | end 33 | 34 | else % ---- Concatenate multiple chunks together ---- 35 | 36 | % Remove initial rows, if necessary 37 | if i1>1 38 | Ydat = Ydat(i1:end,:); 39 | end 40 | 41 | % Allocate space for remaining samples 42 | [n1,nelec] = size(Ydat); % number of samples and number of electrodes 43 | Ydat = [Ydat; zeros(slen-n1,nelec)]; 44 | 45 | % Load additional chunks 46 | for ichunk = ichunk1+1:ichunkN-1 47 | Ychnk = struct2array(load(sprintf(filenamestring,ichunk))); 48 | Ydat(n1+(ichunk-ichunk1-1)*nsampsperchunk+(1:nsampsperchunk),:) = Ychnk; 49 | end 50 | 51 | % Load last chunk 52 | Ychnk = struct2array(load(sprintf(filenamestring,ichunkN))); 53 | Ydat(n1+(ichunkN-ichunk1-1)*nsampsperchunk+1:end,:) = Ychnk(1:iN,:); 54 | 55 | end 56 | -------------------------------------------------------------------------------- /step4_BPspikesort.m: -------------------------------------------------------------------------------- 1 | % step4_BPspikesort.m 2 | % -------------------- 3 | % Estimate spike times on whitened data using Binary Pursuit 4 | 5 | % Set path and loads relevant data structures: 'sdat', 'dirlist', 'filelist' 6 | setSpikeSortParams; 7 | fprintf('Step 4: running Binary Pursuit to estimate spike times\n'); 8 | 9 | % ---- Load initial estimate of spike times (sparse nsamps x ncell array) ------------ 10 | X0 = struct2array(load(filelist.initspikes)); % loads variable 'Xsp_init' 11 | 12 | % --- Set some params governing block size for estimating waveforms ------ 13 | nsamps = sdat.nsamps; % Number of total samples 14 | nsampsPerW = sdat.nsampsPerW; % number of samples to use for each estimate 15 | nWblocks = ceil(nsamps/nsampsPerW); % number of blocks for estimating waveforms 16 | blksize = sdat.nsampsPerBPchunk; % number of samples to process at once for BP. (default 10K) 17 | 18 | % Set prior probability of a spike for each neuron (using base rate in initial sort) 19 | pspike0 = mean(X0); % prior probability of a spike for each neuron 20 | 21 | % -- Do sorting ---- 22 | [Xhat,nrefviols] = estimSps_BinaryPursuit(loadwhitenedY,filelist.Wwht,... 23 | nsampsPerW,X0,pspike0,blksize,sdat.minISIsamps); 24 | 25 | % -- Report refractory period violations (if necessary) ---- 26 | if sum(nrefviols) == 0 27 | fprintf('No refractory period violations (%.1fms)\n',sdat.minISIms); 28 | else 29 | fprintf('Number spikes pruned due to refractory violations (%.1fms)\n',sdat.minISIms); 30 | for j = 1:sdat.ncell 31 | if nrefviols(j)>0 32 | fprintf('cell %d: %d\n', j, nrefviols(j)); 33 | end 34 | end 35 | end 36 | 37 | % -- Save out ------ 38 | save(filelist.Xhat,'Xhat','nrefviols'); 39 | 40 | % ============= NOTES ================ 41 | % 42 | % Note 1: If getting too many spikes out (i.e., unrealistically many spikes), try dividing 43 | % pspike0 10, 10^2 10^3, etc. 44 | 45 | % Note 2: for an assessment of the reliability of each neuron's spikes (i.e., sorting 46 | % accuracy) , try dividing pspike0 by 2 or multiplying it by 2, and see how many more or 47 | % fewer spikes you get. For a highly reliable sort (i.e., where the posterior is strongly 48 | % determined by the likelihood), you should get nearly the same answer, regardless of the 49 | % prior. 50 | 51 | -------------------------------------------------------------------------------- /sort-tools/estimSps_BinaryPursuit.m: -------------------------------------------------------------------------------- 1 | function [Xhat,nrefviols] = estimSps_BinaryPursuit(loadYfun,wfile,nperblck,X0,pspike,nperchnk,minISI) 2 | % [Xhat,nrefviols] = estimSps_BinaryPursuit(loadYfun,wfile,nsampsPerBlock,X0,pspike,blksize,minISI) 3 | % 4 | % Estimates spike train over a large chunk of data by calling 5 | % runBinaryPursuit.m on many, smaller chunks 6 | % 7 | % INPUTS 8 | % loadYfun - function for loading electrode data (takes single argument 'ywin') 9 | % wfile - filename for spike waveforms (with '%d' for block #) 10 | % nperblck - number of samples per block 11 | % X0 [nsamps x ncells] - initial guess at spike trains (sparse) 12 | % pspike - prior probability of a spike in a bin (pos scalar < 1) 13 | % nperchnk - numer of samples in a "chunk" to process at once for BP 14 | % minISI - minimum ISI in # samples 15 | % 16 | % OUTPUTS 17 | % Xhat = estimated binary spike train 18 | % nrefviols - number of spikes removed (post hoc) due to refractory violation 19 | % 20 | % jw pillow 8/18/2014 21 | 22 | 23 | verbose = 10; % report progress mod this value 24 | 25 | [slen,nc] = size(X0); % total size of spike train data matrix 26 | nblocks = slen/nperblck; % number of blocks (each with distinct waveform estimate) 27 | nchnks = slen/nperchnk; % number of chunks to use for processing data 28 | nchnksperblock = nperblck/nperchnk; 29 | 30 | % Load 1st-block spike waveforms 31 | wnum = 1; 32 | load(sprintf(wfile, wnum),'W'); 33 | nw2 = size(W,1)/2; % half-length of spike waveform in samples 34 | 35 | % pre-compute the convolution of the waveforms with themselves 36 | [wProj,wNorm] = compWprojW(W); 37 | 38 | % load first block of electrode data 39 | ywin = [0,nperchnk]; 40 | Y = loadYfun(ywin); 41 | 42 | % Run BP on first chunk 43 | ii = ywin(1)+1:ywin(2); % indices to use 44 | Xhat = X0*0; % Initialize Xhat to zeros 45 | [Xhat(ii,:),nrefviols] = runBinaryPursuit(X0(ii,:),Y,W,pspike,wProj,wNorm,minISI); 46 | 47 | for ichunk = 2:nchnks 48 | 49 | % Report progress (if desired) 50 | if mod(ichunk,verbose)==0 51 | fprintf('Estimating spikes (chunk %d of %d)\n', ichunk, nchnks); 52 | 53 | end 54 | 55 | % Determine indices for this chunk of data 56 | i1 = (ichunk-1)*nperchnk; % index of first bin in this chunk 57 | i2 = min(ichunk*nperchnk,slen); 58 | 59 | % Load waveforms, if necessary 60 | wnumneeded = ceil(i2/nperblck); 61 | if wnum ~= wnumneeded 62 | wnum = wnumneeded; 63 | load(sprintf(wfile, wnum),'W'); 64 | [wProj,wNorm] = compWprojW(W); % pre-compute waveform convolutions 65 | end 66 | 67 | % load electrode data 68 | ywin = [i1-nw2,i2]; % indices to process 69 | Y = loadYfun(ywin); 70 | 71 | % perform BP sorting 72 | [xsrt,nref] = runBinaryPursuit(X0(ywin(1)+1:ywin(2),:),Y,W,pspike,wProj,wNorm,minISI); 73 | Xhat(i1+1:i2,:) = xsrt(nw2+1:end,:); % insert chunk into Xhat 74 | nrefviols = nrefviols+nref; % count if any refractory period violations removed 75 | 76 | end 77 | 78 | 79 | -------------------------------------------------------------------------------- /setSpikeSortParams.m: -------------------------------------------------------------------------------- 1 | % PATH SETTING 2 | addpath sort-tools/ 3 | 4 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 5 | % PARAMS GOVERNING THE RECORDING DATA (sdat = DAT STRUCT) 6 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 7 | 8 | % CHANGE THESE FOR YOUR DATASET 9 | sdat.ne = 6; % number of electrodes 10 | sdat.ncell = 4; % number of neurons 11 | sdat.nsecs = 120; % number of total seconds in the recording 12 | sdat.samprate = 20e3; % sample rate 13 | sdat.nsamps = sdat.nsecs*sdat.samprate; % number of total samples 14 | 15 | 16 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 17 | % SET PARAMS GOVERNING PROCESSING 18 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 19 | 20 | % CAN LEAVE THESE FIXED, BUT CHANGE TO OPTIMIZE PERFORMANCE 21 | sdat.nsampsPerFile= 20000; % num samples per saved (processed) electrode-data file 22 | sdat.nsecsPerW = 60; % number of seconds' data to use for each estimate of spike waveform 23 | sdat.nsampsPerW = sdat.nsecsPerW*sdat.samprate; % num samples per waveform estimate 24 | sdat.nw = 30; % number time bins in spike waveform (MUST BE EVEN) 25 | sdat.nsampsPerBPchunk = 10000; % number samples in a chunk for binary pursuit sorting 26 | sdat.minISIms = 0.5; % minimum ISI, in milliseconds 27 | sdat.minISIsamps = round(sdat.minISIms*sdat.samprate/1000); 28 | 29 | 30 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 31 | % WORKING DIRECTORIES % 32 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 33 | 34 | % Directory where raw data to be loaded from 35 | dirlist.rawdat = sprintf('dat/simdata/'); % CHANGE THIS FOR YOUR DATASET 36 | 37 | % Directories where intermediate processing data to be stored 38 | dirlist.procdat = 'dat/procdat/'; % directory for processed data (to be created) 39 | dirlist.W = [dirlist.procdat 'Wraw/']; % raw waveform estimates (pre-whitening) 40 | dirlist.Wwht = [dirlist.procdat 'Wwht/']; % waveform estimates after whitening 41 | dirlist.Ywht = [dirlist.procdat 'Ywht/']; % sparsified waveform estimates 42 | dirlist.tspEstim = [dirlist.procdat 'tspEstim/']; % spike train estimates 43 | 44 | % -------- Check that all dirs have been created ------------ 45 | dirfields = fieldnames(dirlist); 46 | for jj = 1:length(dirfields) 47 | dirname = dirlist.(dirfields{jj}); % dynamic field names (may break older matlab versions) 48 | if ~isdir(dirname); 49 | fprintf('SETSPIKESORTPARAMS: making directory ''%s''\n', dirname); 50 | mkdir(dirname); 51 | end 52 | end 53 | 54 | 55 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 56 | % WORKING FILENAMES 57 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 58 | 59 | % NAMES FOR RAW DAT FILES (CHANGE THESE) 60 | filelist.initspikes = [dirlist.rawdat, 'Xsp_init.mat']; % initial spike train estimate (sparse nsamps x ncell array) 61 | filelist.Ydat = [dirlist.rawdat, 'Y.mat']; % initial spike train estimate (sparse nsamps x ncell array) 62 | 63 | % NAMES FOR PROCESSED DATA FILES (can leave) 64 | filelist.Ywht = [dirlist.Ywht, 'Y_chunk%d.mat']; % initial spike train estimate (sparse nsamps x ncell array) 65 | filelist.Wraw = [dirlist.W, 'Wraw_%d.mat']; % initial (pre-whitening) estimates of Waveform 66 | filelist.Wwht = [dirlist.W, 'Wwht_%d.mat']; % estimates of Waveform w/ whitened data 67 | filelist.Xhat = [dirlist.tspEstim, 'Xhat.mat']; 68 | 69 | % Function for loading raw electrode data (WRITE THIS FOR YOUR OWN DATA!) 70 | loadrawY = @(twin)(loadRawElecDatWin(twin,filelist.Ydat)); 71 | 72 | % Function for loading whitened electrode data (no need to rewrite) 73 | loadwhitenedY = @(twin)(loadWhiteElecDatWin(twin,filelist.Ywht,sdat.nsampsPerBPchunk)); 74 | 75 | -------------------------------------------------------------------------------- /sort-tools/compWhitening.m: -------------------------------------------------------------------------------- 1 | function [Ywht, tfilts,xfilts] = compWhiteningFilts(X,Y,W,nxc_t,nxc_x) 2 | % [tfilts,xfilts] = compWhiteningFilts(X,Y,W,nxc_t,nxc_x) 3 | % 4 | % Compute filters that whitens y residuals 5 | % 6 | % INPUTS: 7 | % ------ 8 | % X [nsamps x ncells] - each column holds spike train of single neuronjj 9 | % Y [nsamps x nelec] - raw electrode data 10 | % W [nw x nelec x ncells] - tensor of spike waveforms 11 | % nxc_t [1x1] 12 | % nxc_x [1x1] 13 | % 14 | % OUTPUTS: 15 | % -------- 16 | % Ywht [nsamps x nelec] - whitened electrode data 17 | % tfilts [nxc_t x ncells] - temporal whitening filters 18 | % xfilts [nxc_x x nelec x nelec] spatial whitening filters 19 | % 20 | % 21 | % NOTES: 22 | % "Full" covariance too large to invert, so instead we proceed by: 23 | % 1. Whitening each electrode in time (using nxc_t-tap filter) 24 | % 2. Whiten across electrodes (using nxc_x timebins per electrode) 25 | % 26 | % Algorithm: 27 | % 1. Compute the electrode residuals (raw electrodes minus spike 28 | % train convolved with waveforms) 29 | % 2. Compute auto-correlations for each electrode 30 | % 3. Solve for temporal-whitening filter for each electrode 31 | % 4. Temporally whiten electrode residuals 32 | % 5. Compute spatially-whitening filters 33 | 34 | slen = size(X,1); % spike train data (sparse) 35 | ne = size(W,2); % number of electrodes 36 | 37 | % === 1. Compute temporal whitening filts ================= 38 | ypred = compVpredictionSprse(X,W); % predicted electrode data based on spikes and W 39 | yresid = Y-ypred; % residuals 40 | 41 | % Method 1 (cost indep of nxc_t, and faster than xcorr) 42 | tfilts = zeros(nxc_t,ne); % temporal whitening filters 43 | for j = 1:ne 44 | % Compute autocovariance 45 | xc = circxcorr(yresid(:,j),nxc_t-1,'none'); 46 | yxc = flipud(xc(1:nxc_t))/slen; % autocovariance 47 | 48 | % Compute whitening filt 49 | M = sqrtm(inv(toeplitz(yxc))); % whitening matrix 50 | tfilts(:,j) = M(:,nxc_t/2); 51 | 52 | end 53 | % % Check that it worked: 54 | % plot(xcorr(conv2(yresid(:,j),tfilts(:,j),'valid'),nxc_t)/slen); 55 | 56 | 57 | % === 2. Compute spatial whitening filters ==================== 58 | % Will be faster when we used only neighboring electrodes) 59 | 60 | % Compute temporally whiten residuals 61 | yresid_wht = zeros(slen,ne); % whitened residuals 62 | for j = 1:ne; 63 | yresid_wht(:,j) = conv2(yresid(:,j), tfilts(:,j), 'same'); 64 | end 65 | 66 | % Compute spatial cross-correlation(s) 67 | xxc = zeros(ne,ne,nxc_x); 68 | xxc(:,:,1) = yresid_wht'*yresid_wht; 69 | for j = 2:nxc_x 70 | xxc(:,:,j) = circshift(yresid_wht,j-1)'*yresid_wht; 71 | end 72 | xxc = xxc/slen; 73 | 74 | % Insert into big covariance matrix 75 | M = zeros(ne*nxc_x); 76 | for j = 1:ne 77 | for i = j:ne 78 | if i == j 79 | jj = (j-1)*nxc_x+1:j*nxc_x; 80 | M(jj,jj) = toeplitz(squeeze(xxc(j,j,:))); 81 | else 82 | jj = (j-1)*nxc_x+1:j*nxc_x; 83 | ii = (i-1)*nxc_x+1:i*nxc_x; 84 | M(jj,ii) = toeplitz(squeeze(xxc(j,i,:)),squeeze(xxc(i,j,:))); 85 | M(ii,jj) = M(jj,ii)'; 86 | end 87 | end 88 | end 89 | % Compute filters 90 | Q = sqrtm(inv(M)); 91 | xfilts = zeros(nxc_x,ne,ne); 92 | for j = 1:ne 93 | xfilts(:,:,j) = reshape(Q(:,(j-1)*nxc_x+ceil(nxc_x/2)),[],ne); 94 | end 95 | 96 | % % ==================================== 97 | % % OPTIONAL: check that it worked 98 | % for j = 1:ne 99 | % ywht(:,j) = conv2(yresid_wht,fliplr(xfilts(:,:,j)),'valid'); 100 | % end 101 | % subplot(121); 102 | % imagesc(cov(ywht)); 103 | % subplot(122); 104 | % plot(xcorr(ywht(:,1:2),5)); 105 | % % ==================================== 106 | 107 | % === 3. Now whiten the raw data ========================== 108 | Ywht = zeros(slen,ne); 109 | for ii = 1:ne % Temporally whiten 110 | Ywht(:,ii) = conv2(Y(:,ii),tfilts(:,ii),'same'); 111 | end; 112 | Ywht = samefilt(Ywht,xfilts,'conv'); % spatially whiten 113 | -------------------------------------------------------------------------------- /script0_simulateDataForTesting.m: -------------------------------------------------------------------------------- 1 | % Generate some fake data to illustrate the use of binary pursuit spike sorting code. 2 | 3 | fprintf('script0_simulateDataForTesting: generating dataset for spike sorting...\n'); 4 | 5 | % Set params for simulated dataset using those in the 'setSpikeSortParams.m' script 6 | setSpikeSortParams; % (LOOK HERE FOR DETAILS). 7 | 8 | % Set parameters governing the simulated spike trains 9 | nwt = 30; % # time samples in spike waveforms 10 | sprate = 100; % mean spike rate (unrealistically high to observe lots of simultaneous spks) 11 | nsesig = .1; % marginal stdv of additive noise (arbitrary units) 12 | 13 | 14 | %% 2. Generate spike trains % ---------------------- 15 | Xsp = double(sparse((rand(sdat.nsamps,sdat.ncell) < sprate/sdat.samprate))); % each column is spk train 16 | 17 | % remove spikes too close together 18 | minisi = nwt*1.5; % smallest allowed interspike interval 19 | for j = 1:sdat.ncell 20 | tsp = find(Xsp(:,j)); % spike times 21 | isi = [nwt; diff(tsp)]; % interspike intervals 22 | kk = find(isi 0) && (nstp <= maxSteps); 65 | 66 | % % Uncomment to print reports on progress 67 | % % -------------------------------------- 68 | % if mod(nstp,modreport) == 0 % Report output only every 100 bins 69 | % if (xx(ispk,cellsp) == 1) 70 | % fprintf('step %d: REMOVED sp, cell %d, bin %d, Dlogli=%.3f\n',... 71 | % nstp,cellsp,ispk,dlogli(ispk,cellsp)); 72 | % else 73 | % fprintf(1, 'step %d: inserted cell %d, bin %d, Dlogli=%.3f\n',... 74 | % nstp,cellsp,ispk,dlogli(ispk,cellsp)); 75 | % end 76 | % end % ---------------------------------- 77 | 78 | % ---------------------------------------- 79 | % 2A. Insert or remove spike in location where logli is most improved 80 | if xx(ispk,cellsp) == 0 % Insert spike -------------- 81 | 82 | xx(ispk,cellsp) = 1; 83 | dloglictrbin = -dlogli(ispk,cellsp); % dLogli for this bin 84 | inds = ispk-nw2:ispk+nw2-1; 85 | rr(inds,:) = rr(inds,:)- W(:,:,cellsp); % remove waveform 86 | 87 | % update dlogli for all bins within +/- nw 88 | if ((ispk+iirel(1)) >= nw2+1) && ((ispk+iirel(end)) <= slen-nw2+1) 89 | iichg = ispk+iirel; 90 | dlogli(iichg,:)= dlogli(iichg,:)-wproj(:,:,cellsp); 91 | 92 | else % update only for relevant range of indices 93 | iichg = ispk+iirel; 94 | ii = find((iichg>=(nw2+1)) & (iichg<=(slen-nw2+1))); 95 | iichg = iichg(ii); 96 | dlogli(iichg,:)= dlogli(iichg,:)-wproj(ii,:,cellsp); 97 | end 98 | dlogli(ispk,cellsp) = dloglictrbin; % set for center bin 99 | 100 | else % Remove spike ---------------------------------- 101 | 102 | xx(ispk,cellsp) = 0; 103 | dloglictrbin = -dlogli(ispk,cellsp); % dLogli for this bin 104 | inds = ispk-nw2:ispk+nw2-1; 105 | rr(inds,:) = rr(inds,:) + W(:,:,cellsp); % add waveform back to residuals 106 | 107 | % update dlogli for all bins within +/- nw 108 | if ((ispk+iirel(1)) >= nw2+1) && ((ispk+iirel(end))<=slen-nw2+1) 109 | iichg = ispk+iirel; 110 | dlogli(iichg,:)= dlogli(iichg,:)+wproj(:,:,cellsp); 111 | else % update only for relevant range of indices 112 | iichg = ispk+iirel; 113 | ii = find((iichg>=(nw2+1)) & (iichg<=(slen-nw2+1))); 114 | iichg = iichg(ii); 115 | dlogli(iichg,:)= dlogli(iichg,:)+wproj(ii,:,cellsp); 116 | end 117 | dlogli(ispk,cellsp) = dloglictrbin; % set for center bin 118 | 119 | end 120 | 121 | % ---------------------------------------- 122 | % 2B. Do some index arithmetic to max maximum dlogli for each cell 123 | % (Big speedup from searching for the max over all bins for each cell). 124 | 125 | % Find any cells whose prev max was in the region just changed 126 | iimaxchg = find((iimx>=iichg(1)) & (iimx<=iichg(end))); 127 | [mx0,iinw] = max(dlogli(:,iimaxchg)); 128 | mxvals(iimaxchg) = mx0; 129 | iimx(iimaxchg) = iinw; 130 | 131 | % Now see if new maxima arose in region of dlogli that was just altered 132 | [mxvals0,iimx0] = max(dlogli(iichg,:)); 133 | 134 | % Combine to get new maximum & position 135 | [mxvals,OneOrTwo] = max([mxvals0; mxvals]); % max of [changed-bin ; other-bin]; 136 | iflip = find(OneOrTwo == 1); 137 | iimx(iflip) = iichg(iimx0(iflip)); 138 | 139 | % Find next bin to adjust 140 | [mx,cellsp] = max(mxvals); 141 | ispk = iimx(cellsp); 142 | 143 | nstp = nstp+1; 144 | end 145 | 146 | % Notify if MaxSteps exceeded 147 | if nstp > maxSteps 148 | fprintf('estimSps_binary_chun: max # passes exceeded (dlogli=%.3f)\n', dlogli(ispk,cellsp)); 149 | end 150 | 151 | % ---------------------------------------------------------- 152 | % 3. Finally, remove spikes that violate refractory period 153 | nrefviols = zeros(1,nc); % initialize counter 154 | for jcell = 1:nc 155 | isis = diff(find(xx(:,jcell))); 156 | while any(isis= nw2+1) && ((ispk+iirel(end))<=slen-nw2+1) 174 | iichg = ispk+iirel; 175 | dlogli(iichg,:)= dlogli(iichg,:)+wproj(:,:,jcell); 176 | else % update only for relevant range of indices 177 | iichg = ispk+iirel; 178 | ii = find((iichg>=(nw2+1)) & (iichg<=(slen-nw2+1))); 179 | iichg = iichg(ii); 180 | dlogli(iichg,:)= dlogli(iichg,:)+wproj(ii,:,jcell); 181 | end 182 | dlogli(ispk,jcell) = dloglictrbin; % set for center bin 183 | 184 | % Recompute ISIs 185 | isis = diff(find(xx(:,jcell))); 186 | 187 | end 188 | end 189 | 190 | 191 | 192 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 193 | function G = validfilt_sprse(A,B) 194 | % f = validfilt_sprse(A,B); 195 | % 196 | % Convolve sparse A with each frame of B and return only "valid" part 197 | 198 | am = size(A,1); 199 | [bm,~,nflt] = size(B); 200 | nn = am+bm-1; 201 | npre = bm-1; 202 | npost = bm-1; 203 | 204 | % Do convolution 205 | G = zeros(nn,nflt); 206 | for i = 1:nflt 207 | yy = A*B(:,:,i)'; 208 | for j = 1:bm 209 | G(j:j+am-1,i) = G(j:j+am-1,i) + yy(:,bm-j+1); 210 | end 211 | end 212 | G = G(npre+1:nn-npost,:); 213 | --------------------------------------------------------------------------------