├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── TODO.md ├── pyproject.toml ├── screenshot.png └── spikeinterface_gui ├── __init__.py ├── backend_panel.py ├── backend_qt.py ├── basescatterview.py ├── controller.py ├── crosscorrelogramview.py ├── curation_tools.py ├── curationview.py ├── img └── si.png ├── isiview.py ├── layout_presets.py ├── main.py ├── mainsettingsview.py ├── mergeview.py ├── myqt.py ├── ndscatterview.py ├── probeview.py ├── similarityview.py ├── spikeamplitudeview.py ├── spikedepthview.py ├── spikelist.py ├── tests ├── debug_views.py ├── iframe │ ├── iframe_server.py │ ├── iframe_test.html │ └── iframe_test_README.md ├── test_controller.py ├── test_curation_tools.py ├── test_mainwindow_panel.py ├── test_mainwindow_qt.py └── testingtools.py ├── tracemapview.py ├── traceview.py ├── unitlist.py ├── utils_panel.py ├── utils_qt.py ├── version.py ├── view_base.py ├── viewlist.py ├── waveformheatmapview.py └── waveformview.py /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Release to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | deploy: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.11 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.11 19 | - name: Install dependencies for testing 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine build 23 | pip install -e . 24 | # pip install PySide6 25 | # pip install .[test] 26 | #- name: Test core with pytest 27 | # run: | 28 | # pytest spikeinterface-gui 29 | - name: Package and Upload 30 | env: 31 | TWINE_USERNAME: __token__ 32 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 33 | run: | 34 | python -m build --sdist --wheel 35 | twine upload dist/* 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | spikeinterface_gui/tests/*/* 2 | my_dataset/* 3 | 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SpikeInterface 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spikeinterface-gui 2 | 3 | GUI for the `SortingAnalyser` object from spikeinterface . 4 | 5 | This is a cross platform interactive viewer to inspect the final results 6 | and quality of any spike sorter supported by spikeinterface 7 | (kilosort, spykingcircus, tridesclous, mountainssort, yass, ironclust, herdingspikes, hdsort, klusta...) 8 | 9 | This interactive GUI offer several views that dynamically refresh other views. 10 | This allows us to very quickly check the strengths and weaknesses of any sorter output. 11 | 12 | This can be used as a replacement of [phy](https://github.com/cortex-lab/phy). 13 | 14 | This viewer has 2 modes: 15 | * **mode=desktop** : this a local desktop app using internaly Qt, fast and easy when the data is local 16 | * **mode=web** : this is a web app internally using Panel, usefull when the data is remote 17 | 18 | 19 |  20 | 21 | 22 | ## main usage 23 | 24 | The main idea is make visible one or several unit and visualy inspect if they should be merge or remove. 25 | For this visibility: 26 | * ctlr + double click on a unit in *probeview* 27 | * check the box visible in the *unitlist* 28 | * double click on one unit in *unitlist* unit visible alone 29 | * move one of the roi in the *probeview* 30 | 31 | Views can be reorganized by moving docks by clicking in the title bar of a docks. 32 | Any dock (view) can be closed. And can be put back with right click in any title bar of any dock. 33 | 34 | Every view has a **?** button which open the contextual help. **Theses inplace docs are the most important stuff to be read**. (but the contains typos) 35 | 36 | When some units are visible, the related spike list can be refresh. 37 | Then selecting spike per spike can also refersh some views. 38 | This enable a very quick and convinient spike per spike jump on traces. 39 | 40 | Channel visibility can be handled with one of the roi in the probeview. 41 | 42 | Shortcuts: many shortcuts are available, please read the **?** button in each view. 43 | 44 | ## curation mode 45 | 46 | By default this tools is a viewer only. But you can turn it into a tools for manual curation using, 47 | the `curation=True` option. 48 | This tools supoort the [curation format from spikeinterface](https://spikeinterface.readthedocs.io/en/latest/modules/curation.html#manual-curation). 49 | This format enbale to: 50 | 1. remove units 51 | 2. merge units 52 | 3. create manual labels 53 | 54 | When this mode is activated a new view is added on top left to maintain the list of removal and merges. 55 | The curation format can be exported to json. 56 | 57 | 58 | 59 | ## Launch 60 | 61 | In order to use this viewer you will need to know a bit of [spikeinterface](https://spikeinterface.readthedocs.io/) 62 | 63 | ### Step 1 : create and compute SortingAnalyzer 64 | 65 | You first need to is to get a `SortingAnalyzer` object with spikeinterface. 66 | 67 | See help [here](https://spikeinterface.readthedocs.io) 68 | 69 | Note that: 70 | * some extensions are mandatory (unit_location, templates, ) 71 | * some extension are optional 72 | * the more extensions are computed the more view are displayed 73 | 74 | 75 | Example: 76 | 77 | ```python 78 | import spikeinterface.full as si 79 | recording = si.read_XXXX('/path/to/my/recording') 80 | recording_filtered = si.bandpass_filter(recording) 81 | sorting = si.run_sorter('YYYYY', recording_filtered) 82 | 83 | 84 | job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") 85 | 86 | # make the SortingAnalyzer with necessary and some optional extensions 87 | sorting_analyzer = si.create_sorting_analyzer(sorting, recording_filtered, 88 | format="binary_folder", folder="/my_sorting_analyzer", 89 | **job_kwargs) 90 | sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) 91 | sorting_analyzer.compute("waveforms", **job_kwargs) 92 | sorting_analyzer.compute("templates", **job_kwargs) 93 | sorting_analyzer.compute("noise_levels") 94 | sorting_analyzer.compute("unit_locations", method="monopolar_triangulation") 95 | sorting_analyzer.compute("isi_histograms") 96 | sorting_analyzer.compute("correlograms", window_ms=100, bin_ms=5.) 97 | sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs) 98 | sorting_analyzer.compute("quality_metrics", metric_names=["snr", "firing_rate"]) 99 | sorting_analyzer.compute("template_similarity") 100 | sorting_analyzer.compute("spike_amplitudes", **job_kwargs) 101 | 102 | ``` 103 | 104 | 105 | ### Step 2 : open the GUI 106 | 107 | With python: 108 | 109 | ```python 110 | from spikeinterface_gui import run_mainwindow 111 | # reload the SortingAnalyzer 112 | sorting_analyzer = si.load_sorting_analyzer("/my_sorting_analyzer") 113 | # open and run the Qt app 114 | run_mainwindow(sorting_analyzer, mode="desktop") 115 | # open and run the Web app 116 | run_mainwindow(sorting_analyzer, mode="web") 117 | ``` 118 | 119 | Or from spikeinterface: 120 | 121 | ```python 122 | import spikeinterface.widgets as sw 123 | sorting_analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") 124 | sw.plot_sorting_summary(sorting_analyzer, backend="spikeinterface_gui") 125 | ``` 126 | 127 | 128 | With the command line 129 | 130 | ```bash 131 | sigui /path/for/my/sorting_analyzer 132 | ``` 133 | 134 | 135 | The command line support some otions like *--notraces* or *--curation* or *--mode* 136 | ```bash 137 | sigui --mode=web --no-traces --curation /path/for/my/sorting_analyzer 138 | ``` 139 | 140 | 141 | 142 | ## With curation mode 143 | 144 | 145 | To open the viewer with curation mode use `curation=True`. 146 | 147 | This mode is pretty new and was implemented under kind inducement of friends. 148 | I hope that this could be a fair replacement of `phy`. 149 | 150 | 151 | ```python 152 | from spikeinterface_gui import run_mainwindow 153 | run_mainwindow(sorting_analyzer, curation=True) 154 | ``` 155 | 156 | 157 | ```python 158 | from spikeinterface.widgets import plot_sorting_summary 159 | sw.plot_sorting_summary(sorting_analyzer, curation=True, backend="spikeinterface_gui") 160 | ``` 161 | 162 | The `curation_dict` can be saved inside the folder of the analyzer (for "binary_folder" or "zarr" format). 163 | Then it is auto-reloaded when the gui is re-opened. 164 | 165 | 166 | 167 | ## Install 168 | 169 | For beginners or Anaconda users please see our [installation tips](https://github.com/SpikeInterface/spikeinterface/tree/main/installation_tips) 170 | where we provide a yaml for Mac/Windows/Linux to help properly install `spikeinterface` and `spikeinterface-gui` for you in a dedicated 171 | conda environment. 172 | 173 | Otherwise, 174 | 175 | You need first to install **one** of these 3 packages (by order of preference): 176 | * `pip install PySide6` or 177 | * `pip install PyQt6` or 178 | * `pip install PyQt5` 179 | 180 | From pypi: 181 | 182 | ```bash 183 | pip install spikeinterface-gui 184 | ``` 185 | 186 | For Desktop you can do: 187 | 188 | ```bash 189 | pip install spikeinterface-gui[desktop] 190 | ``` 191 | 192 | For web you can do: 193 | 194 | ```bash 195 | pip install spikeinterface-gui[web] 196 | ``` 197 | 198 | 199 | From source: 200 | 201 | ```bash 202 | git clone https://github.com/SpikeInterface/spikeinterface-gui.git 203 | cd spikeinterface-gui 204 | pip install . 205 | ``` 206 | 207 | ## Cedits 208 | 209 | Original author : Samuel Garcia, CNRS, Lyon, France 210 | 211 | This work is a port of the old `tridesclous.gui` submodule on top of 212 | [spikeinterface](https://github.com/SpikeInterface/spikeinterface). 213 | 214 | Main authors and maintainers: 215 | 216 | * qt side : Samuel Garcia, CNRS, Lyon, France 217 | * web side : Alessio Paolo Buccino, Allen Institute for Neural Dynamics, Seattle, USA did 218 | 219 | 220 | ## Message from dictator 221 | 222 | Contrary, to the spikeinterface package, for the developement of this viewer 223 | all good practices of coding are deliberately put aside : no test, no CI, no auto formating, no doc, ... 224 | Feel free to contribute, it is an open wild zone. Code anarchist are very welcome. 225 | So in this mess, persona non grata : pre-commit, black, pytest fixture, ... 226 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | ### Sam 4 | - [ ] remove compute 5 | - [x] QT settings not more dialog 6 | - [ ] remove custom docker 7 | - [x] simple layout description 8 | - [ ] unit list with depth 9 | 10 | 11 | ### Alessio 12 | - [x] handle similarity default compute 13 | - [x] spike list + spike selection 14 | - [x] curation view 15 | - [x] add settings / more columns to panel unit list 16 | - [x] unitlist: fix merge and delete with sorters 17 | - [ ] implement default color toggle button 18 | 19 | 20 | 21 | 22 | ### Panel Views 23 | - [ ] general 24 | - [x] fix settings cards 25 | - [ ] global settings (e.g. color options, layout) 26 | - [x] unit list 27 | - [x] more columns to panel unit list 28 | - [x] add settings to select colums 29 | - [x] add sparsity and channel id 30 | - [x] visible only when clicking on select column 31 | - [x] make x-axis scrollable 32 | - [ ] probe 33 | - [x] fix pan to move circles 34 | - [ ] add option to resize circles 35 | - [ ] spike list 36 | - [x] add unit color 37 | - [x] fix segment index 38 | - [x] fix spike selection 39 | - [x] curation 40 | - [x] merge 41 | - [x] waveform 42 | - [x] zoom on wheel 43 | - [x] flatten mode 44 | - [x] trace 45 | - [x] fix multi-segment selection 46 | - [x] zoom on wheel 47 | - [x] fix spike at init 48 | - [ ] spike amplitudes 49 | - [x] add selection 50 | - [ ] add option to scatter decimate 51 | - [ ] NDscatter 52 | - [x] fix limits 53 | - [ ] add selection 54 | 55 | ### Discussion 56 | * panel.param or Pydantic? 57 | * plotly / bokeh? 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "spikeinterface-gui" 3 | version = '0.10.0' 4 | authors = [ 5 | { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, 6 | ] 7 | 8 | description = "Qt GUI for spikeinterface" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | ] 16 | 17 | dependencies = [ 18 | "spikeinterface[full]>=0.102.0", 19 | "markdown" 20 | ] 21 | 22 | [project.urls] 23 | Homepage = "https://github.com/SpikeInterface/spikeinterface-gui" 24 | Repository = "https://github.com/SpikeInterface/spikeinterface-gui" 25 | 26 | 27 | [build-system] 28 | requires = ["setuptools>=62.0"] 29 | build-backend = "setuptools.build_meta" 30 | 31 | 32 | [tool.setuptools] 33 | packages = ["spikeinterface_gui"] 34 | package-dir = {"spikeinterface_gui" = "spikeinterface_gui"} 35 | 36 | [tool.setuptools.package-data] 37 | "spikeinterface_gui" = ["**/*.png"] 38 | 39 | 40 | [project.scripts] 41 | sigui = "spikeinterface_gui.main:run_mainwindow_cli" 42 | 43 | [project.optional-dependencies] 44 | 45 | desktop = [ 46 | "PySide6", 47 | "pyqtgraph", 48 | ] 49 | 50 | web = [ 51 | "panel", 52 | "bokeh", 53 | ] 54 | 55 | test = [ 56 | "pytest", 57 | "PySide6", 58 | ] 59 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeInterface/spikeinterface-gui/757f27a3ba63e5f06361b2928188c1209573d9e9/screenshot.png -------------------------------------------------------------------------------- /spikeinterface_gui/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some design notes: 3 | * controller is a layer between spikeinterface objects and every view 4 | * every view can notify some signals to other view that are centralized bu the controller 5 | * views have settings 6 | * views have 2 implementations : qt (legacy) and panel (for the web) 7 | They need to implement the make_layout and the refresh for each backends (qt, panel). 8 | They do not hinerits from qt or panel objects for contains qt or panel object (design by composition). 9 | Internally, methods related to qt starts with _qt_XXX. 10 | Internally, methods related to panel starts with _panel_XXX. 11 | """ 12 | 13 | from .version import version as __version__ 14 | 15 | from .main import run_mainwindow 16 | 17 | -------------------------------------------------------------------------------- /spikeinterface_gui/backend_panel.py: -------------------------------------------------------------------------------- 1 | import param 2 | import panel as pn 3 | 4 | 5 | from .viewlist import possible_class_views 6 | from .layout_presets import get_layout_description 7 | 8 | # Used by views to emit/trigger signals 9 | class SignalNotifier(param.Parameterized): 10 | spike_selection_changed = param.Event() 11 | unit_visibility_changed = param.Event() 12 | channel_visibility_changed = param.Event() 13 | manual_curation_updated = param.Event() 14 | time_info_updated = param.Event() 15 | active_view_updated = param.Event() 16 | unit_color_changed = param.Event() 17 | 18 | def __init__(self, view=None): 19 | param.Parameterized.__init__(self) 20 | self.view = view 21 | 22 | def notify_spike_selection_changed(self): 23 | self.param.trigger("spike_selection_changed") 24 | 25 | def notify_unit_visibility_changed(self): 26 | self.param.trigger("unit_visibility_changed") 27 | 28 | def notify_channel_visibility_changed(self): 29 | self.param.trigger("channel_visibility_changed") 30 | 31 | def notify_manual_curation_updated(self): 32 | self.param.trigger("manual_curation_updated") 33 | 34 | def notify_time_info_updated(self): 35 | self.param.trigger("time_info_updated") 36 | 37 | def notify_active_view_updated(self): 38 | # this is used to keep an "active view" in the main window 39 | # when a view triggers this event, it self-declares it as active 40 | # and the other windows will be set as non-active 41 | # this is used in panel to be able to use the same shortcuts in multiple 42 | # views 43 | self.param.trigger("active_view_updated") 44 | 45 | def notify_unit_color_changed(self): 46 | self.param.trigger("unit_color_changed") 47 | 48 | 49 | class SignalHandler(param.Parameterized): 50 | def __init__(self, controller, parent=None): 51 | param.Parameterized.__init__(self) 52 | self.controller = controller 53 | self._active = True 54 | 55 | def activate(self): 56 | self._active = True 57 | 58 | def deactivate(self): 59 | self._active = False 60 | 61 | def connect_view(self, view): 62 | view.notifier.param.watch(self.on_spike_selection_changed, "spike_selection_changed") 63 | view.notifier.param.watch(self.on_unit_visibility_changed, "unit_visibility_changed") 64 | view.notifier.param.watch(self.on_channel_visibility_changed, "channel_visibility_changed") 65 | view.notifier.param.watch(self.on_manual_curation_updated, "manual_curation_updated") 66 | view.notifier.param.watch(self.on_time_info_updated, "time_info_updated") 67 | view.notifier.param.watch(self.on_active_view_updated, "active_view_updated") 68 | view.notifier.param.watch(self.on_unit_color_changed, "unit_color_changed") 69 | 70 | def on_spike_selection_changed(self, param): 71 | if not self._active: 72 | return 73 | 74 | for view in self.controller.views: 75 | if param.obj.view == view: 76 | continue 77 | view.on_spike_selection_changed() 78 | 79 | def on_unit_visibility_changed(self, param): 80 | if not self._active: 81 | return 82 | for view in self.controller.views: 83 | if param.obj.view == view: 84 | continue 85 | view.on_unit_visibility_changed() 86 | 87 | def on_channel_visibility_changed(self, param): 88 | if not self._active: 89 | return 90 | for view in self.controller.views: 91 | if param.obj.view == view: 92 | continue 93 | view.on_channel_visibility_changed() 94 | 95 | def on_manual_curation_updated(self, param): 96 | if not self._active: 97 | return 98 | for view in self.controller.views: 99 | if param.obj.view == view: 100 | continue 101 | view.on_manual_curation_updated() 102 | 103 | def on_time_info_updated(self, param): 104 | # time info is updated also when a view is not active 105 | if not self._active: 106 | return 107 | for view in self.controller.views: 108 | if param.obj.view == view: 109 | continue 110 | view.on_time_info_updated() 111 | 112 | def on_active_view_updated(self, param): 113 | if not self._active: 114 | return 115 | for view in self.controller.views: 116 | if param.obj.view == view: 117 | view._panel_view_is_active = True 118 | else: 119 | view._panel_view_is_active = False 120 | 121 | def on_unit_color_changed(self, param): 122 | if not self._active: 123 | return 124 | for view in self.controller.views: 125 | if param.obj.view == view: 126 | continue 127 | view.on_unit_color_changed() 128 | 129 | param_type_map = { 130 | "float": param.Number, 131 | "int": param.Integer, 132 | "bool": param.Boolean, 133 | "list": param.ObjectSelector, 134 | } 135 | 136 | class SettingsProxy: 137 | # this make the setting dict like (to mimic pyqtgraph) 138 | # for instance self.settings['my_params'] instead of self.settings.my_params 139 | # self.settings['my_params'] = value instead of self.settings.my_params = value 140 | def __init__(self, myparametrized): 141 | self._parameterized = myparametrized 142 | 143 | def __getitem__(self, key): 144 | return getattr(self._parameterized, key) 145 | 146 | def __setitem__(self, key, value): 147 | self._parameterized.param.update(**{key:value}) 148 | 149 | def keys(self): 150 | return list(p for p in self._parameterized.param if p != "name") 151 | 152 | 153 | def create_dynamic_parameterized(settings): 154 | """ 155 | Create a dynamic parameterized class based on the settings provided. 156 | """ 157 | attributes = {} 158 | for setting_data in settings: 159 | if setting_data["type"] == "list": 160 | if "value" in setting_data: 161 | default = setting_data["value"] 162 | else: 163 | default = setting_data["limits"][0] 164 | attributes[setting_data["name"]] = param_type_map[setting_data["type"]]( 165 | objects=setting_data["limits"], doc=f"{setting_data['name']} parameter", default=default 166 | ) 167 | elif "value" in setting_data: 168 | attributes[setting_data["name"]] = param_type_map[setting_data["type"]]( 169 | setting_data["value"], doc=f"{setting_data['name']} parameter" 170 | ) 171 | MyParameterized = type("MyParameterized", (param.Parameterized,), attributes) 172 | return MyParameterized() 173 | 174 | 175 | def create_settings(view): 176 | # Create the class attributes dynamically 177 | settings = create_dynamic_parameterized(view._settings) 178 | 179 | view.settings = SettingsProxy(settings) 180 | 181 | def listen_setting_changes(view): 182 | for setting_data in view._settings: 183 | view.settings._parameterized.param.watch(view.on_settings_changed, setting_data["name"]) 184 | 185 | 186 | 187 | class PanelMainWindow: 188 | 189 | def __init__(self, controller, layout_preset=None): 190 | self.controller = controller 191 | self.layout_preset = layout_preset 192 | self.verbose = controller.verbose 193 | 194 | self.make_views() 195 | self.create_main_layout() 196 | 197 | # refresh all views wihtout notiying 198 | self.controller.signal_handler.deactivate() 199 | self.controller.signal_handler.activate() 200 | 201 | for view in self.views.values(): 202 | if view.is_view_visible(): 203 | view.refresh() 204 | 205 | def make_views(self): 206 | self.views = {} 207 | # this contains view layout + settings + compute 208 | self.view_layouts = {} 209 | for view_name, view_class in possible_class_views.items(): 210 | if 'panel' not in view_class._supported_backend: 211 | continue 212 | if not self.controller.check_is_view_possible(view_name): 213 | continue 214 | 215 | if view_name == 'curation' and not self.controller.curation: 216 | continue 217 | 218 | if view_name in ("trace", "tracemap") and not self.controller.with_traces: 219 | continue 220 | 221 | view = view_class(controller=self.controller, parent=None, backend='panel') 222 | self.views[view_name] = view 223 | 224 | info = pn.Column( 225 | pn.pane.Markdown(view_class._gui_help_txt), 226 | scroll=True, 227 | sizing_mode="stretch_both" 228 | ) 229 | 230 | tabs = [("📊", view.layout)] 231 | if view_class._settings is not None: 232 | settings = pn.Param(view.settings._parameterized, sizing_mode="stretch_height", 233 | name=f"{view_name.capitalize()} settings") 234 | if view_class._need_compute: 235 | compute_button = pn.widgets.Button(name="Compute", button_type="primary") 236 | compute_button.on_click(view.compute) 237 | settings = pn.Row(settings, compute_button) 238 | tabs.append(("⚙️", settings)) 239 | 240 | tabs.append(("ℹ️", info)) 241 | view_layout = pn.Tabs( 242 | *tabs, 243 | sizing_mode="stretch_both", 244 | dynamic=True, 245 | tabs_location="left", 246 | ) 247 | self.view_layouts[view_name] = view_layout 248 | 249 | 250 | def create_main_layout(self): 251 | 252 | pn.extension("gridstack") 253 | 254 | preset = get_layout_description(self.layout_preset) 255 | 256 | layout_zone = {} 257 | for zone, view_names in preset.items(): 258 | # keep only instanciated views 259 | view_names = [view_name for view_name in view_names if view_name in self.view_layouts.keys()] 260 | 261 | if len(view_names) == 0: 262 | layout_zone[zone] = None 263 | else: 264 | layout_zone[zone] = pn.Tabs( 265 | *((view_name, self.view_layouts[view_name]) for view_name in view_names if view_name in self.view_layouts), 266 | sizing_mode="stretch_both", 267 | dynamic=True, 268 | tabs_location="below", 269 | ) 270 | # Function to update visibility 271 | tabs = layout_zone[zone] 272 | tabs.param.watch(self.update_visibility, "active") 273 | # Simulate an event 274 | self.update_visibility(param.parameterized.Event( 275 | cls=None, what="value", type="changed", old=0, new=0, obj=tabs, name="active", 276 | )) 277 | 278 | # Create GridStack layout with resizable regions 279 | grid_per_zone = 2 280 | gs = pn.GridStack( 281 | sizing_mode='stretch_both', 282 | allow_resize=False, 283 | allow_drag=False, 284 | ) 285 | 286 | # Top modifications 287 | for zone in ['zone1', 'zone2', 'zone3', 'zone4']: 288 | view = layout_zone[zone] 289 | row_slice = slice(0, 2 * grid_per_zone) # First two rows 290 | 291 | if zone == 'zone1': 292 | if layout_zone.get('zone2') is None or len(layout_zone['zone2']) == 0: 293 | col_slice = slice(0, 2 * grid_per_zone) # Full width when merged 294 | else: 295 | col_slice = slice(0, grid_per_zone) # Half width when not merged 296 | elif zone == 'zone2': 297 | if layout_zone.get('zone1') is None or len(layout_zone['zone1']) == 0: 298 | col_slice = slice(0, 2 * grid_per_zone) # Full width when merged 299 | else: 300 | col_slice = slice(grid_per_zone, 2 * grid_per_zone) # Right half 301 | elif zone == 'zone3': 302 | if layout_zone.get('zone4') is None or len(layout_zone['zone4']) == 0: 303 | col_slice = slice(2 * grid_per_zone, 4 * grid_per_zone) # Full width when merged 304 | else: 305 | col_slice = slice(2 * grid_per_zone, 3 * grid_per_zone) # Left half 306 | elif zone == 'zone4': 307 | if layout_zone.get('zone3') is None or len(layout_zone['zone3']) == 0: 308 | col_slice = slice(2 * grid_per_zone, 4 * grid_per_zone) # Full width when merged 309 | else: 310 | col_slice = slice(3 * grid_per_zone, 4 * grid_per_zone) # Right half 311 | 312 | if view is not None and len(view) > 0: 313 | # Note: order of slices swapped to [row, col] 314 | gs[row_slice, col_slice] = view 315 | 316 | # Bottom 317 | for zone in ['zone5', 'zone6', 'zone7', 'zone8']: 318 | view = layout_zone[zone] 319 | row_slice = slice(2 * grid_per_zone, 4 * grid_per_zone) 320 | 321 | if zone == 'zone5': 322 | if layout_zone.get('zone6') is None or len(layout_zone['zone6']) == 0: 323 | col_slice = slice(0, 2 * grid_per_zone) 324 | else: 325 | col_slice = slice(0, grid_per_zone) 326 | elif zone == 'zone6': 327 | if layout_zone.get('zone5') is None or len(layout_zone['zone5']) == 0: 328 | col_slice = slice(0, 2 * grid_per_zone) 329 | else: 330 | col_slice = slice(grid_per_zone, 2 * grid_per_zone) 331 | elif zone == 'zone7': 332 | if layout_zone.get('zone8') is None or len(layout_zone['zone8']) == 0: 333 | col_slice = slice(2 * grid_per_zone, 4 * grid_per_zone) 334 | else: 335 | col_slice = slice(2 * grid_per_zone, 3 * grid_per_zone) 336 | elif zone == 'zone8': 337 | if layout_zone.get('zone7') is None or len(layout_zone['zone7']) == 0: 338 | col_slice = slice(2 * grid_per_zone, 4 * grid_per_zone) 339 | else: 340 | col_slice = slice(3 * grid_per_zone, 4 * grid_per_zone) 341 | 342 | if view is not None and len(view) > 0: 343 | gs[row_slice, col_slice] = view 344 | 345 | self.main_layout = gs 346 | 347 | def update_visibility(self, event): 348 | active = event.new 349 | tab_names = event.obj._names 350 | objects = event.obj.objects 351 | for i, (view_name, content) in enumerate(zip(tab_names, objects)): 352 | visible = (i == active) 353 | view = self.views[view_name] 354 | view._panel_view_is_visible = visible 355 | if visible: 356 | # Refresh the view if it is visible 357 | view.refresh() 358 | # we also set the current view as the panel active 359 | view.notify_active_view_updated() 360 | 361 | 362 | 363 | def start_server(mainwindow, address="localhost", port=0): 364 | 365 | pn.config.sizing_mode = "stretch_width" 366 | 367 | # mainwindow.main_layout.servable() 368 | # TODO alessio : find automatically a port when port = 0 369 | 370 | if address != "localhost": 371 | websocket_origin = f"{address}:{port}" 372 | else: 373 | websocket_origin = None 374 | 375 | server = pn.serve({"/": mainwindow.main_layout}, address=address, port=port, 376 | show=False, start=True, dev=True, autoreload=True,websocket_origin=websocket_origin, 377 | title="SpikeInterface GUI") 378 | -------------------------------------------------------------------------------- /spikeinterface_gui/backend_qt.py: -------------------------------------------------------------------------------- 1 | from .myqt import QT 2 | import pyqtgraph as pg 3 | import markdown 4 | 5 | 6 | import weakref 7 | 8 | from .viewlist import possible_class_views 9 | from .layout_presets import get_layout_description 10 | 11 | from .utils_qt import qt_style, add_stretch_to_qtoolbar 12 | 13 | 14 | # Used by views to emit/trigger signals 15 | class SignalNotifier(QT.QObject): 16 | spike_selection_changed = QT.pyqtSignal() 17 | unit_visibility_changed = QT.pyqtSignal() 18 | channel_visibility_changed = QT.pyqtSignal() 19 | manual_curation_updated = QT.pyqtSignal() 20 | time_info_updated = QT.pyqtSignal() 21 | unit_color_changed = QT.pyqtSignal() 22 | 23 | def __init__(self, parent=None, view=None): 24 | QT.QObject.__init__(self, parent=parent) 25 | self.view = view 26 | 27 | def notify_spike_selection_changed(self): 28 | self.spike_selection_changed.emit() 29 | 30 | def notify_unit_visibility_changed(self): 31 | self.unit_visibility_changed.emit() 32 | 33 | def notify_channel_visibility_changed(self): 34 | self.channel_visibility_changed.emit() 35 | 36 | def notify_manual_curation_updated(self): 37 | self.manual_curation_updated.emit() 38 | 39 | def notify_time_info_updated(self): 40 | self.time_info_updated.emit() 41 | 42 | def notify_unit_color_changed(self): 43 | self.unit_color_changed.emit() 44 | 45 | 46 | # Used by controler to handle/callback signals 47 | class SignalHandler(QT.QObject): 48 | def __init__(self, controller, parent=None): 49 | QT.QObject.__init__(self, parent=parent) 50 | self.controller = controller 51 | self._active = True 52 | 53 | def activate(self): 54 | self._active = True 55 | 56 | def deactivate(self): 57 | self._active = False 58 | 59 | def connect_view(self, view): 60 | view.notifier.spike_selection_changed.connect(self.on_spike_selection_changed) 61 | view.notifier.unit_visibility_changed.connect(self.on_unit_visibility_changed) 62 | view.notifier.channel_visibility_changed.connect(self.on_channel_visibility_changed) 63 | view.notifier.manual_curation_updated.connect(self.on_manual_curation_updated) 64 | view.notifier.time_info_updated.connect(self.on_time_info_updated) 65 | view.notifier.unit_color_changed.connect(self.on_unit_color_changed) 66 | 67 | def on_spike_selection_changed(self): 68 | if not self._active: 69 | return 70 | for view in self.controller.views: 71 | if view.qt_widget == self.sender().parent(): 72 | # do not refresh it self 73 | continue 74 | view.on_spike_selection_changed() 75 | 76 | def on_unit_visibility_changed(self): 77 | 78 | if not self._active: 79 | return 80 | for view in self.controller.views: 81 | if view.qt_widget == self.sender().parent(): 82 | # do not refresh it self 83 | continue 84 | view.on_unit_visibility_changed() 85 | 86 | def on_channel_visibility_changed(self): 87 | if not self._active: 88 | return 89 | for view in self.controller.views: 90 | if view.qt_widget == self.sender().parent(): 91 | # do not refresh it self 92 | continue 93 | view.on_channel_visibility_changed() 94 | 95 | def on_manual_curation_updated(self): 96 | if not self._active: 97 | return 98 | for view in self.controller.views: 99 | if view.qt_widget == self.sender().parent(): 100 | # do not refresh it self 101 | continue 102 | view.on_manual_curation_updated() 103 | 104 | def on_time_info_updated(self): 105 | if not self._active: 106 | return 107 | for view in self.controller.views: 108 | if view.qt_widget == self.sender().parent(): 109 | # do not refresh it self 110 | continue 111 | view.on_time_info_updated() 112 | 113 | def on_unit_color_changed(self): 114 | if not self._active: 115 | return 116 | for view in self.controller.views: 117 | if view.qt_widget == self.sender().parent(): 118 | # do not refresh it self 119 | continue 120 | view.on_unit_color_changed() 121 | 122 | 123 | def create_settings(view, parent): 124 | view.settings = pg.parametertree.Parameter.create(name="settings", type='group', children=view._settings) 125 | 126 | # not that the parent is not the view (not Qt anymore) itself but the widget 127 | view.tree_settings = pg.parametertree.ParameterTree(parent=parent) 128 | view.tree_settings.header().hide() 129 | view.tree_settings.setParameters(view.settings, showTop=True) 130 | view.tree_settings.setWindowTitle(u'View options') 131 | # view.tree_settings.setWindowFlags(QT.Qt.Window) 132 | 133 | def listen_setting_changes(view): 134 | view.settings.sigTreeStateChanged.connect(view.on_settings_changed) 135 | 136 | 137 | class QtMainWindow(QT.QMainWindow): 138 | def __init__(self, controller, parent=None, layout_preset=None): 139 | QT.QMainWindow.__init__(self, parent) 140 | 141 | self.controller = controller 142 | self.verbose = controller.verbose 143 | self.layout_preset = layout_preset 144 | 145 | self.make_views() 146 | self.create_main_layout() 147 | 148 | # refresh all views wihtout notiying 149 | self.controller.signal_handler.deactivate() 150 | for view in self.views.values(): 151 | # refresh do not work because view are not yet visible at init 152 | view._refresh() 153 | self.controller.signal_handler.activate() 154 | 155 | # TODO sam : all veiws are always refreshed at the moment so this is useless. 156 | # uncommen this when ViewBase.is_view_visible() work correctly 157 | # for view_name, dock in self.docks.items(): 158 | # dock.visibilityChanged.connect(self.views[view_name].refresh) 159 | 160 | def make_views(self): 161 | self.views = {} 162 | self.docks = {} 163 | for view_name, view_class in possible_class_views.items(): 164 | if 'qt' not in view_class._supported_backend: 165 | continue 166 | if not self.controller.check_is_view_possible(view_name): 167 | continue 168 | 169 | if view_name == 'curation' and not self.controller.curation: 170 | continue 171 | 172 | if view_name in ("trace", "tracemap") and not self.controller.with_traces: 173 | continue 174 | 175 | widget = ViewWidget(view_class) 176 | view = view_class(controller=self.controller, parent=widget, backend='qt') 177 | widget.set_view(view) 178 | dock = QT.QDockWidget(view_name) 179 | dock.setWidget(widget) 180 | # dock.visibilityChanged.connect(view.refresh) 181 | 182 | self.views[view_name] = view 183 | self.docks[view_name] = dock 184 | 185 | 186 | def create_main_layout(self): 187 | import warnings 188 | 189 | warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") 190 | 191 | self.setDockNestingEnabled(True) 192 | 193 | preset = get_layout_description(self.layout_preset) 194 | 195 | widgets_zone = {} 196 | for zone, view_names in preset.items(): 197 | # keep only instantiated views 198 | view_names = [view_name for view_name in view_names if view_name in self.views.keys()] 199 | widgets_zone[zone] = view_names 200 | 201 | ## Handle left 202 | first_left = None 203 | for zone in ['zone1', 'zone5', 'zone2', 'zone6']: 204 | if len(widgets_zone[zone]) == 0: 205 | continue 206 | view_name = widgets_zone[zone][0] 207 | dock = self.docks[view_name] 208 | if len(widgets_zone[zone]) > 0 and first_left is None: 209 | self.addDockWidget(areas['left'], dock) 210 | first_left = view_name 211 | elif zone == 'zone5': 212 | self.splitDockWidget(self.docks[first_left], dock, orientations['vertical']) 213 | elif zone == 'zone2': 214 | self.splitDockWidget(self.docks[first_left], dock, orientations['horizontal']) 215 | elif zone == 'zone6': 216 | if len(widgets_zone['zone5']) > 0: 217 | z = widgets_zone['zone5'][0] 218 | self.splitDockWidget(self.docks[z], dock, orientations['horizontal']) 219 | else: 220 | self.splitDockWidget(self.docks[first_left], dock, orientations['vertical']) 221 | 222 | ## Handle right 223 | first_left = None 224 | for zone in ['zone3', 'zone7', 'zone4', 'zone8']: 225 | if len(widgets_zone[zone]) == 0: 226 | continue 227 | view_name = widgets_zone[zone][0] 228 | dock = self.docks[view_name] 229 | if len(widgets_zone[zone]) > 0 and first_left is None: 230 | self.addDockWidget(areas['right'], dock) 231 | first_left = view_name 232 | elif zone == 'zone7': 233 | self.splitDockWidget(self.docks[first_left], dock, orientations['vertical']) 234 | elif zone == 'zone4': 235 | self.splitDockWidget(self.docks[first_left], dock, orientations['horizontal']) 236 | elif zone == 'zone8': 237 | if len(widgets_zone['zone7']) > 0: 238 | z = widgets_zone['zone7'][0] 239 | self.splitDockWidget(self.docks[z], dock, orientations['horizontal']) 240 | else: 241 | self.splitDockWidget(self.docks[first_left], dock, orientations['vertical']) 242 | 243 | # make tabs 244 | for zone, view_names in widgets_zone.items(): 245 | n = len(widgets_zone[zone]) 246 | if n < 2: 247 | # no tab here 248 | continue 249 | view_name0 = widgets_zone[zone][0] 250 | for i in range(1, n): 251 | view_name = widgets_zone[zone][i] 252 | dock = self.docks[view_name] 253 | self.tabifyDockWidget(self.docks[view_name0], dock) 254 | # make visible the first of each zone 255 | self.docks[view_name0].raise_() 256 | 257 | 258 | class ViewWidget(QT.QWidget): 259 | def __init__(self, view_class, parent=None): 260 | QT.QWidget.__init__(self, parent=parent) 261 | 262 | self.layout = QT.QVBoxLayout() 263 | self.setLayout(self.layout) 264 | self.layout.setContentsMargins(4,4,4,4) 265 | self.layout.setSpacing(4) 266 | 267 | tb = self.view_toolbar = QT.QToolBar() 268 | self.layout.addWidget(self.view_toolbar) 269 | 270 | tb.setStyleSheet(qt_style) 271 | 272 | if view_class._settings is not None: 273 | but = QT.QPushButton('⚙ settings') 274 | tb.addWidget(but) 275 | but.clicked.connect(self.open_settings) 276 | # but.setStyleSheet(qt_style) 277 | 278 | if view_class._need_compute: 279 | but = QT.QPushButton('compute') 280 | tb.addWidget(but) 281 | but.clicked.connect(self.compute) 282 | 283 | but = QT.QPushButton('↻ refresh') 284 | tb.addWidget(but) 285 | but.clicked.connect(self.refresh) 286 | 287 | but = QT.QPushButton('?') 288 | tb.addWidget(but) 289 | but.clicked.connect(self.open_help) 290 | tooltip_html = markdown.markdown(view_class._gui_help_txt) 291 | but.setToolTip(tooltip_html) 292 | 293 | add_stretch_to_qtoolbar(tb) 294 | 295 | # TODO: make _qt method for all existing methods that don't start with _qt or _panel 296 | # skip = ['__init__', 'set_view', 'open_settings', 'compute', 'refresh', 'open_help', 297 | # 'on_spike_selection_changed', 'on_unit_visibility_changed', 298 | # 'on_channel_visibility_changed', 'on_manual_curation_updated'] 299 | # for name in dir(view_class): 300 | # if name.startswith('_qt_') or name.startswith('_panel_') or name in skip: 301 | # continue 302 | # if hasattr(view_class, name): 303 | # method = getattr(view_class, name) 304 | # if callable(method): 305 | # if name == "save_in_analyzer": 306 | # print(f'creating _qt_save_in_analyzer for {view_class}') 307 | # setattr(view_class, '_qt_' + name, method) 308 | 309 | 310 | def set_view(self, view): 311 | self._view = weakref.ref(view) 312 | if view._settings is not None: 313 | self.layout.addWidget(view.tree_settings) 314 | view.tree_settings.hide() 315 | 316 | self.layout.addLayout(view.layout) 317 | 318 | def open_settings(self): 319 | view = self._view() 320 | if not view.tree_settings.isVisible(): 321 | view.tree_settings.show() 322 | else: 323 | view.tree_settings.hide() 324 | 325 | def compute(self): 326 | view = self._view() 327 | if view._need_compute: 328 | view.compute() 329 | 330 | def open_help(self): 331 | view = self._view() 332 | but = self.sender() 333 | txt = view._gui_help_txt 334 | txt = markdown.markdown(txt) 335 | QT.QToolTip.showText(but.mapToGlobal(QT.QPoint()), txt, but) 336 | 337 | def refresh(self): 338 | view = self._view() 339 | view.refresh() 340 | 341 | 342 | areas = { 343 | 'right' : QT.Qt.RightDockWidgetArea, 344 | 'left' : QT.Qt.LeftDockWidgetArea, 345 | } 346 | 347 | orientations = { 348 | 'horizontal' : QT.Qt.Horizontal, 349 | 'vertical' : QT.Qt.Vertical, 350 | } 351 | -------------------------------------------------------------------------------- /spikeinterface_gui/basescatterview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib.path import Path as mpl_path 3 | 4 | from .view_base import ViewBase 5 | 6 | 7 | class BaseScatterView(ViewBase): 8 | _supported_backend = ['qt', 'panel'] 9 | _depend_on = None 10 | _settings = [ 11 | {'name': 'auto_decimate', 'type': 'bool', 'value' : True }, 12 | {'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 10_000 }, 13 | {'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05 }, 14 | {'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5 }, 15 | {'name': 'num_bins', 'type': 'int', 'value' : 400, 'step': 1 }, 16 | ] 17 | _need_compute = False 18 | 19 | def __init__(self, spike_data, y_label, controller=None, parent=None, backend="qt"): 20 | 21 | # compute data bounds 22 | assert len(spike_data) == len(controller.spikes), "spike_data must have the same length as spikes" 23 | assert spike_data.ndim == 1, "spike_data must be 1D" 24 | self.spike_data = spike_data 25 | self.y_label = y_label 26 | 27 | self._data_min = np.min(spike_data) 28 | self._data_max = np.max(spike_data) 29 | eps = (self._data_max - self._data_min) / 100.0 30 | self._data_max += eps 31 | self._max_count = None 32 | 33 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 34 | 35 | 36 | def get_unit_data(self, unit_id, seg_index=0): 37 | inds = self.controller.get_spike_indices(unit_id, seg_index=seg_index) 38 | spike_times = self.controller.spikes["sample_index"][inds] / self.controller.sampling_frequency 39 | spike_data = self.spike_data[inds] 40 | ptp = np.ptp(spike_data) 41 | hist_min, hist_max = [np.min(spike_data) - 0.2 * ptp, np.max(spike_data) + 0.2 * ptp] 42 | 43 | hist_count, hist_bins = np.histogram(spike_data, bins=np.linspace(hist_min, hist_max, self.settings['num_bins'])) 44 | 45 | if self.settings['auto_decimate'] and spike_times.size > self.settings['max_spikes_per_unit']: 46 | step = spike_times.size // self.settings['max_spikes_per_unit'] 47 | spike_times = spike_times[::step] 48 | spike_data = spike_data[::step] 49 | inds = inds[::step] 50 | 51 | return spike_times, spike_data, hist_count, hist_bins, inds 52 | 53 | def get_selected_spikes_data(self, seg_index=0): 54 | sl = self.controller.segment_slices[seg_index] 55 | spikes_in_seg = self.controller.spikes[sl] 56 | selected_indices = self.controller.get_indices_spike_selected() 57 | mask = np.isin(sl.start + np.arange(len(spikes_in_seg)), selected_indices) 58 | selected_spikes = spikes_in_seg[mask] 59 | spike_times = selected_spikes['sample_index'] / self.controller.sampling_frequency 60 | spike_data = self.spike_data[sl][mask] 61 | return (spike_times, spike_data) 62 | 63 | 64 | ## QT zone ## 65 | def _qt_make_layout(self): 66 | from .myqt import QT 67 | import pyqtgraph as pg 68 | from .utils_qt import add_stretch_to_qtoolbar 69 | 70 | self.layout = QT.QVBoxLayout() 71 | 72 | tb = self.qt_widget.view_toolbar 73 | self.combo_seg = QT.QComboBox() 74 | tb.addWidget(self.combo_seg) 75 | self.combo_seg.addItems([ f'Segment {seg_index}' for seg_index in range(self.controller.num_segments) ]) 76 | self.combo_seg.currentIndexChanged.connect(self.refresh) 77 | add_stretch_to_qtoolbar(tb) 78 | self.lasso_but = QT.QPushButton("select", checkable = True) 79 | 80 | tb.addWidget(self.lasso_but) 81 | self.lasso_but.clicked.connect(self.enable_disable_lasso) 82 | 83 | h = QT.QHBoxLayout() 84 | self.layout.addLayout(h) 85 | 86 | self.graphicsview = pg.GraphicsView() 87 | h.addWidget(self.graphicsview, 3) 88 | 89 | self.graphicsview2 = pg.GraphicsView() 90 | h.addWidget(self.graphicsview2, 1) 91 | 92 | self.initialize_plot() 93 | 94 | # Add lasso curve 95 | self.lasso = pg.PlotCurveItem(pen='#7FFF00') 96 | self.plot.addItem(self.lasso) 97 | 98 | # Add selection scatter 99 | brush = QT.QColor('white') 100 | brush.setAlpha(200) 101 | self.scatter_select = pg.ScatterPlotItem(pen=pg.mkPen(None), brush=brush, size=11, pxMode=True) 102 | self.plot.addItem(self.scatter_select) 103 | self.scatter_select.setZValue(1000) 104 | 105 | 106 | 107 | def initialize_plot(self): 108 | import pyqtgraph as pg 109 | from .utils_qt import ViewBoxHandlingLasso 110 | 111 | self.viewBox = ViewBoxHandlingLasso() 112 | self.viewBox.lasso_drawing.connect(self.on_lasso_drawing) 113 | self.viewBox.lasso_finished.connect(self.on_lasso_finished) 114 | self.viewBox.disableAutoRange() 115 | self.plot = pg.PlotItem(viewBox=self.viewBox) 116 | self.graphicsview.setCentralItem(self.plot) 117 | self.plot.hideButtons() 118 | 119 | self.viewBox2 = ViewBoxHandlingLasso() 120 | self.viewBox2.disableAutoRange() 121 | self.plot2 = pg.PlotItem(viewBox=self.viewBox2) 122 | self.graphicsview2.setCentralItem(self.plot2) 123 | self.plot2.hideButtons() 124 | self.plot2.setYLink(self.plot) 125 | 126 | 127 | self.scatter = pg.ScatterPlotItem(size=self.settings['scatter_size'], pxMode = True) 128 | self.plot.addItem(self.scatter) 129 | 130 | self._text_items = [] 131 | 132 | self.plot.setYRange(self._data_min,self._data_max, padding = 0.0) 133 | 134 | def _qt_on_spike_selection_changed(self): 135 | self.refresh() 136 | 137 | def _qt_refresh(self): 138 | from .myqt import QT 139 | import pyqtgraph as pg 140 | 141 | self.scatter.clear() 142 | self.plot2.clear() 143 | self.scatter_select.clear() 144 | 145 | if self.spike_data is None: 146 | return 147 | 148 | max_count = 1 149 | for unit_id in self.controller.get_visible_unit_ids(): 150 | 151 | spike_times, spike_data, hist_count, hist_bins, _ = self.get_unit_data(unit_id) 152 | 153 | # make a copy of the color 154 | color = QT.QColor(self.get_unit_color(unit_id)) 155 | color.setAlpha(int(self.settings['alpha']*255)) 156 | self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color) 157 | 158 | color = self.get_unit_color(unit_id) 159 | curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color) 160 | self.plot2.addItem(curve) 161 | 162 | max_count = max(max_count, np.max(hist_count)) 163 | 164 | self._max_count = max_count 165 | seg_index = self.combo_seg.currentIndex() 166 | time_max = self.controller.get_num_samples(seg_index) / self.controller.sampling_frequency 167 | 168 | self.plot.setXRange( 0., time_max, padding = 0.0) 169 | self.plot2.setXRange(0, self._max_count, padding = 0.0) 170 | 171 | spike_times, spike_data = self.get_selected_spikes_data() 172 | self.scatter_select.setData(spike_times, spike_data) 173 | 174 | def enable_disable_lasso(self, checked): 175 | if checked and len(self.controller.get_visible_unit_ids()) == 1: 176 | self.viewBox.lasso_active = checked 177 | else: 178 | self.viewBox.lasso_active = False 179 | self.lasso_but.setChecked(False) 180 | self.scatter_select.clear() 181 | 182 | def on_lasso_drawing(self, points): 183 | points = np.array(points) 184 | self.lasso.setData(points[:, 0], points[:, 1]) 185 | 186 | def on_lasso_finished(self, points): 187 | self.lasso.setData([], []) 188 | vertices = np.array(points) 189 | 190 | seg_index = self.combo_seg.currentIndex() 191 | sl = self.controller.segment_slices[seg_index] 192 | spikes_in_seg = self.controller.spikes[sl] 193 | fs = self.controller.sampling_frequency 194 | 195 | # Create mask for visible units 196 | visible_mask = np.zeros(len(spikes_in_seg), dtype=bool) 197 | for unit_index, unit_id in self.controller.iter_visible_units(): 198 | visible_mask |= (spikes_in_seg['unit_index'] == unit_index) 199 | 200 | # Only consider spikes from visible units 201 | visible_spikes = spikes_in_seg[visible_mask] 202 | if len(visible_spikes) == 0: 203 | # Clear selection if no visible spikes 204 | self.controller.set_indices_spike_selected([]) 205 | self.refresh() 206 | self.notify_spike_selection_changed() 207 | return 208 | 209 | spike_times = visible_spikes['sample_index'] / fs 210 | spike_data = self.spike_data[sl][visible_mask] 211 | 212 | points = np.column_stack((spike_times, spike_data)) 213 | inside = mpl_path(vertices).contains_points(points) 214 | 215 | # Clear selection if no spikes inside lasso 216 | if not np.any(inside): 217 | self.controller.set_indices_spike_selected([]) 218 | self.refresh() 219 | self.notify_spike_selection_changed() 220 | return 221 | 222 | # Map back to original indices 223 | visible_indices = np.nonzero(visible_mask)[0] 224 | selected_indices = sl.start + visible_indices[inside] 225 | self.controller.set_indices_spike_selected(selected_indices) 226 | self.refresh() 227 | self.notify_spike_selection_changed() 228 | 229 | 230 | 231 | ## Panel zone ## 232 | def _panel_make_layout(self): 233 | import panel as pn 234 | import bokeh.plotting as bpl 235 | from bokeh.models import ColumnDataSource, LassoSelectTool, Range1d 236 | from .utils_panel import _bg_color, slow_lasso 237 | 238 | self.lasso_tool = LassoSelectTool() 239 | 240 | self.segment_index = 0 241 | self.segment_selector = pn.widgets.Select( 242 | name="", 243 | options=[f"Segment {i}" for i in range(self.controller.num_segments)], 244 | value=f"Segment {self.segment_index}", 245 | ) 246 | self.segment_selector.param.watch(self._panel_change_segment, 'value') 247 | 248 | self.select_toggle_button = pn.widgets.Toggle(name="Select") 249 | self.select_toggle_button.param.watch(self._panel_on_select_button, 'value') 250 | 251 | self.y_range = Range1d(self._data_min, self._data_max) 252 | self.scatter_source = ColumnDataSource(data={"x": [], "y": [], "color": []}) 253 | self.scatter_fig = bpl.figure( 254 | sizing_mode="stretch_both", 255 | tools="reset,wheel_zoom", 256 | active_scroll="wheel_zoom", 257 | background_fill_color=_bg_color, 258 | border_fill_color=_bg_color, 259 | outline_line_color="white", 260 | y_range=self.y_range, 261 | styles={"flex": "1"} 262 | ) 263 | self.scatter = self.scatter_fig.scatter( 264 | "x", 265 | "y", 266 | source=self.scatter_source, 267 | size=self.settings['scatter_size'], 268 | color="color", 269 | fill_alpha=self.settings['alpha'], 270 | ) 271 | self.scatter_fig.toolbar.logo = None 272 | self.scatter_fig.add_tools(self.lasso_tool) 273 | self.scatter_fig.toolbar.active_drag = None 274 | self.scatter_fig.xaxis.axis_label = "Time (s)" 275 | self.scatter_fig.yaxis.axis_label = self.y_label 276 | time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency 277 | self.scatter_fig.x_range = Range1d(0., time_max) 278 | 279 | slow_lasso(self.scatter_source, self._on_panel_lasso_selected) 280 | 281 | self.hist_fig = bpl.figure( 282 | tools="reset,wheel_zoom", 283 | sizing_mode="stretch_both", 284 | background_fill_color=_bg_color, 285 | border_fill_color=_bg_color, 286 | outline_line_color="white", 287 | y_range=self.y_range, 288 | styles={"flex": "1"} # Make histogram narrower than scatter plot 289 | ) 290 | self.hist_fig.toolbar.logo = None 291 | self.hist_fig.yaxis.axis_label = self.y_label 292 | self.hist_fig.xaxis.axis_label = "Count" 293 | self.hist_fig.x_range = Range1d(0, 1000) # Initial x range for histogram 294 | 295 | self.layout = pn.Column( 296 | pn.Row(self.segment_selector, self.select_toggle_button, sizing_mode="stretch_width"), 297 | pn.Row( 298 | pn.Column( 299 | self.scatter_fig, 300 | styles={"flex": "1"}, 301 | sizing_mode="stretch_both" 302 | ), 303 | pn.Column( 304 | self.hist_fig, 305 | styles={"flex": "0.3"}, 306 | sizing_mode="stretch_both" 307 | ), 308 | ) 309 | ) 310 | self.hist_lines = [] 311 | self.noise_harea = [] 312 | self.plotted_inds = [] 313 | 314 | def _panel_refresh(self): 315 | from bokeh.models import ColumnDataSource, Range1d 316 | 317 | # clear figures 318 | for renderer in self.hist_lines: 319 | self.hist_fig.renderers.remove(renderer) 320 | self.hist_lines = [] 321 | self.plotted_inds = [] 322 | 323 | max_count = 1 324 | xs = [] 325 | ys = [] 326 | colors = [] 327 | 328 | visible_unit_ids = self.controller.get_visible_unit_ids() 329 | for unit_id in visible_unit_ids: 330 | spike_times, spike_data, hist_count, hist_bins, inds = self.get_unit_data( 331 | unit_id, 332 | seg_index=self.segment_index 333 | ) 334 | color = self.get_unit_color(unit_id) 335 | xs.extend(spike_times) 336 | ys.extend(spike_data) 337 | colors.extend([color] * len(spike_times)) 338 | max_count = max(max_count, np.max(hist_count)) 339 | self.plotted_inds.extend(inds) 340 | 341 | hist_lines = self.hist_fig.line( 342 | "x", 343 | "y", 344 | source=ColumnDataSource( 345 | {"x":hist_count, 346 | "y":hist_bins[:-1], 347 | } 348 | ), 349 | line_color=color, 350 | line_width=2, 351 | ) 352 | self.hist_lines.append(hist_lines) 353 | 354 | self._max_count = max_count 355 | 356 | # Add scatter plot with correct alpha parameter 357 | self.scatter_source.data = { 358 | "x": xs, 359 | "y": ys, 360 | "color": colors 361 | } 362 | self.scatter.glyph.size = self.settings['scatter_size'] 363 | self.scatter.glyph.fill_alpha = self.settings['alpha'] 364 | 365 | # handle selected spikes 366 | self._panel_update_selected_spikes() 367 | 368 | # set y range to min and max of visible spike amplitudes plus a margin 369 | margin = 50 370 | all_amps = ys 371 | if len(all_amps) > 0: 372 | self.y_range.start = np.min(all_amps) - margin 373 | self.y_range.end = np.max(all_amps) + margin 374 | self.hist_fig.x_range.end = max_count 375 | 376 | def _panel_on_select_button(self, event): 377 | if self.select_toggle_button.value and len(self.controller.get_visible_unit_ids()) == 1: 378 | self.scatter_fig.toolbar.active_drag = self.lasso_tool 379 | else: 380 | self.scatter_fig.toolbar.active_drag = None 381 | self.scatter_source.selected.indices = [] 382 | self._on_panel_lasso_selected(None, None, None) 383 | 384 | def _panel_change_segment(self, event): 385 | self.segment_index = int(self.segment_selector.value.split()[-1]) 386 | time_max = self.controller.get_num_samples(self.segment_index) / self.controller.sampling_frequency 387 | self.scatter_fig.x_range.end = time_max 388 | self.refresh() 389 | 390 | def _on_panel_lasso_selected(self, attr, old, new): 391 | """ 392 | Handle selection changes in the scatter plot. 393 | """ 394 | if self.select_toggle_button.value: 395 | selected = self.scatter_source.selected.indices 396 | if len(selected) == 0: 397 | self.controller.set_indices_spike_selected([]) 398 | self.notify_spike_selection_changed() 399 | return 400 | 401 | # Map back to original indices 402 | sl = self.controller.segment_slices[self.segment_index] 403 | spikes_in_seg = self.controller.spikes[sl] 404 | # Create mask for visible units 405 | visible_mask = np.zeros(len(spikes_in_seg), dtype=bool) 406 | for unit_index, unit_id in self.controller.iter_visible_units(): 407 | visible_mask |= (spikes_in_seg['unit_index'] == unit_index) 408 | 409 | # Map back to original indices 410 | visible_indices = np.nonzero(visible_mask)[0] 411 | selected_indices = sl.start + visible_indices[selected] 412 | self.controller.set_indices_spike_selected(selected_indices) 413 | self.notify_spike_selection_changed() 414 | 415 | 416 | def _panel_update_selected_spikes(self): 417 | # handle selected spikes 418 | selected_spike_indices = self.controller.get_indices_spike_selected() 419 | if len(selected_spike_indices) > 0: 420 | # map absolute indices to visible spikes 421 | sl = self.controller.segment_slices[self.segment_index] 422 | spikes_in_seg = self.controller.spikes[sl] 423 | visible_mask = np.zeros(len(spikes_in_seg), dtype=bool) 424 | for unit_index, unit_id in self.controller.iter_visible_units(): 425 | visible_mask |= (spikes_in_seg['unit_index'] == unit_index) 426 | visible_indices = sl.start + np.nonzero(visible_mask)[0] 427 | selected_indices = np.nonzero(np.isin(visible_indices, selected_spike_indices))[0] 428 | # set selected spikes in scatter plot 429 | if self.settings["auto_decimate"]: 430 | selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices)) 431 | self.scatter_source.selected.indices = list(selected_indices) 432 | else: 433 | self.scatter_source.selected.indices = [] 434 | 435 | 436 | def _panel_on_spike_selection_changed(self): 437 | # set selection in scatter plot 438 | selected_indices = self.controller.get_indices_spike_selected() 439 | if len(selected_indices) == 0: 440 | self.scatter_source.selected.indices = [] 441 | return 442 | elif len(selected_indices) == 1: 443 | selected_segment = self.controller.spikes[selected_indices[0]]['segment_index'] 444 | if selected_segment != self.segment_index: 445 | self.segment_selector.value = f"Segment {selected_segment}" 446 | self._panel_change_segment(None) 447 | # update selected spikes 448 | self._panel_update_selected_spikes() 449 | 450 | -------------------------------------------------------------------------------- /spikeinterface_gui/crosscorrelogramview.py: -------------------------------------------------------------------------------- 1 | from .view_base import ViewBase 2 | 3 | 4 | 5 | class CrossCorrelogramView(ViewBase): 6 | _supported_backend = ['qt', 'panel'] 7 | _depend_on = ["correlograms"] 8 | _settings = [ 9 | {'name': 'window_ms', 'type': 'float', 'value' : 50. }, 10 | {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 }, 11 | {'name': 'display_axis', 'type': 'bool', 'value' : True }, 12 | {'name': 'max_visible', 'type': 'int', 'value' : 8 }, 13 | ] 14 | _need_compute = True 15 | 16 | def __init__(self, controller=None, parent=None, backend="qt"): 17 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 18 | 19 | self.ccg, self.bins = self.controller.get_correlograms() 20 | 21 | 22 | def _on_settings_changed(self): 23 | self.ccg = None 24 | self.refresh() 25 | 26 | def _compute(self): 27 | self.ccg, self.bins = self.controller.compute_correlograms( 28 | self.settings['window_ms'], self.settings['bin_ms']) 29 | 30 | ## Qt ## 31 | 32 | def _qt_make_layout(self): 33 | from .myqt import QT 34 | import pyqtgraph as pg 35 | 36 | self.layout = QT.QVBoxLayout() 37 | 38 | h = QT.QHBoxLayout() 39 | self.layout.addLayout(h) 40 | 41 | self.grid = pg.GraphicsLayoutWidget() 42 | self.layout.addWidget(self.grid) 43 | 44 | 45 | def _qt_refresh(self): 46 | import pyqtgraph as pg 47 | 48 | self.grid.clear() 49 | 50 | if self.ccg is None: 51 | return 52 | 53 | visible_unit_ids = self.controller.get_visible_unit_ids() 54 | visible_unit_ids = visible_unit_ids[:self.settings['max_visible']] 55 | 56 | n = len(visible_unit_ids) 57 | 58 | unit_ids = list(self.controller.unit_ids) 59 | 60 | for r in range(n): 61 | for c in range(r, n): 62 | 63 | i = unit_ids.index(visible_unit_ids[r]) 64 | j = unit_ids.index(visible_unit_ids[c]) 65 | count = self.ccg[i, j, :] 66 | 67 | plot = pg.PlotItem() 68 | if not self.settings['display_axis']: 69 | plot.hideAxis('bottom') 70 | plot.hideAxis('left') 71 | 72 | if r==c: 73 | unit_id = visible_unit_ids[r] 74 | color = self.get_unit_color(unit_id) 75 | else: 76 | color = (120,120,120,120) 77 | 78 | curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color) 79 | plot.addItem(curve) 80 | self.grid.addItem(plot, row=r, col=c) 81 | 82 | ## panel ## 83 | 84 | def _panel_make_layout(self): 85 | import panel as pn 86 | import bokeh.plotting as bpl 87 | from .utils_panel import _bg_color 88 | 89 | empty_fig = bpl.figure( 90 | sizing_mode="stretch_both", 91 | background_fill_color=_bg_color, 92 | border_fill_color=_bg_color, 93 | outline_line_color="white", 94 | ) 95 | self.empty_plot_pane = pn.pane.Bokeh(empty_fig, sizing_mode="stretch_both") 96 | 97 | self.layout = pn.Column( 98 | self.empty_plot_pane, 99 | sizing_mode="stretch_both", 100 | ) 101 | self.is_warning_active = False 102 | 103 | self.plots = [] 104 | 105 | def _panel_refresh(self): 106 | import panel as pn 107 | import bokeh.plotting as bpl 108 | from bokeh.layouts import gridplot 109 | from .utils_panel import _bg_color, insert_warning, clear_warning 110 | 111 | # clear previous plot 112 | self.plots = [] 113 | 114 | if self.ccg is None: 115 | return 116 | 117 | visible_unit_ids = self.controller.get_visible_unit_ids() 118 | 119 | # Show warning above the plot if too many visible units 120 | if len(visible_unit_ids) > self.settings['max_visible']: 121 | warning_msg = f"Only showing first {self.settings['max_visible']} units out of {len(visible_unit_ids)} visible units" 122 | insert_warning(self, warning_msg) 123 | self.is_warning_active = True 124 | return 125 | if self.is_warning_active: 126 | clear_warning(self) 127 | self.is_warning_active = False 128 | 129 | visible_unit_ids = visible_unit_ids[:self.settings['max_visible']] 130 | 131 | n = len(visible_unit_ids) 132 | unit_ids = list(self.controller.unit_ids) 133 | for r in range(n): 134 | row_plots = [] 135 | for c in range(r, n): 136 | 137 | i = unit_ids.index(visible_unit_ids[r]) 138 | j = unit_ids.index(visible_unit_ids[c]) 139 | count = self.ccg[i, j, :] 140 | 141 | # Create Bokeh figure 142 | p = bpl.figure( 143 | width=250, 144 | height=250, 145 | tools="pan,wheel_zoom,reset", 146 | active_drag="pan", 147 | active_scroll="wheel_zoom", 148 | background_fill_color=_bg_color, 149 | border_fill_color=_bg_color, 150 | outline_line_color="white", 151 | ) 152 | p.toolbar.logo = None 153 | 154 | # Get color from controller 155 | if r == c: 156 | unit_id = visible_unit_ids[r] 157 | color = self.get_unit_color(unit_id) 158 | fill_alpha = 0.7 159 | else: 160 | color = "lightgray" 161 | fill_alpha = 0.4 162 | 163 | p.quad( 164 | top=count, 165 | bottom=0, 166 | left=self.bins[:-1], 167 | right=self.bins[1:], 168 | fill_color=color, 169 | line_color=color, 170 | alpha=fill_alpha, 171 | ) 172 | 173 | row_plots.append(p) 174 | # Fill row with None for proper spacing 175 | full_row = [None] * r + row_plots + [None] * (n - len(row_plots)) 176 | self.plots.append(full_row) 177 | 178 | if len(self.plots) > 0: 179 | grid = gridplot(self.plots, toolbar_location="right", sizing_mode="stretch_both") 180 | self.layout[0] = pn.Column( 181 | grid, 182 | styles={'background-color': f'{_bg_color}'} 183 | ) 184 | else: 185 | self.layout[0] = self.empty_plot_pane 186 | 187 | 188 | 189 | CrossCorrelogramView._gui_help_txt = """ 190 | ## Correlograms View 191 | 192 | This view shows the auto-correlograms and cross-correlograms of the selected units. 193 | """ 194 | -------------------------------------------------------------------------------- /spikeinterface_gui/curation_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | default_label_definitions = { 5 | "quality": { 6 | "label_options": ["good", "noise", "MUA"], 7 | "exclusive": True 8 | }, 9 | } 10 | 11 | 12 | empty_curation_data = { 13 | "manual_labels": [], 14 | "merge_unit_groups": [], 15 | "removed_units": [] 16 | } 17 | 18 | def adding_group(previous_groups, new_group): 19 | # this is to ensure that np.str_ types are rendered as str 20 | to_merge = [np.array(new_group).tolist()] 21 | unchanged = [] 22 | for c_prev in previous_groups: 23 | is_unaffected = True 24 | 25 | for c_new in new_group: 26 | if c_new in c_prev: 27 | is_unaffected = False 28 | to_merge.append(c_prev) 29 | break 30 | 31 | if is_unaffected: 32 | unchanged.append(c_prev) 33 | new_merge_group = [sum(to_merge, [])] 34 | new_merge_group.extend(unchanged) 35 | # Ensure the unicity 36 | new_merge_group = [list(set(gp)) for gp in new_merge_group] 37 | return new_merge_group 38 | -------------------------------------------------------------------------------- /spikeinterface_gui/curationview.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from .view_base import ViewBase 5 | 6 | from spikeinterface.core.core_tools import check_json 7 | 8 | 9 | 10 | 11 | class CurationView(ViewBase): 12 | _supported_backend = ['qt', 'panel'] 13 | _need_compute = False 14 | 15 | def __init__(self, controller=None, parent=None, backend="qt"): 16 | self.active_table = "merge" 17 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 18 | 19 | # TODO: Cast unit ids to the correct type here 20 | def restore_units(self): 21 | if self.backend == 'qt': 22 | unit_ids = self._qt_get_delete_table_selection() 23 | else: 24 | unit_ids = self._panel_get_delete_table_selection() 25 | if unit_ids is not None: 26 | unit_ids = [self.controller.unit_ids.dtype.type(unit_id) for unit_id in unit_ids] 27 | self.controller.make_manual_restore(unit_ids) 28 | self.notify_manual_curation_updated() 29 | self.refresh() 30 | 31 | def unmerge_groups(self): 32 | if self.backend == 'qt': 33 | merge_indices = self._qt_get_merge_table_row() 34 | else: 35 | merge_indices = self._panel_get_merge_table_row() 36 | if merge_indices is not None: 37 | self.controller.make_manual_restore_merge(merge_indices) 38 | self.notify_manual_curation_updated() 39 | self.refresh() 40 | 41 | ## Qt 42 | def _qt_make_layout(self): 43 | from .myqt import QT 44 | import pyqtgraph as pg 45 | 46 | 47 | self.merge_info = {} 48 | self.layout = QT.QVBoxLayout() 49 | 50 | 51 | tb = self.qt_widget.view_toolbar 52 | if self.controller.curation_can_be_saved(): 53 | but = QT.QPushButton("Save in analyzer") 54 | tb.addWidget(but) 55 | but.clicked.connect(self.save_in_analyzer) 56 | but = QT.QPushButton("Export JSON") 57 | but.clicked.connect(self._qt_export_json) 58 | tb.addWidget(but) 59 | 60 | h = QT.QHBoxLayout() 61 | self.layout.addLayout(h) 62 | 63 | v = QT.QVBoxLayout() 64 | h.addLayout(v) 65 | v.addWidget(QT.QLabel("Merges")) 66 | self.table_merge = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, 67 | selectionBehavior=QT.QAbstractItemView.SelectRows) 68 | # self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu) 69 | v.addWidget(self.table_merge) 70 | 71 | self.table_merge.setContextMenuPolicy(QT.Qt.CustomContextMenu) 72 | self.table_merge.customContextMenuRequested.connect(self._qt_open_context_menu_merge) 73 | self.table_merge.itemSelectionChanged.connect(self._qt_on_item_selection_changed_merge) 74 | 75 | self.merge_menu = QT.QMenu() 76 | act = self.merge_menu.addAction('Remove merge group') 77 | act.triggered.connect(self.unmerge_groups) 78 | shortcut_unmerge = QT.QShortcut(self.qt_widget) 79 | shortcut_unmerge.setKey(QT.QKeySequence("ctrl+u")) 80 | shortcut_unmerge.activated.connect(self.unmerge_groups) 81 | 82 | 83 | v = QT.QVBoxLayout() 84 | h.addLayout(v) 85 | v.addWidget(QT.QLabel("Deleted")) 86 | self.table_delete = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, 87 | selectionBehavior=QT.QAbstractItemView.SelectRows) 88 | v.addWidget(self.table_delete) 89 | self.table_delete.setContextMenuPolicy(QT.Qt.CustomContextMenu) 90 | self.table_delete.customContextMenuRequested.connect(self._qt_open_context_menu_delete) 91 | self.table_delete.itemSelectionChanged.connect(self._qt_on_item_selection_changed_delete) 92 | 93 | 94 | self.delete_menu = QT.QMenu() 95 | act = self.delete_menu.addAction('Restore') 96 | act.triggered.connect(self.restore_units) 97 | shortcut_restore = QT.QShortcut(self.qt_widget) 98 | shortcut_restore.setKey(QT.QKeySequence("ctrl+r")) 99 | shortcut_restore.activated.connect(self.restore_units) 100 | 101 | def _qt_refresh(self): 102 | from .myqt import QT 103 | # Merged 104 | merged_units = self.controller.curation_data["merge_unit_groups"] 105 | self.table_merge.clear() 106 | self.table_merge.setRowCount(len(merged_units)) 107 | self.table_merge.setColumnCount(1) 108 | self.table_merge.setHorizontalHeaderLabels(["Merged groups"]) 109 | self.table_merge.setSortingEnabled(False) 110 | for ix, group in enumerate(merged_units): 111 | item = QT.QTableWidgetItem(str(group)) 112 | item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) 113 | self.table_merge.setItem(ix, 0, item) 114 | for i in range(self.table_merge.columnCount()): 115 | self.table_merge.resizeColumnToContents(i) 116 | 117 | ## deleted 118 | removed_units = self.controller.curation_data["removed_units"] 119 | self.table_delete.clear() 120 | self.table_delete.setRowCount(len(removed_units)) 121 | self.table_delete.setColumnCount(1) 122 | self.table_delete.setHorizontalHeaderLabels(["unit_id"]) 123 | self.table_delete.setSortingEnabled(False) 124 | for i, unit_id in enumerate(removed_units): 125 | color = self.get_unit_color(unit_id) 126 | pix = QT.QPixmap(16,16) 127 | pix.fill(color) 128 | icon = QT.QIcon(pix) 129 | item = QT.QTableWidgetItem( f'{unit_id}') 130 | item.setFlags(QT.Qt.ItemIsEnabled|QT.Qt.ItemIsSelectable) 131 | self.table_delete.setItem(i,0, item) 132 | item.setIcon(icon) 133 | item.unit_id = unit_id 134 | self.table_delete.resizeColumnToContents(0) 135 | 136 | 137 | 138 | def _qt_get_delete_table_selection(self): 139 | selected_items = self.table_delete.selectedItems() 140 | if len(selected_items) == 0: 141 | return None 142 | else: 143 | return [s.unit_id for s in selected_items] 144 | 145 | def _qt_get_merge_table_row(self): 146 | selected_items = self.table_merge.selectedItems() 147 | if len(selected_items) == 0: 148 | return None 149 | else: 150 | return [s.row() for s in selected_items] 151 | 152 | def _qt_open_context_menu_delete(self): 153 | self.delete_menu.popup(self.qt_widget.cursor().pos()) 154 | 155 | def _qt_open_context_menu_merge(self): 156 | self.merge_menu.popup(self.qt_widget.cursor().pos()) 157 | 158 | def _qt_on_item_selection_changed_merge(self): 159 | if len(self.table_merge.selectedIndexes()) == 0: 160 | return 161 | 162 | dtype = self.controller.unit_ids.dtype 163 | ind = self.table_merge.selectedIndexes()[0].row() 164 | visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind] 165 | visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids] 166 | self.controller.set_visible_unit_ids(visible_unit_ids) 167 | self.notify_unit_visibility_changed() 168 | 169 | def _qt_on_item_selection_changed_delete(self): 170 | if len(self.table_delete.selectedIndexes()) == 0: 171 | return 172 | ind = self.table_delete.selectedIndexes()[0].row() 173 | unit_id = self.controller.curation_data["removed_units"][ind] 174 | self.controller.set_all_unit_visibility_off() 175 | # convert to the correct type 176 | unit_id = self.controller.unit_ids.dtype.type(unit_id) 177 | self.controller.set_visible_unit_ids([unit_id]) 178 | self.notify_unit_visibility_changed() 179 | 180 | def _qt_on_restore_shortcut(self): 181 | sel_rows = self._qt_get_selected_rows() 182 | self._qt_delete_unit() 183 | if len(sel_rows) > 0: 184 | self.table.clearSelection() 185 | self.table.setCurrentCell(min(sel_rows[-1] + 1, self.table.rowCount() - 1), 0) 186 | 187 | 188 | def on_manual_curation_updated(self): 189 | self.refresh() 190 | 191 | def on_unit_visibility_changed(self): 192 | pass 193 | 194 | def save_in_analyzer(self): 195 | self.controller.save_curation_in_analyzer() 196 | 197 | def _qt_export_json(self): 198 | from .myqt import QT 199 | fd = QT.QFileDialog(fileMode=QT.QFileDialog.AnyFile, acceptMode=QT.QFileDialog.AcceptSave) 200 | fd.setNameFilters(['JSON (*.json);']) 201 | fd.setDefaultSuffix('json') 202 | fd.setViewMode(QT.QFileDialog.Detail) 203 | if fd.exec_(): 204 | json_file = Path(fd.selectedFiles()[0]) 205 | with json_file.open("w") as f: 206 | curation_dict = check_json(self.controller.construct_final_curation()) 207 | json.dump(curation_dict, f, indent=4) 208 | 209 | # PANEL 210 | def _panel_make_layout(self): 211 | import pandas as pd 212 | import panel as pn 213 | 214 | from .utils_panel import KeyboardShortcut, KeyboardShortcuts, SelectableTabulator 215 | 216 | pn.extension("tabulator") 217 | 218 | # Create dataframe 219 | merge_df = pd.DataFrame({"merge_groups": []}) 220 | delete_df = pd.DataFrame({"deleted_unit_id": []}) 221 | 222 | # Create tables 223 | self.table_merge = SelectableTabulator( 224 | merge_df, 225 | show_index=False, 226 | disabled=True, 227 | sortable=False, 228 | formatters={"merge_groups": "plaintext"}, 229 | sizing_mode="stretch_width", 230 | # SelectableTabulator functions 231 | parent_view=self, 232 | # refresh_table_function=self.refresh, 233 | conditional_shortcut=self._conditional_refresh_merge, 234 | column_callbacks={"merge_groups": self._panel_on_merged_col}, 235 | ) 236 | self.table_delete = SelectableTabulator( 237 | delete_df, 238 | show_index=False, 239 | disabled=True, 240 | sortable=False, 241 | formatters={"deleted_unit_id": "plaintext"}, 242 | sizing_mode="stretch_width", 243 | # SelectableTabulator functions 244 | parent_view=self, 245 | # refresh_table_function=self.refresh, 246 | conditional_shortcut=self._conditional_refresh_delete, 247 | column_callbacks={"deleted_unit_id": self._panel_on_deleted_col}, 248 | ) 249 | 250 | self.table_delete.param.watch(self._panel_update_unit_visibility, "selection") 251 | self.table_merge.param.watch(self._panel_update_unit_visibility, "selection") 252 | 253 | # Create buttons 254 | save_button = pn.widgets.Button( 255 | name="Save in analyzer", 256 | button_type="primary", 257 | height=30 258 | ) 259 | save_button.on_click(self._panel_save_in_analyzer) 260 | 261 | download_button = pn.widgets.FileDownload( 262 | button_type="primary", 263 | filename="curation.json", 264 | callback=self._panel_generate_json, 265 | height=30 266 | ) 267 | 268 | restore_button = pn.widgets.Button( 269 | name="Restore", 270 | button_type="primary", 271 | height=30 272 | ) 273 | restore_button.on_click(self._panel_restore_units) 274 | 275 | remove_merge_button = pn.widgets.Button( 276 | name="Unmerge", 277 | button_type="primary", 278 | height=30 279 | ) 280 | remove_merge_button.on_click(self._panel_unmerge_groups) 281 | 282 | submit_button = pn.widgets.Button( 283 | name="Submit to parent", 284 | button_type="primary", 285 | height=30 286 | ) 287 | 288 | # Create layout 289 | buttons_save = pn.Row( 290 | save_button, 291 | download_button, 292 | submit_button, 293 | sizing_mode="stretch_width", 294 | ) 295 | save_sections = pn.Column( 296 | buttons_save, 297 | sizing_mode="stretch_width", 298 | ) 299 | buttons_curate = pn.Row( 300 | restore_button, 301 | remove_merge_button, 302 | sizing_mode="stretch_width", 303 | ) 304 | 305 | # shortcuts 306 | shortcuts = [ 307 | KeyboardShortcut(name="restore", key="r", ctrlKey=True), 308 | KeyboardShortcut(name="unmerge", key="u", ctrlKey=True), 309 | ] 310 | shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) 311 | shortcuts_component.on_msg(self._panel_handle_shortcut) 312 | 313 | # Create main layout with proper sizing 314 | sections = pn.Row(self.table_merge, self.table_delete, sizing_mode="stretch_width") 315 | self.layout = pn.Column( 316 | save_sections, 317 | buttons_curate, 318 | sections, 319 | shortcuts_component, 320 | scroll=True, 321 | sizing_mode="stretch_both" 322 | ) 323 | 324 | # Add a custom JavaScript callback to the button that doesn't interact with Bokeh models 325 | submit_button.on_click(self._panel_submit_to_parent) 326 | 327 | # Add a hidden div to store the data 328 | self.data_div = pn.pane.HTML("", width=0, height=0, margin=0, sizing_mode="fixed") 329 | self.layout.append(self.data_div) 330 | 331 | 332 | def _panel_refresh(self): 333 | import pandas as pd 334 | # Merged 335 | merged_units = self.controller.curation_data["merge_unit_groups"] 336 | 337 | # for visualization, we make all row entries strings 338 | merged_units_str = [] 339 | for group in merged_units: 340 | # convert to string 341 | group = [str(unit_id) for unit_id in group] 342 | merged_units_str.append(" - ".join(group)) 343 | df = pd.DataFrame({"merge_groups": merged_units_str}) 344 | self.table_merge.value = df 345 | self.table_merge.selection = [] 346 | 347 | ## deleted 348 | removed_units = self.controller.curation_data["removed_units"] 349 | removed_units = [str(unit_id) for unit_id in removed_units] 350 | df = pd.DataFrame({"deleted_unit_id": removed_units}) 351 | self.table_delete.value = df 352 | self.table_delete.selection = [] 353 | 354 | def _panel_update_unit_visibility(self, event): 355 | unit_dtype = self.controller.unit_ids.dtype 356 | if self.active_table == "delete": 357 | visible_unit_ids = self.table_delete.value["deleted_unit_id"].values[self.table_delete.selection].tolist() 358 | visible_unit_ids = [unit_dtype.type(unit_id) for unit_id in visible_unit_ids] 359 | self.controller.set_visible_unit_ids(visible_unit_ids) 360 | elif self.active_table == "merge": 361 | merge_groups = self.table_merge.value["merge_groups"].values[self.table_merge.selection].tolist() 362 | # self.controller.set_all_unit_visibility_off() 363 | visible_unit_ids = [] 364 | for merge_group in merge_groups: 365 | merge_unit_ids = [unit_dtype.type(unit_id) for unit_id in merge_group.split(" - ")] 366 | visible_unit_ids.extend(merge_unit_ids) 367 | self.controller.set_visible_unit_ids(visible_unit_ids) 368 | self.notify_unit_visibility_changed() 369 | 370 | def _panel_restore_units(self, event): 371 | self.restore_units() 372 | 373 | def _panel_unmerge_groups(self, event): 374 | self.unmerge_groups() 375 | 376 | def _panel_save_in_analyzer(self, event): 377 | self.save_in_analyzer() 378 | 379 | def _panel_generate_json(self): 380 | # Get the path from the text input 381 | export_path = "curation.json" 382 | # Save the JSON file 383 | curation_dict = check_json(self.controller.construct_final_curation()) 384 | 385 | with open(export_path, "w") as f: 386 | json.dump(curation_dict, f, indent=4) 387 | 388 | return export_path 389 | 390 | def _panel_get_delete_table_selection(self): 391 | selected_items = self.table_delete.selection 392 | if len(selected_items) == 0: 393 | return None 394 | else: 395 | return self.table_delete.value["deleted_unit_id"].values[selected_items].tolist() 396 | 397 | def _panel_get_merge_table_row(self): 398 | selected_items = self.table_merge.selection 399 | if len(selected_items) == 0: 400 | return None 401 | else: 402 | return selected_items 403 | 404 | def _panel_handle_shortcut(self, event): 405 | if event.data == "restore": 406 | self.restore_units() 407 | elif event.data == "unmerge": 408 | self.unmerge_groups() 409 | 410 | def _panel_submit_to_parent(self, event): 411 | """Send the curation data to the parent window""" 412 | # Get the curation data and convert it to a JSON string 413 | curation_data = json.dumps(check_json(self.controller.construct_final_curation())) 414 | 415 | # Create a JavaScript snippet that will send the data to the parent window 416 | js_code = f""" 417 | 436 | """ 437 | 438 | # Update the hidden div with the JavaScript code 439 | self.data_div.object = js_code 440 | 441 | def _panel_on_deleted_col(self, row): 442 | self.active_table = "delete" 443 | self.table_merge.selection = [] 444 | 445 | def _panel_on_merged_col(self, row): 446 | self.active_table = "merge" 447 | self.table_delete.selection = [] 448 | 449 | def _conditional_refresh_merge(self): 450 | # Check if the view is active before refreshing 451 | if self.is_view_active() and self.active_table == "merge": 452 | return True 453 | else: 454 | return False 455 | 456 | def _conditional_refresh_delete(self): 457 | # Check if the view is active before refreshing 458 | if self.is_view_active() and self.active_table == "delete": 459 | return True 460 | else: 461 | return False 462 | 463 | 464 | CurationView._gui_help_txt = """ 465 | ## Curation View 466 | 467 | The curation view shows the current status of the curation process and allows the user to manually visualize, 468 | revert, and export the curation data. 469 | 470 | ### Controls 471 | - **save in analyzer**: Save the current curation state in the analyzer. 472 | - **export/download JSON**: Export the current curation state to a JSON file. 473 | - **restore**: Restore the selected unit from the deleted units table. 474 | - **unmerge**: Unmerge the selected merge group from the merged units table. 475 | - **submit to parent**: Submit the current curation state to the parent window (for use in web applications). 476 | - **press 'ctrl+r'**: Restore the selected units from the deleted units table. 477 | - **press 'ctrl+u'**: Unmerge the selected merge groups from the merged units table. 478 | """ 479 | -------------------------------------------------------------------------------- /spikeinterface_gui/img/si.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeInterface/spikeinterface-gui/757f27a3ba63e5f06361b2928188c1209573d9e9/spikeinterface_gui/img/si.png -------------------------------------------------------------------------------- /spikeinterface_gui/isiview.py: -------------------------------------------------------------------------------- 1 | from .view_base import ViewBase 2 | 3 | 4 | class ISIView(ViewBase): 5 | _supported_backend = ['qt', 'panel'] 6 | _settings = [ 7 | {'name': 'window_ms', 'type': 'float', 'value' : 50. }, 8 | {'name': 'bin_ms', 'type': 'float', 'value' : 1.0 }, 9 | ] 10 | _depend_on = ["isi_histograms"] 11 | _need_compute = True 12 | 13 | 14 | def __init__(self, controller=None, parent=None, backend="qt"): 15 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 16 | self.isi_histograms, self.isi_bins = self.controller.get_isi_histograms() 17 | 18 | def compute(self): 19 | self.isi_histograms, self.isi_bins = self.controller.compute_isi_histograms( 20 | self.settings['window_ms'], self.settings['bin_ms']) 21 | self.refresh() 22 | 23 | def _on_settings_changed(self): 24 | self.isi_histograms, self.isi_bins = None, None 25 | self.refresh() 26 | 27 | ## QT ## 28 | 29 | def _qt_make_layout(self): 30 | import pyqtgraph as pg 31 | from .myqt import QT 32 | from .utils_qt import ViewBoxHandlingDoubleClick 33 | 34 | self.layout = QT.QVBoxLayout() 35 | 36 | self.graphicsview = pg.GraphicsView() 37 | self.layout.addWidget(self.graphicsview) 38 | 39 | self.viewBox = ViewBoxHandlingDoubleClick() 40 | 41 | self.plot = pg.PlotItem(viewBox=self.viewBox) 42 | self.graphicsview.setCentralItem(self.plot) 43 | self.plot.hideButtons() 44 | 45 | def _qt_refresh(self): 46 | import pyqtgraph as pg 47 | 48 | self.plot.clear() 49 | if self.isi_histograms is None: 50 | return 51 | 52 | n = 0 53 | for unit_index, unit_id in self.controller.iter_visible_units(): 54 | 55 | isi = self.isi_histograms[unit_index, :] 56 | 57 | qcolor = self.get_unit_color(unit_id) 58 | curve = pg.PlotCurveItem(self.isi_bins[:-1], isi, pen=pg.mkPen(qcolor, width=3)) 59 | self.plot.addItem(curve) 60 | 61 | ## Panel ## 62 | 63 | def _panel_make_layout(self): 64 | import panel as pn 65 | import bokeh.plotting as bpl 66 | from bokeh.models import Range1d 67 | from .utils_panel import _bg_color 68 | 69 | # Create Bokeh figure 70 | self.figure = bpl.figure( 71 | sizing_mode="stretch_both", 72 | tools="pan,box_zoom,wheel_zoom,reset", 73 | x_axis_label="Time (ms)", 74 | y_axis_label="Count", 75 | background_fill_color=_bg_color, 76 | border_fill_color=_bg_color, 77 | outline_line_color="white", 78 | styles={"flex": "1"} 79 | ) 80 | self.figure.toolbar.logo = None 81 | # Update plot ranges 82 | self.figure.x_range = Range1d(0, self.settings['window_ms']) 83 | self.figure.y_range = Range1d(0, 100) 84 | 85 | self.layout = pn.Column( 86 | self.figure, 87 | styles={"flex": "1"}, 88 | sizing_mode="stretch_both" 89 | ) 90 | 91 | def _panel_refresh(self): 92 | from bokeh.models import ColumnDataSource 93 | 94 | # this clear the figure 95 | self.figure.renderers = [] 96 | self.lines = {} 97 | 98 | y_max = 0 99 | for unit_index, unit_id in self.controller.iter_visible_units(): 100 | isi = self.isi_histograms[unit_index, :] 101 | source = ColumnDataSource({"x": self.isi_bins[:-1].tolist(), "y": isi.tolist()}) 102 | color = self.get_unit_color(unit_id) 103 | self.lines[unit_id] = self.figure.line( 104 | "x", 105 | "y", 106 | source=source, 107 | line_color=color, 108 | line_width=2, 109 | visible=True, 110 | ) 111 | y_max = max(y_max, isi.max()) 112 | self.figure.y_range.end = y_max * 1.1 113 | 114 | ISIView._gui_help_txt = """ 115 | ## ISI View 116 | 117 | This view shows the inter spike interval histograms for each unit. 118 | """ -------------------------------------------------------------------------------- /spikeinterface_gui/layout_presets.py: -------------------------------------------------------------------------------- 1 | """ 2 | A preset need 8 zones like this: 3 | 4 | +-----------------+-----------------+ 5 | | [zone1 zone2] | [zone3 zone4] | 6 | +-----------------+-----------------+ 7 | | [zone5 zone6] | [zone7 zone8] | 8 | +-----------------+-----------------+ 9 | 10 | """ 11 | 12 | _presets = {} 13 | def get_layout_description(preset_name): 14 | if preset_name is None: 15 | preset_name = 'default' 16 | _presets[preset_name] 17 | return _presets[preset_name] 18 | 19 | 20 | default_layout = dict( 21 | zone1=['curation', 'spikelist'], 22 | zone2=['unitlist', 'mergelist'], 23 | zone3=['trace', 'tracemap', 'spikeamplitude', 'spikedepth'], 24 | zone4=[], 25 | zone5=['probe'], 26 | zone6=['ndscatter', 'similarity'], 27 | zone7=['waveform', 'waveformheatmap', ], 28 | zone8=['correlogram', 'isi', 'mainsettings'], 29 | ) 30 | _presets['default'] = default_layout 31 | 32 | 33 | # legacy layout for nostalgic people like me 34 | legacy_layout = dict( 35 | zone1=['curation', 'spikelist'], 36 | zone2=['unitlist', 'mergelist'], 37 | zone3=['trace', 'tracemap', 'waveform', 'waveformheatmap', 'isi', 'correlogram', 'spikeamplitude'], 38 | zone4=[], 39 | zone5=['probe'], 40 | zone6=['ndscatter', 'similarity'], 41 | zone7=[], 42 | zone8=[], 43 | ) 44 | _presets['legacy'] = legacy_layout 45 | 46 | # yep is for testing 47 | yep_layout = dict( 48 | zone1=['curation', 'spikelist'], 49 | zone2=['unitlist', 'mergelist'], 50 | zone3=['trace', 'tracemap', 'spikeamplitude'], 51 | zone4=['similarity'], 52 | zone5=['probe'], 53 | zone6=['ndscatter', ], 54 | zone7=['waveform', 'waveformheatmap', ], 55 | zone8=['correlogram', 'isi'], 56 | ) 57 | _presets['yep'] = yep_layout 58 | -------------------------------------------------------------------------------- /spikeinterface_gui/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from pathlib import Path 4 | import numpy as np 5 | import warnings 6 | 7 | from spikeinterface import load_sorting_analyzer, load 8 | from spikeinterface.core.core_tools import is_path_remote 9 | 10 | # this force the loding of spikeinterface sub module 11 | import spikeinterface.postprocessing 12 | import spikeinterface.qualitymetrics 13 | 14 | from spikeinterface_gui.controller import Controller 15 | 16 | 17 | def run_mainwindow( 18 | analyzer, 19 | mode="desktop", 20 | with_traces=True, 21 | curation=False, 22 | curation_dict=None, 23 | label_definitions=None, 24 | displayed_unit_properties=None, 25 | extra_unit_properties=None, 26 | skip_extensions=None, 27 | recording=None, 28 | start_app=True, 29 | layout_preset=None, 30 | address="localhost", 31 | port=0, 32 | verbose=False, 33 | ): 34 | """ 35 | Create the main window and start the QT app loop. 36 | 37 | Parameters 38 | ---------- 39 | analyzer: SortingAnalyzer 40 | The sorting analyzer object 41 | mode: 'desktop' | 'web' 42 | The GUI mode to use. 43 | 'desktop' will run a Qt app. 44 | 'web' will run a Panel app. 45 | with_traces: bool, default: True 46 | If True, traces are displayed 47 | curation: bool, default: False 48 | If True, the curation panel is displayed 49 | curation_dict: dict | None, default: None 50 | The curation dictionary to start from an existing curation 51 | label_definitions: dict | None, default: None 52 | The label definitions to provide to the curation panel 53 | displayed_unit_properties: list | None, default: None 54 | The displayed unit properties in the unit table 55 | extra_unit_properties: list | None, default: None 56 | The extra unit properties in the unit table 57 | skip_extensions: list | None, default: None 58 | The list of extensions to skip when loading the sorting analyzer 59 | recording: RecordingExtractor | None, default: None 60 | The recording object to display traces. This can be used when the 61 | SortingAnalyzer is recordingless. 62 | start_qt_app: bool, default: True 63 | If True, the QT app loop is started 64 | layout_preset : str | None 65 | The name of the layout preset. None is default. 66 | address: str, default : "localhost" 67 | For "web" mode only. By default only on local machine. 68 | port: int, default: 0 69 | For "web" mode only. If 0 then the port is automatic. 70 | verbose: bool, default: False 71 | If True, print some information in the console 72 | """ 73 | 74 | if mode == "desktop": 75 | backend = "qt" 76 | elif mode == "web": 77 | backend = "panel" 78 | else: 79 | raise ValueError(f"spikeinterface-gui wrong mode {mode}") 80 | 81 | 82 | if recording is not None: 83 | analyzer.set_temporary_recording(recording) 84 | 85 | if verbose: 86 | import time 87 | t0 = time.perf_counter() 88 | controller = Controller( 89 | analyzer, backend=backend, verbose=verbose, 90 | curation=curation, curation_data=curation_dict, 91 | label_definitions=label_definitions, 92 | with_traces=with_traces, 93 | displayed_unit_properties=displayed_unit_properties, 94 | extra_unit_properties=extra_unit_properties, 95 | skip_extensions=skip_extensions, 96 | ) 97 | if verbose: 98 | t1 = time.perf_counter() 99 | print('controller init time', t1 - t0) 100 | 101 | if backend == "qt": 102 | from spikeinterface_gui.myqt import QT, mkQApp 103 | from spikeinterface_gui.backend_qt import QtMainWindow 104 | 105 | # Suppress a known pyqtgraph warning 106 | warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph") 107 | warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*") 108 | 109 | 110 | app = mkQApp() 111 | 112 | win = QtMainWindow(controller, layout_preset=layout_preset) 113 | win.setWindowTitle('SpikeInterface GUI') 114 | this_file = Path(__file__).absolute() 115 | win.setWindowIcon(QT.QIcon(str(this_file.parent / 'img' / 'si.png'))) 116 | win.show() 117 | if start_app: 118 | app.exec() 119 | 120 | elif backend == "panel": 121 | import panel 122 | from .backend_panel import PanelMainWindow, start_server 123 | win = PanelMainWindow(controller, layout_preset=layout_preset) 124 | win.main_layout.servable(title='SpikeInterface GUI') 125 | if start_app: 126 | start_server(win, address=address, port=port) 127 | 128 | 129 | return win 130 | 131 | 132 | def run_mainwindow_cli(): 133 | argv = sys.argv[1:] 134 | 135 | parser = argparse.ArgumentParser(description='spikeinterface-gui') 136 | parser.add_argument('analyzer_folder', help='SortingAnalyzer folder path', default=None, nargs='?') 137 | parser.add_argument('--mode', help='Mode desktop or web', default='desktop') 138 | parser.add_argument('--no-traces', help='Do not show traces', action='store_true', default=False) 139 | parser.add_argument('--curation', help='Enable curation panel', action='store_true', default=False) 140 | parser.add_argument('--recording', help='Path to a recording file (.json/.pkl) or folder that can be loaded with spikeinterface.load', default=None) 141 | parser.add_argument('--recording-base-folder', help='Base folder path for the recording (if .json/.pkl)', default=None) 142 | parser.add_argument('--verbose', help='Make the output verbose', action='store_true', default=False) 143 | 144 | args = parser.parse_args(argv) 145 | 146 | analyzer_folder = args.analyzer_folder 147 | if analyzer_folder is None: 148 | print('You must specify the analyzer folder like this: sigui /path/to/my/analyzer/folder') 149 | exit() 150 | if args.verbose: 151 | print('Loading analyzer...') 152 | analyzer = load_sorting_analyzer(analyzer_folder, load_extensions=not is_path_remote(analyzer_folder)) 153 | if args.verbose: 154 | print('Analyzer loaded') 155 | 156 | recording = None 157 | if args.recording is not None: 158 | try: 159 | if args.verbose: 160 | print('Loading recording...') 161 | recording_base_path = args.recording_base_path 162 | recording = load(args.recording, base_folder=recording_base_path) 163 | if args.verbose: 164 | print('Recording loaded') 165 | except Exception as e: 166 | print('Error when loading recording. Please check the path or the file format') 167 | if recording is not None: 168 | if analyzer.get_num_channels() != recording.get_num_channels(): 169 | print('Recording and analyzer have different number of channels. Slicing recording') 170 | channel_mask = np.isin(recording.channel_ids, analyzer.channel_ids) 171 | if np.sum(channel_mask) != analyzer.get_num_channels(): 172 | raise ValueError('The recording does not have the same channel ids as the analyzer') 173 | recording = recording.select_channels(recording.channel_ids[channel_mask]) 174 | 175 | run_mainwindow(analyzer, mode=args.mode, with_traces=not(args.no_traces), curation=args.curation, recording=recording, verbose=args.verbose) 176 | -------------------------------------------------------------------------------- /spikeinterface_gui/mainsettingsview.py: -------------------------------------------------------------------------------- 1 | from .view_base import ViewBase 2 | 3 | 4 | # this control controller.main_settings 5 | main_settings = [ 6 | {'name': 'max_visible_units', 'type': 'int', 'value' : 10 }, 7 | {'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit', 8 | 'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']}, 9 | ] 10 | 11 | 12 | class MainSettingsView(ViewBase): 13 | _supported_backend = ['qt', 'panel'] 14 | _settings = None 15 | _depend_on = [] 16 | _need_compute = False 17 | 18 | def __init__(self, controller=None, parent=None, backend="qt"): 19 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 20 | 21 | 22 | def on_max_visible_units_changed(self): 23 | max_visible = self.main_settings['max_visible_units'] 24 | self.controller.main_settings['max_visible_units'] = max_visible 25 | 26 | visible_ids = self.controller.get_visible_unit_ids() 27 | if len(visible_ids) > max_visible: 28 | visible_ids = visible_ids[:max_visible] 29 | self.controller.set_visible_unit_ids(visible_ids) 30 | self.notify_unit_visibility_changed() 31 | 32 | def on_change_color_mode(self): 33 | 34 | self.controller.main_settings['color_mode'] = self.main_settings['color_mode'] 35 | self.controller.refresh_colors() 36 | self.notify_unit_color_changed() 37 | 38 | # for view in self.controller.views: 39 | # view.refresh() 40 | 41 | ## QT zone 42 | def _qt_make_layout(self): 43 | from .myqt import QT 44 | import pyqtgraph as pg 45 | 46 | self.layout = QT.QVBoxLayout() 47 | 48 | txt = self.controller.get_information_txt() 49 | self.info_label = QT.QLabel(txt) 50 | self.layout.addWidget(self.info_label) 51 | 52 | self.main_settings = pg.parametertree.Parameter.create(name="main settings", type='group', children=main_settings) 53 | 54 | # not that the parent is not the view (not Qt anymore) itself but the widget 55 | self.tree_main_settings = pg.parametertree.ParameterTree(parent=self.qt_widget) 56 | self.tree_main_settings.header().hide() 57 | self.tree_main_settings.setParameters(self.main_settings, showTop=True) 58 | # self.tree_main_settings.setWindowTitle(u'Main settings') 59 | self.layout.addWidget(self.tree_main_settings) 60 | 61 | self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed) 62 | self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode) 63 | 64 | 65 | def _qt_refresh(self): 66 | pass 67 | 68 | 69 | ## panel zone 70 | def _panel_make_layout(self): 71 | import panel as pn 72 | from .backend_panel import create_dynamic_parameterized, SettingsProxy 73 | 74 | # Create method and arguments layout 75 | self.main_settings = SettingsProxy(create_dynamic_parameterized(main_settings)) 76 | self.main_settings_layout = pn.Param(self.main_settings._parameterized, sizing_mode="stretch_both", 77 | name=f"Main settings") 78 | self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units') 79 | self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode') 80 | self.layout = pn.Column(self.main_settings_layout, sizing_mode="stretch_both") 81 | 82 | def _panel_on_max_visible_units_changed(self, event): 83 | self.on_max_visible_units_changed() 84 | 85 | def _panel_on_change_color_mode(self, event): 86 | self.on_change_color_mode() 87 | 88 | def _panel_refresh(self): 89 | pass 90 | 91 | 92 | MainSettingsView._gui_help_txt = """ 93 | ## Main settings 94 | 95 | Overview and main controls 96 | """ -------------------------------------------------------------------------------- /spikeinterface_gui/mergeview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | from .view_base import ViewBase 5 | 6 | 7 | class MergeView(ViewBase): 8 | _supported_backend = ['qt', 'panel'] 9 | 10 | _settings = None 11 | 12 | _methods = [{"name": "method", "type": "list", "limits": ["similarity", "automerge"]}] 13 | 14 | _method_params = { 15 | "similarity": [ 16 | {"name": "similarity_threshold", "type": "float", "value": .9, "step": 0.01}, 17 | {"name": "similarity_method", "type": "list", "limits": ["l1", "l2", "cosine"]}, 18 | ], 19 | "automerge": [ 20 | {"name": "automerge_preset", "type": "list", "limits": [ 21 | 'similarity_correlograms', 22 | 'temporal_splits', 23 | 'x_contaminations', 24 | 'feature_neighbors' 25 | ]} 26 | ] 27 | } 28 | 29 | _need_compute = False 30 | 31 | def __init__(self, controller=None, parent=None, backend="qt"): 32 | if controller.has_extension("template_similarity"): 33 | similarity_ext = controller.analyzer.get_extension("template_similarity") 34 | similarity_method = similarity_ext.params["method"] 35 | self._method_params["similarity"][1]["value"] = similarity_method 36 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 37 | 38 | def get_potential_merges(self): 39 | method = self.method 40 | if self.controller.verbose: 41 | print(f"Computing potential merges using {method} method") 42 | if method == 'similarity': 43 | similarity_params = self.method_params['similarity'] 44 | similarity = self.controller.get_similarity(similarity_params['similarity_method']) 45 | if similarity is None: 46 | similarity = self.controller.compute_similarity(similarity_params['similarity_method']) 47 | th_sim = similarity > similarity_params['similarity_threshold'] 48 | unit_ids = self.controller.unit_ids 49 | self.proposed_merge_unit_groups = [[unit_ids[i], unit_ids[j]] for i, j in zip(*np.nonzero(th_sim)) if i < j] 50 | self.merge_info = {'similarity': similarity} 51 | elif method == 'automerge': 52 | automerge_params = self.method_params['automerge'] 53 | params = { 54 | 'preset': automerge_params['automerge_preset'] 55 | } 56 | self.proposed_merge_unit_groups, self.merge_info = self.controller.compute_auto_merge(**params) 57 | else: 58 | raise ValueError(f"Unknown method: {method}") 59 | if self.controller.verbose: 60 | print(f"Found {len(self.proposed_merge_unit_groups)} merge groups using {method} method") 61 | 62 | def get_table_data(self, include_deleted=False): 63 | """Get data for displaying in table""" 64 | if not self.proposed_merge_unit_groups: 65 | return [], [] 66 | 67 | max_group_size = max(len(g) for g in self.proposed_merge_unit_groups) 68 | potential_labels = {"similarity", "correlogram_diff", "templates_diff"} 69 | more_labels = [] 70 | for lbl in self.merge_info.keys(): 71 | if lbl in potential_labels: 72 | if max_group_size == 2: 73 | more_labels.append(lbl) 74 | else: 75 | more_labels.append([lbl + "_min", lbl + "_max"]) 76 | 77 | labels = [f"unit_id{i}" for i in range(max_group_size)] + more_labels + ["group_ids"] 78 | 79 | rows = [] 80 | unit_ids = list(self.controller.unit_ids) 81 | for group_ids in self.proposed_merge_unit_groups: 82 | if not include_deleted and self.controller.curation: 83 | deleted_unit_ids = self.controller.curation_data["removed_units"] 84 | if any(unit_id in deleted_unit_ids for unit_id in group_ids): 85 | continue 86 | 87 | row = {} 88 | # Add unit information 89 | for i, unit_id in enumerate(group_ids): 90 | row[f"unit_id{i}"] = unit_id 91 | # row[f"unit_id{i}_color"] = self.controller.get_unit_color(unit_id) 92 | row["group_ids"] = group_ids 93 | 94 | # Add metrics information 95 | for info_name in more_labels: 96 | values = [] 97 | for unit_id1, unit_id2 in itertools.combinations(group_ids, 2): 98 | unit_ind1 = unit_ids.index(unit_id1) 99 | unit_ind2 = unit_ids.index(unit_id2) 100 | values.append(self.merge_info[info_name][unit_ind1][unit_ind2]) 101 | 102 | if max_group_size == 2: 103 | row[info_name] = f"{values[0]:.2f}" 104 | else: 105 | min_, max_ = min(values), max(values) 106 | row[f"{info_name}_min"] = f"{min_:.2f}" 107 | row[f"{info_name}_max"] = f"{max_:.2f}" 108 | rows.append(row) 109 | return labels, rows 110 | 111 | def accept_group_merge(self, group_ids): 112 | self.controller.make_manual_merge_if_possible(group_ids) 113 | self.notify_manual_curation_updated() 114 | self.refresh() 115 | 116 | ### QT 117 | def _qt_get_selected_group_ids(self): 118 | inds = self.table.selectedIndexes() 119 | if len(inds) != self.table.columnCount(): 120 | row_ix = None 121 | else: 122 | row_ix = inds[0].row() 123 | if row_ix is None: 124 | return None, None 125 | item = self.table.item(row_ix, 0) 126 | group_ids = item.group_ids 127 | return row_ix, group_ids 128 | 129 | def _qt_on_accept_shorcut(self): 130 | row_ix, group_ids = self._qt_get_selected_group_ids() 131 | if group_ids is None: 132 | return 133 | self.accept_group_merge(group_ids) 134 | n_rows = self.table.rowCount() 135 | self.table.setCurrentCell(min(n_rows - 1, row_ix + 1), 0) 136 | 137 | def _qt_on_item_selection_changed(self): 138 | r = self._qt_get_selected_group_ids() 139 | if r is None: 140 | return 141 | row_ix, group_ids = r 142 | if group_ids is None: 143 | return 144 | 145 | self.controller.set_visible_unit_ids(group_ids) 146 | 147 | self.notify_unit_visibility_changed() 148 | 149 | def _qt_on_double_click(self, item): 150 | self.accept_group_merge(item.group_ids) 151 | 152 | def _qt_on_method_change(self): 153 | self.method = self.method_selector['method'] 154 | for method in self.method_params_selectors: 155 | self.method_params_selectors[method].setVisible(method == self.method) 156 | 157 | 158 | def _qt_make_layout(self): 159 | from .myqt import QT 160 | import pyqtgraph as pg 161 | 162 | self.proposed_merge_unit_groups = [] 163 | 164 | # create method and arguments layout 165 | self.method_selector = pg.parametertree.Parameter.create(name="method", type='group', children=self._methods) 166 | method_select = pg.parametertree.ParameterTree(parent=None) 167 | method_select.header().hide() 168 | method_select.setParameters(self.method_selector, showTop=True) 169 | method_select.setWindowTitle(u'View options') 170 | method_select.setFixedHeight(50) 171 | self.method_selector.sigTreeStateChanged.connect(self._qt_on_method_change) 172 | 173 | self.merge_info = {} 174 | self.layout = QT.QVBoxLayout() 175 | self.layout.addWidget(method_select) 176 | 177 | self.method_params_selectors = {} 178 | self.method_params = {} 179 | for method, params in self._method_params.items(): 180 | method_params = pg.parametertree.Parameter.create(name="params", type='group', children=params) 181 | method_tree_settings = pg.parametertree.ParameterTree(parent=None) 182 | method_tree_settings.header().hide() 183 | method_tree_settings.setParameters(method_params, showTop=True) 184 | method_tree_settings.setWindowTitle(u'View options') 185 | method_tree_settings.setFixedHeight(100) 186 | self.method_params_selectors[method] = method_tree_settings 187 | self.method_params[method] = method_params 188 | self.layout.addWidget(method_tree_settings) 189 | self.method = self.method_selector['method'] 190 | self._qt_on_method_change() 191 | 192 | row_layout = QT.QHBoxLayout() 193 | 194 | but = QT.QPushButton('Calculate merges') 195 | but.clicked.connect(self._qt_calculate_potential_automerge) 196 | row_layout.addWidget(but) 197 | 198 | if self.controller.curation: 199 | self.include_deleted = QT.QCheckBox("Include deleted units") 200 | self.include_deleted.setChecked(False) 201 | row_layout.addWidget(self.include_deleted) 202 | 203 | self.layout.addLayout(row_layout) 204 | 205 | self.sorting_column = 2 206 | self.sorting_direction = QT.Qt.SortOrder.AscendingOrder 207 | 208 | self.table = QT.QTableWidget(selectionMode=QT.QAbstractItemView.SingleSelection, 209 | selectionBehavior=QT.QAbstractItemView.SelectRows) 210 | self.table.setContextMenuPolicy(QT.Qt.CustomContextMenu) 211 | self.layout.addWidget(self.table) 212 | self.table.itemSelectionChanged.connect(self._qt_on_item_selection_changed) 213 | 214 | shortcut_accept = QT.QShortcut(self.qt_widget) 215 | shortcut_accept.setKey(QT.QKeySequence('ctrl+a')) 216 | shortcut_accept.activated.connect(self._qt_on_accept_shorcut) 217 | 218 | self.refresh() 219 | 220 | def _qt_refresh(self): 221 | from .myqt import QT 222 | from .utils_qt import CustomItem 223 | 224 | self.table.clear() 225 | self.table.setSortingEnabled(False) 226 | 227 | include_deleted = self.include_deleted.isChecked() if self.controller.curation else False 228 | labels, rows = self.get_table_data(include_deleted=include_deleted) 229 | if "group_ids" in labels: 230 | labels.remove("group_ids") 231 | 232 | if not rows: 233 | self.table.setColumnCount(0) 234 | self.table.setRowCount(0) 235 | return 236 | 237 | self.table.setColumnCount(len(labels)) 238 | self.table.setHorizontalHeaderLabels(labels) 239 | self.table.setRowCount(len(rows)) 240 | 241 | for r, row in enumerate(rows): 242 | for c, label in enumerate(labels): 243 | if label.startswith("unit_id"): 244 | unit_id = row[label] 245 | n = self.controller.num_spikes[unit_id] 246 | name = f'{unit_id} n={n}' 247 | color = self.get_unit_color(unit_id) 248 | pix = QT.QPixmap(16, 16) 249 | pix.fill(color) 250 | icon = QT.QIcon(pix) 251 | item = QT.QTableWidgetItem(name) 252 | item.setData(QT.Qt.ItemDataRole.UserRole, unit_id) 253 | item.setFlags(QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable) 254 | self.table.setItem(r, c, item) 255 | item.setIcon(icon) 256 | item.group_ids = row.get("group_ids", []) 257 | elif "_color" not in label: 258 | value = row[label] 259 | item = CustomItem(value) 260 | self.table.setItem(r, c, item) 261 | 262 | for i in range(self.table.columnCount()): 263 | self.table.resizeColumnToContents(i) 264 | self.table.setSortingEnabled(True) 265 | 266 | def _qt_calculate_potential_automerge(self): 267 | self.get_potential_merges() 268 | self.refresh() 269 | 270 | def _qt_on_spike_selection_changed(self): 271 | pass 272 | 273 | def _qt_on_unit_visibility_changed(self): 274 | pass 275 | 276 | ## PANEL 277 | def _panel_make_layout(self): 278 | import panel as pn 279 | from .utils_panel import KeyboardShortcut, KeyboardShortcuts 280 | from .backend_panel import create_dynamic_parameterized, SettingsProxy 281 | 282 | pn.extension("tabulator") 283 | 284 | self.proposed_merge_unit_groups = [] 285 | 286 | # Create method and arguments layout 287 | method_settings = SettingsProxy(create_dynamic_parameterized(self._methods)) 288 | self.method_selector = pn.Param(method_settings._parameterized, sizing_mode="stretch_width", name="Method") 289 | for setting_data in self._methods: 290 | method_settings._parameterized.param.watch(self._panel_on_method_change, setting_data["name"]) 291 | 292 | self.method_params = {} 293 | self.method_params_selectors = {} 294 | for method, params in self._method_params.items(): 295 | method_params = SettingsProxy(create_dynamic_parameterized(params)) 296 | self.method_params[method] = method_params 297 | self.method_params_selectors[method] = pn.Param(method_params._parameterized, sizing_mode="stretch_width", 298 | name=f"{method.capitalize()} parameters") 299 | self.method = list(self.method_params.keys())[0] 300 | 301 | # shortcuts 302 | shortcuts = [ 303 | KeyboardShortcut(name="accept", key="a", ctrlKey=True), 304 | KeyboardShortcut(name="next", key="ArrowDown", ctrlKey=False), 305 | KeyboardShortcut(name="previous", key="ArrowUp", ctrlKey=False), 306 | ] 307 | shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) 308 | shortcuts_component.on_msg(self._panel_handle_shortcut) 309 | 310 | # Create data source and table 311 | self.table = None 312 | self.table_area = pn.pane.Placeholder("No merges computed yet.", height=400) 313 | 314 | self.caluculate_merges_button = pn.widgets.Button(name="Calculate merges", button_type="primary", sizing_mode="stretch_width") 315 | self.caluculate_merges_button.on_click(self._panel_calculate_merges) 316 | 317 | calculate_list = [self.caluculate_merges_button] 318 | 319 | if self.controller.curation: 320 | self.include_deleted = pn.widgets.Checkbox(name="Include deleted units", value=False) 321 | calculate_list.append(self.include_deleted) 322 | calculate_row = pn.Row(*calculate_list, sizing_mode="stretch_width") 323 | 324 | self.layout = pn.Column( 325 | # add params 326 | self.method_selector, 327 | self.method_params_selectors[self.method], 328 | calculate_row, 329 | self.table_area, 330 | shortcuts_component, 331 | scroll=True, 332 | sizing_mode="stretch_width", 333 | ) 334 | 335 | 336 | def _panel_refresh(self): 337 | """Update the table with current data""" 338 | import pandas as pd 339 | import panel as pn 340 | import matplotlib.colors as mcolors 341 | from .utils_panel import unit_formatter 342 | 343 | pn.extension("tabulator") 344 | # Create table 345 | include_deleted = self.include_deleted.value if self.controller.curation else False 346 | labels, rows = self.get_table_data(include_deleted=include_deleted) 347 | # set unmutable data 348 | data = {label: [] for label in labels} 349 | for row in rows: 350 | for label in labels: 351 | if label.startswith("unit_id"): 352 | unit_id = row[label] 353 | data[label].append({"id": unit_id, "color": mcolors.to_hex(self.controller.get_unit_color(unit_id))}) 354 | else: 355 | data[label].append(row[label]) 356 | 357 | df = pd.DataFrame(data=data) 358 | formatters = {label: unit_formatter for label in labels if label.startswith("unit_id")} 359 | self.table = pn.widgets.Tabulator( 360 | df, 361 | formatters=formatters, 362 | height=400, 363 | layout="fit_data", 364 | show_index=False, 365 | hidden_columns=["group_ids"], 366 | disabled=True, 367 | selectable=1, 368 | sortable=False 369 | ) 370 | 371 | # Add click handler with double click detection 372 | self.table.on_click(self._panel_on_click) 373 | self.table_area.update(self.table) 374 | 375 | def _panel_calculate_merges(self, event): 376 | import panel as pn 377 | self.table_area.update(pn.indicators.LoadingSpinner(size=50, value=True)) 378 | self.get_potential_merges() 379 | self.refresh() 380 | 381 | def _panel_on_method_change(self, event): 382 | self.method = event.new 383 | self.layout[1] = self.method_params_selectors[self.method] 384 | 385 | def _panel_on_click(self, event): 386 | # set unit visibility 387 | row = event.row 388 | self.table.selection = [row] 389 | self._panel_update_visible_pair(row) 390 | 391 | def _panel_update_visible_pair(self, row): 392 | table_row = self.table.value.iloc[row] 393 | visible_unit_ids = [] 394 | for name, value in zip(table_row.index, table_row): 395 | if name.startswith("unit_id"): 396 | unit_id = value["id"] 397 | visible_unit_ids.append(unit_id) 398 | self.controller.set_visible_unit_ids(visible_unit_ids) 399 | self.notify_unit_visibility_changed() 400 | 401 | def _panel_handle_shortcut(self, event): 402 | if event.data == "accept": 403 | selected = self.table.selection 404 | for row in selected: 405 | group_ids = self.table.value.iloc[row].group_ids 406 | self.accept_group_merge(group_ids) 407 | self.notify_manual_curation_updated() 408 | elif event.data == "next": 409 | next_row = min(self.table.selection[0] + 1, len(self.table.value) - 1) 410 | self.table.selection = [next_row] 411 | self._panel_update_visible_pair(next_row) 412 | elif event.data == "previous": 413 | previous_row = max(self.table.selection[0] - 1, 0) 414 | self.table.selection = [previous_row] 415 | self._panel_update_visible_pair(previous_row) 416 | 417 | def _panel_on_spike_selection_changed(self): 418 | pass 419 | 420 | def _panel_on_unit_visibility_changed(self): 421 | pass 422 | 423 | 424 | 425 | MergeView._gui_help_txt = """ 426 | ## Merge View 427 | 428 | This view allows you to compute potential merges between units based on their similarity or using the auto merge function. 429 | Select the method to use for merging units. 430 | The available methods are: 431 | - similarity: Computes the similarity between units based on their features. 432 | - automerge: uses the auto merge function in SpikeInterface to find potential merges. 433 | 434 | Click "Calculate merges" to compute the potential merges. When finished, the table will be populated 435 | with the potential merges. 436 | 437 | ### Controls 438 | - **left click** : select a potential merge group 439 | - **arrow up/down** : navigate through the potential merge groups 440 | - **ctrl + a** : accept the selected merge group 441 | """ 442 | -------------------------------------------------------------------------------- /spikeinterface_gui/myqt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Helper for importing Qt bindings library 4 | see 5 | http://mikeboers.com/blog/2015/07/04/static-libraries-in-a-dynamic-world#the-fold 6 | """ 7 | 8 | 9 | class ModuleProxy(object): 10 | 11 | def __init__(self, prefixes, modules): 12 | self.prefixes = prefixes 13 | self.modules = modules 14 | 15 | def __getattr__(self, name): 16 | 17 | if QT_MODE == 'PySide6' and name == 'pyqtSignal': 18 | name = 'Signal' 19 | 20 | for prefix in self.prefixes: 21 | fullname = prefix + name 22 | for module in self.modules: 23 | obj = getattr(module, fullname, None) 24 | if obj is not None: 25 | setattr(self, name, obj) # cache it 26 | return obj 27 | raise AttributeError(name) 28 | 29 | QT_MODE = None 30 | 31 | 32 | if QT_MODE is None: 33 | try: 34 | import PySide6 35 | from PySide6 import QtCore, QtGui, QtWidgets 36 | QT_MODE = 'PySide6' 37 | except ImportError: 38 | pass 39 | 40 | if QT_MODE is None: 41 | try: 42 | import PyQt6 43 | from PyQt6 import QtCore, QtGui, QtWidgets 44 | QT_MODE = 'PyQt6' 45 | except ImportError: 46 | pass 47 | 48 | if QT_MODE is None: 49 | try: 50 | import PyQt5 51 | from PyQt5 import QtCore, QtGui, QtWidgets 52 | QT_MODE = 'PyQt5' 53 | except ImportError: 54 | pass 55 | 56 | #~ print(QT_MODE) 57 | 58 | if QT_MODE == 'PySide6': 59 | QT = ModuleProxy(['', 'Q', 'Qt'], [QtCore.Qt, QtCore, QtGui, QtWidgets]) 60 | elif QT_MODE == 'PyQt6': 61 | QT = ModuleProxy(['', 'Q', 'Qt'], [QtCore.Qt, QtCore, QtGui, QtWidgets]) 62 | elif QT_MODE == 'PyQt5': 63 | QT = ModuleProxy(['', 'Q', 'Qt'], [QtCore.Qt, QtCore, QtGui, QtWidgets]) 64 | else: 65 | QT = None 66 | 67 | if QT is not None: 68 | from pyqtgraph import mkQApp 69 | -------------------------------------------------------------------------------- /spikeinterface_gui/similarityview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.cm 3 | import matplotlib.colors 4 | 5 | from .view_base import ViewBase 6 | 7 | 8 | 9 | 10 | class SimilarityView(ViewBase): 11 | _supported_backend = ['qt', 'panel'] 12 | _depend_on = ["template_similarity"] 13 | _settings = [ 14 | {'name': 'method', 'type': 'list', 'limits' : ['l1', 'l2', 'cosine'] }, 15 | {'name': 'colormap', 'type': 'list', 'limits' : ['viridis', 'jet', 'gray', 'hot', ] }, 16 | {'name': 'show_all', 'type': 'bool', 'value' : True }, 17 | ] 18 | _need_compute = True 19 | 20 | def __init__(self, controller=None, parent=None, backend="qt"): 21 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 22 | self.similarity = self.controller.get_similarity(method=None) 23 | 24 | def get_similarity_data(self): 25 | unit_ids = self.controller.unit_ids 26 | 27 | if self.similarity is None: 28 | return None, None 29 | 30 | if self.settings["show_all"]: 31 | visible_mask = np.ones(len(unit_ids), dtype="bool") 32 | s = self.similarity 33 | else: 34 | visible_mask = self.controller.get_units_visibility_mask() 35 | s = self.similarity[visible_mask, :][:, visible_mask] 36 | 37 | if not np.any(visible_mask): 38 | return None, None 39 | 40 | return s, visible_mask 41 | 42 | def select_unit_pair_on_click(self, x, y, reset=True): 43 | unit_ids = self.controller.unit_ids 44 | 45 | if self.settings['show_all']: 46 | visible_ids = unit_ids 47 | else: 48 | visible_ids = self.get_visible_unit_ids() 49 | 50 | n = len(visible_ids) 51 | 52 | inside = (0 <= x <= n) and (0 <= y <= n) 53 | 54 | if not inside: 55 | return 56 | 57 | unit_id0 = unit_ids[int(np.floor(x))] 58 | unit_id1 = unit_ids[int(np.floor(y))] 59 | 60 | if reset: 61 | self.controller.set_all_unit_visibility_off() 62 | self.controller.set_unit_visibility(unit_id0, True) 63 | self.controller.set_unit_visibility(unit_id1, True) 64 | 65 | self.notify_unit_visibility_changed() 66 | self.refresh() 67 | 68 | 69 | ## Qt ## 70 | def _qt_make_layout(self): 71 | from .myqt import QT 72 | import pyqtgraph as pg 73 | from .utils_qt import ViewBoxHandlingClickToPositionWithCtrl 74 | 75 | self.layout = QT.QVBoxLayout() 76 | 77 | self.graphicsview = pg.GraphicsView() 78 | self.layout.addWidget(self.graphicsview) 79 | 80 | self.viewBox = ViewBoxHandlingClickToPositionWithCtrl() 81 | self.viewBox.clicked.connect(self._qt_select_pair) 82 | self.viewBox.disableAutoRange() 83 | 84 | self.plot = pg.PlotItem(viewBox=self.viewBox) 85 | self.graphicsview.setCentralItem(self.plot) 86 | self.plot.hideButtons() 87 | 88 | self.image = pg.ImageItem() 89 | self.plot.addItem(self.image) 90 | 91 | self.plot.hideAxis('bottom') 92 | self.plot.hideAxis('left') 93 | 94 | self._text_items = [] 95 | 96 | 97 | self.similarity = self.controller.get_similarity(method=self.settings['method']) 98 | self.on_settings_changed()#this do refresh 99 | 100 | def _on_settings_changed(self): 101 | 102 | # TODO : check if method have changed or not 103 | # self.similarity = None 104 | 105 | N = 512 106 | cmap_name = self.settings['colormap'] 107 | cmap = matplotlib.colormaps[cmap_name].resampled(N) 108 | 109 | lut = [] 110 | for i in range(N): 111 | r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i)) 112 | lut.append([r*255,g*255,b*255]) 113 | self.lut = np.array(lut, dtype='uint8') 114 | 115 | 116 | self.refresh() 117 | 118 | def _compute(self): 119 | self.similarity = self.controller.compute_similarity(method=self.settings['method']) 120 | 121 | def _qt_refresh(self): 122 | import pyqtgraph as pg 123 | 124 | unit_ids = self.controller.unit_ids 125 | 126 | if self.similarity is None: 127 | self.image.hide() 128 | return 129 | 130 | similarity, visible_mask = self.get_similarity_data() 131 | 132 | if not np.any(visible_mask): 133 | self.image.hide() 134 | return 135 | 136 | _max = np.max(self.similarity) 137 | self.image.setImage(similarity, lut=self.lut, levels=[0, _max]) 138 | self.image.show() 139 | self.plot.setXRange(0, similarity.shape[0]) 140 | self.plot.setYRange(0, similarity.shape[1]) 141 | 142 | pos = 0 143 | 144 | for item in self._text_items: 145 | self.plot.removeItem(item) 146 | 147 | if np.sum(visible_mask) < 10: 148 | for unit_index, unit_id in enumerate(self.controller.unit_ids): 149 | if not visible_mask[unit_index]: 150 | continue 151 | for i in range(2): 152 | item = pg.TextItem(text=f'{unit_id}', color='#FFFFFF', anchor=(0.5, 0.5), border=None) 153 | self.plot.addItem(item) 154 | if i==0: 155 | item.setPos(pos + 0.5, 0) 156 | else: 157 | item.setPos(0, pos + 0.5) 158 | self._text_items.append(item) 159 | pos += 1 160 | 161 | 162 | 163 | def _qt_select_pair(self, x, y, reset): 164 | 165 | self.select_unit_pair_on_click(x, y, reset=reset) 166 | 167 | 168 | ## panel ## 169 | def _panel_make_layout(self): 170 | import panel as pn 171 | import bokeh.plotting as bpl 172 | from .utils_panel import _bg_color 173 | from bokeh.models import ColumnDataSource, LinearColorMapper 174 | from bokeh.events import Tap 175 | 176 | 177 | # Create Bokeh figure 178 | self.figure = bpl.figure( 179 | sizing_mode="stretch_both", 180 | tools="reset,wheel_zoom,tap", 181 | title="Similarity Matrix", 182 | background_fill_color=_bg_color, 183 | border_fill_color=_bg_color, 184 | outline_line_color="white", 185 | styles={"flex": "1"} 186 | ) 187 | self.figure.toolbar.logo = None 188 | 189 | # Create initial color mapper 190 | N = 512 191 | cmap = matplotlib.colormaps[self.settings['colormap']] 192 | self.color_mapper = LinearColorMapper( 193 | palette=[matplotlib.colors.rgb2hex(cmap(i)[:3]) for i in np.linspace(0, 1, N)], low=0, high=1 194 | ) 195 | 196 | self.image_source = ColumnDataSource({"image": [np.zeros((1, 1))], "dw": [1], "dh": [1]}) 197 | self.image_glyph = self.figure.image( 198 | image="image", x=0, y=0, dw="dw", dh="dh", color_mapper=self.color_mapper, source=self.image_source 199 | ) 200 | 201 | self.text_source = ColumnDataSource({"x": [], "y": [], "text": []}) 202 | self.text_glyphs = self.figure.text( 203 | x="x", 204 | y="y", 205 | text="text", 206 | source=self.text_source, 207 | text_color="white", 208 | text_align="center", 209 | text_baseline="middle", 210 | ) 211 | 212 | self.figure.on_event(Tap, self._panel_on_tap) 213 | 214 | self.layout = pn.Column( 215 | self.figure, 216 | styles={"display": "flex", "flex-direction": "column"}, 217 | sizing_mode="stretch_both" 218 | ) 219 | 220 | def _panel_refresh(self): 221 | similarity, visible_mask = self.get_similarity_data() 222 | 223 | 224 | if similarity is None: 225 | return 226 | 227 | self.color_mapper.low = 0 228 | self.color_mapper.high = np.max(self.similarity) 229 | 230 | self.image_source.data.update({"image": [similarity], "dw": [similarity.shape[1]], "dh": [similarity.shape[0]]}) 231 | 232 | # Update text labels 233 | x_positions = [] 234 | y_positions = [] 235 | texts = [] 236 | pos = 0 237 | 238 | for unit_index, unit_id in enumerate(self.controller.unit_ids): 239 | if not visible_mask[unit_index]: 240 | continue 241 | # Add labels on both axes 242 | x_positions.extend([pos + 0.5, 0]) 243 | y_positions.extend([0, pos + 0.5]) 244 | texts.extend([str(unit_id), str(unit_id)]) 245 | pos += 1 246 | 247 | self.text_source.data.update({"x": x_positions, "y": y_positions, "text": texts}) 248 | 249 | # Update plot ranges 250 | self.figure.x_range.start = 0 251 | self.figure.x_range.end = similarity.shape[1] 252 | self.figure.y_range.start = 0 253 | self.figure.y_range.end = similarity.shape[0] 254 | 255 | def _panel_on_tap(self, event): 256 | if event.x is None or event.y is None: 257 | return 258 | 259 | self.select_unit_pair_on_click(event.x, event.y, reset=True) 260 | 261 | 262 | 263 | SimilarityView._gui_help_txt = """ 264 | ## Similarity View 265 | 266 | This view displays the template similarity matrix between units. 267 | 268 | ### Controls 269 | - **left click** : select a pair of units to show in the unit view. 270 | """ 271 | -------------------------------------------------------------------------------- /spikeinterface_gui/spikeamplitudeview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from .basescatterview import BaseScatterView 5 | 6 | 7 | class SpikeAmplitudeView(BaseScatterView): 8 | _depend_on = ["spike_amplitudes"] 9 | _settings = BaseScatterView._settings + [ 10 | {'name': 'noise_level', 'type': 'bool', 'value' : True }, 11 | {'name': 'noise_factor', 'type': 'int', 'value' : 5 }, 12 | ] 13 | 14 | def __init__(self, controller=None, parent=None, backend="qt"): 15 | y_label = "Amplitude (uV)" 16 | spike_data = controller.spike_amplitudes 17 | # set noise level to False by default in panel 18 | if backend == 'panel': 19 | noise_level_settings_index = [s["name"] for s in SpikeAmplitudeView._settings].index("noise_level") 20 | SpikeAmplitudeView._settings[noise_level_settings_index]['value'] = False 21 | BaseScatterView.__init__( 22 | self, 23 | controller=controller, 24 | parent=parent, 25 | backend=backend, 26 | y_label=y_label, 27 | spike_data=spike_data, 28 | ) 29 | 30 | def _qt_make_layout(self): 31 | super()._qt_make_layout() 32 | self.noise_harea = [] 33 | if self.settings["noise_level"]: 34 | self._qt_add_noise_area() 35 | 36 | def _qt_refresh(self): 37 | super()._qt_refresh() 38 | # average noise across channels 39 | if self.settings["noise_level"] and len(self.noise_harea) == 0: 40 | self._qt_add_noise_area() 41 | # remove noise area if not selected 42 | elif not self.settings["noise_level"] and len(self.noise_harea) > 0: 43 | for n in self.noise_harea: 44 | self.plot2.removeItem(n) 45 | self.noise_harea = [] 46 | 47 | def _qt_add_noise_area(self): 48 | import pyqtgraph as pg 49 | 50 | n = self.settings["noise_factor"] 51 | noise = np.mean(self.controller.noise_levels) 52 | alpha_factor = 50 / n 53 | for i in range(1, n + 1): 54 | n = self.plot2.addItem( 55 | pg.LinearRegionItem(values=(-i * noise, i * noise), orientation="horizontal", 56 | brush=(255, 255, 255, int(i * alpha_factor)), pen=(0, 0, 0, 0)) 57 | ) 58 | self.noise_harea.append(n) 59 | 60 | def _panel_refresh(self): 61 | super()._panel_refresh() 62 | # update noise area 63 | self.noise_harea = [] 64 | if self.settings['noise_level'] and len(self.noise_harea) == 0: 65 | self._panel_add_noise_area() 66 | else: 67 | self.noise_harea = [] 68 | 69 | def _panel_add_noise_area(self): 70 | self.noise_harea = [] 71 | noise = np.mean(self.controller.noise_levels) 72 | n = self.settings['noise_factor'] 73 | alpha_factor = 50 / n 74 | for i in range(1, n + 1): 75 | h = self.hist_fig.harea( 76 | y="y", 77 | x1="x1", 78 | x2="x2", 79 | source={ 80 | "y": [-i * noise, i * noise], 81 | "x1": [0, 0], 82 | "x2": [10_000, 10_000], 83 | }, 84 | alpha=int(i * alpha_factor) / 255, # Match Qt alpha scaling 85 | color="lightgray", 86 | ) 87 | self.noise_harea.append(h) 88 | 89 | 90 | SpikeAmplitudeView._gui_help_txt = """ 91 | ## Spike Amplitude View 92 | 93 | Check amplitudes of spikes across the recording time or in a histogram 94 | comparing the distribution of ampltidues to the noise levels. 95 | 96 | ### Controls 97 | - **select** : activate lasso selection to select individual spikes 98 | """ 99 | -------------------------------------------------------------------------------- /spikeinterface_gui/spikedepthview.py: -------------------------------------------------------------------------------- 1 | from .basescatterview import BaseScatterView 2 | 3 | 4 | class SpikeDepthView(BaseScatterView): 5 | _depend_on = ["spike_locations"] 6 | 7 | def __init__(self, controller=None, parent=None, backend="qt"): 8 | y_label = "Depth (um)" 9 | spike_data = controller.spike_depths 10 | BaseScatterView.__init__( 11 | self, 12 | controller=controller, 13 | parent=parent, 14 | backend=backend, 15 | y_label=y_label, 16 | spike_data=spike_data, 17 | ) 18 | 19 | 20 | 21 | SpikeDepthView._gui_help_txt = """ 22 | ## Spike Depth View 23 | 24 | Check deppth of spikes across the recording time or in a histogram. 25 | 26 | ### Controls 27 | - **select** : activate lasso selection to select individual spikes 28 | """ 29 | -------------------------------------------------------------------------------- /spikeinterface_gui/spikelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | from .view_base import ViewBase 5 | 6 | 7 | _columns = ['#', 'unit_id', 'segment_index', 'sample_index', 'channel_index', 'rand_selected'] 8 | 9 | 10 | def get_qt_spike_model(): 11 | # this getter is to protect import QT when using panel 12 | 13 | from .myqt import QT 14 | 15 | class SpikeModel(QT.QAbstractItemModel): 16 | 17 | def __init__(self, parent =None, controller=None, columns=[]): 18 | QT.QAbstractItemModel.__init__(self,parent) 19 | self.controller = controller 20 | self.columns = columns 21 | # self.refresh_colors() 22 | 23 | self.visible_ind = self.controller.get_indices_spike_visible() 24 | 25 | def columnCount(self , parentIndex): 26 | return len(self.columns) 27 | 28 | def rowCount(self, parentIndex): 29 | if not parentIndex.isValid(): 30 | return int(self.visible_ind.size) 31 | else : 32 | return 0 33 | 34 | def index(self, row, column, parentIndex): 35 | if not parentIndex.isValid(): 36 | return self.createIndex(row, column, None) 37 | else: 38 | return QT.QModelIndex() 39 | 40 | def parent(self, index): 41 | return QT.QModelIndex() 42 | 43 | def data(self, index, role): 44 | 45 | if not index.isValid(): 46 | return None 47 | 48 | if role not in (QT.Qt.DisplayRole, QT.Qt.DecorationRole): 49 | return 50 | 51 | col = index.column() 52 | row = index.row() 53 | 54 | abs_ind = self.visible_ind[row] 55 | spike = self.controller.spikes[abs_ind] 56 | unit_id = self.controller.unit_ids[spike['unit_index']] 57 | 58 | if role ==QT.Qt.DisplayRole : 59 | if col == 0: 60 | return '{}'.format(abs_ind) 61 | elif col == 1: 62 | return '{}'.format(unit_id) 63 | elif col == 2: 64 | return '{}'.format(spike['segment_index']) 65 | elif col == 3: 66 | return '{}'.format(spike['sample_index']) 67 | elif col == 4: 68 | return '{}'.format(spike['channel_index']) 69 | elif col == 5: 70 | return '{}'.format(spike['rand_selected']) 71 | else: 72 | return None 73 | elif role == QT.Qt.DecorationRole : 74 | if col != 0: 75 | return None 76 | if unit_id in self.icons: 77 | return self.icons[unit_id] 78 | else: 79 | return None 80 | else : 81 | return None 82 | 83 | 84 | def flags(self, index): 85 | if not index.isValid(): 86 | return QT.Qt.NoItemFlags 87 | return QT.Qt.ItemIsEnabled | QT.Qt.ItemIsSelectable #| Qt.ItemIsDragEnabled 88 | 89 | def headerData(self, section, orientation, role): 90 | if orientation == QT.Qt.Horizontal and role == QT.Qt.DisplayRole: 91 | return self.columns[section] 92 | return 93 | 94 | def refresh_colors(self, qcolors): 95 | self.icons = { } 96 | for unit_id, qcolor in qcolors.items(): 97 | pix = QT.QPixmap(10,10 ) 98 | pix.fill(qcolor) 99 | self.icons[unit_id] = QT.QIcon(pix) 100 | 101 | def refresh(self): 102 | self.visible_ind = self.controller.get_indices_spike_visible() 103 | self.layoutChanged.emit() 104 | 105 | def clear(self): 106 | self.visible_ind = np.array([]) 107 | self.layoutChanged.emit() 108 | 109 | return SpikeModel 110 | 111 | 112 | 113 | class SpikeListView(ViewBase): 114 | _supported_backend = ['qt', 'panel'] 115 | _settings = [ 116 | {'name': 'select_change_channel_visibility', 'type': 'bool', 'value': False}, 117 | ] 118 | 119 | def __init__(self, controller=None, parent=None, backend="qt"): 120 | ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) 121 | 122 | def handle_selection(self, inds): 123 | self.controller.set_indices_spike_selected(inds) 124 | self.notify_spike_selection_changed() 125 | 126 | if len(inds) == 1 and self.settings['select_change_channel_visibility']: 127 | # also change channel for centering trace view. 128 | sparsity_mask = self.controller.get_sparsity_mask() 129 | unit_index = self.controller.spikes[inds[0]]['unit_index'] 130 | visible_channel_inds, = np.nonzero(sparsity_mask[unit_index, :]) 131 | 132 | # check if channel visibility must be changed 133 | if not np.all(np.isin(visible_channel_inds, self.controller.visible_channel_inds)): 134 | self.controller.set_channel_visibility(visible_channel_inds) 135 | self.notify_channel_visibility_changed() 136 | 137 | ## Qt ## 138 | def _qt_make_layout(self): 139 | from .myqt import QT 140 | 141 | # this getter is to protect import QT 142 | SpikeModel = get_qt_spike_model() 143 | 144 | self.layout = QT.QVBoxLayout() 145 | 146 | h = QT.QHBoxLayout() 147 | self.layout.addLayout(h) 148 | 149 | self.label = QT.QLabel('') 150 | h.addWidget(self.label) 151 | 152 | # h.addStretch() 153 | 154 | but = QT.QPushButton('↻ spikes') 155 | # h.addWidget(but) 156 | tb = self.qt_widget.view_toolbar 157 | tb.addWidget(but) 158 | but.clicked.connect(self.refresh) 159 | 160 | self.tree = QT.QTreeView(minimumWidth = 100, uniformRowHeights = True, 161 | selectionMode= QT.QAbstractItemView.ExtendedSelection, selectionBehavior = QT.QTreeView.SelectRows, 162 | contextMenuPolicy = QT.Qt.CustomContextMenu,) 163 | 164 | 165 | self.layout.addWidget(self.tree) 166 | 167 | self.model = SpikeModel(controller=self.controller, columns=_columns) 168 | qcolors = {unit_id:self.get_unit_color(unit_id) for unit_id in self.controller.unit_ids} 169 | self.model.refresh_colors(qcolors) 170 | 171 | self.tree.setModel(self.model) 172 | self.tree.selectionModel().selectionChanged.connect(self._qt_on_tree_selection) 173 | 174 | for i in range(self.model.columnCount(None)): 175 | self.tree.resizeColumnToContents(i) 176 | self.tree.setColumnWidth(0,80) 177 | 178 | 179 | def _qt_refresh_label(self): 180 | n1 = self.controller.spikes.size 181 | n2 = self.controller.get_indices_spike_visible().size 182 | n3 = self.controller.get_indices_spike_selected().size 183 | txt = f'All spikes : {n1} - visible : {n2} - selected : {n3}' 184 | self.label.setText(txt) 185 | 186 | def _qt_refresh(self): 187 | self._qt_refresh_label() 188 | qcolors = {unit_id:self.get_unit_color(unit_id) for unit_id in self.controller.unit_ids} 189 | self.model.refresh_colors(qcolors) 190 | self.model.refresh() 191 | 192 | def _qt_on_tree_selection(self): 193 | inds = [] 194 | for index in self.tree.selectedIndexes(): 195 | if index.column() == 0: 196 | ind = self.model.visible_ind[index.row()] 197 | inds.append(ind) 198 | 199 | self.handle_selection(inds) 200 | self._qt_refresh_label() 201 | 202 | def _qt_on_unit_visibility_changed(self): 203 | # we cannot refresh this list in real time whil moving channel/unit visibility 204 | # it is too slow. So the list is clear. 205 | self._qt_refresh_label() 206 | self.model.clear() 207 | qcolors = {unit_id:self.get_unit_color(unit_id) for unit_id in self.controller.unit_ids} 208 | self.model.refresh_colors(qcolors) 209 | 210 | def _qt_on_unit_color_changed(self): 211 | # we cannot refresh this list in real time whil moving channel/unit visibility 212 | # it is too slow. So the list is clear. 213 | self._qt_refresh_label() 214 | self.model.clear() 215 | qcolors = {unit_id:self.get_unit_color(unit_id) for unit_id in self.controller.unit_ids} 216 | self.model.refresh_colors(qcolors) 217 | 218 | 219 | def _qt_on_spike_selection_changed(self): 220 | from .myqt import QT 221 | self.tree.selectionModel().selectionChanged.disconnect(self._qt_on_tree_selection) 222 | 223 | selected_inds = self.controller.get_indices_spike_selected() 224 | visible_inds = self.controller.get_indices_spike_visible() 225 | row_selected, = np.nonzero(np.isin(visible_inds, selected_inds)) 226 | 227 | if row_selected.size>100:#otherwise this is verry slow 228 | row_selected = row_selected[:10] 229 | 230 | # change selection 231 | self.tree.selectionModel().clearSelection() 232 | flags = QT.QItemSelectionModel.Select #| QItemSelectionModel.Rows 233 | itemsSelection = QT.QItemSelection() 234 | for r in row_selected: 235 | for c in range(2): 236 | index = self.tree.model().index(r,c,QT.QModelIndex()) 237 | ir = QT.QItemSelectionRange( index ) 238 | itemsSelection.append(ir) 239 | self.tree.selectionModel().select(itemsSelection , flags) 240 | 241 | # set selection visible 242 | if len(row_selected)>=1: 243 | index = self.tree.model().index(row_selected[0],0,QT.QModelIndex()) 244 | self.tree.scrollTo(index) 245 | 246 | self.tree.selectionModel().selectionChanged.connect(self._qt_on_tree_selection) 247 | self._qt_refresh_label() 248 | 249 | ## panel ## 250 | def _panel_make_layout(self): 251 | import panel as pn 252 | import pandas as pd 253 | from .utils_panel import spike_formatter, SelectableTabulator #KeyboardShortcut, KeyboardShortcuts 254 | 255 | # pn.extension('tabulator') 256 | 257 | # Configure columns for tabulator 258 | df = pd.DataFrame(columns=_columns) 259 | formatters = {"unit_id": spike_formatter} 260 | 261 | # Create tabulator instance 262 | self.table = SelectableTabulator( 263 | df, 264 | layout="fit_data", 265 | formatters=formatters, 266 | frozen_columns=["#", "unit_id"], 267 | sizing_mode="stretch_both", 268 | show_index=False, 269 | selectable=True, 270 | disabled=True, 271 | pagination=None, 272 | # SelectableTabulator functions 273 | parent_view=self, 274 | refresh_table_function=self._panel_refresh_table, 275 | conditional_shortcut=self.is_view_active, 276 | ) 277 | # Add selection event handler 278 | self.table.param.watch(self._panel_on_selection_changed, "selection") 279 | 280 | self.refresh_button = pn.widgets.Button(name="↻ spikes", button_type="default", sizing_mode="stretch_width") 281 | self.refresh_button.on_click(self._panel_on_refresh_click) 282 | 283 | self.clear_button = pn.widgets.Button(name="Clear", button_type="default", sizing_mode="stretch_width") 284 | self.clear_button.on_click(self._panel_on_clear_click) 285 | 286 | self.info_text = pn.pane.HTML("") 287 | 288 | # Create main layout 289 | self.layout = pn.Column( 290 | self.info_text, 291 | pn.Row( 292 | self.clear_button, 293 | self.refresh_button, 294 | ), 295 | self.table, 296 | sizing_mode="stretch_both", 297 | ) 298 | 299 | self.last_clicked = None 300 | self.last_clicked_row = None 301 | self.last_selected_row = None 302 | 303 | def _panel_refresh(self): 304 | pass 305 | 306 | def _panel_refresh_table(self): 307 | import matplotlib.colors as mcolors 308 | import pandas as pd 309 | 310 | visible_inds = self.controller.get_indices_spike_visible() 311 | unit_ids = self.controller.unit_ids 312 | spikes = self.controller.spikes[visible_inds] 313 | 314 | spike_unit_ids = [] 315 | for i, spike in enumerate(spikes): 316 | unit_id = unit_ids[spike['unit_index']] 317 | color = mcolors.to_hex(self.controller.get_unit_color(unit_id)) 318 | spike_unit_ids.append({"id": unit_id, "color": color}) 319 | 320 | # Prepare data for tabulator 321 | data = { 322 | '#': visible_inds, 323 | 'unit_id': spike_unit_ids, 324 | 'segment_index': spikes['segment_index'], 325 | 'sample_index': spikes['sample_index'], 326 | 'channel_index': spikes['channel_index'], 327 | 'rand_selected': spikes['rand_selected'] 328 | } 329 | 330 | # Update table data 331 | df = pd.DataFrame(data) 332 | self.table.value = df 333 | 334 | selected_inds = self.controller.get_indices_spike_selected() 335 | if len(selected_inds) == 0: 336 | self.table.selection = [] 337 | else: 338 | # Find the rows corresponding to the selected indices 339 | row_selected, = np.nonzero(np.isin(visible_inds, selected_inds)) 340 | self.table.selection = [int(r) for r in row_selected] 341 | 342 | self._panel_refresh_label() 343 | 344 | def _panel_on_refresh_click(self, event): 345 | self._panel_refresh_label() 346 | self._panel_refresh_table() 347 | self.notify_active_view_updated() 348 | 349 | def _panel_on_clear_click(self, event): 350 | self.controller.set_indices_spike_selected([]) 351 | self.table.selection = [] 352 | self.notify_spike_selection_changed() 353 | self._panel_refresh_label() 354 | self.notify_active_view_updated() 355 | 356 | def _panel_on_selection_changed(self, event=None): 357 | selection = event.new 358 | if len(selection) == 0: 359 | self.handle_selection([]) 360 | else: 361 | absolute_indices = self.controller.get_indices_spike_visible()[np.array(selection)] 362 | self.handle_selection(absolute_indices) 363 | self._panel_refresh_label() 364 | 365 | def _panel_refresh_label(self): 366 | n1 = self.controller.spikes.size 367 | n2 = self.controller.get_indices_spike_visible().size 368 | n3 = self.controller.get_indices_spike_selected().size 369 | txt = f"All spikes: {n1} - visible: {n2} - selected: {n3}" 370 | self.info_text.object = txt 371 | 372 | def _panel_on_unit_visibility_changed(self): 373 | import pandas as pd 374 | # Clear the table when visibility changes 375 | self.table.value = pd.DataFrame(columns=_columns, data=[]) 376 | self._panel_refresh_label() 377 | 378 | def _panel_on_spike_selection_changed(self): 379 | selected_inds = self.controller.get_indices_spike_selected() 380 | visible_inds = self.controller.get_indices_spike_visible() 381 | row_selected, = np.nonzero(np.isin(visible_inds, selected_inds)) 382 | row_selected = [int(r) for r in row_selected] 383 | # Update the selection in the table 384 | self.table.selection = row_selected 385 | self._panel_refresh_label() 386 | 387 | 388 | SpikeListView._gui_help_txt = """ 389 | ## Spike list View 390 | 391 | Show all spikes of the visible units. When spikes are selected, they are highlighted in the Spike Amplitude View and the ND SCatter View. 392 | When a single spike is selected, the Trace and TraceMap Views are centered on it. 393 | 394 | ### Controls 395 | * **↻ spikes**: refresh the spike list 396 | * **clear**: clear the selected spikes 397 | * **shift + arrow up/down** : select next/previous spike and make it visible alone 398 | """ 399 | -------------------------------------------------------------------------------- /spikeinterface_gui/tests/debug_views.py: -------------------------------------------------------------------------------- 1 | import spikeinterface_gui as sigui 2 | from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict 3 | 4 | from spikeinterface_gui.controller import Controller 5 | from spikeinterface_gui.myqt import mkQApp 6 | from spikeinterface_gui.viewlist import possible_class_views 7 | from spikeinterface_gui.backend_qt import ViewWidget 8 | 9 | 10 | import spikeinterface.full as si 11 | 12 | 13 | 14 | from pathlib import Path 15 | 16 | 17 | # test_folder = Path(__file__).parent / 'my_dataset_small' 18 | test_folder = Path(__file__).parent / 'my_dataset_big' 19 | # test_folder = Path(__file__).parent / 'my_dataset_multiprobe' 20 | 21 | 22 | def debug_one_view(): 23 | 24 | app = mkQApp() 25 | analyzer = si.load_sorting_analyzer(test_folder / "sorting_analyzer") 26 | 27 | curation_dict = make_curation_dict(analyzer) 28 | # curation_dict = None 29 | 30 | curation = curation_dict is not None 31 | 32 | controller = Controller(analyzer, verbose=True, curation=curation, curation_data=curation_dict) 33 | 34 | # view_class = possible_class_views['unitlist'] 35 | view_class = possible_class_views['mainsettings'] 36 | # view_class = possible_class_views['spikeamplitude'] 37 | widget = ViewWidget(view_class) 38 | view = view_class(controller=controller, parent=widget, backend='qt') 39 | widget.set_view(view) 40 | widget.show() 41 | view.refresh() 42 | 43 | app.exec() 44 | 45 | 46 | if __name__ == '__main__': 47 | debug_one_view() 48 | -------------------------------------------------------------------------------- /spikeinterface_gui/tests/iframe/iframe_server.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import webbrowser 3 | from pathlib import Path 4 | import argparse 5 | from flask import Flask, send_file, jsonify 6 | 7 | app = Flask(__name__) 8 | panel_server = None 9 | panel_url = None 10 | panel_thread = None 11 | panel_port_global = None 12 | 13 | @app.route('/') 14 | def index(): 15 | """Serve the iframe test HTML page""" 16 | return send_file('iframe_test.html') 17 | 18 | @app.route('/start_test_server') 19 | def start_test_server(): 20 | """Start the Panel server in a separate thread""" 21 | global panel_server, panel_url, panel_thread, panel_port_global 22 | 23 | # If a server is already running, return its URL 24 | if panel_url: 25 | return jsonify({"success": True, "url": panel_url}) 26 | 27 | # Make sure the test dataset exists 28 | test_folder = Path(__file__).parent / "my_dataset" 29 | if not test_folder.is_dir(): 30 | from spikeinterface_gui.tests.testingtools import make_analyzer_folder 31 | make_analyzer_folder(test_folder) 32 | 33 | # Function to run the Panel server in a thread 34 | def run_panel_server(): 35 | global panel_server, panel_url, panel_port_global 36 | try: 37 | # Start the Panel server with curation enabled 38 | # Use a direct import to avoid circular imports 39 | from spikeinterface import load_sorting_analyzer 40 | from spikeinterface_gui import run_mainwindow 41 | 42 | # Load the analyzer 43 | analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer") 44 | 45 | # Start the Panel server directly 46 | win = run_mainwindow( 47 | analyzer, 48 | backend="panel", 49 | start_app=False, 50 | verbose=True, 51 | curation=True, 52 | make_servable=True 53 | ) 54 | 55 | # Start the server manually 56 | import panel as pn 57 | pn.serve(win.main_layout, port=panel_port_global, address="localhost", show=False, start=True) 58 | 59 | # Get the server URL 60 | panel_url = f"http://localhost:{panel_port_global}" 61 | panel_server = win 62 | 63 | print(f"Panel server started at {panel_url}") 64 | except Exception as e: 65 | print(f"Error starting Panel server: {e}") 66 | import traceback 67 | traceback.print_exc() 68 | 69 | # Start the Panel server in a separate thread 70 | panel_thread = threading.Thread(target=run_panel_server) 71 | panel_thread.daemon = True 72 | panel_thread.start() 73 | 74 | # Give the server some time to start 75 | import time 76 | time.sleep(5) # Increased wait time 77 | 78 | # Check if the server is actually running 79 | import requests 80 | try: 81 | response = requests.get(f"http://localhost:{panel_port_global}", timeout=2) 82 | if response.status_code == 200: 83 | return jsonify({"success": True, "url": f"http://localhost:{panel_port_global}"}) 84 | else: 85 | return jsonify({"success": False, "error": f"Server returned status code {response.status_code}"}) 86 | except requests.exceptions.RequestException as e: 87 | return jsonify({"success": False, "error": f"Could not connect to Panel server: {str(e)}"}) 88 | 89 | @app.route('/stop_test_server') 90 | def stop_test_server(): 91 | """Stop the Panel server""" 92 | global panel_server, panel_url, panel_thread 93 | 94 | if panel_server: 95 | # Clean up resources 96 | # clean_all(Path(__file__).parent / 'my_dataset') 97 | panel_url = None 98 | panel_server = None 99 | return jsonify({"success": True}) 100 | else: 101 | return jsonify({"success": False, "error": "No server running"}) 102 | 103 | def main(flask_port=5000, panel_port=5006): 104 | """Start the Flask server and open the browser""" 105 | global panel_port_global 106 | panel_port_global = panel_port 107 | # Open the browser 108 | webbrowser.open(f'http://localhost:{flask_port}') 109 | 110 | # Start the Flask server 111 | app.run(debug=False, port=flask_port) 112 | 113 | 114 | parser = argparse.ArgumentParser(description="Run the Flask and Panel servers.") 115 | parser.add_argument('--flask-port', type=int, default=5000, help="Port for the Flask server (default: 5000)") 116 | parser.add_argument('--panel-port', type=int, default=5006, help="Port for the Panel server (default: 5006)") 117 | 118 | if __name__ == '__main__': 119 | args = parser.parse_args() 120 | 121 | main(flask_port=int(args.flask_port), panel_port=int(args.panel_port)) 122 | -------------------------------------------------------------------------------- /spikeinterface_gui/tests/iframe/iframe_test.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 |No data received yet...84 |