├── LICENSE ├── README.md ├── Spike_sorting.ipynb └── images ├── Spike_sorting_11_0.png ├── Spike_sorting_13_0.png ├── Spike_sorting_17_0.png ├── Spike_sorting_19_0.png ├── Spike_sorting_21_0.png ├── Spike_sorting_3_0.png └── Spike_sorting_7_0.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Carsten Klein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Using signal processing and K-means clustering to extract and sort neural events in Python 3 | 4 | This is the Python Jupyter Notebook for the Medium articles ([X](https://towardsdatascience.com/using-signal-processing-to-extract-neural-events-in-python-964437dc7c0) and [Y](https://towardsdatascience.com/whos-talking-using-k-means-clustering-to-sort-neural-events-in-python-e7a8a76f316)) on how to use signal processing techniques and K-means clustering to sort spikes. 5 | 6 | ### Part I 7 | Code to read data from a .ncs file and extract the spike channel form the raw braod band signal through bandpass filtering. Also includes a function to extract and align spikes from the signal. 8 | 9 | 10 | ### Part II 11 | Code to perform PCA on the extracted spike waveforms. Followed by a function that does K-means clustering on the PCA data to determine the number of clusters in the data and average waveforms according to their cluster number. 12 | 13 | First we start with importing the libraries for reading in the data and processing it. 14 | 15 | 16 | ```python 17 | from scipy.signal import butter, lfilter 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | 21 | # Enable plots inside the Jupyter Notebook 22 | %matplotlib inline 23 | ``` 24 | 25 | Before we can start looking at the data (https://www2.le.ac.uk/centres/csn/software) we need to write a function to import the .ncs data format. You can check out the organization of the data format on the companies web page (https://neuralynx.com/software/NeuralynxDataFileFormats.pdf). The information provided there is the basis for the import routine below. 26 | 27 | 28 | ```python 29 | # Define data path 30 | data_file = './UCLA_data/CSC4.Ncs' 31 | 32 | # Header has 16 kilobytes length 33 | header_size = 16 * 1024 34 | 35 | # Open file 36 | fid = open(data_file, 'rb') 37 | 38 | # Skip header by shifting position by header size 39 | fid.seek(header_size) 40 | 41 | # Read data according to Neuralynx information 42 | data_format = np.dtype([('TimeStamp', np.uint64), 43 | ('ChannelNumber', np.uint32), 44 | ('SampleFreq', np.uint32), 45 | ('NumValidSamples', np.uint32), 46 | ('Samples', np.int16, 512)]) 47 | 48 | raw = np.fromfile(fid, dtype=data_format) 49 | 50 | # Close file 51 | fid.close() 52 | 53 | # Get sampling frequency 54 | sf = raw['SampleFreq'][0] 55 | 56 | # Create data vector 57 | data = raw['Samples'].ravel() 58 | 59 | # Determine duration of recording in seconds 60 | dur_sec = data.shape[0]/sf 61 | 62 | # Create time vector 63 | time = np.linspace(0, dur_sec,data.shape[0]) 64 | 65 | # Plot first second of data 66 | fig, ax = plt.subplots(figsize=(15, 5)) 67 | ax.plot(time[0:sf], data[0:sf]) 68 | ax.set_title('Broadband; Sampling Frequency: {}Hz'.format(sf), fontsize=23) 69 | ax.set_xlim(0, time[sf]) 70 | ax.set_xlabel('time [s]', fontsize=20) 71 | ax.set_ylabel('amplitude [uV]', fontsize=20) 72 | plt.show() 73 | ``` 74 | 75 | 76 | ![png](images/Spike_sorting_3_0.png) 77 | 78 | 79 | # Part I 80 | 81 | ## Bandpass filter the data 82 | As we can see the signal has strong 60Hz noise in it. The function below will bandpass filter the signal to exclude the 60Hz domain. 83 | 84 | 85 | ```python 86 | def filter_data(data, low, high, sf, order=2): 87 | # Determine Nyquist frequency 88 | nyq = sf/2 89 | 90 | # Set bands 91 | low = low/nyq 92 | high = high/nyq 93 | 94 | # Calculate coefficients 95 | b, a = butter(order, [low, high], btype='band') 96 | 97 | # Filter signal 98 | filtered_data = lfilter(b, a, data) 99 | 100 | return filtered_data 101 | ``` 102 | 103 | Using the above function lets us compare the raw data with the filtered signal 104 | 105 | 106 | ```python 107 | spike_data = filter_data(data, low=500, high=9000, sf=sf) 108 | 109 | # Plot signals 110 | fig, ax = plt.subplots(2, 1, figsize=(15, 5)) 111 | ax[0].plot(time[0:sf], data[0:sf]) 112 | ax[0].set_xticks([]) 113 | ax[0].set_title('Broadband', fontsize=23) 114 | ax[0].set_xlim(0, time[sf]) 115 | ax[0].set_ylabel('amplitude [uV]', fontsize=16) 116 | ax[0].tick_params(labelsize=12) 117 | 118 | ax[1].plot(time[0:sf], spike_data[0:sf]) 119 | ax[1].set_title('Spike channel [0.5 to 9kHz]', fontsize=23) 120 | ax[1].set_xlim(0, time[sf]) 121 | ax[1].set_xlabel('time [s]', fontsize=20) 122 | ax[1].set_ylabel('amplitude [uV]', fontsize=16) 123 | ax[1].tick_params(labelsize=12) 124 | plt.show() 125 | ``` 126 | 127 | 128 | ![png](images/Spike_sorting_7_0.png) 129 | 130 | 131 | ## Extract spikes from the filtered signal 132 | Now that we have a clean spike channel we can identify and extract spikes. The following function does that for us. It take five input arguments. 1) the filtered data 2) the number of samples or window which should be extracted from the signal 3) the threshold factor (mean(signal)*tf) 4) an offset expressed in number of samples which shifts the maximum peak from the center 5) the upper threshold which excludes data points above this limit to avoid extracting artifacts. 133 | 134 | 135 | ```python 136 | def get_spikes(data, spike_window=80, tf=5, offset=10, max_thresh=350): 137 | 138 | # Calculate threshold based on data mean 139 | thresh = np.mean(np.abs(data)) *tf 140 | 141 | # Find positions wherere the threshold is crossed 142 | pos = np.where(data > thresh)[0] 143 | pos = pos[pos > spike_window] 144 | 145 | # Extract potential spikes and align them to the maximum 146 | spike_samp = [] 147 | wave_form = np.empty([1, spike_window*2]) 148 | for i in pos: 149 | if i < data.shape[0] - (spike_window+1): 150 | # Data from position where threshold is crossed to end of window 151 | tmp_waveform = data[i:i+spike_window*2] 152 | 153 | # Check if data in window is below upper threshold (artifact rejection) 154 | if np.max(tmp_waveform) < max_thresh: 155 | # Find sample with maximum data point in window 156 | tmp_samp = np.argmax(tmp_waveform) +i 157 | 158 | # Re-center window on maximum sample and shift it by offset 159 | tmp_waveform = data[tmp_samp-(spike_window-offset):tmp_samp+(spike_window+offset)] 160 | 161 | # Append data 162 | spike_samp = np.append(spike_samp, tmp_samp) 163 | wave_form = np.append(wave_form, tmp_waveform.reshape(1, spike_window*2), axis=0) 164 | 165 | # Remove duplicates 166 | ind = np.where(np.diff(spike_samp) > 1)[0] 167 | spike_samp = spike_samp[ind] 168 | wave_form = wave_form[ind] 169 | 170 | return spike_samp, wave_form 171 | ``` 172 | 173 | Using the function on our filtered spike channel and plotting 100 randomly selected waveforms that were extracted. 174 | 175 | 176 | ```python 177 | spike_samp, wave_form = get_spikes(spike_data, spike_window=50, tf=8, offset=20) 178 | 179 | np.random.seed(10) 180 | fig, ax = plt.subplots(figsize=(15, 5)) 181 | 182 | for i in range(100): 183 | spike = np.random.randint(0, wave_form.shape[0]) 184 | ax.plot(wave_form[spike, :]) 185 | 186 | ax.set_xlim([0, 90]) 187 | ax.set_xlabel('# sample', fontsize=20) 188 | ax.set_ylabel('amplitude [uV]', fontsize=20) 189 | ax.set_title('spike waveforms', fontsize=23) 190 | plt.show() 191 | ``` 192 | 193 | 194 | ![png](images/Spike_sorting_11_0.png) 195 | 196 | 197 | # Part II 198 | 199 | ## Reducing the number of dimensions with PCA 200 | To cluster the waveforms we need some features to work with. A feature could be for example the peak amplitude of the spike or the width of the waveform. Another way to go is to use the principal components of the waveforms. Principal component analysis (PCA) is a dimensionality reduction method which requires normalized data. Here we will use Scikit Learn for both the normalization and the PCA. We will not go into the details of PCA here since the focus is the clustering. 201 | 202 | 203 | ```python 204 | import sklearn as sk 205 | from sklearn.decomposition import PCA 206 | 207 | # Apply min-max scaling 208 | scaler= sk.preprocessing.MinMaxScaler() 209 | dataset_scaled = scaler.fit_transform(wave_form) 210 | 211 | # Do PCA 212 | pca = PCA(n_components=12) 213 | pca_result = pca.fit_transform(dataset_scaled) 214 | 215 | # Plot the 1st principal component aginst the 2nd and use the 3rd for color 216 | fig, ax = plt.subplots(figsize=(8, 8)) 217 | ax.scatter(pca_result[:, 0], pca_result[:, 1], c=pca_result[:, 2]) 218 | ax.set_xlabel('1st principal component', fontsize=20) 219 | ax.set_ylabel('2nd principal component', fontsize=20) 220 | ax.set_title('first 3 principal components', fontsize=23) 221 | 222 | fig.subplots_adjust(wspace=0.1, hspace=0.1) 223 | plt.show() 224 | ``` 225 | 226 | 227 | ![png](images/Spike_sorting_13_0.png) 228 | 229 | 230 | The way we will implement K-means is quite straight forward. First, we choose a number of K random datapoints from our sample. These datapoints represent the cluster centers and their number equals the number of clusters. Next, we will calculate the Euclidean distance between all of the random cluster centers and any other datapoint. Then we assign each datapoint to the cluster center closest to it. Obviously doing all of this with random datapoints will not give us a good clustering result. So, we start over again. But this time we don't use random datapoints as cluster centers. Instead we calculate the actual cluster centers based on the previous random assignments and start the process again… and again… and again. With every iteration the datapoints that switch clusters will go down and we will arrive at a (hopefully) global optimum. 231 | 232 | 233 | ```python 234 | def k_means(data, num_clus=3, steps=200): 235 | 236 | # Convert data to Numpy array 237 | cluster_data = np.array(data) 238 | 239 | # Initialize by randomly selecting points in the data 240 | center_init = np.random.randint(0, cluster_data.shape[0], num_clus) 241 | 242 | # Create a list with center coordinates 243 | center_init = cluster_data[center_init, :] 244 | 245 | # Repeat clustering x times 246 | for _ in range(steps): 247 | 248 | # Calculate distance of each data point to cluster center 249 | distance = [] 250 | for center in center_init: 251 | tmp_distance = np.sqrt(np.sum((cluster_data - center)**2, axis=1)) 252 | 253 | # Adding smalle random noise to the data to avoid matching distances to centroids 254 | tmp_distance = tmp_distance + np.abs(np.random.randn(len(tmp_distance))*0.0001) 255 | distance.append(tmp_distance) 256 | 257 | # Assign each point to cluster based on minimum distance 258 | _, cluster = np.where(np.transpose(distance == np.min(distance, axis=0))) 259 | 260 | # Find center of mass for each cluster 261 | center_init = [] 262 | for i in range(num_clus): 263 | center_init.append(cluster_data[cluster == i, :].mean(axis=0).tolist()) 264 | 265 | return cluster, center_init, distance 266 | ``` 267 | 268 | So how should we choose the number of clusters? One way is to run our K-means function many times with different cluster numbers. The resulting plot shows the average inter-cluster distance. That is the average Euclidian distance between the datapoints of a cluster to the cluster center. From the plot we can see that after 4 to 6 clusters we do not see a strong decrease in the inter-cluster distance. 269 | 270 | 271 | ```python 272 | max_num_clusters = 15 273 | 274 | average_distance = [] 275 | for run in range(20): 276 | tmp_average_distance = [] 277 | for num_clus in range(1, max_num_clusters +1): 278 | cluster, centers, distance = k_means(pca_result, num_clus) 279 | tmp_average_distance.append(np.mean([np.mean(distance[x][cluster==x]) for x in range(num_clus)], axis=0)) 280 | average_distance.append(tmp_average_distance) 281 | 282 | fig, ax = plt.subplots(1, 1, figsize=(15, 5)) 283 | ax.plot(range(1, max_num_clusters +1), np.mean(average_distance, axis=0)) 284 | ax.set_xlim([1, max_num_clusters]) 285 | ax.set_xlabel('number of clusters', fontsize=20) 286 | ax.set_ylabel('average inter cluster distance', fontsize=20) 287 | ax.set_title('Ellbow point', fontsize=23) 288 | plt.show() 289 | ``` 290 | 291 | 292 | ![png](images/Spike_sorting_17_0.png) 293 | 294 | 295 | So, let’s see what we get if we run the K-means algorithm with 6 cluster. Probably we could also go with 4 but let’s check the results with 6 first. 296 | 297 | 298 | ```python 299 | num_clus = 6 300 | cluster, centers, distance = k_means(pca_result, num_clus) 301 | 302 | # Plot the result 303 | fig, ax = plt.subplots(1, 2, figsize=(15, 5)) 304 | ax[0].scatter(pca_result[:, 0], pca_result[:, 1], c=cluster) 305 | ax[0].set_xlabel('1st principal component', fontsize=20) 306 | ax[0].set_ylabel('2nd principal component', fontsize=20) 307 | ax[0].set_title('clustered data', fontsize=23) 308 | 309 | time = np.linspace(0, wave_form.shape[1]/sf, wave_form.shape[1])*1000 310 | for i in range(num_clus): 311 | cluster_mean = wave_form[cluster==i, :].mean(axis=0) 312 | cluster_std = wave_form[cluster==i, :].std(axis=0) 313 | 314 | ax[1].plot(time, cluster_mean, label='Cluster {}'.format(i)) 315 | ax[1].fill_between(time, cluster_mean-cluster_std, cluster_mean+cluster_std, alpha=0.15) 316 | 317 | ax[1].set_title('average waveforms', fontsize=23) 318 | ax[1].set_xlim([0, time[-1]]) 319 | ax[1].set_xlabel('time [ms]', fontsize=20) 320 | ax[1].set_ylabel('amplitude [uV]', fontsize=20) 321 | 322 | plt.legend() 323 | plt.show() 324 | ``` 325 | 326 | 327 | ![png](images/Spike_sorting_19_0.png) 328 | 329 | 330 | Looking at above results it seems we chose to many clusters. The waveforms plot indicates that we may have extracted spikes from three different sources. Clusters 0, 1, 3 and 4 look like they have the same origin, while clusters 0 and 5 seem to be separate neurons. So, lets combine clusters 0, 1, 3 and 4. 331 | 332 | 333 | ```python 334 | combine_clusters = [0, 1, 3, 4] 335 | combined_waveforms_mean = wave_form[[x in combine_clusters for x in cluster], :].mean(axis=0) 336 | combined_waveforms_std = wave_form[[x in combine_clusters for x in cluster], :].std(axis=0) 337 | 338 | cluster_0_waveform_mean = wave_form[cluster==2, :].mean(axis=0) 339 | cluster_0_waveform_std = wave_form[cluster==2, :].std(axis=0) 340 | 341 | cluster_1_waveform_mean = wave_form[cluster==5, :].mean(axis=0) 342 | cluster_1_waveform_std = wave_form[cluster==5, :].std(axis=0) 343 | 344 | fig, ax = plt.subplots(1, 1, figsize=(8, 5)) 345 | ax.plot(time, combined_waveforms_mean, label='Cluster 1') 346 | ax.fill_between(time, combined_waveforms_mean-combined_waveforms_std, combined_waveforms_mean+combined_waveforms_std, 347 | alpha=0.15) 348 | 349 | ax.plot(time, cluster_0_waveform_mean, label='Cluster 2') 350 | ax.fill_between(time, cluster_0_waveform_mean-cluster_0_waveform_std, cluster_0_waveform_mean+cluster_0_waveform_std, 351 | alpha=0.15) 352 | 353 | ax.plot(time, cluster_1_waveform_mean, label='Cluster 3') 354 | ax.fill_between(time, cluster_1_waveform_mean-cluster_1_waveform_std, cluster_1_waveform_mean+cluster_1_waveform_std, 355 | alpha=0.15) 356 | 357 | ax.set_title('average waveforms', fontsize=23) 358 | ax.set_xlim([0, time[-1]]) 359 | ax.set_xlabel('time [ms]', fontsize=20) 360 | ax.set_ylabel('amplitude [uV]', fontsize=20) 361 | 362 | plt.legend() 363 | plt.show() 364 | ``` 365 | 366 | 367 | ![png](images/Spike_sorting_21_0.png) 368 | -------------------------------------------------------------------------------- /images/Spike_sorting_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_11_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_13_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_13_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_17_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_19_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_19_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_21_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_3_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_3_0.png -------------------------------------------------------------------------------- /images/Spike_sorting_7_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/spike_sorting/7873e9631ed101a58ebe7f20bbcefd5f23397fb6/images/Spike_sorting_7_0.png --------------------------------------------------------------------------------