├── README.md ├── desktop.ini ├── klustakwik.exe ├── klustakwik.zip ├── phy_config.py └── plugins ├── CustomActionPlugin.py ├── Readme.md.txt ├── SplitShortISI.py ├── SplitShortISI_v2.py ├── good_labels.py ├── mahalanobis_v2.py ├── raw_data_filter.py ├── readme.txt ├── recluster.py ├── recluster_v2.py ├── shortISI_v3.py └── tempdir.py /README.md: -------------------------------------------------------------------------------- 1 | # Plugins to Phy2 2 | These plugins add additional features to Phy2. Originally created for Phy1 by Peter Petersen and made compatible with Phy2 by Thomas Hainmueller. Newly updated implementations made by Mingze Dou. 3 | 4 | ## Features 5 | 6 | ### Newly Updated Features 7 | * **ImprovedISIAnalysis** (`alt+i`): Detects ISI conflicts using multiple metrics 8 | * **StableMahalanobisDetection** (`alt+x`): Outlier detection with interactive visualization 9 | * **ReclusterUMAP** (`alt+k`): Modern reclustering using UMAP and template matching 10 | * **GoodLabelsPlugin**: Improved cluster organization that sorts by quality (good > mua > noise). Provides better workflow organization despite Phy's real-time update limitations. 11 | 12 | ### Legacy Features 13 | * **Reclustering** (`alt+shift+k`, `alt+shift+t`): KlustaKwik 2.0 based reclustering 14 | * **Mahalanobis Distance** (`alt+shift+x`): Outlier detection (threshold: 16 std) 15 | * **K-means Clustering** (`alt+shift+q`): Basic clustering with adjustable clusters 16 | * **ISI Violation** (`alt+shift+i`): Visualize refractory period violations 17 | 18 | ## Dependencies 19 | ```bash 20 | pip install pandas numpy scipy scikit-learn umap-learn 21 | ``` 22 | 23 | ## Installation 24 | 1. Copy the plugin files from the `plugins` directory of this repository to your Phy user configuration directory, inside a `plugins` subfolder. 25 | - On Windows, this is typically `C:\Users\\.phy\plugins`. 26 | - If the `plugins` folder doesn't exist, you will need to create it. 27 | 2. Copy other supporting files (e.g., `phy_config.py`, `klustakwik.exe` if used) from the root of this repository to your Phy user configuration directory (e.g., `~/.phy/` or `C:\Users\\.phy`). 28 | 3. Install dependencies (see [Dependencies](#dependencies) section above). 29 | 4. Copy 'tempdir.py' (from the `plugins` directory of this repository) to `*YourPhyDirectory*/phy/utils`. 30 | 31 | ## Authors 32 | - Original Phy1: Peter Petersen 33 | - Phy2 compatibility: Thomas Hainmueller 34 | - New implementations: Mingze Dou 35 | 36 | ## How to Cite 37 | [![DOI](https://zenodo.org/badge/126424002.svg)](https://zenodo.org/badge/latestdoi/126424002) 38 | -------------------------------------------------------------------------------- /desktop.ini: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petersenpeter/phy2-plugins/9044cf42c7f74f9e6bbcc624fb47f4342107277b/desktop.ini -------------------------------------------------------------------------------- /klustakwik.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petersenpeter/phy2-plugins/9044cf42c7f74f9e6bbcc624fb47f4342107277b/klustakwik.exe -------------------------------------------------------------------------------- /klustakwik.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/petersenpeter/phy2-plugins/9044cf42c7f74f9e6bbcc624fb47f4342107277b/klustakwik.zip -------------------------------------------------------------------------------- /phy_config.py: -------------------------------------------------------------------------------- 1 | # You can also put your plugins in ~/.phy/plugins/. 2 | 3 | from phy import IPlugin 4 | 5 | try: 6 | import phycontrib 7 | except: 8 | pass 9 | 10 | c = get_config() 11 | 12 | # Plugin directories 13 | c.Plugins.dirs = [r'~/.phy/plugins/'] 14 | 15 | # Configure GUI plugins 16 | c.TemplateGUI.plugins = [ 17 | 'CustomActionPlugin', 18 | 'GoodLabelsPlugin', 19 | 'RawDataFilterPlugin', 20 | 'SplitShortISI', 21 | 'Recluster', 22 | 23 | 'StableMahalanobisDetection', 24 | 'ReclusterUMAP', 25 | 'ImprovedISIAnalysis' 26 | ] -------------------------------------------------------------------------------- /plugins/CustomActionPlugin.py: -------------------------------------------------------------------------------- 1 | # import from plugins/action_status_bar.py 2 | """Show how to create new actions in the GUI. 3 | 4 | The first action just displays a message in the status bar. 5 | 6 | The second action selects the first N clusters, where N is a parameter that is entered by 7 | the user in a prompt dialog. 8 | 9 | """ 10 | 11 | from phy import IPlugin, connect 12 | import numpy as np 13 | import logging 14 | 15 | logger = logging.getLogger('phy') 16 | 17 | class CustomActionPlugin(IPlugin): 18 | def attach_to_controller(self, controller): 19 | @connect 20 | def on_gui_ready(sender, gui): 21 | 22 | @controller.supervisor.actions.add(shortcut='ctrl+c') 23 | def select_first_unsorted(): 24 | 25 | # All cluster view methods are called with a callback function because of the 26 | # asynchronous nature of Python-Javascript interactions in Qt5. 27 | @controller.supervisor.cluster_view.get_ids 28 | def find_unsorted(cluster_ids): 29 | """This function is called when the ordered list of cluster ids is returned 30 | by the Javascript view.""" 31 | groups = controller.supervisor.cluster_meta.get('group',list(range(max(cluster_ids)))) 32 | for ii in cluster_ids: 33 | if groups[ii] == None or groups[ii] == 'unsorted': 34 | s = controller.supervisor.clustering.spikes_in_clusters([ii]) 35 | if len(s)>0: 36 | firstclu = ii 37 | break 38 | 39 | if 'firstclu' in locals(): 40 | controller.supervisor.select(firstclu) 41 | 42 | return 43 | 44 | 45 | @controller.supervisor.actions.add(shortcut='ctrl+v') 46 | def move_selected_to_end(): 47 | 48 | logger.warn("Moving selected cluster to end") 49 | selected = controller.supervisor.selected 50 | s = controller.supervisor.clustering.spikes_in_clusters(selected) 51 | outliers2 = np.ones(len(s),dtype=int) 52 | controller.supervisor.actions.split(s,outliers2) 53 | 54 | 55 | @controller.supervisor.actions.add(shortcut='ctrl+b') 56 | def move_similar_to_end(): 57 | 58 | logger.warn("Moving selected similar cluster to end") 59 | sim = controller.supervisor.selected_similar 60 | s = controller.supervisor.clustering.spikes_in_clusters(sim) 61 | outliers2 = np.ones(len(s),dtype=int) 62 | controller.supervisor.actions.split(s,outliers2) 63 | 64 | -------------------------------------------------------------------------------- /plugins/Readme.md.txt: -------------------------------------------------------------------------------- 1 | Updated Version of some plugins for Phy2, originally created by Peter Petersen, plus one visualization for spikes 2 | violating the refractory period and a filter for the raw traces. 3 | 4 | To 'install', copy 'tempdir.py' in "*YourPhyDirectory*/phy/utils", copy 'phy_config.py' into 5 | "Users/*YourUserName*/.phy" and the remaining files into "Users/*YourUserName*/.phy/plugins". 6 | So far only tested once on a Windows machine. 7 | 8 | Thomas Hainmueller, 11/22/2019 -------------------------------------------------------------------------------- /plugins/SplitShortISI.py: -------------------------------------------------------------------------------- 1 | """Show how to write a custom split action.""" 2 | 3 | from phy import IPlugin, connect 4 | import numpy as np 5 | import logging 6 | 7 | logger = logging.getLogger('phy') 8 | 9 | 10 | 11 | class SplitShortISI(IPlugin): 12 | def attach_to_controller(self, controller): 13 | @connect 14 | def on_gui_ready(sender, gui): 15 | #@gui.edit_actions.add(shortcut='alt+i') 16 | @controller.supervisor.actions.add(shortcut='alt+i') 17 | def VisualizeShortISI(): 18 | """Split all spikes with an interspike interval of less than 1.5 ms into a separate 19 | cluster. THIS IS FOR VISUALIZATION ONLY, it will show you where potential noise 20 | spikes may be located. Re-merge the clusters again afterwards and cut the cluster with 21 | another method!""" 22 | 23 | logger.info('Detecting spikes with ISI less than 1.5 ms') 24 | 25 | # Selected clusters across the cluster view and similarity view. 26 | cluster_ids = controller.supervisor.selected 27 | 28 | # Get the amplitudes, using the same controller method as what the amplitude view 29 | # is using. 30 | # Note that we need load_all=True to load all spikes from the selected clusters, 31 | # instead of just the selection of them chosen for display. 32 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 33 | 34 | # We get the spike ids and the corresponding spike template amplitudes. 35 | # NOTE: in this example, we only consider the first selected cluster. 36 | spike_ids = bunchs[0].spike_ids 37 | spike_times = controller.model.spike_times[spike_ids] 38 | dspike_times = np.diff(spike_times) 39 | 40 | labels = np.ones(len(dspike_times),'int64') 41 | labels[dspike_times<.0015]=2 42 | labels = np.append(labels,1) #include last spike to match with len spike_ids 43 | 44 | # We perform the clustering algorithm, which returns an integer for each 45 | # subcluster. 46 | #labels = k_means(y.reshape((-1, 1))) 47 | assert spike_ids.shape == labels.shape 48 | 49 | # We split according to the labels. 50 | controller.supervisor.actions.split(spike_ids, labels) 51 | logger.info('Splitted short ISI spikes from main cluster') 52 | -------------------------------------------------------------------------------- /plugins/SplitShortISI_v2.py: -------------------------------------------------------------------------------- 1 | """Show how to write a custom split action.""" 2 | 3 | from phy import IPlugin, connect 4 | import numpy as np 5 | import logging 6 | 7 | logger = logging.getLogger('phy') 8 | 9 | 10 | 11 | class SplitShortISI(IPlugin): 12 | def attach_to_controller(self, controller): 13 | @connect 14 | def on_gui_ready(sender, gui): 15 | @controller.supervisor.actions.add(shortcut='alt+shift+i') 16 | #@gui.edit_actions.add(shortcut='alt+i') 17 | def VisualizeShortISI(): 18 | """Split all spikes with an interspike interval of less than 1.5 ms into a separate 19 | cluster. THIS IS FOR VISUALIZATION ONLY, it will show you where potential noise 20 | spikes may be located. Re-merge the clusters again afterwards and cut the cluster with 21 | another method!""" 22 | 23 | logger.info('Detecting spikes with ISI less than 1.5 ms') 24 | 25 | # Selected clusters across the cluster view and similarity view. 26 | cluster_ids = controller.supervisor.selected 27 | 28 | # Get the amplitudes, using the same controller method as what the amplitude view 29 | # is using. 30 | # Note that we need load_all=True to load all spikes from the selected clusters, 31 | # instead of just the selection of them chosen for display. 32 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 33 | 34 | # We get the spike ids and the corresponding spike template amplitudes. 35 | # NOTE: in this example, we only consider the first selected cluster. 36 | spike_ids = bunchs[0].spike_ids 37 | spike_times = controller.model.spike_times[spike_ids] 38 | dspike_times = np.diff(spike_times) 39 | 40 | labels = np.ones(len(dspike_times),'int64') 41 | labels[dspike_times<.0015]=2 42 | labels = np.append(labels,1) #include last spike to match with len spike_ids 43 | 44 | # We perform the clustering algorithm, which returns an integer for each 45 | # subcluster. 46 | #labels = k_means(y.reshape((-1, 1))) 47 | assert spike_ids.shape == labels.shape 48 | 49 | # We split according to the labels. 50 | controller.supervisor.actions.split(spike_ids, labels) 51 | logger.info('Splitted short ISI spikes from main cluster') 52 | -------------------------------------------------------------------------------- /plugins/good_labels.py: -------------------------------------------------------------------------------- 1 | from phy import IPlugin 2 | 3 | ''' 4 | Since phy lacks native event hooks or refresh triggers, making group_order update in real time requires workarounds that 5 | are not ideal for performance or reliability. The current implementation—where the metric is set up once at 6 | initialization—works well for sorting clusters in a static way, but it won't reflect updates in real time. 7 | Achieving true real-time updates would need phy to support a more event-driven or refreshable design. 8 | ''' 9 | 10 | class GoodLabelsPlugin(IPlugin): 11 | def attach_to_controller(self, controller): 12 | """ 13 | Attach the plugin to the controller. Sorts clusters with good first, then mua, then noise. 14 | """ 15 | 16 | def group_order(cluster_id): 17 | """Return a numeric value for sorting (good=1, mua=2, noise=3)""" 18 | metadata = controller.model.metadata or {} 19 | groups = metadata.get('group', {}) 20 | label = groups.get(cluster_id, None) 21 | 22 | order_dict = { 23 | 'good': 1, # good units first 24 | 'mua': 2, # then mua 25 | 'noise': 3, # then noise 26 | None: 4 # unlabeled last 27 | } 28 | return order_dict.get(label, 4) 29 | 30 | # Register the metric 31 | controller.cluster_metrics['group_order'] = group_order -------------------------------------------------------------------------------- /plugins/mahalanobis_v2.py: -------------------------------------------------------------------------------- 1 | from phy import IPlugin, connect 2 | import logging 3 | import numpy as np 4 | from scipy.stats import chi2 5 | from sklearn.preprocessing import StandardScaler 6 | import warnings 7 | import matplotlib.pyplot as plt 8 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 9 | from matplotlib.figure import Figure 10 | from qtpy import QtWidgets, QtCore 11 | import seaborn as sns 12 | 13 | logger = logging.getLogger('phy') 14 | 15 | 16 | class StableMahalanobisDetection(IPlugin): 17 | def __init__(self): 18 | super(StableMahalanobisDetection, self).__init__() 19 | self._shortcuts_created = False 20 | self.current_distances = None 21 | self.current_threshold = None 22 | self.plot_window = None 23 | self._spike_ids = None 24 | self._n_features = None 25 | self._feature_structure = None 26 | 27 | def attach_to_controller(self, controller): 28 | def get_feature_dimensions(features_arr): 29 | """Analyze the feature array structure to get actual dimensions""" 30 | try: 31 | # Get the feature dimensions from the model 32 | feature_shape = controller.model._load_features().data.shape 33 | if len(feature_shape) == 3: # (n_spikes, n_channels, n_pcs) 34 | n_channels = feature_shape[1] 35 | n_pcs = feature_shape[2] 36 | logger.info(f"Feature structure: {n_channels} channels with {n_pcs} PCs each") 37 | return n_channels * n_pcs 38 | else: 39 | logger.warn(f"Unexpected feature shape: {feature_shape}") 40 | return features_arr.shape[1] 41 | except Exception as e: 42 | logger.error(f"Error getting feature dimensions: {str(e)}") 43 | return features_arr.shape[1] 44 | 45 | def prepare_features(spike_ids): 46 | """Prepare feature matrix from spike data with proper dimensionality""" 47 | try: 48 | # Load features with original structure 49 | features = controller.model._load_features().data[spike_ids] 50 | 51 | # Log feature shape information 52 | logger.info(f"Original feature shape: {features.shape}") 53 | 54 | # Reshape maintaining actual structure 55 | features_flat = np.reshape(features, (features.shape[0], -1)) 56 | 57 | # Get actual feature dimensions 58 | self._n_features = get_feature_dimensions(features) 59 | logger.info(f"Using {self._n_features} feature dimensions for Mahalanobis distance") 60 | 61 | return features_flat 62 | 63 | except Exception as e: 64 | logger.error(f"Error preparing features: {str(e)}") 65 | return None 66 | 67 | def stable_mahalanobis(X): 68 | """Compute Mahalanobis distances with numerical stability safeguards""" 69 | if X is None or len(X) == 0: 70 | logger.error("Empty or invalid feature matrix") 71 | return None 72 | 73 | try: 74 | scaler = StandardScaler() 75 | X_scaled = scaler.fit_transform(X) 76 | cov = np.cov(X_scaled, rowvar=False) 77 | n_features = cov.shape[0] 78 | cov += np.eye(n_features) * 1e-6 79 | 80 | try: 81 | U, s, Vt = np.linalg.svd(cov) 82 | s[s < 1e-8] = 1e-8 83 | inv_cov = (U / s) @ Vt 84 | mu = np.mean(X_scaled, axis=0) 85 | diff = X_scaled - mu 86 | distances = np.sqrt(np.sum(diff @ inv_cov * diff, axis=1)) 87 | return np.nan_to_num(distances, nan=np.inf) 88 | 89 | except np.linalg.LinAlgError as e: 90 | logger.error(f"SVD failed: {str(e)}, falling back to diagonal covariance") 91 | inv_cov = np.diag(1.0 / np.diag(cov)) 92 | mu = np.mean(X_scaled, axis=0) 93 | diff = X_scaled - mu 94 | return np.sqrt(np.sum(diff @ inv_cov * diff, axis=1)) 95 | 96 | except Exception as e: 97 | logger.error(f"Error in Mahalanobis distance calculation: {str(e)}") 98 | return None 99 | 100 | def calculate_robust_threshold(distances): 101 | """Calculate default threshold based on chi-square distribution""" 102 | if distances is None or len(distances) == 0 or self._n_features is None: 103 | return None 104 | # Use 99.99% chi-square threshold as default (very conservative) 105 | return np.sqrt(chi2.ppf(0.9999, self._n_features)) 106 | 107 | def suggest_thresholds(distances): 108 | """Suggest thresholds with focus on empirical distribution""" 109 | if distances is None or len(distances) == 0: 110 | return {} 111 | 112 | try: 113 | # Calculate empirical thresholds 114 | empirical_thresholds = { 115 | 'pct_99': np.percentile(distances, 99), 116 | 'pct_999': np.percentile(distances, 99.9), 117 | 'pct_9999': np.percentile(distances, 99.99) # More conservative 118 | } 119 | 120 | # Add chi-square thresholds if dimensionality is available 121 | if self._n_features is not None: 122 | p = self._n_features 123 | chi2_thresh_999 = np.sqrt(chi2.ppf(0.999, p)) # More conservative (0.1% false positive rate) 124 | chi2_thresh_9999 = np.sqrt(chi2.ppf(0.9999, p)) # Very conservative (0.01% false positive rate) 125 | empirical_thresholds['χ²_999'] = chi2_thresh_999 126 | empirical_thresholds['χ²_9999'] = chi2_thresh_9999 127 | 128 | return empirical_thresholds 129 | 130 | except Exception as e: 131 | logger.error(f"Error calculating threshold suggestions: {str(e)}") 132 | return {} 133 | 134 | def plot_distribution(distances, threshold=None): 135 | """Create distribution plot with optional theoretical comparison""" 136 | if distances is None or len(distances) == 0: 137 | logger.error("No valid distances to plot") 138 | return 139 | 140 | if self.plot_window is None: 141 | self.plot_window = QtWidgets.QMainWindow() 142 | self.plot_window.setWindowTitle('Mahalanobis Distance Distribution') 143 | 144 | # Create widgets and layout 145 | widget = QtWidgets.QWidget() 146 | self.plot_window.setCentralWidget(widget) 147 | layout = QtWidgets.QVBoxLayout(widget) 148 | 149 | # Create figure 150 | fig = Figure(figsize=(12, 7)) 151 | canvas = FigureCanvas(fig) 152 | ax = fig.add_subplot(111) 153 | fig.subplots_adjust(right=0.85, bottom=0.15) 154 | 155 | # Plot range 156 | max_dist = np.max(distances) 157 | q99_9 = np.percentile(distances, 99.9) 158 | plot_max = min(max_dist, q99_9 * 1.2) 159 | 160 | # Plot empirical distribution 161 | n_bins = min(100, int(np.sqrt(len(distances)))) 162 | sns.histplot(distances, ax=ax, bins=n_bins, stat='density') 163 | ax.set_xlim(0, plot_max) 164 | 165 | # Add theoretical comparison if dimensions are known 166 | if self._n_features is not None: 167 | x = np.linspace(0, plot_max, 200) 168 | chi_density = 2 * x * chi2.pdf(x ** 2, self._n_features) 169 | ax.plot(x, chi_density, 'r--', alpha=0.3, 170 | label=f'χ² ({self._n_features} df)') 171 | 172 | # Labels and formatting 173 | ax.set_xlabel('Mahalanobis Distance', fontsize=10) 174 | ax.set_ylabel('Density', fontsize=10) 175 | ax.tick_params(labelsize=9) 176 | 177 | # Plot thresholds 178 | suggestions = suggest_thresholds(distances) 179 | colors = ['r', 'g', 'b', 'purple'] 180 | for (name, value), color in zip(suggestions.items(), colors): 181 | if value <= plot_max: 182 | n_spikes = np.sum(distances > value) 183 | ax.axvline(x=value, color=color, linestyle='--', 184 | label=f'{name}: {value:.1f}\n({n_spikes} spikes, {n_spikes / len(distances) * 100:.1f}%)') 185 | 186 | if threshold is not None and threshold <= plot_max: 187 | n_spikes = np.sum(distances > threshold) 188 | ax.axvline(x=threshold, color='black', linestyle='-', 189 | label=f'Current: {threshold:.1f}\n({n_spikes} spikes, {n_spikes / len(distances) * 100:.1f}%)') 190 | 191 | ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=9) 192 | 193 | # Create input widgets 194 | form_layout = QtWidgets.QFormLayout() 195 | threshold_input = QtWidgets.QLineEdit() 196 | threshold_input.setPlaceholderText('Enter threshold value') 197 | if threshold is not None: 198 | threshold_input.setText(str(threshold)) 199 | 200 | def on_return_pressed(): 201 | apply_button.click() 202 | 203 | threshold_input.returnPressed.connect(on_return_pressed) 204 | 205 | form_layout.addRow("Threshold:", threshold_input) 206 | 207 | # Button layout 208 | button_layout = QtWidgets.QHBoxLayout() 209 | button_layout.setSpacing(10) 210 | 211 | # Add preset buttons 212 | for name, value in suggestions.items(): 213 | preset_button = QtWidgets.QPushButton(f'Use {name}') 214 | preset_button.setMinimumWidth(100) 215 | preset_button.clicked.connect(lambda checked, v=value: threshold_input.setText(f"{v:.2f}")) 216 | button_layout.addWidget(preset_button) 217 | 218 | # Preview and Apply buttons 219 | preview_button = QtWidgets.QPushButton('Preview Selection') 220 | apply_button = QtWidgets.QPushButton('Apply Threshold') 221 | preview_button.setMinimumWidth(120) 222 | apply_button.setMinimumWidth(120) 223 | 224 | def on_preview(): 225 | try: 226 | new_threshold = float(threshold_input.text()) 227 | n_outliers = np.sum(distances > new_threshold) 228 | QtWidgets.QMessageBox.information( 229 | self.plot_window, 'Preview', 230 | f'This threshold would mark {n_outliers} spikes ({n_outliers / len(distances) * 100:.2f}%) as outliers.\n' 231 | f'Maximum distance: {np.max(distances):.1f}\n' 232 | f'99.9th percentile: {np.percentile(distances, 99.9):.1f}\n' 233 | f'99th percentile: {np.percentile(distances, 99):.1f}' 234 | ) 235 | except ValueError: 236 | logger.error("Invalid threshold value") 237 | 238 | def on_apply(): 239 | try: 240 | new_threshold = float(threshold_input.text()) 241 | if not new_threshold > 0: 242 | logger.error("Threshold must be positive") 243 | return 244 | 245 | self.current_threshold = new_threshold 246 | n_outliers = np.sum(distances > new_threshold) 247 | 248 | if n_outliers > len(distances) * 0.1: 249 | reply = QtWidgets.QMessageBox.question( 250 | self.plot_window, 'Warning', 251 | f'This threshold would mark {n_outliers} spikes ({n_outliers / len(distances) * 100:.1f}%) as outliers. Continue?', 252 | QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No 253 | ) 254 | if reply == QtWidgets.QMessageBox.No: 255 | return 256 | 257 | perform_outlier_detection(new_threshold, self.current_distances) 258 | self.plot_window.close() 259 | 260 | except ValueError: 261 | logger.error("Invalid threshold value") 262 | except Exception as e: 263 | logger.error(f"Error applying threshold: {str(e)}") 264 | 265 | preview_button.clicked.connect(on_preview) 266 | apply_button.clicked.connect(on_apply) 267 | 268 | # Add widgets to layout 269 | layout.addWidget(canvas) 270 | layout.addLayout(form_layout) 271 | layout.addLayout(button_layout) 272 | layout.addSpacing(10) 273 | button_row = QtWidgets.QHBoxLayout() 274 | button_row.addWidget(preview_button) 275 | button_row.addWidget(apply_button) 276 | layout.addLayout(button_row) 277 | 278 | # Set window size and prepare to show 279 | self.plot_window.resize(1600, 900) 280 | 281 | # Create timer to select text after window is shown 282 | def select_text(): 283 | threshold_input.setFocus() 284 | threshold_input.selectAll() 285 | 286 | # Use QTimer to ensure window is fully shown 287 | timer = QtCore.QTimer() 288 | timer.singleShot(100, select_text) 289 | 290 | # Show window 291 | self.plot_window.show() 292 | 293 | def perform_outlier_detection(threshold, distances): 294 | """Perform outlier detection with given threshold""" 295 | if distances is None or self._spike_ids is None: 296 | logger.warn("No distances or spike IDs available") 297 | return 298 | 299 | try: 300 | outliers = distances > threshold 301 | n_outliers = np.sum(outliers) 302 | 303 | # Log results 304 | logger.info(f"Analysis with threshold {threshold}:") 305 | logger.info(f"- Detected {n_outliers} outliers ({n_outliers / len(distances) * 100:.1f}%)") 306 | logger.info(f"- Maximum distance: {np.max(distances):.1f}") 307 | logger.info(f"- 99.9th percentile: {np.percentile(distances, 99.9):.1f}") 308 | logger.info(f"- 99th percentile: {np.percentile(distances, 99):.1f}") 309 | logger.info(f"- Median distance: {np.median(distances):.1f}") 310 | 311 | # Sort and display top distances 312 | sorted_dist = np.sort(distances)[-10:] 313 | logger.info(f"Top 10 distances: {', '.join(f'{d:.1f}' for d in sorted_dist)}") 314 | 315 | # Prepare for split 316 | if n_outliers > 0: 317 | labels = np.ones(len(distances), dtype=int) 318 | labels[outliers] = 2 319 | controller.supervisor.actions.split(self._spike_ids, labels) 320 | else: 321 | logger.info("No outliers detected at current threshold") 322 | 323 | except Exception as e: 324 | logger.error(f"Error in outlier detection: {str(e)}") 325 | 326 | @connect 327 | def on_gui_ready(sender, gui): 328 | if self._shortcuts_created: 329 | return 330 | self._shortcuts_created = True 331 | 332 | @controller.supervisor.actions.add(shortcut='alt+x') 333 | def stable_mahalanobis_outliers(): 334 | """ 335 | Stable Mahalanobis Outlier Detection with visualization (Alt+X) 336 | """ 337 | try: 338 | # Get selected clusters and spikes 339 | cluster_ids = controller.supervisor.selected 340 | if not cluster_ids: 341 | logger.warn("No clusters selected!") 342 | return 343 | 344 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 345 | self._spike_ids = bunchs[0].spike_ids 346 | 347 | # Prepare features 348 | features = prepare_features(self._spike_ids) 349 | if features is None: 350 | return 351 | 352 | # Minimum spikes check 353 | if features.shape[0] < features.shape[1] * 2: 354 | logger.warn(f"Warning: Need at least {features.shape[1] * 2} spikes!") 355 | return 356 | 357 | # Compute distances 358 | with warnings.catch_warnings(): 359 | warnings.simplefilter("ignore") 360 | distances = stable_mahalanobis(features) 361 | 362 | if distances is None: 363 | return 364 | 365 | # Store current distances 366 | self.current_distances = distances 367 | 368 | # Calculate initial threshold using chi-square distribution 369 | initial_threshold = calculate_robust_threshold(distances) 370 | if initial_threshold is None: 371 | return 372 | 373 | # Log distribution analysis 374 | logger.info("\nDistribution Analysis:") 375 | logger.info(f"Number of dimensions: {self._n_features}") 376 | logger.info(f"Expected mean distance (sqrt(p)): {np.sqrt(self._n_features):.2f}") 377 | logger.info(f"Observed mean distance: {np.mean(distances):.2f}") 378 | logger.info(f"Observed median distance: {np.median(distances):.2f}") 379 | 380 | # Check for substantial deviation from theoretical expectation 381 | expected_mean = np.sqrt(self._n_features) 382 | observed_mean = np.mean(distances) 383 | if abs(observed_mean - expected_mean) / expected_mean > 0.5: 384 | logger.warn(f"Substantial deviation from theoretical expectation:") 385 | logger.warn(f"This might indicate non-normal features or other irregularities.") 386 | 387 | # Show distribution plot 388 | plot_distribution(distances, initial_threshold) 389 | 390 | except Exception as e: 391 | logger.error(f"Error in stable_mahalanobis_outliers: {str(e)}") 392 | logger.error("Stack trace:", exc_info=True) -------------------------------------------------------------------------------- /plugins/raw_data_filter.py: -------------------------------------------------------------------------------- 1 | """Show how to add a custom raw data filter for the TraceView and Waveform View 2 | 3 | Use Alt+R in the GUI to toggle the filter. 4 | 5 | """ 6 | 7 | import numpy as np 8 | from scipy.signal import butter, filtfilt#lfilter 9 | 10 | from phy import IPlugin 11 | 12 | 13 | class RawDataFilterPlugin(IPlugin): 14 | def attach_to_controller(self, controller): 15 | b, a = butter(3, 150.0 / controller.model.sample_rate * 2.0, 'high') 16 | 17 | @controller.raw_data_filter.add_filter 18 | def high_pass(arr, axis=0): 19 | arr = filtfilt(b, a, arr, axis=axis) 20 | #arr = np.flip(arr, axis=axis) 21 | #arr = lfilter(b, a, arr, axis=axis) 22 | #arr = np.flip(arr, axis=axis) 23 | return arr 24 | -------------------------------------------------------------------------------- /plugins/readme.txt: -------------------------------------------------------------------------------- 1 | Updated Plugins for Phy2 2 | By Mingze Dou, 2025 3 | 4 | Original by Peter Petersen, Phy2 port by Thomas Hainmueller 5 | 6 | NEW FEATURES: 7 | 1. ImprovedISIAnalysis (alt+i): Better spike conflict detection 8 | 2. StableMahalanobisDetection (alt+x): Interactive outlier removal 9 | 3. ReclusterUMAP (alt+k): Modern clustering with UMAP 10 | 4. GoodLabelsPlugin: Sorts clusters by quality 11 | 12 | REQUIRES: 13 | pip install pandas numpy scipy scikit-learn umap-learn -------------------------------------------------------------------------------- /plugins/recluster.py: -------------------------------------------------------------------------------- 1 | from phy import IPlugin, connect 2 | 3 | import logging 4 | import os 5 | import numpy as np 6 | 7 | import platform 8 | 9 | from pathlib import Path 10 | from subprocess import Popen 11 | 12 | from phy.utils.tempdir import TemporaryDirectory 13 | from scipy.cluster.vq import kmeans2, whiten 14 | 15 | #logger = logging.getLogger(__name__) 16 | logger = logging.getLogger('phy') 17 | 18 | ##try: 19 | ## from klusta.launch import cluster2 20 | ##except ImportError: # pragma: no cover 21 | ## logger.warn("Package klusta not installed: the KwikGUI will not work.") 22 | # Not used 23 | try: 24 | import pandas as pd 25 | except ImportError: # pragma: no cover 26 | logger.warn("Package pandas not installed.") 27 | try: 28 | from phy.utils.config import phy_config_dir 29 | except ImportError: # pragma: no cover 30 | logger.warn("phy_config_dir not available.") 31 | 32 | class Recluster(IPlugin): 33 | def attach_to_controller(self, controller): 34 | @connect 35 | #@controller.supervisor.connect 36 | #def on_create_cluster_views(): 37 | def on_gui_ready(sender,gui): 38 | @controller.supervisor.actions.add(shortcut='alt+shift+k') 39 | def Recluster_Local_PCAs(): 40 | def write_fet(fet, filepath): 41 | with open(filepath, 'w') as fd: 42 | #header line: number of features 43 | fd.write('%i\n' % fet.shape[1]) 44 | #next lines: one feature vector per line 45 | for x in range(0,fet.shape[0]): 46 | fet[x,:].tofile(fd, sep="\t", format="%i") 47 | fd.write ("\n") 48 | #np.savetxt(fd, fet[0], fmt="%i", delimiter=' ') 49 | 50 | def read_clusters(filename_clu): 51 | clusters = load_text(filename_clu, np.int64) 52 | return process_clusters(clusters) 53 | def process_clusters(clusters): 54 | return clusters[1:] 55 | def load_text(filepath, dtype, skiprows=0, delimiter=' '): 56 | if not filepath: 57 | raise IOError("The filepath is empty.") 58 | with open(filepath, 'r') as f: 59 | for _ in range(skiprows): 60 | f.readline() 61 | x = pd.read_csv(f, header=None, 62 | sep=delimiter).values.astype(dtype).squeeze() 63 | return x 64 | 65 | """Relaunch KlustaKwik on selected clusters.""" 66 | # Selected clusters. 67 | cluster_ids = controller.supervisor.selected 68 | #spike_ids = controller.selector.select_spikes(cluster_ids) 69 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 70 | spike_ids = bunchs[0].spike_ids 71 | logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) 72 | # s = controller.supervisor.clustering.spikes_in_clusters(cluster_ids) 73 | data3 = controller.model._load_features().data[spike_ids] 74 | fet2 = np.reshape(data3,(data3.shape[0],data3.shape[1]*data3.shape[2])) 75 | 76 | dtype = np.int64 77 | factor = 2.**60 78 | #dtype = np.int32 79 | #factor = 2.**31 80 | factor = factor/np.abs(fet2).max() 81 | fet2 = (fet2 * factor).astype(dtype) 82 | # logger.warn(str(fet2[0,:])) 83 | 84 | # Run KK2 in a temporary directory to avoid side effects. 85 | # n = 10 86 | # spike_times = controller.model.spike_times[spike_ids]*controller.model.sample_rate 87 | 88 | #spike_times = convert_dtype(spike_times, np.int32) 89 | # times = np.expand_dims(spike_times, axis =1) 90 | 91 | # fet = 1000*np.concatenate((fet2,times),axis = 1) 92 | fet = fet2 93 | 94 | name = 'tempClustering' 95 | shank = 3 96 | mainfetfile = os.path.join(name + '.fet.' + str(shank)) 97 | write_fet(fet, mainfetfile) 98 | if platform.system() == 'Windows': 99 | program = os.path.join(phy_config_dir(),'klustakwik.exe') 100 | else: 101 | program = '~/klustakwik/KlustaKwik' 102 | cmd = [program, name, str(shank)] 103 | cmd +=["-UseDistributional",'0',"-MaxPossibleClusters",'20',"-MinClusters",'20'] #,"-MinClusters",'2',"-MaxClusters",'12' ,"-MaxClusters",'12',"-MaxClusters",'12' 104 | 105 | # Run KlustaKwik 106 | p = Popen(cmd) 107 | p.wait() 108 | # Read back the clusters 109 | spike_clusters = read_clusters(name + '.clu.' + str(shank)) 110 | controller.supervisor.actions.split(spike_ids, spike_clusters) 111 | logger.warn("Reclustering complete!") 112 | 113 | @controller.supervisor.actions.add(shortcut='alt+shift+t') 114 | def Recluster_Global_PCAs(): 115 | def write_fet(fet, filepath): 116 | with open(filepath, 'w') as fd: 117 | #header line: number of features 118 | fd.write('%i\n' % fet.shape[1]) 119 | #next lines: one feature vector per line 120 | for x in range(0,fet.shape[0]): 121 | fet[x,:].tofile(fd, sep="\t", format="%i") 122 | fd.write ("\n") 123 | #np.savetxt(fd, fet[0], fmt="%i", delimiter=' ') 124 | 125 | def read_clusters(filename_clu): 126 | clusters = load_text(filename_clu, np.int64) 127 | return process_clusters(clusters) 128 | def process_clusters(clusters): 129 | return clusters[1:] 130 | def load_text(filepath, dtype, skiprows=0, delimiter=' '): 131 | if not filepath: 132 | raise IOError("The filepath is empty.") 133 | with open(filepath, 'r') as f: 134 | for _ in range(skiprows): 135 | f.readline() 136 | x = pd.read_csv(f, header=None, 137 | sep=delimiter).values.astype(dtype).squeeze() 138 | return x 139 | 140 | """Relaunch KlustaKwik on selected clusters.""" 141 | # Selected clusters. 142 | cluster_ids = controller.supervisor.selected 143 | #spike_ids = controller.selector.select_spikes(cluster_ids) 144 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 145 | spike_ids = bunchs[0].spike_ids 146 | logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) 147 | # s = controller.supervisor.clustering.spikes_in_clusters(cluster_ids) 148 | data3 = controller.model._load_features().data[spike_ids] 149 | fet2 = np.reshape(data3,(data3.shape[0],data3.shape[1]*data3.shape[2])) 150 | 151 | dtype = np.int64 152 | factor = 2.**60 153 | #dtype = np.int32 154 | #factor = 2.**31 155 | factor = factor/np.abs(fet2).max() 156 | fet2 = (fet2 * factor).astype(dtype) 157 | # logger.warn(str(fet2[0,:])) 158 | 159 | # Run KK2 in a temporary directory to avoid side effects. 160 | # n = 10 161 | # spike_times = controller.model.spike_times[spike_ids]*controller.model.sample_rate 162 | 163 | #spike_times = convert_dtype(spike_times, np.int32) 164 | # times = np.expand_dims(spike_times, axis =1) 165 | 166 | # fet = 1000*np.concatenate((fet2,times),axis = 1) 167 | fet = fet2 168 | 169 | name = 'tempClustering' 170 | shank = 3 171 | mainfetfile = os.path.join(name + '.fet.' + str(shank)) 172 | write_fet(fet, mainfetfile) 173 | if platform.system() == 'Windows': 174 | program = os.path.join(phy_config_dir(),'klustakwik.exe') 175 | else: 176 | program = '~/klustakwik/KlustaKwik' 177 | cmd = [program, name, str(shank)] 178 | cmd +=["-UseDistributional",'0'] # ,"-MinClusters",'2',"-MaxClusters",'12' 179 | 180 | # Run KlustaKwik. 181 | p = Popen(cmd) 182 | p.wait() 183 | # Read back the clusters 184 | spike_clusters = read_clusters(name + '.clu.' + str(shank)) 185 | controller.supervisor.actions.split(spike_ids, spike_clusters) 186 | logger.warn("Reclustering complete!") 187 | 188 | @controller.supervisor.actions.add(shortcut='alt+shift+q', prompt=True, prompt_default=lambda: 2) 189 | def K_means_clustering(kmeanclusters): 190 | """Select number of clusters. 191 | 192 | Example: `2` 193 | 194 | """ 195 | logger.warn("Running K-means clustering") 196 | 197 | cluster_ids = controller.supervisor.selected 198 | #spike_ids = controller.selector.select_spikes(cluster_ids) 199 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 200 | spike_ids = bunchs[0].spike_ids 201 | s = controller.supervisor.clustering.spikes_in_clusters(cluster_ids) 202 | data = controller.model._load_features() 203 | data3 = data.data[spike_ids] 204 | data2 = np.reshape(data3,(data3.shape[0],data3.shape[1]*data3.shape[2])) 205 | whitened = whiten(data2) 206 | clusters_out,label = kmeans2(whitened,kmeanclusters) 207 | controller.supervisor.actions.split(s,label) 208 | logger.warn("K means clustering complete") 209 | 210 | #@controller.supervisor.actions.add(shortcut='alt+x') 211 | @controller.supervisor.actions.add(shortcut='alt+shift+x', prompt=True, prompt_default=lambda: 14) 212 | def MahalanobisDist(thres_in): 213 | """Select threshold in STDs. 214 | 215 | Example: `14` 216 | 217 | """ 218 | logger.warn("Removing outliers by Mahalanobis distance") 219 | 220 | def MahalanobisDistCalc2(x, y): 221 | covariance_xy = np.cov(x,y, rowvar=0) 222 | inv_covariance_xy = np.linalg.inv(covariance_xy) 223 | xy_mean = np.mean(x),np.mean(y) 224 | x_diff = np.array([x_i - xy_mean[0] for x_i in x]) 225 | y_diff = np.array([y_i - xy_mean[1] for y_i in y]) 226 | diff_xy = np.transpose([x_diff, y_diff]) 227 | md = [] 228 | for i in range(len(diff_xy)): 229 | md.append(np.sqrt(np.dot(np.dot(np.transpose(diff_xy[i]),inv_covariance_xy),diff_xy[i]))) 230 | return md 231 | 232 | def MahalanobisDistCalc(X, Y): 233 | rx = X.shape[0] 234 | cx = X.shape[1] 235 | ry = Y.shape[0] 236 | cy = Y.shape[1] 237 | 238 | m = np.mean(X, axis=0) 239 | M = np.tile(m,(ry,1)) 240 | C = X - np.tile(m,(rx,1)) 241 | Q, R = np.linalg.qr(C) 242 | ri,ri2,ri3,ri4 = np.linalg.lstsq(np.transpose(R),np.transpose(Y-M)) 243 | d = np.transpose(np.sum(ri*ri, axis=0)).dot(rx-1) 244 | return d 245 | 246 | cluster_ids = controller.supervisor.selected 247 | #spike_ids = controller.selector.select_spikes(cluster_ids) 248 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 249 | spike_ids = bunchs[0].spike_ids 250 | s = controller.supervisor.clustering.spikes_in_clusters(cluster_ids) 251 | data = controller.model._load_features() 252 | data3 = data.data[spike_ids] 253 | data2 = np.reshape(data3,(data3.shape[0],data3.shape[1]*data3.shape[2])) 254 | if data2.shape[0] < data2.shape[1]: 255 | logger.warn("Error: Not enought spikes in the cluster!") 256 | return 257 | 258 | MD = MahalanobisDistCalc(data2,data2) 259 | #threshold = 16**2 260 | threshold = thres_in**2 261 | outliers = np.where(MD > threshold)[0] 262 | outliers2 = np.ones(len(s),dtype=int) 263 | outliers2[outliers] = 2 264 | logger.info("Outliers detected: %d.", len(outliers)) 265 | controller.supervisor.actions.split(s,outliers2) 266 | -------------------------------------------------------------------------------- /plugins/recluster_v2.py: -------------------------------------------------------------------------------- 1 | from phy import IPlugin, connect 2 | import logging 3 | import numpy as np 4 | from scipy.spatial.distance import cdist 5 | from sklearn.preprocessing import StandardScaler 6 | from sklearn.decomposition import PCA 7 | from sklearn.mixture import GaussianMixture 8 | from sklearn.cluster import MiniBatchKMeans 9 | import umap 10 | 11 | logger = logging.getLogger('phy') 12 | 13 | 14 | class ReclusterUMAP(IPlugin): 15 | """ 16 | Modern spike sorting plugin with optimized performance and intelligent merging 17 | """ 18 | 19 | def __init__(self): 20 | super(ReclusterUMAP, self).__init__() 21 | self._shortcuts_created = False 22 | self._umap_reducer = None 23 | self._last_n_spikes = None 24 | 25 | def attach_to_controller(self, controller): 26 | def prepareFeatures(spikeIds): 27 | data = controller.model._load_features().data[spikeIds] 28 | features = np.reshape(data, (data.shape[0], -1)) 29 | return features 30 | 31 | def compute_template_correlation(features, labels): 32 | """Compute correlation between cluster templates""" 33 | n_clusters = len(np.unique(labels[labels > 0])) 34 | correlations = np.zeros((n_clusters, n_clusters)) 35 | 36 | for i in range(n_clusters): 37 | template_i = np.mean(features[labels == i + 1], axis=0) 38 | for j in range(i + 1, n_clusters): 39 | template_j = np.mean(features[labels == j + 1], axis=0) 40 | # Correlation between templates 41 | corr = np.corrcoef(template_i, template_j)[0, 1] 42 | correlations[i, j] = correlations[j, i] = corr 43 | 44 | return correlations 45 | 46 | def check_spatial_consistency(features, labels, threshold=0.6): 47 | """Check if clusters are spatially consistent""" 48 | n_clusters = len(np.unique(labels[labels > 0])) 49 | spatial_consistent = np.zeros((n_clusters, n_clusters), dtype=bool) 50 | 51 | # Assuming features contain channel information 52 | for i in range(n_clusters): 53 | spikes_i = features[labels == i + 1] 54 | channels_i = np.var(spikes_i, axis=0).argsort()[-4:] # Top 4 channels 55 | 56 | for j in range(i + 1, n_clusters): 57 | spikes_j = features[labels == j + 1] 58 | channels_j = np.var(spikes_j, axis=0).argsort()[-4:] 59 | 60 | # Check channel overlap 61 | common_channels = len(set(channels_i) & set(channels_j)) 62 | spatial_consistent[i, j] = spatial_consistent[j, i] = ( 63 | common_channels >= len(channels_i) * threshold 64 | ) 65 | 66 | return spatial_consistent 67 | 68 | def merge_similar_clusters(features, labels, template_threshold=0.9, spatial_threshold=0.6): 69 | """Merge clusters based on template similarity and spatial consistency""" 70 | while True: 71 | n_clusters = len(np.unique(labels[labels > 0])) 72 | if n_clusters <= 2: # Don't merge if only 2 clusters remain 73 | break 74 | 75 | # Compute similarity matrices 76 | correlations = compute_template_correlation(features, labels) 77 | spatial_consistent = check_spatial_consistency(features, labels, spatial_threshold) 78 | 79 | # Find most similar pair that's spatially consistent 80 | max_corr = template_threshold 81 | merge_pair = None 82 | 83 | for i in range(n_clusters): 84 | for j in range(i + 1, n_clusters): 85 | if (correlations[i, j] > max_corr and spatial_consistent[i, j]): 86 | max_corr = correlations[i, j] 87 | merge_pair = (i, j) 88 | 89 | if merge_pair is None: 90 | break 91 | 92 | # Perform merge 93 | i, j = merge_pair 94 | new_labels = np.zeros_like(labels) 95 | for k in range(n_clusters): 96 | if k == i: 97 | new_labels[labels == k + 1] = i + 1 98 | elif k == j: 99 | new_labels[labels == j + 1] = i + 1 100 | elif k > j: 101 | new_labels[labels == k + 1] = k 102 | else: 103 | new_labels[labels == k + 1] = k + 1 104 | 105 | labels = new_labels 106 | logger.info(f"Merged clusters {i + 1} and {j + 1} (correlation: {max_corr:.3f})") 107 | 108 | return labels 109 | 110 | def fastClustering(embedding, target_clusters=4): 111 | """Fast clustering with intelligent merging""" 112 | # Initial over-clustering 113 | initial_clusters = min(target_clusters * 3, len(embedding) // 50) 114 | 115 | # Initial clustering 116 | kmeans = MiniBatchKMeans( 117 | n_clusters=initial_clusters, 118 | batch_size=1000, 119 | random_state=42 120 | ) 121 | initial_labels = kmeans.fit_predict(embedding) + 1 # Make labels 1-based 122 | 123 | # Merge similar clusters 124 | final_labels = merge_similar_clusters( 125 | embedding, 126 | initial_labels, 127 | template_threshold=0.9, 128 | spatial_threshold=0.6 129 | ) 130 | 131 | return final_labels 132 | 133 | @connect 134 | def on_gui_ready(sender, gui): 135 | if self._shortcuts_created: 136 | return 137 | self._shortcuts_created = True 138 | 139 | @controller.supervisor.actions.add(shortcut='alt+k', prompt=True, prompt_default=lambda: 4) 140 | def umapGmmClustering(target_clusters): 141 | """Fast UMAP-GMM Clustering with intelligent merging (Alt+K)""" 142 | try: 143 | target_clusters = int(target_clusters) 144 | if target_clusters < 2: 145 | logger.warn("Need at least 2 clusters, using 2") 146 | target_clusters = 2 147 | 148 | clusterIds = controller.supervisor.selected 149 | if not clusterIds: 150 | logger.warn("No clusters selected!") 151 | return 152 | 153 | bunchs = controller._amplitude_getter(clusterIds, name='template', load_all=True) 154 | spikeIds = bunchs[0].spike_ids 155 | n_spikes = len(spikeIds) 156 | logger.info(f"Processing {n_spikes} spikes with target {target_clusters} clusters") 157 | 158 | # Feature preparation 159 | features = prepareFeatures(spikeIds) 160 | scaler = StandardScaler() 161 | featuresScaled = scaler.fit_transform(features) 162 | 163 | # Dimensionality reduction 164 | pca = PCA(n_components=min(30, featuresScaled.shape[1])) 165 | featuresPca = pca.fit_transform(featuresScaled) 166 | 167 | # UMAP reduction 168 | if (self._umap_reducer is None or 169 | self._last_n_spikes is None or 170 | abs(self._last_n_spikes - n_spikes) > n_spikes * 0.2): 171 | self._umap_reducer = umap.UMAP( 172 | n_neighbors=min(30, n_spikes // 100), 173 | min_dist=0.2, 174 | n_components=2, 175 | random_state=42, 176 | n_jobs=-1, 177 | metric='euclidean', 178 | low_memory=True 179 | ) 180 | self._last_n_spikes = n_spikes 181 | 182 | embedding = self._umap_reducer.fit_transform(featuresPca) 183 | 184 | # Clustering with merging 185 | labels = fastClustering(embedding, target_clusters) 186 | n_clusters = len(np.unique(labels)) 187 | 188 | logger.info(f"Created {n_clusters} clusters after merging") 189 | controller.supervisor.actions.split(spikeIds, labels) 190 | 191 | except Exception as e: 192 | logger.error(f"Error in umapGmmClustering: {str(e)}") 193 | 194 | # Keep the existing templateBasedSplit function unchanged 195 | @controller.supervisor.actions.add(shortcut='alt+t', prompt=True, prompt_default=lambda: 0.85) 196 | def templateBasedSplit(similarityThreshold): 197 | """Template-based spike sorting (Alt+T)""" 198 | # ... [rest of the existing templateBasedSplit code remains unchanged] -------------------------------------------------------------------------------- /plugins/shortISI_v3.py: -------------------------------------------------------------------------------- 1 | from phy import IPlugin, connect 2 | import numpy as np 3 | import logging 4 | from scipy.stats import zscore 5 | from sklearn.preprocessing import StandardScaler 6 | from scipy.spatial.distance import cdist 7 | 8 | logger = logging.getLogger('phy') 9 | 10 | 11 | class ImprovedISIAnalysis(IPlugin): 12 | """More reliable spike analysis using combined metrics""" 13 | 14 | def __init__(self): 15 | super(ImprovedISIAnalysis, self).__init__() 16 | self._shortcuts_created = False 17 | 18 | def attach_to_controller(self, controller): 19 | def get_waveform_features(spike_ids): 20 | """Extract key waveform features""" 21 | # Get waveforms 22 | data = controller.model._load_features().data[spike_ids] 23 | return np.reshape(data, (data.shape[0], -1)) 24 | 25 | def analyze_suspicious_spikes(spike_times, spike_amps, waveforms, isi_threshold=0.0015): 26 | """ 27 | Analyze spikes with multiple metrics: 28 | - ISI violations 29 | - Amplitude changes 30 | - Waveform changes 31 | """ 32 | n_spikes = len(spike_times) 33 | suspicious = np.zeros(n_spikes, dtype=bool) 34 | 35 | # Find ISI violations 36 | isi_prev = np.diff(spike_times, prepend=spike_times[0] - 1) 37 | isi_next = np.diff(spike_times, append=spike_times[-1] + 1) 38 | 39 | # Look for changes in nearby spikes 40 | for i in range(n_spikes): 41 | if isi_prev[i] < isi_threshold or isi_next[i] < isi_threshold: 42 | # For spikes with short ISI, check for: 43 | 44 | # 1. Amplitude changes 45 | amp_window = slice(max(0, i - 1), min(n_spikes, i + 2)) 46 | amp_variation = np.std(spike_amps[amp_window]) 47 | 48 | # 2. Waveform changes 49 | wave_window = slice(max(0, i - 1), min(n_spikes, i + 2)) 50 | waves = waveforms[wave_window] 51 | wave_distances = cdist(waves, waves, metric='correlation') 52 | wave_variation = np.mean(wave_distances) 53 | 54 | # Mark as suspicious if there are significant changes 55 | if (amp_variation > np.std(spike_amps) * 1.5 or 56 | wave_variation > 0.1): # Correlation distance threshold 57 | suspicious[i] = True 58 | 59 | return suspicious 60 | 61 | @connect 62 | def on_gui_ready(sender, gui): 63 | if self._shortcuts_created: 64 | return 65 | self._shortcuts_created = True 66 | 67 | @controller.supervisor.actions.add(shortcut='alt+i') 68 | def analyze_spike_patterns(): 69 | """ 70 | Analyze spike patterns using multiple metrics: 71 | - ISI violations 72 | - Amplitude changes 73 | - Waveform changes 74 | Only splits when multiple criteria suggest different units. 75 | """ 76 | try: 77 | # Get selected clusters 78 | cluster_ids = controller.supervisor.selected 79 | if not cluster_ids: 80 | logger.warn("No clusters selected!") 81 | return 82 | 83 | # Get spike data 84 | bunchs = controller._amplitude_getter(cluster_ids, name='template', load_all=True) 85 | spike_ids = bunchs[0].spike_ids 86 | spike_times = controller.model.spike_times[spike_ids] 87 | 88 | # Get spike amplitudes 89 | spike_amps = bunchs[0].amplitudes 90 | 91 | # Get waveform features 92 | waveforms = get_waveform_features(spike_ids) 93 | 94 | # Analyze spikes 95 | suspicious = analyze_suspicious_spikes( 96 | spike_times, 97 | spike_amps, 98 | waveforms 99 | ) 100 | 101 | # Prepare labels 102 | labels = np.ones(len(spike_ids), dtype=int) 103 | labels[suspicious] = 2 104 | 105 | # Count suspicious spikes 106 | n_suspicious = np.sum(suspicious) 107 | 108 | if n_suspicious > 0: 109 | # Log analysis results 110 | logger.info(f"Found {n_suspicious} suspicious spikes " 111 | f"({n_suspicious / len(spike_ids) * 100:.1f}%) " 112 | f"with notable physical changes") 113 | 114 | # Only split if we found enough suspicious spikes 115 | if n_suspicious >= 10 and n_suspicious <= len(spike_ids) * 0.5: 116 | controller.supervisor.actions.split(spike_ids, labels) 117 | logger.info("Split suspicious spikes for manual review") 118 | else: 119 | logger.info("Too few or too many suspicious spikes for reliable splitting") 120 | else: 121 | logger.info("No suspicious spikes found") 122 | 123 | except Exception as e: 124 | logger.error(f"Error in analyze_spike_patterns: {str(e)}") -------------------------------------------------------------------------------- /plugins/tempdir.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Temporary directory used in unit tests.""" 4 | 5 | #------------------------------------------------------------------------------ 6 | # Imports 7 | #------------------------------------------------------------------------------ 8 | 9 | import warnings as _warnings 10 | import os as _os 11 | from tempfile import mkdtemp 12 | 13 | #from ..ext import six 14 | import six 15 | 16 | #------------------------------------------------------------------------------ 17 | # Temporary directory 18 | #------------------------------------------------------------------------------ 19 | 20 | class TemporaryDirectory(object): 21 | """Create and return a temporary directory. This has the same 22 | behavior as mkdtemp but can be used as a context manager. For 23 | example: 24 | with TemporaryDirectory() as tmpdir: 25 | ... 26 | Upon exiting the context, the directory and everything contained 27 | in it are removed. 28 | """ 29 | def __init__(self, suffix="", prefix="tmp", dir=None): 30 | self._closed = False 31 | self.name = None # Handle mkdtemp raising an exception 32 | self.name = mkdtemp(suffix, prefix, dir) 33 | 34 | def __repr__(self): 35 | return "<{} {!r}>".format(self.__class__.__name__, self.name) 36 | 37 | def __enter__(self): 38 | return self.name 39 | 40 | def cleanup(self, _warn=False): 41 | if self.name and not self._closed: 42 | try: 43 | self._rmtree(self.name) 44 | except (TypeError, AttributeError) as ex: 45 | # Issue #10188: Emit a warning on stderr 46 | # if the directory could not be cleaned 47 | # up due to missing globals 48 | if "None" not in str(ex): 49 | raise 50 | print("ERROR: {!r} while cleaning up {!r}".format(ex, self,), 51 | file=_sys.stderr) 52 | return 53 | self._closed = True 54 | if _warn: 55 | self._warn("Implicitly cleaning up {!r}".format(self), 56 | ResourceWarning) 57 | 58 | def __exit__(self, exc, value, tb): 59 | self.cleanup() 60 | 61 | def __del__(self): 62 | # Issue a ResourceWarning if implicit cleanup needed 63 | self.cleanup(_warn=True) 64 | 65 | # XXX (ncoghlan): The following code attempts to make 66 | # this class tolerant of the module nulling out process 67 | # that happens during CPython interpreter shutdown 68 | # Alas, it doesn't actually manage it. See issue #10188 69 | _listdir = staticmethod(_os.listdir) 70 | _path_join = staticmethod(_os.path.join) 71 | _isdir = staticmethod(_os.path.isdir) 72 | _islink = staticmethod(_os.path.islink) 73 | _remove = staticmethod(_os.remove) 74 | _rmdir = staticmethod(_os.rmdir) 75 | _warn = _warnings.warn 76 | 77 | def _rmtree(self, path): 78 | # Essentially a stripped down version of shutil.rmtree. We can't 79 | # use globals because they may be None'ed out at shutdown. 80 | for name in self._listdir(path): 81 | fullname = self._path_join(path, name) 82 | try: 83 | isdir = self._isdir(fullname) and not self._islink(fullname) 84 | except OSError: 85 | isdir = False 86 | if isdir: 87 | self._rmtree(fullname) 88 | else: 89 | try: 90 | self._remove(fullname) 91 | except OSError: 92 | pass 93 | try: 94 | self._rmdir(path) 95 | except OSError: 96 | pass 97 | --------------------------------------------------------------------------------