├── .flake8 ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── atlaselectrophysiology ├── AdaptedAxisItem.py ├── ColorBar.py ├── README.md ├── __init__.py ├── alignment_with_easyqc.py ├── compare_alignments.py ├── create_overview_plots.py ├── ephys_atlas_gui.py ├── ephys_atlas_image.png ├── ephys_gui_setup.py ├── example_code.py ├── extract_files.py ├── get_scale_factor.py ├── load_data.py ├── load_data_local.py ├── load_histology.py ├── plot_data.py ├── qc_table.py ├── rendering.py ├── root.obj ├── sandbox.py └── subject_scaling.py ├── atlasview ├── __init__.py ├── atlasview.py ├── channels_test.npy ├── launcher.py ├── region_values.npy ├── sliceview.ui └── topview.ui ├── data_exploration_gui ├── README.md ├── cluster.py ├── cluster_class.py ├── data_class.py ├── data_explore_gui.py ├── data_model.py ├── filter.py ├── filter_class.py ├── gui_main.py ├── load_data.py ├── misc.py ├── misc_class.py ├── plot.py ├── plot_class.py ├── scatter.py ├── scatter_class.py └── utils.py ├── dlc ├── DLC_labeled_video.py ├── Example_DLC_access.ipynb ├── README.md ├── __init__.py ├── get_dlc_traces.py ├── overview_plot_dlc.py ├── stream_dlc_labeled_frames.py └── wheel_dlc_viewer.py ├── ephysfeatures ├── __init__.py ├── features_across_region.py ├── prepare_features.py └── qrangeslider.py ├── histology ├── __init__.py ├── atlas_mpl.py ├── atlas_mpl.ui └── transform-tracks.py ├── launch_phy ├── README.md ├── cluster_table.py ├── defined_metrics.py ├── metrics.py ├── phy_launcher.py ├── plugins │ └── phy_plugin.py └── populate_cluster_table.py ├── mesoscope_gui └── gui.py ├── needles2 ├── __init__.py ├── coverageUI.ui ├── layerUI.ui ├── mainUI.ui ├── needles_viewer.py ├── probe_model.py ├── regionUI.ui ├── run_needles2.py ├── spike_features.py └── tableUI.ui ├── qt_helpers ├── __init__.py ├── qt.py └── qt_matplotlib.py ├── requirements.txt ├── run_tests ├── setup.py ├── tests ├── __init__.py ├── fixtures │ └── data_alignmentqc_gui.npz ├── test_alignment_qc_gui.py └── test_task_qc_viewer.py └── viewspikes ├── README.md ├── datoviz.py ├── example_view_ephys_session.py ├── gui.py ├── load_ma_data.py ├── plots.py └── raster.ui /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 99 3 | ignore = W504 4 | exclude = 5 | .git, 6 | __pycache__, 7 | __init__.py, 8 | data_exploration_gui/ 9 | launch_phy/ 10 | histology/ 11 | qt_matplotlib.py 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master, develop ] 6 | pull_request: 7 | branches: [ master, develop ] 8 | 9 | jobs: 10 | build: 11 | name: build (${{ matrix.python-version }}, ${{ matrix.os }}) 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | max-parallel: 3 15 | matrix: 16 | os: ["ubuntu-latest", "macos-latest", "windows-latest"] 17 | python-version: ["3.12"] 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - uses: conda-incubator/setup-miniconda@v2.0.0 22 | with: 23 | auto-update-conda: true 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install requirements 27 | shell: bash -l {0} 28 | run: | 29 | conda activate test 30 | pip install --requirement requirements.txt 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IBL apps stuff 2 | dlc/*.mp4 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # VSCode 135 | .vscode/ 136 | 137 | scratch/ 138 | 139 | # pycharm 140 | .idea/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 International Brain Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iblapps 2 | IBL related applications that rely on unsupported libraries such as pyqt5. See package README files for more details. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/__init__.py -------------------------------------------------------------------------------- /atlaselectrophysiology/ColorBar.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui 2 | import pyqtgraph as pg 3 | import matplotlib 4 | import numpy as np 5 | from pyqtgraph.functions import makeARGB 6 | 7 | 8 | class ColorBar(pg.GraphicsWidget): 9 | 10 | def __init__(self, cmap_name, cbin=256, parent=None, data=None): 11 | pg.GraphicsWidget.__init__(self) 12 | 13 | # Create colour map from matplotlib colourmap name 14 | self.cmap_name = cmap_name 15 | cmap = matplotlib.cm.get_cmap(self.cmap_name) 16 | if type(cmap) == matplotlib.colors.LinearSegmentedColormap: 17 | cbins = np.linspace(0.0, 1.0, cbin) 18 | colors = (cmap(cbins)[np.newaxis, :, :3][0]).tolist() 19 | else: 20 | colors = cmap.colors 21 | colors = [(np.array(c) * 255).astype(int).tolist() + [255.] for c in colors] 22 | positions = np.linspace(0, 1, len(colors)) 23 | self.map = pg.ColorMap(positions, colors) 24 | self.lut = self.map.getLookupTable() 25 | self.grad = self.map.getGradient() 26 | 27 | def getBrush(self, data, levels=None): 28 | if levels is None: 29 | levels = [np.min(data), np.max(data)] 30 | brush_rgb, _ = makeARGB(data[:, np.newaxis], levels=levels, lut=self.lut, useRGBA=True) 31 | brush = [QtGui.QColor(*col) for col in np.squeeze(brush_rgb)] 32 | return brush 33 | 34 | def getColourMap(self): 35 | return self.lut 36 | 37 | def makeColourBar(self, width, height, fig, min=0, max=1, label='', lim=False): 38 | self.cbar = HorizontalBar(width, height, self.grad) 39 | ax = fig.getAxis('top') 40 | ax.setPen('k') 41 | ax.setTextPen('k') 42 | ax.setStyle(stopAxisAtTick=((True, True))) 43 | # labelStyle = {'font-size': '8pt'} 44 | ax.setLabel(label) 45 | ax.setHeight(30) 46 | if lim: 47 | ax.setTicks([[(0, str(np.around(min, 2))), (width, str(np.around(max, 2)))]]) 48 | else: 49 | ax.setTicks([[(0, str(np.around(min, 2))), (width / 2, 50 | str(np.around(min + (max - min) / 2, 2))), 51 | (width, str(np.around(max, 2)))], 52 | [(width / 4, str(np.around(min + (max - min) / 4, 2))), 53 | (3 * width / 4, str(np.around(min + (3 * (max - min) / 4), 2)))]]) 54 | fig.setXRange(0, width) 55 | fig.setYRange(0, height) 56 | 57 | return self.cbar 58 | 59 | 60 | class HorizontalBar(pg.GraphicsWidget): 61 | def __init__(self, width, height, grad): 62 | pg.GraphicsWidget.__init__(self) 63 | self.width = width 64 | self.height = height 65 | self.grad = grad 66 | QtGui.QPainter() 67 | 68 | def paint(self, p, *args): 69 | p.setPen(QtCore.Qt.NoPen) 70 | self.grad.setStart(0, self.height / 2) 71 | self.grad.setFinalStop(self.width, self.height / 2) 72 | p.setBrush(pg.QtGui.QBrush(self.grad)) 73 | p.drawRect(QtCore.QRectF(0, 0, self.width, self.height)) 74 | 75 | 76 | class VerticalBar(pg.GraphicsWidget): 77 | def __init__(self, width, height, grad): 78 | pg.GraphicsWidget.__init__(self) 79 | self.width = width 80 | self.height = height 81 | self.grad = grad 82 | QtGui.QPainter() 83 | 84 | def paint(self, p, *args): 85 | p.setPen(QtCore.Qt.NoPen) 86 | self.grad.setStart(self.width / 2, self.height) 87 | self.grad.setFinalStop(self.width / 2, 0) 88 | p.setBrush(pg.QtGui.QBrush(self.grad)) 89 | p.drawRect(QtCore.QRectF(0, 0, self.width, self.height)) 90 | -------------------------------------------------------------------------------- /atlaselectrophysiology/README.md: -------------------------------------------------------------------------------- 1 | # Ephys Atlas GUI 2 | 3 | GUI to allow user to align electrophysiology data with histology data. Please refer to this wiki page for information on installation and usage https://github.com/int-brain-lab/iblapps/wiki 4 | 5 | 6 | -------------------------------------------------------------------------------- /atlaselectrophysiology/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlaselectrophysiology/__init__.py -------------------------------------------------------------------------------- /atlaselectrophysiology/compare_alignments.py: -------------------------------------------------------------------------------- 1 | # import modules 2 | from one.api import ONE 3 | from ibllib.pipes.ephys_alignment import EphysAlignment 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib 7 | import iblatlas.atlas as atlas 8 | from pathlib import Path 9 | # Instantiate brain atlas and one 10 | brain_atlas = atlas.AllenAtlas(25) 11 | one = ONE() 12 | 13 | fig_path = Path('C:/Users/Mayo/Documents/PYTHON/alignment_figures/scale_factor') 14 | # Find eid of interest 15 | aligned_sess = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track', 16 | django='probe_insertion__session__project__name__icontains,' 17 | 'ibl_neuropixel_brainwide_01,' 18 | 'probe_insertion__session__qc__lt,30') 19 | eids = np.array([s['session']['id'] for s in aligned_sess]) 20 | probes = np.array([s['probe_name'] for s in aligned_sess]) 21 | 22 | json = [s['json'] for s in aligned_sess] 23 | idx_none = [i for i, val in enumerate(json) if val is None] 24 | json_val = np.delete(json, idx_none) 25 | keys = [list(s.keys()) for s in json_val] 26 | 27 | eids = np.delete(eids, idx_none) 28 | probes = np.delete(probes, idx_none) 29 | 30 | # Find index of json fields with 2 or more keys 31 | len_key = [len(s) for s in keys] 32 | idx_several = [i for i, val in enumerate(keys) if len(val) >= 2] 33 | eid_several = eids[idx_several] 34 | probe_several = probes[idx_several] 35 | 36 | 37 | # subject = 'KS023' 38 | # date = '2019-12-10' 39 | # sess_no = 1 40 | # probe_label = 'probe00' 41 | # eid = one.search(subject=subject, date=date, number=sess_no)[0] 42 | # eid = 'e2448a52-2c22-4ecc-bd48-632789147d9c' 43 | 44 | small_scaling = [["CSHL045", "2020-02-25", "probe00", "2020-06-12T10:40:22_noam.roth"], 45 | ["CSHL045", "2020-02-25", "probe01", "2020-09-12T23:43:03_petrina.lau"], 46 | ["CSHL047", "2020-01-20", "probe00", "2020-09-28T08:18:19_noam.roth"], 47 | ["CSHL047", "2020-01-27", "probe00", "2020-09-13T15:51:21_petrina.lau"], 48 | ["CSHL049", "2020-01-08", "probe00", "2020-09-14T15:44:56_nate"], 49 | ["CSHL051", "2020-02-05", "probe00", "2020-06-12T14:05:49_guido"], 50 | ["CSHL055", "2020-02-18", "probe00", "2020-08-13T12:07:08_jeanpaul"], 51 | ["CSH_ZAD_001", "2020-01-13", "probe00", "2020-09-22T17:23:50_petrina.lau"], 52 | ["KS014", "2019-12-03", "probe01", "2020-06-17T19:42:02_Karolina_Socha"], 53 | ["KS016", "2019-12-05", "probe01", "2020-06-18T10:26:54_Karolina_Socha"], 54 | ["KS020", "2020-02-06", "probe00", "2020-09-13T15:08:15_petrina.lau"], 55 | ["NYU-11", "2020-02-21", "probe01", "2020-09-13T11:19:59_petrina.lau"], 56 | ["SWC_014", "2019-12-15", "probe01", "2020-07-26T22:24:39_noam.roth"], 57 | ["ZM_2240", "2020-01-23", "probe00", "2020-06-05T14:57:46_guido"]] 58 | 59 | big_scaling = [["CSHL045", "2020-02-25", "probe00", "2020-06-12T10:40:22_noam.roth"], 60 | ["CSHL045", "2020-02-25", "probe01", "2020-09-12T23:43:03_petrina.lau"], 61 | ["CSH_ZAD_001", "2020-01-13", "probe00", "2020-09-22T17:23:50_petrina.lau"], 62 | ["KS014", "2019-12-03", "probe00", "2020-06-17T15:15:01_Karolina_Socha"], 63 | ["KS014", "2019-12-03", "probe01", "2020-06-17T19:42:02_Karolina_Socha"], 64 | ["KS014", "2019-12-04", "probe00", "2020-09-12T16:39:14_petrina.lau"], 65 | ["KS014", "2019-12-06", "probe00", "2020-06-17T13:40:00_Karolina_Socha"], 66 | ["KS014", "2019-12-07", "probe00", "2020-06-17T16:21:35_Karolina_Socha"], 67 | ["KS016", "2019-12-04", "probe00", "2020-08-13T14:02:44_jeanpaul"], 68 | ["KS016", "2019-12-05", "probe00", "2020-06-18T10:49:53_Karolina_Socha"], 69 | ["KS023", "2019-12-07", "probe00", "2020-09-09T14:21:35_nate"], 70 | ["KS023", "2019-12-07", "probe01", "2020-06-18T15:50:55_Karolina_Socha"], 71 | ["KS023", "2019-12-10", "probe01", "2020-06-12T13:59:02_guido"], 72 | ["KS023", "2019-12-11", "probe01", "2020-06-18T16:04:18_Karolina_Socha"], 73 | ["NYU-11", "2020-02-21", "probe01", "2020-09-13T11:19:59_petrina.lau"], 74 | ["NYU-12", "2020-01-22", "probe00", "2020-09-13T19:53:43_petrina.lau"], 75 | ["SWC_014", "2019-12-12", "probe00", "2020-07-27T11:27:16_noam.roth"], 76 | ["SWC_038", "2020-08-01", "probe01", "2020-08-31T12:32:05_nate"], 77 | ["ibl_witten_14", "2019-12-11", "probe00", "2020-06-14T15:33:45_noam.roth"]] 78 | 79 | for sess in small_scaling: 80 | eid = one.search(subject=sess[0], date=sess[1])[0] 81 | probe_label = sess[2] 82 | # for eid, probe_label in zip(eid_several, probe_several): 83 | trajectory = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track', 84 | session=eid, probe=probe_label) 85 | 86 | subject = trajectory[0]['session']['subject'] 87 | date = trajectory[0]['session']['start_time'][0:10] 88 | chn_coords = one.load(eid, dataset_types=['channels.localCoordinates'])[0] 89 | depths = chn_coords[:, 1] 90 | 91 | insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe_label) 92 | xyz_picks = np.array(insertion[0]['json']['xyz_picks']) / 1e6 93 | 94 | alignments = trajectory[0]['json'] 95 | 96 | def plot_regions(region, label, colour, ax): 97 | for reg, col in zip(region, colour): 98 | height = np.abs(reg[1] - reg[0]) 99 | color = col / 255 100 | ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w') 101 | 102 | ax.set_yticks(label[:, 0].astype(int)) 103 | ax.set_yticklabels(label[:, 1]) 104 | ax.yaxis.set_tick_params(labelsize=10) 105 | ax.tick_params(axis="y", direction="in", pad=-50) 106 | ax.set_ylim([20, 3840]) 107 | ax.get_xaxis().set_visible(False) 108 | 109 | def plot_scaling(region, scale, mapper, ax): 110 | for reg, col in zip(region_scaled, scale_factor): 111 | height = np.abs(reg[1] - reg[0]) 112 | color = np.array(mapper.to_rgba(col, bytes=True)) / 255 113 | ax.bar(x=1.1, height=height, width=0.2, color=color, bottom=reg[0], edgecolor='w') 114 | 115 | sec_ax = ax.secondary_yaxis('right') 116 | sec_ax.set_yticks(np.mean(region, axis=1)) 117 | sec_ax.set_yticklabels(np.around(scale, 2)) 118 | sec_ax.tick_params(axis="y", direction="in") 119 | sec_ax.set_ylim([20, 3840]) 120 | 121 | fig, ax = plt.subplots(1, len(alignments) + 1, figsize=(15, 15)) 122 | ephysalign = EphysAlignment(xyz_picks, depths, brain_atlas=brain_atlas) 123 | feature, track, _ = ephysalign.get_track_and_feature() 124 | channels_orig = ephysalign.get_channel_locations(feature, track) 125 | region, region_label = ephysalign.scale_histology_regions(feature, track) 126 | region_scaled, scale_factor = ephysalign.get_scale_factor(region) 127 | region_colour = ephysalign.region_colour 128 | 129 | norm = matplotlib.colors.Normalize(vmin=0.5, vmax=1.5, clip=True) 130 | mapper = matplotlib.cm.ScalarMappable(norm=norm, cmap=matplotlib.cm.seismic) 131 | 132 | ax_i = fig.axes[0] 133 | plot_regions(region, region_label, region_colour, ax_i) 134 | plot_scaling(region_scaled, scale_factor, mapper, ax_i) 135 | ax_i.set_title('Original') 136 | 137 | for iK, key in enumerate(alignments): 138 | # Location of reference lines used for alignmnet 139 | feature = np.array(alignments[key][0]) 140 | track = np.array(alignments[key][1]) 141 | user = key[20:] 142 | # Instantiate EphysAlignment object 143 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature, 144 | brain_atlas=brain_atlas) 145 | 146 | channels = ephysalign.get_channel_locations(feature, track) 147 | avg_dist = np.mean(np.sqrt(np.sum((channels - channels_orig) ** 2, axis=1)), axis=0) 148 | region, region_label = ephysalign.scale_histology_regions(feature, track) 149 | region_scaled, scale_factor = ephysalign.get_scale_factor(region) 150 | 151 | ax_i = fig.axes[iK + 1] 152 | plot_regions(region, region_label, region_colour, ax_i) 153 | plot_scaling(region_scaled, scale_factor, mapper, ax_i) 154 | ax_i.set_title(user + '\n Avg dist = ' + str(np.around(avg_dist * 1e6, 2))) 155 | 156 | fig.suptitle(subject + '_' + str(date) + '_' + probe_label, fontsize=16) 157 | plt.show() 158 | fig.savefig(fig_path.joinpath(subject + '_' + str(date) + '_' + probe_label + '.png'), dpi=100) 159 | plt.close(fig) 160 | -------------------------------------------------------------------------------- /atlaselectrophysiology/create_overview_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pathlib import Path 3 | import glob 4 | 5 | 6 | def make_overview_plot(folder, sess_info, save_folder=None): 7 | 8 | image_folder = folder 9 | image_info = sess_info 10 | if not save_folder: 11 | save_folder = image_folder 12 | 13 | def load_image(image_name, ax): 14 | with image_name as ifile: 15 | image = plt.imread(ifile) 16 | 17 | ax.spines['right'].set_visible(False) 18 | ax.spines['top'].set_visible(False) 19 | ax.spines['left'].set_visible(False) 20 | ax.spines['bottom'].set_visible(False) 21 | ax.set_axis_off() 22 | ax.set_aspect('equal') 23 | ax.imshow(image) 24 | return image 25 | 26 | fig = plt.figure(constrained_layout=True, figsize=(18, 9)) 27 | gs = fig.add_gridspec(3, 18) 28 | gs.update(wspace=0.025, hspace=0.05) 29 | 30 | ignore_img_plots = ['leftGabor', 'rightGabor', 'noiseOn', 'valveOn', 'toneOn'] 31 | img_row_order = [0, 0, 0, 0, 0, 0, 1, 1, 1] 32 | img_column_order = [0, 3, 6, 9, 12, 15, 0, 3, 6] 33 | img_idx = [0, 5, 4, 6, 7, 8, 1, 2, 3] 34 | img_files = glob.glob(str(image_folder.joinpath(image_info + 'img_*.png'))) 35 | img_files = [img for img in img_files if not any([ig in img for ig in ignore_img_plots])] 36 | img_files_sort = [img_files[idx] for idx in img_idx] 37 | 38 | for iF, file in enumerate(img_files_sort): 39 | ax = fig.add_subplot(gs[img_row_order[iF], img_column_order[iF]:img_column_order[iF] + 3]) 40 | load_image(Path(file), ax) 41 | 42 | ignore_probe_plots = ['RF Map'] 43 | probe_row_order = [1, 1, 1, 1, 1, 1, 2, 2, 2] 44 | probe_column_order = [9, 10, 11, 12, 13, 14, 12, 13, 14] 45 | probe_idx = [0, 3, 1, 2, 4, 5, 6] 46 | probe_files = glob.glob(str(image_folder.joinpath(image_info + 'probe_*.png'))) 47 | probe_files = [probe for probe in probe_files if not any([pr in probe for pr in 48 | ignore_probe_plots])] 49 | probe_files_sort = [probe_files[idx] for idx in probe_idx] 50 | line_files = glob.glob(str(image_folder.joinpath(image_info + 'line_*.png'))) 51 | 52 | for iF, file in enumerate(probe_files_sort + line_files): 53 | ax = fig.add_subplot(gs[probe_row_order[iF], probe_column_order[iF]]) 54 | load_image(Path(file), ax) 55 | 56 | slice_files = glob.glob(str(image_folder.joinpath(image_info + 'slice_*.png'))) 57 | slice_row_order = [2, 2, 2, 2] 58 | slice_idx = [0, 1, 2, 3] 59 | slice_column_order = [0, 3, 6, 9] 60 | slice_files_sort = [slice_files[idx] for idx in slice_idx] 61 | 62 | for iF, file in enumerate(slice_files_sort): 63 | ax = fig.add_subplot(gs[slice_row_order[iF], 64 | slice_column_order[iF]:slice_column_order[iF] + 3]) 65 | load_image(Path(file), ax) 66 | 67 | slice_files = glob.glob(str(image_folder.joinpath(image_info + 'slice_zoom*.png'))) 68 | slice_row_order = [2, 2, 2, 2] 69 | slice_idx = [0, 1, 2, 3] 70 | slice_column_order = [2, 5, 8, 11] 71 | slice_files_sort = [slice_files[idx] for idx in slice_idx] 72 | 73 | for iF, file in enumerate(slice_files_sort): 74 | ax = fig.add_subplot(gs[slice_row_order[iF], slice_column_order[iF]]) 75 | load_image(Path(file), ax) 76 | 77 | hist_files = glob.glob(str(image_folder.joinpath(image_info + 'hist*.png'))) 78 | for iF, file in enumerate(hist_files): 79 | ax = fig.add_subplot(gs[1:3, 15:18]) 80 | load_image(Path(file), ax) 81 | 82 | ax.text(0.5, 0, image_info[:-1], va="center", ha="center", transform=ax.transAxes) 83 | plt.savefig(save_folder.joinpath(image_info + "overview.png"), 84 | bbox_inches='tight', pad_inches=0) 85 | # plt.close() 86 | # plt.show() 87 | -------------------------------------------------------------------------------- /atlaselectrophysiology/ephys_atlas_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlaselectrophysiology/ephys_atlas_image.png -------------------------------------------------------------------------------- /atlaselectrophysiology/example_code.py: -------------------------------------------------------------------------------- 1 | from one.api import ONE 2 | from atlaselectrophysiology.alignment_with_easyqc import viewer 3 | 4 | one = ONE() 5 | probe_id = 'ce397420-3cd2-4a55-8fd1-5e28321981f4' 6 | 7 | 8 | av = viewer(probe_id, one=one) 9 | -------------------------------------------------------------------------------- /atlaselectrophysiology/extract_files.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import logging 3 | 4 | from tqdm import tqdm 5 | import spikeglx 6 | import numpy as np 7 | from ibldsp import fourier, utils 8 | from scipy import signal 9 | import one.alf.io as alfio 10 | import ibllib.ephys.ephysqc as ephysqc 11 | 12 | from phylib.io import alf 13 | 14 | _logger = logging.getLogger('ibllib') 15 | 16 | 17 | RMS_WIN_LENGTH_SECS = 3 18 | WELCH_WIN_LENGTH_SAMPLES = 1024 19 | 20 | 21 | def rmsmap(fbin, spectra=True): 22 | """ 23 | Computes RMS map in time domain and spectra for each channel of Neuropixel probe 24 | 25 | :param fbin: binary file in spike glx format (will look for attached metatdata) 26 | :type fbin: str or pathlib.Path 27 | :param spectra: whether to compute the power spectrum (only need for lfp data) 28 | :type: bool 29 | :return: a dictionary with amplitudes in channeltime space, channelfrequency space, time 30 | and frequency scales 31 | """ 32 | if not isinstance(fbin, spikeglx.Reader): 33 | sglx = spikeglx.Reader(fbin) 34 | sglx.open() 35 | rms_win_length_samples = 2 ** np.ceil(np.log2(sglx.fs * RMS_WIN_LENGTH_SECS)) 36 | # the window generator will generates window indices 37 | wingen = utils.WindowGenerator(ns=sglx.ns, nswin=rms_win_length_samples, overlap=0) 38 | # pre-allocate output dictionary of numpy arrays 39 | win = {'TRMS': np.zeros((wingen.nwin, sglx.nc)), 40 | 'nsamples': np.zeros((wingen.nwin,)), 41 | 'fscale': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sglx.fs, one_sided=True), 42 | 'tscale': wingen.tscale(fs=sglx.fs)} 43 | win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc)) 44 | # loop through the whole session 45 | with tqdm(total=wingen.nwin) as pbar: 46 | for first, last in wingen.firstlast: 47 | D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose() 48 | # remove low frequency noise below 1 Hz 49 | D = fourier.hp(D, 1 / sglx.fs, [0, 1]) 50 | iw = wingen.iw 51 | win['TRMS'][iw, :] = utils.rms(D) 52 | win['nsamples'][iw] = D.shape[1] 53 | if spectra: 54 | # the last window may be smaller than what is needed for welch 55 | if last - first < WELCH_WIN_LENGTH_SAMPLES: 56 | continue 57 | # compute a smoothed spectrum using welch method 58 | _, w = signal.welch( 59 | D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES, 60 | detrend='constant', return_onesided=True, scaling='density', axis=-1 61 | ) 62 | win['spectral_density'] += w.T 63 | # print at least every 20 windows 64 | if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0: 65 | pbar.update(iw) 66 | 67 | sglx.close() 68 | return win 69 | 70 | 71 | def extract_rmsmap(fbin, out_folder=None, spectra=True): 72 | """ 73 | Wrapper for rmsmap that outputs _ibl_ephysRmsMap and _ibl_ephysSpectra ALF files 74 | 75 | :param fbin: binary file in spike glx format (will look for attached metatdata) 76 | :param out_folder: folder in which to store output ALF files. Default uses the folder in which 77 | the `fbin` file lives. 78 | :param spectra: whether to compute the power spectrum (only need for lfp data) 79 | :type: bool 80 | :return: None 81 | """ 82 | _logger.info(f"Computing QC for {fbin}") 83 | sglx = spikeglx.Reader(fbin) 84 | # check if output ALF files exist already: 85 | if out_folder is None: 86 | out_folder = Path(fbin).parent 87 | else: 88 | out_folder = Path(out_folder) 89 | alf_object_time = f'ephysTimeRms{sglx.type.upper()}' 90 | alf_object_freq = f'ephysSpectralDensity{sglx.type.upper()}' 91 | 92 | # crunch numbers 93 | rms = rmsmap(fbin, spectra=spectra) 94 | # output ALF files, single precision with the optional label as suffix before extension 95 | if not out_folder.exists(): 96 | out_folder.mkdir() 97 | tdict = {'rms': rms['TRMS'].astype(np.single), 'timestamps': rms['tscale'].astype(np.single)} 98 | alfio.save_object_npy(out_folder, object=alf_object_time, dico=tdict, namespace='iblqc') 99 | if spectra: 100 | fdict = {'power': rms['spectral_density'].astype(np.single), 101 | 'freqs': rms['fscale'].astype(np.single)} 102 | alfio.save_object_npy( 103 | out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc') 104 | 105 | 106 | def _sample2v(ap_file): 107 | """ 108 | Convert raw ephys data to Volts 109 | """ 110 | md = spikeglx.read_meta_data(ap_file.with_suffix('.meta')) 111 | s2v = spikeglx._conversion_sample2v_from_meta(md) 112 | return s2v['ap'][0] 113 | 114 | 115 | def ks2_to_alf(ks_path, bin_path, out_path, bin_file=None, ampfactor=1, label=None, force=True): 116 | """ 117 | Convert Kilosort 2 output to ALF dataset for single probe data 118 | :param ks_path: 119 | :param bin_path: path of raw data 120 | :param out_path: 121 | :return: 122 | """ 123 | m = ephysqc.phy_model_from_ks2_path(ks2_path=ks_path, bin_path=bin_path, bin_file=bin_file) 124 | ac = alf.EphysAlfCreator(m) 125 | ac.convert(out_path, label=label, force=force, ampfactor=ampfactor) 126 | 127 | # set depths to spike_depths to catch cases where it can't be computed from pc features (e.g in case of KS3) 128 | m.depths = np.load(out_path.joinpath('spikes.depths.npy')) 129 | ephysqc.spike_sorting_metrics_ks2(ks_path, m, save=True, save_path=out_path) 130 | 131 | 132 | def extract_data(ks_path, ephys_path, out_path): 133 | efiles = spikeglx.glob_ephys_files(ephys_path) 134 | 135 | for efile in efiles: 136 | if efile.get('ap') and efile.ap.exists(): 137 | ks2_to_alf(ks_path, ephys_path, out_path, bin_file=efile.ap, 138 | ampfactor=_sample2v(efile.ap), label=None, force=True) 139 | 140 | extract_rmsmap(efile.ap, out_folder=out_path, spectra=False) 141 | if efile.get('lf') and efile.lf.exists(): 142 | extract_rmsmap(efile.lf, out_folder=out_path) 143 | 144 | 145 | # if __name__ == '__main__': 146 | # 147 | # ephys_path = Path('C:/Users/Mayo/Downloads/raw_ephys_data') 148 | # ks_path = Path('C:/Users/Mayo/Downloads/KS2') 149 | # out_path = Path('C:/Users/Mayo/Downloads/alf') 150 | # extract_data(ks_path, ephys_path, out_path) 151 | -------------------------------------------------------------------------------- /atlaselectrophysiology/get_scale_factor.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Extract channel locations from reference points of previous alignments saved in json field of 3 | trajectory object 4 | Create plot showing histology regions which channels pass through as well as coronal slice with 5 | channel locations shown 6 | ''' 7 | 8 | # import modules 9 | from oneibl.one import ONE 10 | from ibllib.pipes.ephys_alignment import EphysAlignment 11 | import numpy as np 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | import iblatlas.atlas as atlas 16 | 17 | 18 | # Instantiate brain atlas and one 19 | brain_atlas = atlas.AllenAtlas(25) 20 | one = ONE() 21 | 22 | # Find eid of interest 23 | subject = 'KS022' 24 | 25 | # Find the ephys aligned trajectory for eid probe combination 26 | trajectory = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track', 27 | subject=subject) 28 | 29 | # Load in channels.localCoordinates dataset type 30 | chn_coords = one.load(trajectory[0]['session']['id'], 31 | dataset_types=['channels.localCoordinates'])[0] 32 | depths = chn_coords[:, 1] 33 | 34 | subject_summary = pd.DataFrame(columns={'Session', 'User', 'Scale Factor', 'Avg Scale Factor'}) 35 | sf = [] 36 | sess = [] 37 | user = [] 38 | for traj in trajectory: 39 | alignments = traj['json'] 40 | 41 | insertion = one.alyx.rest('insertions', 'list', session=traj['session']['id'], 42 | name=traj['probe_name']) 43 | xyz_picks = np.array(insertion[0]['json']['xyz_picks']) / 1e6 44 | 45 | session_info = traj['session']['start_time'][:10] + '_' + traj['probe_name'] 46 | 47 | for iK, key in enumerate(alignments): 48 | # Location of reference lines used for alignmnet 49 | feature = np.array(alignments[key][0]) 50 | track = np.array(alignments[key][1]) 51 | user = key[:19] 52 | # Instantiate EphysAlignment object 53 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature) 54 | region_scaled, _ = ephysalign.scale_histology_regions(feature, track) 55 | _, scale_factor = ephysalign.get_scale_factor(region_scaled) 56 | 57 | if np.all(np.round(np.diff(scale_factor), 3) == 0): 58 | # Case where there is no scaling but just an offset 59 | scale_factor = np.array([1]) 60 | avg_sf = 1 61 | else: 62 | if feature.size > 4: 63 | # Case where 3 or more reference lines have been placed so take gradient of 64 | # linear fit to represent average scaling factor 65 | avg_sf = scale_factor[0] 66 | else: 67 | # Case where 2 reference lines have been used. Only have local scaling between 68 | # two reference lines, everywhere else scaling is 1. Use the local scaling as the 69 | # average scaling factor 70 | avg_sf = np.mean(scale_factor[1:-1]) 71 | 72 | for iS, sf in enumerate(scale_factor): 73 | if iS == 0: 74 | subject_summary = subject_summary.append( 75 | {'Session': session_info, 'User': user, 'Scale Factor': sf, 76 | 'Avg Scale Factor': avg_sf}, ignore_index=True) 77 | else: 78 | subject_summary = subject_summary.append( 79 | {'Session': session_info, 'User': user, 'Scale Factor': sf, 80 | 'Avg Scale Factor': np.NaN}, ignore_index=True) 81 | 82 | fig, ax = plt.subplots(figsize=(10, 8)) 83 | sns.swarmplot(x='Session', y='Scale Factor', hue='User', data=subject_summary, ax=ax) 84 | sns.swarmplot(x='Session', y='Avg Scale Factor', hue='User', size=8, linewidth=1, 85 | data=subject_summary, ax=ax) 86 | # ensures value in legend isn't repeated 87 | handles, labels = ax.get_legend_handles_labels() 88 | by_label = dict(zip(labels, handles)) 89 | ax.legend(by_label.values(), by_label.keys()) 90 | 91 | plt.show() 92 | -------------------------------------------------------------------------------- /atlaselectrophysiology/load_histology.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import requests 4 | import re 5 | from one import params 6 | from one.webclient import http_download_file 7 | import SimpleITK as sitk 8 | 9 | 10 | def download_histology_data(subject, lab): 11 | 12 | if lab == 'hoferlab': 13 | lab_temp = 'mrsicflogellab' 14 | else: 15 | lab_temp = lab 16 | 17 | par = params.get() 18 | 19 | try: 20 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject, 21 | 'downsampledStacks_25', 'sample2ARA') 22 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts)) 23 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD)) 24 | r.raise_for_status() 25 | except Exception as err: 26 | print(err) 27 | try: 28 | subject_rem = subject.replace("_", "") 29 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject_rem, 30 | 'downsampledStacks_25', 'sample2ARA') 31 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts)) 32 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD)) 33 | r.raise_for_status() 34 | except Exception as err: 35 | if lab_temp == 'churchlandlab_ucla': 36 | try: 37 | lab_temp = 'churchlandlab' 38 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject, 39 | 'downsampledStacks_25', 'sample2ARA') 40 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts)) 41 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD)) 42 | r.raise_for_status() 43 | except Exception as err: 44 | print(err) 45 | path_to_nrrd = None 46 | return path_to_nrrd 47 | else: 48 | print(err) 49 | path_to_nrrd = None 50 | return path_to_nrrd 51 | 52 | 53 | tif_files = [] 54 | for line in r.text.splitlines(): 55 | result = re.findall('href="(.*).tif"', line) 56 | if result: 57 | tif_files.append(result[0] + '.tif') 58 | 59 | CACHE_DIR = params.get_cache_dir().joinpath(lab, 'Subjects', subject, 'histology') 60 | CACHE_DIR.mkdir(exist_ok=True, parents=True) 61 | path_to_files = [] 62 | for file in tif_files: 63 | path_to_image = Path(CACHE_DIR, file) 64 | if not path_to_image.exists(): 65 | url = (baseurl + '/' + file) 66 | http_download_file(url, target_dir=CACHE_DIR, 67 | username=par.HTTP_DATA_SERVER_LOGIN, 68 | password=par.HTTP_DATA_SERVER_PWD) 69 | 70 | path_to_nrrd = tif2nrrd(path_to_image) 71 | path_to_files.append(path_to_nrrd) 72 | 73 | if len(path_to_files) > 3: 74 | path_to_files = path_to_files[1:3] 75 | 76 | return path_to_files 77 | 78 | 79 | def tif2nrrd(path_to_image): 80 | path_to_nrrd = Path(path_to_image).with_suffix('.nrrd') 81 | if not path_to_nrrd.exists(): 82 | reader = sitk.ImageFileReader() 83 | reader.SetImageIO("TIFFImageIO") 84 | reader.SetFileName(str(path_to_image)) 85 | img = reader.Execute() 86 | 87 | new_img = sitk.PermuteAxes(img, [2, 1, 0]) 88 | new_img = sitk.Flip(new_img, [True, False, False]) 89 | new_img.SetSpacing([1, 1, 1]) 90 | writer = sitk.ImageFileWriter() 91 | writer.SetImageIO("NrrdImageIO") 92 | writer.SetFileName(str(path_to_nrrd)) 93 | writer.Execute(new_img) 94 | 95 | return path_to_nrrd 96 | -------------------------------------------------------------------------------- /atlaselectrophysiology/qc_table.py: -------------------------------------------------------------------------------- 1 | import datajoint as dj 2 | 3 | schema = dj.schema('group_shared_ephys') 4 | 5 | 6 | @schema 7 | class EphysQC(dj.Imported): 8 | definition = """ 9 | probe_insertion_uuid: uuid # probe insertion uuid 10 | -> reference.LabMember # user name 11 | --- 12 | alignment_qc=null: enum('high', 'medium', 'low') # confidence in alignment 13 | ephys_qc=null: enum('pass', 'critical', 'warning') # quality of ephys 14 | ephys_qc_description=null: varchar(255) #Description for ephys_qc 15 | """ 16 | -------------------------------------------------------------------------------- /atlaselectrophysiology/rendering.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import cv2 5 | import vtk 6 | from matplotlib import pyplot as plt # noqa 7 | import mayavi.mlab as mlab 8 | 9 | 10 | def add_mesh(fig, obj_file, color=(1., 1., 1.), opacity=0.4): 11 | """ 12 | Adds a mesh object from an *.obj file to the mayavi figure 13 | :param fig: mayavi figure 14 | :param obj_file: full path to a local *.obj file 15 | :param color: rgb tuple of floats between 0 and 1 16 | :param opacity: float between 0 and 1 17 | :return: vtk actor 18 | """ 19 | reader = vtk.vtkOBJReader() 20 | reader.SetFileName(str(obj_file)) 21 | reader.Update() 22 | mapper = vtk.vtkPolyDataMapper() 23 | mapper.SetInputConnection(reader.GetOutputPort()) 24 | actor = vtk.vtkActor() 25 | actor.SetMapper(mapper) 26 | actor.GetProperty().SetOpacity(opacity) 27 | actor.GetProperty().SetColor(color) 28 | fig.scene.add_actor(actor) 29 | fig.scene.render() 30 | return mapper, actor 31 | 32 | 33 | def figure(grid=False, **kwargs): 34 | """ 35 | Creates a mayavi figure with the brain atlas mesh 36 | :return: mayavi figure 37 | """ 38 | fig = mlab.figure(bgcolor=(1, 1, 1), **kwargs) 39 | # engine = mlab.get_engine() # Returns the running mayavi engine. 40 | obj_file = Path(__file__).parent.joinpath("root.obj") 41 | mapper, actor = add_mesh(fig, obj_file) 42 | 43 | if grid: 44 | # https://vtk.org/Wiki/VTK/Examples/Python/Visualization/CubeAxesActor 45 | cubeAxesActor = vtk.vtkCubeAxesActor() 46 | cubeAxesActor.SetMapper(mapper) 47 | cubeAxesActor.SetBounds(mapper.GetBounds()) 48 | cubeAxesActor.SetCamera(fig.scene.renderer._vtk_obj.GetActiveCamera()) 49 | cubeAxesActor.SetXTitle("AP (um)") 50 | cubeAxesActor.SetYTitle("DV (um)") 51 | cubeAxesActor.SetZTitle("ML (um)") 52 | cubeAxesActor.GetTitleTextProperty(0).SetColor(1.0, 0.0, 0.0) 53 | cubeAxesActor.GetLabelTextProperty(0).SetColor(1.0, 0.0, 0.0) 54 | cubeAxesActor.GetTitleTextProperty(1).SetColor(0.0, 1.0, 0.0) 55 | cubeAxesActor.GetLabelTextProperty(1).SetColor(0.0, 1.0, 0.0) 56 | cubeAxesActor.GetTitleTextProperty(2).SetColor(0.0, 0.0, 1.0) 57 | cubeAxesActor.GetLabelTextProperty(2).SetColor(0.0, 0.0, 1.0) 58 | cubeAxesActor.DrawXGridlinesOn() 59 | cubeAxesActor.DrawYGridlinesOn() 60 | cubeAxesActor.DrawZGridlinesOn() 61 | fig.scene.add_actor(cubeAxesActor) 62 | 63 | mlab.view(azimuth=180, elevation=0) 64 | mlab.view(azimuth=210, elevation=210, reset_roll=False) 65 | 66 | return fig 67 | 68 | 69 | def rotating_video(output_file, mfig, fps=12, secs=6): 70 | # ffmpeg -i input.webm -pix_fmt rgb24 output.gif 71 | # ffmpeg -i certification.webm -vf scale=640:-1 -r 10 -f image2pipe -vcodec ppm - | convert -delay 5 -loop 0 - certification.gif # noqa 72 | 73 | file_video = Path(output_file) 74 | if file_video.suffix == '.avi': 75 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 76 | elif file_video.suffix == '.webm': 77 | fourcc = cv2.VideoWriter_fourcc(*'VP80') 78 | else: 79 | NotImplementedError(f"Extension {file_video.suffix} not supported") 80 | 81 | mlab.view(azimuth=180, elevation=0) 82 | mfig.scene.render() 83 | mfig.scene._lift() 84 | frame = mlab.screenshot(figure=mfig, mode='rgb', antialiased=True) 85 | w, h, _ = frame.shape 86 | video = cv2.VideoWriter(str(file_video), fourcc, float(fps), (h, w)) 87 | 88 | # import time 89 | for e in np.linspace(-180, 180, secs * fps): 90 | # frame = np.random.randint(0, 256, (w, h, 3), dtype=np.uint8) 91 | mlab.view(azimuth=0, elevation=e, reset_roll=False) 92 | mfig.scene.render() 93 | frame = mlab.screenshot(figure=mfig, mode='rgb', antialiased=True) 94 | print(e, (h, w), frame.shape) 95 | video.write(np.flip(frame, axis=2)) # bgr instead of rgb... 96 | # time.sleep(0.05) 97 | 98 | video.release() 99 | cv2.destroyAllWindows() 100 | -------------------------------------------------------------------------------- /atlaselectrophysiology/sandbox.py: -------------------------------------------------------------------------------- 1 | from brainbox.io.one import load_spike_sorting, load_channel_locations 2 | from oneibl.one import ONE 3 | 4 | one = ONE(base_url="https://dev.alyx.internationalbrainlab.org") 5 | eids = one.search(subject='ZM_2407', task_protocol='ephys') 6 | 7 | channels = load_channel_locations(eids[0], one=one) 8 | spikes, clusters = load_spike_sorting(eids[0], one=one) 9 | -------------------------------------------------------------------------------- /atlasview/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/__init__.py -------------------------------------------------------------------------------- /atlasview/channels_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/channels_test.npy -------------------------------------------------------------------------------- /atlasview/launcher.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from atlasview import atlasview # mouais il va falloir changer ça 4 | av = atlasview.view() # need to have an output argument here or the garbage collector will clean 5 | # it up and boom 6 | 7 | """ Roadmap 8 | - swap volumes combox (label RGB option / density) 9 | - overlay brain regions with transparencye 10 | - overlay volumes (for example coverage with transparency) 11 | - overlay plots: probes channels 12 | - tilted slices 13 | - coordinate swaps: add Allen / Needles / Voxel options 14 | - should we add horizontal slices ? 15 | """ 16 | 17 | # add brain regions feature: 18 | reg_values = np.load(Path(atlasview.__file__).parent.joinpath('region_values.npy')) 19 | av.add_regions_feature(reg_values, 'Blues', opacity=0.7) 20 | 21 | # add scatter feature: 22 | chans = np.load( 23 | Path(atlasview.__file__).parent.joinpath('channels_test.npy'), allow_pickle=True) 24 | av.add_scatter_feature(chans) 25 | -------------------------------------------------------------------------------- /atlasview/region_values.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/region_values.npy -------------------------------------------------------------------------------- /atlasview/sliceview.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | SliceWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 1101 10 | 737 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 100 27 | 0 28 | 29 | 30 | 31 | 32 | 120 33 | 16777215 34 | 35 | 36 | 37 | QFrame::StyledPanel 38 | 39 | 40 | QFrame::Raised 41 | 42 | 43 | 44 | 45 | 46 | 47 | 70 48 | 0 49 | 50 | 51 | 52 | 53 | 80 54 | 20 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 15 67 | 0 68 | 69 | 70 | 71 | 72 | 10 73 | 20 74 | 75 | 76 | 77 | y 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 15 86 | 0 87 | 88 | 89 | 90 | 91 | 10 92 | 20 93 | 94 | 95 | 96 | ix 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 15 105 | 0 106 | 107 | 108 | 109 | 110 | 10 111 | 20 112 | 113 | 114 | 115 | iy 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 70 124 | 0 125 | 126 | 127 | 128 | 129 | 80 130 | 20 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 15 143 | 0 144 | 145 | 146 | 147 | 148 | 10 149 | 20 150 | 151 | 152 | 153 | x 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 70 162 | 0 163 | 164 | 165 | 166 | 167 | 80 168 | 20 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 70 181 | 0 182 | 183 | 184 | 185 | 186 | 80 187 | 20 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | Qt::Vertical 199 | 200 | 201 | 202 | 20 203 | 40 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | v 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | Qt::Vertical 236 | 237 | 238 | QSizePolicy::Fixed 239 | 240 | 241 | 242 | 20 243 | 5 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | Qt::Horizontal 255 | 256 | 257 | QSizePolicy::Maximum 258 | 259 | 260 | 261 | 5 262 | 17 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | Qt::Horizontal 271 | 272 | 273 | QSizePolicy::Fixed 274 | 275 | 276 | 277 | 5 278 | 20 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 0 297 | 0 298 | 3 299 | 22 300 | 301 | 302 | 303 | 304 | 305 | 306 | PlotWidget 307 | QGraphicsView 308 |
pyqtgraph
309 |
310 |
311 | 312 | 313 |
314 | -------------------------------------------------------------------------------- /atlasview/topview.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | TopView 4 | 5 | 6 | 7 | 0 8 | 0 9 | 435 10 | 461 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 100 26 | 0 27 | 28 | 29 | 30 | QFrame::StyledPanel 31 | 32 | 33 | QFrame::Raised 34 | 35 | 36 | 37 | 38 | 20 39 | 20 40 | 67 41 | 17 42 | 43 | 44 | 45 | Annotation 46 | 47 | 48 | 49 | 50 | 51 | 40 52 | 50 53 | 16 54 | 160 55 | 56 | 57 | 58 | 20 59 | 60 | 61 | Qt::Vertical 62 | 63 | 64 | 65 | 66 | 67 | 10 68 | 220 69 | 67 70 | 17 71 | 72 | 73 | 74 | Image 75 | 76 | 77 | 78 | 79 | 80 | 0 81 | 350 82 | 93 83 | 27 84 | 85 | 86 | 87 | 88 | 89 | 90 | 10 91 | 330 92 | 67 93 | 17 94 | 95 | 96 | 97 | Mapping 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 0 108 | 0 109 | 435 110 | 24 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | PlotWidget 119 | QGraphicsView 120 |
pyqtgraph
121 |
122 |
123 | 124 | 125 |
126 | -------------------------------------------------------------------------------- /data_exploration_gui/README.md: -------------------------------------------------------------------------------- 1 | # Data Exploration GUI 2 | 3 | GUI to allow user to explore ephys data from IBL task 4 | 5 | ## Setup 6 | 7 | Install ibl environment following [these instructions](https://github.com/int-brain-lab/iblenv#iblenv-installation-guide) 8 | 9 | Go to the ```data_exploration_gui``` folder 10 | 11 | ``` 12 | cd iblapps/data_exploration_gui 13 | ``` 14 | 15 | ### Using GUI 16 | To launch the gui you should run the following from the command line. You can specify either a probe insertion id 17 | e.g 18 | ``` 19 | python data_explore_gui.py -pid 9657af01-50bd-4120-8303-416ad9e24a51 20 | ``` 21 | 22 | or an eid and probe name, e.g 23 | ``` 24 | python data_explore_gui.py -eid 7f6b86f9-879a-4ea2-8531-294a221af5d0 -name probe00 25 | ``` 26 | -------------------------------------------------------------------------------- /data_exploration_gui/cluster.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtWidgets 2 | from data_exploration_gui import utils 3 | 4 | 5 | class ClusterGroup: 6 | 7 | def __init__(self): 8 | 9 | self.cluster_buttons = QtWidgets.QButtonGroup() 10 | self.cluster_group = QtWidgets.QGroupBox('Sort Clusters By:') 11 | self.cluster_layout = QtWidgets.QHBoxLayout() 12 | for i, val in enumerate(utils.SORT_CLUSTER_OPTIONS): 13 | button = QtWidgets.QRadioButton(val) 14 | if i == 0: 15 | button.setChecked(True) 16 | else: 17 | button.setChecked(False) 18 | self.cluster_buttons.addButton(button) 19 | self.cluster_layout.addWidget(button) 20 | 21 | self.cluster_group.setLayout(self.cluster_layout) 22 | 23 | self.cluster_colours = QtWidgets.QGroupBox() 24 | h_layout_col = QtWidgets.QHBoxLayout() 25 | h_layout_lab = QtWidgets.QHBoxLayout() 26 | for val in utils.UNIT_OPTIONS: 27 | img = QtWidgets.QLabel() 28 | label = QtWidgets.QLabel(val) 29 | pix = QtGui.QPixmap(40, 5) 30 | pix.fill(utils.colours[val]) 31 | img.setPixmap(pix) 32 | h_layout_lab.addWidget(img) 33 | h_layout_col.addWidget(label) 34 | 35 | v_layout = QtWidgets.QVBoxLayout() 36 | v_layout.addLayout(h_layout_lab) 37 | v_layout.addLayout(h_layout_col) 38 | 39 | self.cluster_colours.setLayout(v_layout) 40 | 41 | self.cluster_list = QtWidgets.QListWidget() 42 | 43 | self.cluster_next_button = QtWidgets.QPushButton('Next') 44 | self.cluster_next_button.setFixedSize(90, 30) 45 | self.cluster_previous_button = QtWidgets.QPushButton('Previous') 46 | self.cluster_previous_button.setFixedSize(90, 30) 47 | 48 | self.cluster_list_group = QtWidgets.QGroupBox() 49 | 50 | group_layout = QtWidgets.QGridLayout() 51 | group_layout.addWidget(self.cluster_group, 0, 0, 1, 3) 52 | group_layout.addWidget(self.cluster_colours, 1, 0, 1, 3) 53 | group_layout.addWidget(self.cluster_list, 2, 0, 3, 3) 54 | group_layout.addWidget(self.cluster_previous_button, 5, 0) 55 | group_layout.addWidget(self.cluster_next_button, 5, 2) 56 | self.cluster_list_group.setLayout(group_layout) 57 | 58 | self.reset() 59 | 60 | def reset(self): 61 | self.cluster_list.clear() 62 | self.clust_colour_ks = None 63 | self.clust_colour_ibl = None 64 | self.clust_ids = None 65 | self.clust = None 66 | self.clust_prev = None 67 | 68 | def populate(self, clust_ids, clust_colour_ks, clust_colour_ibl): 69 | self.clust_ids = clust_ids 70 | self.clust_colour_ks = clust_colour_ks 71 | self.clust_colour_ibl = clust_colour_ibl 72 | 73 | for idx, val in enumerate(clust_ids): 74 | item = QtWidgets.QListWidgetItem('Cluster Number ' + str(val)) 75 | icon = utils.get_icon(self.clust_colour_ibl[idx], self.clust_colour_ks[idx], 20) 76 | item.setIcon(QtGui.QIcon(icon)) 77 | self.cluster_list.addItem(item) 78 | self.cluster_list.setCurrentRow(0) 79 | 80 | def update_list_icon(self): 81 | item = self.cluster_list.item(self.clust_prev) 82 | icon = utils.get_icon(self.clust_colour_ibl[self.clust_prev], 83 | self.clust_colour_ks[self.clust_prev], 12) 84 | item.setIcon(QtGui.QIcon(icon)) 85 | item.setText(' Cluster Number ' + str(self.clust_ids[self.clust_prev])) 86 | 87 | def update_current_row(self, clust): 88 | self.clust = clust 89 | self.cluster_list.setCurrentRow(self.clust) 90 | 91 | def on_next_cluster_clicked(self): 92 | self.clust += 1 93 | self.cluster_list.setCurrentRow(self.clust) 94 | return self.clust 95 | 96 | def on_previous_cluster_clicked(self): 97 | self.clust -= 1 98 | self.cluster_list.setCurrentRow(self.clust) 99 | return self.clust 100 | 101 | def on_cluster_list_clicked(self): 102 | self.clust = self.cluster_list.currentRow() 103 | return self.clust 104 | 105 | def update_cluster_index(self): 106 | self.clust_prev = self.clust 107 | return self.clust_prev 108 | 109 | def initialise_cluster_index(self): 110 | self.clust = 0 111 | self.clust_prev = 0 112 | 113 | return self.clust, self.clust_prev 114 | 115 | def get_selected_sort(self): 116 | return self.cluster_buttons.checkedButton().text() 117 | -------------------------------------------------------------------------------- /data_exploration_gui/cluster_class.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | 3 | 4 | class ClusterGroup: 5 | 6 | def __init__(self): 7 | 8 | self.cluster_buttons = QtWidgets.QButtonGroup() 9 | self.cluster_group = QtWidgets.QGroupBox('Sort Clusters By:') 10 | self.cluster_layout = QtWidgets.QHBoxLayout() 11 | cluster_options = ['cluster no.', 'no. of spikes', 'good units'] 12 | for i, val in enumerate(cluster_options): 13 | button = QtWidgets.QRadioButton(val) 14 | if i == 0: 15 | button.setChecked(True) 16 | else: 17 | button.setChecked(False) 18 | self.cluster_buttons.addButton(button, id=i) 19 | self.cluster_layout.addWidget(button) 20 | 21 | self.cluster_group.setLayout(self.cluster_layout) 22 | 23 | self.cluster_list = QtWidgets.QListWidget() 24 | self.cluster_list.SingleSelection 25 | 26 | self.cluster_next_button = QtWidgets.QPushButton('Next') 27 | self.cluster_next_button.setFixedSize(90, 30) 28 | self.cluster_previous_button = QtWidgets.QPushButton('Previous') 29 | self.cluster_previous_button.setFixedSize(90, 30) 30 | 31 | self.cluster_list_group = QtWidgets.QGroupBox() 32 | self.cluster_list_group.setFixedSize(400, 200) 33 | self.group_widget() 34 | 35 | self.clust_colour = [] 36 | self.clust_ids = [] 37 | self.clust = [] 38 | self.clust_prev = [] 39 | 40 | 41 | def group_widget(self): 42 | group_layout = QtWidgets.QGridLayout() 43 | group_layout.addWidget(self.cluster_group, 0, 0, 1, 3) 44 | group_layout.addWidget(self.cluster_list, 1, 0, 3, 3) 45 | group_layout.addWidget(self.cluster_previous_button, 4, 0) 46 | group_layout.addWidget(self.cluster_next_button, 4, 2) 47 | self.cluster_list_group.setLayout(group_layout) 48 | 49 | def reset(self): 50 | self.cluster_list.clear() 51 | #self.cluster_option1.setChecked(True) 52 | self.clust_colour = [] 53 | self.clust_ids = [] 54 | self.clust = [] 55 | self.clust_prev = [] 56 | 57 | def populate(self, clust_ids, clust_colour): 58 | self.clust_ids = clust_ids 59 | self.clust_colour = clust_colour 60 | icon = QtGui.QPixmap(20, 20) 61 | for idx, val in enumerate(clust_ids): 62 | item = QtWidgets.QListWidgetItem('Cluster Number ' + str(val)) 63 | icon.fill(self.clust_colour[idx]) 64 | item.setIcon(QtGui.QIcon(icon)) 65 | self.cluster_list.addItem(item) 66 | self.cluster_list.setCurrentRow(0) 67 | 68 | def update_list_icon(self): 69 | item = self.cluster_list.item(self.clust_prev) 70 | icon = QtGui.QPixmap(10, 10) 71 | icon.fill(self.clust_colour[self.clust_prev]) 72 | item.setIcon(QtGui.QIcon(icon)) 73 | item.setText(' Cluster Number ' + str(self.clust_ids[self.clust_prev])) 74 | 75 | def update_current_row(self, clust): 76 | self.clust = clust 77 | self.cluster_list.setCurrentRow(self.clust) 78 | 79 | def on_next_cluster_clicked(self): 80 | self.clust += 1 81 | self.cluster_list.setCurrentRow(self.clust) 82 | return self.clust 83 | 84 | def on_previous_cluster_clicked(self): 85 | self.clust -= 1 86 | self.cluster_list.setCurrentRow(self.clust) 87 | return self.clust 88 | 89 | def on_cluster_list_clicked(self): 90 | self.clust = self.cluster_list.currentRow() 91 | return self.clust 92 | 93 | def update_cluster_index(self): 94 | self.clust_prev = self.clust 95 | return self.clust_prev 96 | 97 | def initialise_cluster_index(self): 98 | self.clust = 0 99 | self.clust_prev = 0 100 | 101 | return self.clust, self.clust_prev 102 | -------------------------------------------------------------------------------- /data_exploration_gui/data_class.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtWidgets 2 | import numpy as np 3 | import os 4 | import one.alf.io as alfio 5 | from brainbox.processing import get_units_bunch, compute_cluster_average 6 | from brainbox.population.decode import xcorr 7 | from brainbox.singlecell import calculate_peths 8 | from brainbox.io.spikeglx import extract_waveforms 9 | from pathlib import Path 10 | 11 | 12 | class DataGroup: 13 | def __init__(self): 14 | 15 | self.waveform_button = QtWidgets.QPushButton('Generate Waveform') 16 | self.waveform_list = QtWidgets.QListWidget() 17 | self.waveform_list.SingleSelection 18 | self.waveform_text = QtWidgets.QLabel('No. of spikes =') 19 | 20 | waveform_layout = QtWidgets.QGridLayout() 21 | waveform_layout.addWidget(self.waveform_button, 0, 0) 22 | waveform_layout.addWidget(self.waveform_text, 1, 0) 23 | waveform_layout.addWidget(self.waveform_list, 0, 1, 3, 1) 24 | 25 | self.waveform_group = QtWidgets.QGroupBox() 26 | self.waveform_group.setLayout(waveform_layout) 27 | 28 | 29 | #For peths and rasters 30 | self.t_before = 0.4 31 | self.t_after = 1 32 | 33 | #For autocorrolelogram 34 | self.autocorr_window = 0.1 35 | self.autocorr_bin = 0.001 36 | 37 | 38 | #For waveform (N.B in ms) 39 | self.waveform_window = 2 40 | self.CAR = False 41 | 42 | def load(self, folder_path): 43 | self.folder_path = folder_path 44 | self.find_files() 45 | self.load_data() 46 | self.compute_timescales() 47 | 48 | return self.ephys_file_path, self.gui_path 49 | 50 | def find_files(self): 51 | self.probe_path = self.folder_path 52 | self.alf_path = self.folder_path.parent 53 | dir_path = self.folder_path.parent.parent 54 | self.gui_path = os.path.join(self.probe_path, 'gui') 55 | probe = os.path.split(self.probe_path)[1] 56 | ephys_path = os.path.join(dir_path, 'raw_ephys_data', probe) 57 | 58 | self.ephys_file_path = [] 59 | 60 | try: 61 | for i in os.listdir(ephys_path): 62 | if 'ap' in i and 'bin' in i: 63 | self.ephys_file_path = os.path.join(ephys_path, i) 64 | except: 65 | self.ephys_file_path = [] 66 | 67 | def load_data(self): 68 | self.spikes = alfio.load_object(self.probe_path, 'spikes') 69 | self.trials = alfio.load_object(self.alf_path, 'trials') 70 | self.clusters = alfio.load_object(self.probe_path, 'clusters') 71 | self.prepare_data(self.spikes, self.clusters, self.trials) 72 | # self.ids = np.unique(self.spikes.clusters) 73 | # self.metrics = np.array(self.clusters.metrics.ks2_label[self.ids]) 74 | # self.colours = np.array(self.clusters.metrics.ks2_label[self.ids]) 75 | # self.colours[np.where(self.colours == 'mua')[0]] = QtGui.QColor('#fdc086') 76 | # self.colours[np.where(self.colours == 'good')[0]] = QtGui.QColor('#7fc97f') 77 | # 78 | # file_count = 0 79 | # if os.path.isdir(self.gui_path): 80 | # for i in os.listdir(self.gui_path): 81 | # if 'depth' in i: 82 | # self.depths = np.load(Path(self.gui_path + '/cluster_depths.npy')) 83 | # file_count += 1 84 | # elif 'amp' in i: 85 | # self.amps = np.load(Path(self.gui_path + '/cluster_amps.npy')) 86 | # file_count += 1 87 | # elif 'nspikes' in i: 88 | # self.nspikes = np.load(Path(self.gui_path + '/cluster_nspikes.npy')) 89 | # file_count += 1 90 | # if file_count != 3: 91 | # self.compute_depth_and_amplitudes() 92 | # else: 93 | # os.mkdir(self.gui_path) 94 | # self.compute_depth_and_amplitudes() 95 | # 96 | # self.sort_by_id = np.arange(len(self.ids)) 97 | # self.sort_by_nspikes = np.argsort(self.nspikes) 98 | # self.sort_by_nspikes = self.sort_by_nspikes[::-1] 99 | # self.sort_by_good = np.append(np.where(self.metrics == 'good')[0], np.where(self.metrics == 'mua')[0]) 100 | # self.n_trials = len(self.trials['contrastLeft']) 101 | 102 | def prepare_data(self, spikes, clusters, trials): 103 | self.spikes = spikes 104 | self.clusters = clusters 105 | self.trials = trials 106 | self.ids = np.unique(spikes.clusters) 107 | self.metrics = np.array(clusters.metrics.ks2_label[self.ids]) 108 | self.colours = np.array(clusters.metrics.ks2_label[self.ids]) 109 | self.colours[np.where(self.colours == 'mua')[0]] = QtGui.QColor('#fdc086') 110 | self.colours[np.where(self.colours == 'good')[0]] = QtGui.QColor('#7fc97f') 111 | _, self.depths, self.nspikes = compute_cluster_average(spikes.clusters, spikes.depths) 112 | _, self.amps, _ = compute_cluster_average(spikes.clusters, spikes.amps) 113 | self.amps = self.amps * 1e6 114 | self.sort_by_id = np.arange(len(self.ids)) 115 | self.sort_by_nspikes = np.argsort(self.nspikes) 116 | self.sort_by_nspikes = self.sort_by_nspikes[::-1] 117 | self.sort_by_good = np.append(np.where(self.metrics == 'good')[0], 118 | np.where(self.metrics == 'mua')[0]) 119 | self.n_trials = len(trials['contrastLeft']) 120 | 121 | 122 | def compute_depth_and_amplitudes(self): 123 | units_b = get_units_bunch(self.spikes) 124 | self.depths = [] 125 | self.amps = [] 126 | self.nspikes = [] 127 | for clu in self.ids: 128 | self.depths = np.append(self.depths, np.nanmean(units_b.depths[str(clu)])) 129 | self.amps = np.append(self.amps, np.nanmean(units_b.amps[str(clu)]) * 1e6) 130 | self.nspikes = np.append(self.nspikes, len(units_b.amps[str(clu)])) 131 | 132 | np.save((self.gui_path + '/cluster_depths'), self.depths) 133 | np.save((self.gui_path + '/cluster_amps'), self.amps) 134 | np.save((self.gui_path + '/cluster_nspikes'), self.nspikes) 135 | 136 | def sort_data(self, order): 137 | self.clust_ids = self.ids[order] 138 | self.clust_amps = self.amps[order] 139 | self.clust_depths = self.depths[order] 140 | self.clust_colours = self.colours[order] 141 | 142 | return self.clust_ids, self.clust_amps, self.clust_depths, self.clust_colours 143 | 144 | def reset(self): 145 | self.waveform_list.clear() 146 | 147 | def populate(self, clust): 148 | self.reset() 149 | self.clus_idx = np.where(self.spikes.clusters == self.clust_ids[clust])[0] 150 | 151 | if len(self.clus_idx) <= 500: 152 | self.spk_intervals = [0, len(self.clus_idx)] 153 | else: 154 | self.spk_intervals = np.arange(0, len(self.clus_idx), 500) 155 | self.spk_intervals = np.append(self.spk_intervals, len(self.clus_idx)) 156 | 157 | for idx in range(0, len(self.spk_intervals) - 1): 158 | item = QtWidgets.QListWidgetItem(str(self.spk_intervals[idx]) + ':' + str(self.spk_intervals[idx+1])) 159 | self.waveform_list.addItem(item) 160 | 161 | self.waveform_list.setCurrentRow(0) 162 | self.n_waveform = 0 163 | 164 | def compute_peth(self, trial_type, clust, trials_id): 165 | peths, bin = calculate_peths(self.spikes.times, self.spikes.clusters, 166 | [self.clust_ids[clust]], self.trials[trial_type][trials_id], self.t_before, self.t_after) 167 | 168 | peth_mean = peths.means[0, :] 169 | peth_std = peths.stds[0, :] / np.sqrt(len(trials_id)) 170 | t_peth = peths.tscale 171 | 172 | return t_peth, peth_mean, peth_std 173 | 174 | def compute_rasters(self, trial_type, clust, trials_id): 175 | self.x = np.empty(0) 176 | self.y = np.empty(0) 177 | spk_times = self.spikes.times[self.spikes.clusters == self.clust_ids[clust]] 178 | for idx, val in enumerate(self.trials[trial_type][trials_id]): 179 | spks_to_include = np.bitwise_and(spk_times >= val - self.t_before, spk_times <= val + self.t_after) 180 | trial_spk_times = spk_times[spks_to_include] 181 | trial_spk_times_aligned = trial_spk_times - val 182 | trial_no = (np.ones(len(trial_spk_times_aligned))) * idx * 10 183 | self.x = np.append(self.x, trial_spk_times_aligned) 184 | self.y = np.append(self.y, trial_no) 185 | 186 | return self.x, self.y, self.n_trials 187 | 188 | def compute_autocorr(self, clust): 189 | self.clus_idx = np.where(self.spikes.clusters == self.clust_ids[clust])[0] 190 | 191 | x_corr = xcorr(self.spikes.times[self.clus_idx], self.spikes.clusters[self.clus_idx], 192 | self.autocorr_bin, self.autocorr_window) 193 | 194 | corr = x_corr[0, 0, :] 195 | 196 | return self.t_autocorr, corr 197 | 198 | def compute_template(self, clust): 199 | template = (self.clusters.waveforms[self.clust_ids[clust], :, 0]) * 1e6 200 | return self.t_template, template 201 | 202 | def compute_waveform(self, clust): 203 | if len(self.ephys_file_path) != 0: 204 | spk_times = self.spikes.times[self.clus_idx][self.spk_intervals[self.n_waveform]:self.spk_intervals[self.n_waveform + 1]] 205 | max_ch = self.clusters['channels'][self.clust_ids[clust]] 206 | wf = extract_waveforms(self.ephys_file_path, spk_times, max_ch, t = self.waveform_window, car = self.CAR) 207 | wf_mean = np.mean(wf[:,:,0], axis = 0) 208 | wf_std = np.std(wf[:,:,0], axis = 0) 209 | 210 | return self.t_waveform, wf_mean, wf_std 211 | 212 | def compute_timescales(self): 213 | self.t_autocorr = np.arange((self.autocorr_window/2) - self.autocorr_window, (self.autocorr_window/2)+ self.autocorr_bin, self.autocorr_bin) 214 | sr = 30000 215 | n_template = len(self.clusters.waveforms[self.ids[0], :, 0]) 216 | self.t_template = 1e3 * (np.arange(n_template)) / sr 217 | 218 | n_waveform = int(sr / 1000 * (self.waveform_window)) 219 | self.t_waveform = 1e3 * (np.arange(n_waveform)) / sr 220 | #self.t_waveform = 1e3 * (np.arange(n_waveform/2 - n_waveform, n_waveform/2, 1))/sr 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /data_exploration_gui/filter.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtWidgets, QtGui 2 | import numpy as np 3 | from data_exploration_gui import utils 4 | 5 | 6 | class FilterGroup: 7 | def __init__(self): 8 | 9 | # button to reset filters to default 10 | self.reset_button = QtWidgets.QPushButton('Reset Filters') 11 | 12 | # checkboxes for contrasts 13 | self.contrasts = utils.CONTRAST_OPTIONS 14 | self.contrast_buttons = QtWidgets.QButtonGroup() 15 | self.contrast_buttons.setExclusive(False) 16 | self.contrast_group = QtWidgets.QGroupBox('Stimulus Contrast') 17 | self.contrast_layout = QtWidgets.QVBoxLayout() 18 | for val in self.contrasts: 19 | button = QtWidgets.QCheckBox(str(val * 100)) 20 | button.setCheckState(QtCore.Qt.Checked) 21 | self.contrast_buttons.addButton(button) 22 | self.contrast_layout.addWidget(button) 23 | 24 | self.contrast_group.setLayout(self.contrast_layout) 25 | 26 | # button for whether to overlay trial plots 27 | self.hold_button = QtWidgets.QCheckBox('Hold') 28 | self.hold_button.setCheckState(QtCore.Qt.Checked) 29 | 30 | # checkboxes for trial options 31 | self.trial_buttons = QtWidgets.QButtonGroup() 32 | # Just for now 33 | self.trial_group = QtWidgets.QGroupBox('Trial Options') 34 | self.trial_layout = QtWidgets.QVBoxLayout() 35 | for val in utils.TRIAL_OPTIONS: 36 | button = QtWidgets.QCheckBox(val) 37 | if val == utils.TRIAL_OPTIONS[0]: 38 | button.setCheckState(QtCore.Qt.Checked) 39 | else: 40 | button.setCheckState(QtCore.Qt.Unchecked) 41 | self.trial_buttons.addButton(button) 42 | self.trial_layout.addWidget(button) 43 | 44 | self.trial_colours = QtWidgets.QVBoxLayout() 45 | for val in utils.TRIAL_OPTIONS: 46 | img = QtWidgets.QLabel() 47 | pix = QtGui.QPixmap(40, 5) 48 | pix.fill(utils.colours[val]) 49 | img.setPixmap(pix) 50 | self.trial_colours.addWidget(img) 51 | 52 | hlayout = QtWidgets.QHBoxLayout() 53 | hlayout.addLayout(self.trial_layout) 54 | hlayout.addLayout(self.trial_colours) 55 | 56 | self.trial_group.setLayout(hlayout) 57 | 58 | # radio buttons for order options 59 | self.order_buttons = QtWidgets.QButtonGroup() 60 | self.order_group = QtWidgets.QGroupBox('Order Trials By:') 61 | self.order_layout = QtWidgets.QVBoxLayout() 62 | for val in utils.ORDER_OPTIONS: 63 | button = QtWidgets.QRadioButton(val) 64 | if val == utils.ORDER_OPTIONS[0]: 65 | button.setChecked(True) 66 | else: 67 | button.setChecked(False) 68 | self.order_buttons.addButton(button) 69 | self.order_layout.addWidget(button) 70 | 71 | self.order_group.setLayout(self.order_layout) 72 | 73 | # radio buttons for sort options 74 | self.sort_buttons = QtWidgets.QButtonGroup() 75 | self.sort_group = QtWidgets.QGroupBox('Sort Trials By:') 76 | self.sort_layout = QtWidgets.QVBoxLayout() 77 | for val in utils.SORT_OPTIONS: 78 | button = QtWidgets.QRadioButton(val) 79 | if val == utils.SORT_OPTIONS[0]: 80 | button.setChecked(True) 81 | else: 82 | button.setChecked(False) 83 | self.sort_buttons.addButton(button) 84 | self.sort_layout.addWidget(button) 85 | self.sort_group.setLayout(self.sort_layout) 86 | 87 | self.event_group = QtWidgets.QGroupBox('Align to:') 88 | self.event_list = QtGui.QStandardItemModel() 89 | self.event_combobox = QtWidgets.QComboBox() 90 | self.event_combobox.setModel(self.event_list) 91 | self.event_layout = QtWidgets.QVBoxLayout() 92 | self.event_layout.addWidget(self.event_combobox) 93 | self.event_group.setLayout(self.event_layout) 94 | 95 | self.behav_group = QtWidgets.QGroupBox('Show behaviour:') 96 | self.behav_list = QtGui.QStandardItemModel() 97 | self.behav_combobox = QtWidgets.QComboBox() 98 | self.behav_combobox.setModel(self.behav_list) 99 | self.behav_layout = QtWidgets.QVBoxLayout() 100 | self.behav_layout.addWidget(self.behav_combobox) 101 | self.behav_group.setLayout(self.behav_layout) 102 | 103 | # Now group everything into one big box 104 | self.filter_options_group = QtWidgets.QGroupBox() 105 | group_layout = QtWidgets.QVBoxLayout() 106 | group_layout.addWidget(self.reset_button) 107 | group_layout.addWidget(self.contrast_group) 108 | group_layout.addWidget(self.hold_button) 109 | group_layout.addWidget(self.trial_group) 110 | group_layout.addWidget(self.order_group) 111 | group_layout.addWidget(self.sort_group) 112 | group_layout.addWidget(self.event_group) 113 | group_layout.addWidget(self.behav_group) 114 | self.filter_options_group.setLayout(group_layout) 115 | 116 | def populate(self, event_options, behav_options): 117 | self.populate_lists(event_options, self.event_list, self.event_combobox) 118 | self.populate_lists(behav_options, self.behav_list, self.behav_combobox) 119 | 120 | def populate_lists(self, data, list_name, combobox): 121 | 122 | list_name.clear() 123 | for dat in data: 124 | item = QtGui.QStandardItem(dat) 125 | item.setEditable(False) 126 | list_name.appendRow(item) 127 | 128 | # This makes sure the drop down menu is wide enough to show full length of string 129 | min_width = combobox.fontMetrics().width(max(data, key=len)) 130 | min_width += combobox.view().autoScrollMargin() 131 | min_width += combobox.style().pixelMetric(QtWidgets.QStyle.PM_ScrollBarExtent) 132 | combobox.view().setMinimumWidth(min_width) 133 | 134 | # Set the default to be the first option 135 | combobox.setCurrentIndex(0) 136 | 137 | def set_selected_event(self, event): 138 | for it in range(self.event_combobox.count()): 139 | if self.event_combobox.itemText(it) == event: 140 | self.event_combobox.setCurrentIndex(it) 141 | 142 | def get_selected_event(self): 143 | return self.event_combobox.currentText() 144 | 145 | def get_selected_behaviour(self): 146 | return self.behav_combobox.currentText() 147 | 148 | def get_selected_contrasts(self): 149 | contrasts = [] 150 | for button in self.contrast_buttons.buttons(): 151 | if button.isChecked(): 152 | contrasts.append(np.float(button.text()) / 100) 153 | 154 | return contrasts 155 | 156 | def get_selected_trials(self): 157 | trials = [] 158 | for button in self.trial_buttons.buttons(): 159 | if button.isChecked(): 160 | trials.append(button.text()) 161 | return trials 162 | 163 | def get_selected_order(self): 164 | return self.order_buttons.checkedButton().text() 165 | 166 | def get_selected_sort(self): 167 | return self.sort_buttons.checkedButton().text() 168 | 169 | def get_hold_status(self): 170 | return self.hold_button.isChecked() 171 | 172 | def get_selected_filters(self): 173 | contrasts = self.get_selected_contrasts() 174 | order = self.get_selected_order() 175 | sort = self.get_selected_sort() 176 | hold = self.get_hold_status() 177 | 178 | return contrasts, order, sort, hold 179 | 180 | def set_sorted_button(self, sort): 181 | for button in self.sort_buttons.buttons(): 182 | if button.text() == sort: 183 | button.setChecked(True) 184 | else: 185 | button.setChecked(False) 186 | 187 | def reset_filters(self, contrasts=True): 188 | 189 | if contrasts: 190 | for button in self.contrast_buttons.buttons(): 191 | button.setChecked(True) 192 | 193 | for button in self.trial_buttons.buttons(): 194 | if button.text() == utils.TRIAL_OPTIONS[0]: 195 | #if not button.isChecked(): 196 | button.setChecked(True) 197 | else: 198 | button.setChecked(False) 199 | 200 | for button in self.order_buttons.buttons(): 201 | if button.text() == utils.ORDER_OPTIONS[0]: 202 | button.setChecked(True) 203 | else: 204 | button.setChecked(False) 205 | 206 | for button in self.sort_buttons.buttons(): 207 | if button.text() == utils.SORT_OPTIONS[0]: 208 | button.setChecked(True) 209 | else: 210 | button.setChecked(False) -------------------------------------------------------------------------------- /data_exploration_gui/load_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from one.api import ONE 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser(description='Load in subject info') 6 | 7 | parser.add_argument('-s', '--subject', default=False, required=False, 8 | help='Subject Name') 9 | parser.add_argument('-d', '--date', default=False, required=False, 10 | help='Date of session YYYY-MM-DD') 11 | parser.add_argument('-n', '--session_no', default=1, required=False, 12 | help='Session Number', type=int) 13 | parser.add_argument('-e', '--eid', default=False, required=False, 14 | help='Session eid') 15 | parser.add_argument('-p', '--probe_label', default=False, required=True, 16 | help='Probe Label') 17 | 18 | args = parser.parse_args() 19 | 20 | one = ONE() 21 | if not args.eid: 22 | if not np.all(np.array([args.subject, args.date, args.session_no], 23 | dtype=object)): 24 | print('Must give Subject, Date and Session number') 25 | else: 26 | eid = one.search(subject=str(args.subject), date=str(args.date), number=args.session_no)[0] 27 | print(eid) 28 | else: 29 | eid = str(args.eid) 30 | 31 | _ = one.load_object(eid, obj='trials', collection='alf', 32 | download_only=True) 33 | _ = one.load_object(eid, obj='spikes', attribute=['times', 'clusters', 'amps', 'depths'], 34 | collection=f'alf/{str(args.probe_label)}', download_only=True) 35 | _ = one.load_object(eid, obj='clusters', collection=f'alf/{str(args.probe_label)}', 36 | download_only=True) 37 | -------------------------------------------------------------------------------- /data_exploration_gui/misc.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtWidgets 2 | from data_exploration_gui import utils 3 | 4 | class MiscGroup: 5 | def __init__(self): 6 | 7 | self.qc_group = QtWidgets.QGroupBox() 8 | self.dlc_warning_label = QtWidgets.QLabel() 9 | self.sess_qc_group = QtWidgets.QHBoxLayout() 10 | self.clust_qc_group = QtWidgets.QHBoxLayout() 11 | 12 | self.sess_qc_labels = [] 13 | self.clust_qc_labels = [] 14 | 15 | for qc in utils.SESS_QC: 16 | qc_label = QtWidgets.QLabel(f'{qc}:') 17 | self.sess_qc_group.addWidget(qc_label) 18 | self.sess_qc_labels.append(qc_label) 19 | 20 | for qc in utils.CLUSTER_QC: 21 | qc_label = QtWidgets.QLabel(f'{qc}: ') 22 | self.clust_qc_group.addWidget(qc_label) 23 | self.clust_qc_labels.append(qc_label) 24 | 25 | vlayout = QtWidgets.QVBoxLayout() 26 | vlayout.addWidget(self.dlc_warning_label) 27 | vlayout.addLayout(self.sess_qc_group) 28 | vlayout.addLayout(self.clust_qc_group) 29 | self.qc_group.setLayout(vlayout) 30 | 31 | def set_sess_qc_text(self, data): 32 | for label in self.sess_qc_labels: 33 | title = label.text().split(':')[0] 34 | text = title + ': ' + str(data[title]) 35 | label.setText(text) 36 | 37 | def set_clust_qc_text(self, data): 38 | for label in self.clust_qc_labels: 39 | title = label.text().split(':')[0] 40 | text = title + ': ' + str(data[title]) 41 | label.setText(text) 42 | 43 | def set_dlc_label(self, aligned): 44 | if not aligned: 45 | self.dlc_warning_label.setText(utils.dlc_warning) -------------------------------------------------------------------------------- /data_exploration_gui/misc_class.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | 3 | import pandas as pd 4 | 5 | 6 | class MiscGroup: 7 | def __init__(self): 8 | 9 | folder_prompt = QtWidgets.QLabel('Select Folder') 10 | self.folder_line = QtWidgets.QLineEdit() 11 | self.folder_button = QtWidgets.QToolButton() 12 | self.folder_button.setText('...') 13 | 14 | folder_layout = QtWidgets.QHBoxLayout() 15 | folder_layout.addWidget(folder_prompt) 16 | folder_layout.addWidget(self.folder_line) 17 | folder_layout.addWidget(self.folder_button) 18 | 19 | self.folder_group = QtWidgets.QGroupBox() 20 | self.folder_group.setLayout(folder_layout) 21 | 22 | clust_list1_title = QtWidgets.QLabel('Go Cue Aligned Clusters') 23 | self.clust_list1 = QtWidgets.QListWidget() 24 | self.remove_clust_button1 = QtWidgets.QPushButton('Remove') 25 | 26 | clust_list2_title = QtWidgets.QLabel('Feedback Aligned Clusters') 27 | self.clust_list2 = QtWidgets.QListWidget() 28 | self.remove_clust_button2 = QtWidgets.QPushButton('Remove') 29 | self.save_clust_button = QtWidgets.QPushButton('Save Clusters') 30 | 31 | clust_interest_layout = QtWidgets.QGridLayout() 32 | clust_interest_layout.addWidget(clust_list1_title, 0, 0) 33 | clust_interest_layout.addWidget(self.remove_clust_button1, 0, 1) 34 | clust_interest_layout.addWidget(self.clust_list1, 1, 0, 1, 2) 35 | clust_interest_layout.addWidget(clust_list2_title, 2, 0) 36 | clust_interest_layout.addWidget(self.remove_clust_button2, 2, 1) 37 | clust_interest_layout.addWidget(self.clust_list2, 3, 0, 1, 2) 38 | clust_interest_layout.addWidget(self.save_clust_button, 4, 0) 39 | 40 | self.clust_interest = QtWidgets.QGroupBox() 41 | self.clust_interest.setLayout(clust_interest_layout) 42 | self.clust_interest.setFixedSize(250, 350) 43 | 44 | self.terminal = QtWidgets.QTextBrowser() 45 | self.terminal_button = QtWidgets.QPushButton('Clear') 46 | self.terminal_button.setFixedSize(70, 30) 47 | self.terminal_button.setParent(self.terminal) 48 | self.terminal_button.move(580, 0) 49 | 50 | 51 | def on_save_button_clicked(self, gui_path): 52 | save_file = QtWidgets.QFileDialog() 53 | save_file.setAcceptMode(QtWidgets.QFileDialog.AcceptSave) 54 | file = save_file.getSaveFileName(None, 'Save File', gui_path, 'CSV (*.csv)') 55 | print(file[0]) 56 | if file[0]: 57 | save_clusters = pd.DataFrame(columns=['go_Cue_aligned', 'feedback_aligned']) 58 | 59 | go_cue_clusters = [] 60 | for idx in range(self.clust_list1.count()): 61 | cluster_no = self.clust_list1.item(idx).text() 62 | go_cue_clusters.append(int(cluster_no[8:])) 63 | 64 | go_cue = pd.Series(go_cue_clusters) 65 | save_clusters['go_Cue_aligned'] = go_cue 66 | 67 | feedback_clusters = [] 68 | for idx in range(self.clust_list2.count()): 69 | cluster_no = self.clust_list2.item(idx).text() 70 | feedback_clusters.append(int(cluster_no[8:])) 71 | 72 | feedback = pd.Series(feedback_clusters) 73 | save_clusters['feedback_aligned'] = feedback 74 | 75 | save_clusters.to_csv(file[0], index=None, header=True) 76 | 77 | return file[0] 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /data_exploration_gui/scatter.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore 2 | import pyqtgraph as pg 3 | import numpy as np 4 | 5 | 6 | class ScatterGroup: 7 | 8 | def __init__(self): 9 | 10 | self.fig_scatter = pg.PlotWidget(background='w') 11 | self.fig_scatter.setMouseTracking(True) 12 | self.fig_scatter.scene().sigMouseMoved.connect(self.on_mouse_hover) 13 | self.fig_scatter.setLabel('bottom', 'Cluster Amplitude (uV)') 14 | self.fig_scatter.setLabel('left', 'Cluster Distance From Tip (uV)') 15 | 16 | self.scatter_plot = pg.ScatterPlotItem() 17 | self.scatter_text = pg.TextItem(color='k') 18 | self.scatter_text.hide() 19 | 20 | self.fig_scatter.addItem(self.scatter_plot) 21 | self.fig_scatter.addItem(self.scatter_text) 22 | 23 | self.reset() 24 | 25 | def reset(self): 26 | self.scatter_plot.setData() 27 | self.clust_amps = None 28 | self.clust_depths = None 29 | self.clust_color_ks = None 30 | self.clust_color_ibl = None 31 | self.clust_ids = None 32 | self.point_pos = None 33 | self.point_pos_prev = None 34 | 35 | def populate(self, clust_amps, clust_depths, clust_ids, clust_color_ks, clust_color_ibl): 36 | self.clust_amps = clust_amps 37 | self.clust_depths = clust_depths 38 | self.clust_color_ks = clust_color_ks 39 | self.clust_color_ibl = clust_color_ibl 40 | self.clust_ids = clust_ids 41 | self.scatter_plot.setData(x=self.clust_amps, y=self.clust_depths, size=10) 42 | self.scatter_plot.setBrush(clust_color_ibl) 43 | self.scatter_plot.setPen(clust_color_ks) 44 | self.x_min = np.min(self.clust_amps) - 2 45 | self.x_max = np.max(self.clust_amps) + 2 46 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max) 47 | self.fig_scatter.setYRange(min=0, max=4000) 48 | 49 | def on_scatter_plot_clicked(self, point): 50 | self.point_pos = point[0].pos() 51 | clust = np.argwhere(self.scatter_plot.data['x'] == self.point_pos.x())[0][0] 52 | return clust 53 | 54 | def on_mouse_hover(self, pos): 55 | mouse_pos = self.scatter_plot.mapFromScene(pos) 56 | point = self.scatter_plot.pointsAt(mouse_pos) 57 | 58 | if len(point) != 0: 59 | point_pos = point[0].pos() 60 | clust = np.argwhere(self.scatter_plot.data['x'] == point_pos.x())[0][0] 61 | self.scatter_text.setText('Cluster no. ' + str(self.clust_ids[clust])) 62 | self.scatter_text.setPos(mouse_pos.x(), (mouse_pos.y())) 63 | self.scatter_text.show() 64 | else: 65 | self.scatter_text.hide() 66 | 67 | def update_scatter_icon(self, clust_prev): 68 | point = self.scatter_plot.pointsAt(self.point_pos) 69 | if len(point) > 1: 70 | p_diff = [] 71 | for p in point: 72 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x())) 73 | idx = np.argmin(p_diff) 74 | point[idx].setBrush('k') 75 | point[idx].setPen('k') 76 | else: 77 | point[0].setBrush('k') 78 | point[0].setPen('k') 79 | 80 | point_prev = self.scatter_plot.pointsAt(self.point_pos_prev) 81 | if len(point_prev) > 1: 82 | p_diff = [] 83 | for p in point_prev: 84 | p_diff.append(np.abs(p.pos().x() - self.point_pos_prev.x())) 85 | idx = np.argmin(p_diff) 86 | point_prev[idx].setPen(self.clust_color_ks[clust_prev]) 87 | point_prev[idx].setBrush(self.clust_color_ibl[clust_prev]) 88 | point_prev[idx].setSize(8) 89 | else: 90 | point_prev[0].setPen(self.clust_color_ks[clust_prev]) 91 | point_prev[0].setBrush(self.clust_color_ibl[clust_prev]) 92 | point_prev[0].setSize(8) 93 | 94 | def set_current_point(self, clust): 95 | self.point_pos.setX(self.clust_amps[clust]) 96 | self.point_pos.setY(self.clust_depths[clust]) 97 | 98 | def update_prev_point(self): 99 | self.point_pos_prev.setX(self.point_pos.x()) 100 | self.point_pos_prev.setY(self.point_pos.y()) 101 | 102 | def initialise_scatter_index(self): 103 | self.point_pos = QtCore.QPointF() 104 | self.point_pos_prev = QtCore.QPointF() 105 | self.point_pos.setX(self.clust_amps[0]) 106 | self.point_pos.setY(self.clust_depths[0]) 107 | self.point_pos_prev.setX(self.clust_amps[0]) 108 | self.point_pos_prev.setY(self.clust_depths[0]) 109 | 110 | point = self.scatter_plot.pointsAt(self.point_pos) 111 | 112 | if len(point) > 1: 113 | p_diff = [] 114 | for p in point: 115 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x())) 116 | idx = np.argmin(p_diff) 117 | point[idx].setBrush('k') 118 | point[idx].setPen('k') 119 | else: 120 | point[0].setBrush('k') 121 | point[0].setPen('k') 122 | -------------------------------------------------------------------------------- /data_exploration_gui/scatter_class.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtCore, QtGui, QtWidgets 2 | import pyqtgraph as pg 3 | import numpy as np 4 | 5 | class ScatterGroup: 6 | 7 | def __init__(self): 8 | 9 | self.fig_scatter = pg.PlotWidget(background='w') 10 | self.fig_scatter.setMouseTracking(True) 11 | self.fig_scatter.setLabel('bottom', 'Cluster Amplitude (uV)') 12 | self.fig_scatter.setLabel('left', 'Cluster Distance From Tip (uV)') 13 | 14 | self.scatter_plot = pg.ScatterPlotItem() 15 | self.scatter_text = pg.TextItem(color='k') 16 | self.scatter_text.hide() 17 | 18 | self.scatter_reset_button = QtWidgets.QPushButton('Reset Axis') 19 | self.scatter_reset_button.setFixedSize(70, 30) 20 | self.scatter_reset_button.setParent(self.fig_scatter) 21 | 22 | self.fig_scatter.addItem(self.scatter_plot) 23 | self.fig_scatter.addItem(self.scatter_text) 24 | self.fig_scatter.setFixedSize(400, 580) 25 | 26 | self.clust_amps = [] 27 | self.clust_depths = [] 28 | self.clust_color = [] 29 | self.clust_ids = [] 30 | self.point_pos = [] 31 | self.point_pos_prev = [] 32 | 33 | def reset(self): 34 | self.scatter_plot.setData() 35 | self.clust_amps = [] 36 | self.clust_depths = [] 37 | self.clust_color = [] 38 | self.clust_ids = [] 39 | self.point_pos = [] 40 | self.point_pos_prev = [] 41 | 42 | def populate(self, clust_amps, clust_depths, clust_ids, clust_color): 43 | self.clust_amps = clust_amps 44 | self.clust_depths = clust_depths 45 | self.clust_color = clust_color 46 | self.clust_ids = clust_ids 47 | self.scatter_plot.setData(x=self.clust_amps, y=self.clust_depths, size=8) 48 | self.scatter_plot.setBrush(clust_color) 49 | self.scatter_plot.setPen(QtGui.QColor(30, 50, 2)) 50 | self.x_min = self.clust_amps.min() - 2 51 | self.x_max = self.clust_amps.max() - 2 52 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max) 53 | self.fig_scatter.setYRange(min=0, max=4000) 54 | 55 | def on_scatter_plot_clicked(self, point): 56 | self.point_pos = point[0].pos() 57 | clust = np.argwhere(self.scatter_plot.data['x'] == self.point_pos.x())[0][0] 58 | return clust 59 | 60 | def on_mouse_hover(self, pos): 61 | mouse_pos = self.scatter_plot.mapFromScene(pos) 62 | point = self.scatter_plot.pointsAt(mouse_pos) 63 | 64 | if len(point) != 0: 65 | point_pos = point[0].pos() 66 | clust = np.argwhere(self.scatter_plot.data['x'] == point_pos.x())[0][0] 67 | self.scatter_text.setText('Cluster no. ' + str(self.clust_ids[clust])) 68 | self.scatter_text.setPos(mouse_pos.x(), (mouse_pos.y())) 69 | self.scatter_text.show() 70 | else: 71 | self.scatter_text.hide() 72 | 73 | def on_scatter_plot_reset(self): 74 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max) 75 | self.fig_scatter.setYRange(min=0, max=4000) 76 | 77 | def update_scatter_icon(self, clust_prev): 78 | point = self.scatter_plot.pointsAt(self.point_pos) 79 | if len(point) > 1: 80 | p_diff = [] 81 | for p in point: 82 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x())) 83 | idx = np.argmin(p_diff) 84 | point[idx].setBrush('b') 85 | point[idx].setPen('b') 86 | else: 87 | point[0].setBrush('b') 88 | point[0].setPen('b') 89 | 90 | point_prev = self.scatter_plot.pointsAt(self.point_pos_prev) 91 | if len(point_prev) > 1: 92 | p_diff = [] 93 | for p in point_prev: 94 | p_diff.append(np.abs(p.pos().x() - self.point_pos_prev.x())) 95 | idx = np.argmin(p_diff) 96 | point_prev[idx].setPen('w') 97 | point_prev[idx].setBrush(self.clust_color[clust_prev]) 98 | else: 99 | point_prev[0].setPen('w') 100 | point_prev[0].setBrush(self.clust_color[clust_prev]) 101 | 102 | def set_current_point(self, clust): 103 | self.point_pos.setX(self.clust_amps[clust]) 104 | self.point_pos.setY(self.clust_depths[clust]) 105 | 106 | return self.point_pos 107 | 108 | def update_prev_point(self): 109 | self.point_pos_prev.setX(self.point_pos.x()) 110 | self.point_pos_prev.setY(self.point_pos.y()) 111 | 112 | return self.point_pos_prev 113 | 114 | def initialise_scatter_index(self): 115 | self.point_pos = QtCore.QPointF() 116 | self.point_pos_prev = QtCore.QPointF() 117 | self.point_pos.setX(self.clust_amps[0]) 118 | self.point_pos.setY(self.clust_depths[0]) 119 | self.point_pos_prev.setX(self.clust_amps[0]) 120 | self.point_pos_prev.setY(self.clust_depths[0]) 121 | 122 | point = self.scatter_plot.pointsAt(self.point_pos) 123 | point[0].setBrush('b') 124 | point[0].setPen('b') 125 | 126 | return self.point_pos, self.point_pos_prev 127 | -------------------------------------------------------------------------------- /data_exploration_gui/utils.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtCore 2 | from iblutil.util import Bunch 3 | import copy 4 | 5 | colours = {'all': QtGui.QColor('#808080'), 6 | 'correct': QtGui.QColor('#1f77b4'), 7 | 'incorrect': QtGui.QColor('#d62728'), 8 | 'left': QtGui.QColor('#2ca02c'), 9 | 'right': QtGui.QColor('#bcbd22'), 10 | 'left correct': QtGui.QColor('#17becf'), 11 | 'right correct': QtGui.QColor('#9467bd'), 12 | 'left incorrect': QtGui.QColor('#8c564b'), 13 | 'right incorrect': QtGui.QColor('#ff7f0e'), 14 | 'KS good': QtGui.QColor('#1f77b4'), 15 | 'KS mua': QtGui.QColor('#fdc086'), 16 | 'IBL good': QtGui.QColor('#7fc97f'), 17 | 'IBL bad': QtGui.QColor('#d62728'), 18 | 'line': QtGui.QColor('#7732a8'), 19 | 'no metric': QtGui.QColor('#989898')} 20 | 21 | 22 | all = Bunch() 23 | idx = Bunch() 24 | idx['colours'] = [] 25 | idx['text'] = [] 26 | side = Bunch() 27 | side['colours'] = [colours['left'], colours['right']] 28 | side['text'] = ['left', 'right'] 29 | choice = Bunch() 30 | choice['colours'] = [colours['correct'], colours['incorrect']] 31 | choice['text'] = ['correct', 'incorrect'] 32 | choice_and_side = Bunch() 33 | choice_and_side['colours'] = [colours['left correct'], colours['right correct'], 34 | colours['left incorrect'], colours['right incorrect']] 35 | choice_and_side['text'] = ['left correct', 'right correct', 'left incorrect', 'right incorrect'] 36 | all['idx'] = idx 37 | all['side'] = side 38 | all['choice'] = choice 39 | all['choice and side'] = choice_and_side 40 | 41 | 42 | correct = Bunch() 43 | side = Bunch() 44 | side['colours'] = [colours['left correct'], colours['right correct']] 45 | side['text'] = ['left correct', 'right correct'] 46 | choice = Bunch() 47 | choice['colours'] = [colours['correct']] 48 | choice['text'] = ['correct'] 49 | correct['idx'] = idx 50 | correct['side'] = side 51 | correct['choice'] = choice 52 | correct['choice and side'] = side 53 | 54 | incorrect = Bunch() 55 | side = Bunch() 56 | side['colours'] = [colours['left incorrect'], colours['right incorrect']] 57 | side['text'] = ['left incorrect', 'right incorrect'] 58 | choice = Bunch() 59 | choice['colours'] = [colours['incorrect']] 60 | choice['text'] = ['incorrect'] 61 | incorrect['idx'] = idx 62 | incorrect['side'] = side 63 | incorrect['choice'] = choice 64 | incorrect['choice and side'] = side 65 | 66 | 67 | left = Bunch() 68 | side = Bunch() 69 | side['colours'] = [colours['left']] 70 | side['text'] = ['left'] 71 | choice = Bunch() 72 | choice['colours'] = [colours['left correct'], colours['left incorrect']] 73 | choice['text'] = ['left correct', 'left incorrect'] 74 | left['idx'] = idx 75 | left['side'] = side 76 | left['choice'] = choice 77 | left['choice and side'] = choice 78 | 79 | 80 | right = Bunch() 81 | side = Bunch() 82 | side['colours'] = [colours['right']] 83 | side['text'] = ['right'] 84 | choice = Bunch() 85 | choice['colours'] = [colours['right correct'], colours['right incorrect']] 86 | choice['text'] = ['right correct', 'right incorrect'] 87 | right['idx'] = idx 88 | right['side'] = side 89 | right['choice'] = choice 90 | right['choice and side'] = choice 91 | 92 | left_correct = Bunch() 93 | side = Bunch() 94 | side['colours'] = [colours['left correct']] 95 | side['text'] = ['left correct'] 96 | left_correct['idx'] = idx 97 | left_correct['side'] = side 98 | left_correct['choice'] = side 99 | left_correct['choice and side'] = side 100 | 101 | right_correct = Bunch() 102 | side = Bunch() 103 | side['colours'] = [colours['right correct']] 104 | side['text'] = ['right correct'] 105 | right_correct['idx'] = idx 106 | right_correct['side'] = side 107 | right_correct['choice'] = side 108 | right_correct['choice and side'] = side 109 | 110 | left_incorrect = Bunch() 111 | side = Bunch() 112 | side['colours'] = [colours['left incorrect']] 113 | side['text'] = ['left incorrect'] 114 | left_incorrect['idx'] = idx 115 | left_incorrect['side'] = side 116 | left_incorrect['choice'] = side 117 | left_incorrect['choice and side'] = side 118 | 119 | right_incorrect = Bunch() 120 | side = Bunch() 121 | side['colours'] = [colours['right incorrect']] 122 | side['text'] = ['right incorrect'] 123 | right_incorrect['idx'] = idx 124 | right_incorrect['side'] = side 125 | right_incorrect['choice'] = side 126 | right_incorrect['choice and side'] = side 127 | 128 | RASTER_OPTIONS = Bunch() 129 | RASTER_OPTIONS['all'] = all 130 | RASTER_OPTIONS['left'] = left 131 | RASTER_OPTIONS['right'] = right 132 | RASTER_OPTIONS['correct'] = correct 133 | RASTER_OPTIONS['incorrect'] = incorrect 134 | RASTER_OPTIONS['left correct'] = left_correct 135 | RASTER_OPTIONS['right correct'] = right_correct 136 | RASTER_OPTIONS['left incorrect'] = left_incorrect 137 | RASTER_OPTIONS['right incorrect'] = right_incorrect 138 | 139 | all = Bunch() 140 | all['colour'] = copy.copy(colours['all']) 141 | all['fill'] = colours['all'] 142 | all['text'] = 'all' 143 | 144 | left = Bunch() 145 | left['colour'] = copy.copy(colours['left']) 146 | left['fill'] = colours['left'] 147 | left['text'] = 'left' 148 | 149 | right = Bunch() 150 | right['colour'] = copy.copy(colours['right']) 151 | right['fill'] = colours['right'] 152 | right['text'] = 'right' 153 | 154 | correct = Bunch() 155 | correct['colour'] = copy.copy(colours['correct']) 156 | correct['fill'] = colours['correct'] 157 | correct['text'] = 'correct' 158 | 159 | incorrect = Bunch() 160 | incorrect['colour'] = copy.copy(colours['incorrect']) 161 | incorrect['fill'] = colours['incorrect'] 162 | incorrect['text'] = 'incorrect' 163 | 164 | left_correct = Bunch() 165 | left_correct['colour'] = copy.copy(colours['left correct']) 166 | left_correct['fill'] = colours['left correct'] 167 | left_correct['text'] = 'left correct' 168 | 169 | right_correct = Bunch() 170 | right_correct['colour'] = copy.copy(colours['right correct']) 171 | right_correct['fill'] = colours['right correct'] 172 | right_correct['text'] = 'right correct' 173 | 174 | left_incorrect = Bunch() 175 | left_incorrect['colour'] = copy.copy(colours['left incorrect']) 176 | left_incorrect['fill'] = colours['left incorrect'] 177 | left_incorrect['text'] = 'left incorrect' 178 | 179 | right_incorrect = Bunch() 180 | right_incorrect['colour'] = copy.copy(colours['right incorrect']) 181 | right_incorrect['fill'] = colours['right incorrect'] 182 | right_incorrect['text'] = 'right incorrect' 183 | 184 | PSTH_OPTIONS = Bunch() 185 | PSTH_OPTIONS['all'] = all 186 | PSTH_OPTIONS['left'] = left 187 | PSTH_OPTIONS['right'] = right 188 | PSTH_OPTIONS['correct'] = correct 189 | PSTH_OPTIONS['incorrect'] = incorrect 190 | PSTH_OPTIONS['left correct'] = left_correct 191 | PSTH_OPTIONS['right correct'] = right_correct 192 | PSTH_OPTIONS['left incorrect'] = left_incorrect 193 | PSTH_OPTIONS['right incorrect'] = right_incorrect 194 | 195 | MAP_SIDE_OPTIONS = Bunch() 196 | MAP_SIDE_OPTIONS['all'] = 'all' 197 | MAP_SIDE_OPTIONS['left'] = 'left' 198 | MAP_SIDE_OPTIONS['right'] = 'right' 199 | MAP_SIDE_OPTIONS['correct'] = 'all' 200 | MAP_SIDE_OPTIONS['incorrect'] = 'all' 201 | MAP_SIDE_OPTIONS['left correct'] = 'left' 202 | MAP_SIDE_OPTIONS['right correct'] = 'right' 203 | MAP_SIDE_OPTIONS['left incorrect'] = 'left' 204 | MAP_SIDE_OPTIONS['right incorrect'] = 'right' 205 | 206 | MAP_CHOICE_OPTIONS = Bunch() 207 | MAP_CHOICE_OPTIONS['all'] = 'all' 208 | MAP_CHOICE_OPTIONS['left'] = 'all' 209 | MAP_CHOICE_OPTIONS['right'] = 'all' 210 | MAP_CHOICE_OPTIONS['correct'] = 'correct' 211 | MAP_CHOICE_OPTIONS['incorrect'] = 'incorrect' 212 | MAP_CHOICE_OPTIONS['left correct'] = 'correct' 213 | MAP_CHOICE_OPTIONS['right correct'] = 'correct' 214 | MAP_CHOICE_OPTIONS['left incorrect'] = 'incorrect' 215 | MAP_CHOICE_OPTIONS['right incorrect'] = 'incorrect' 216 | 217 | MAP_SORT_OPTIONS = Bunch() 218 | MAP_SORT_OPTIONS['all'] = 'idx' 219 | MAP_SORT_OPTIONS['left'] = 'side' 220 | MAP_SORT_OPTIONS['right'] = 'side' 221 | MAP_SORT_OPTIONS['correct'] = 'choice' 222 | MAP_SORT_OPTIONS['incorrect'] = 'choice' 223 | MAP_SORT_OPTIONS['left correct'] = 'choice and side' 224 | MAP_SORT_OPTIONS['right correct'] = 'choice and side' 225 | MAP_SORT_OPTIONS['left incorrect'] = 'choice and side' 226 | MAP_SORT_OPTIONS['right incorrect'] = 'choice and side' 227 | 228 | 229 | TRIAL_OPTIONS = ['all', 'correct', 'incorrect', 'left', 'right', 'left correct', 'left incorrect', 230 | 'right correct', 'right incorrect'] 231 | CONTRAST_OPTIONS = [1, 0.25, 0.125, 0.0625, 0] 232 | ORDER_OPTIONS = ['trial num', 'reaction time'] 233 | SORT_OPTIONS = ['idx', 'choice', 'side', 'choice and side'] 234 | UNIT_OPTIONS = ['IBL good', 'IBL bad', 'KS good', 'KS mua'] 235 | SORT_CLUSTER_OPTIONS = ['ids', 'n spikes', 'IBL good', 'KS good'] 236 | 237 | SESS_QC = ['task', 'behavior', 'dlcLeft', 'dlcRight', 'videoLeft', 'videoRight'] 238 | CLUSTER_QC = ['noise_cutoff', 'amp_median', 'slidingRP_viol'] 239 | 240 | dlc_warning = 'WARNING: dlc points and timestamps differ in length, dlc points are ' \ 241 | 'not aligned correctly' 242 | 243 | 244 | def get_icon(col_outer, col_inner, pix_size): 245 | 246 | p1 = QtGui.QPixmap(pix_size, pix_size) 247 | p1.fill(col_outer) 248 | p2 = QtGui.QPixmap(pix_size, pix_size) 249 | p2.fill(QtCore.Qt.transparent) 250 | p = QtGui.QPainter(p2) 251 | p.fillRect(int(pix_size / 4), int(pix_size / 4), int(pix_size / 2), 252 | int(pix_size / 2), col_inner) 253 | p.end() 254 | 255 | result = QtGui.QPixmap(p1) 256 | painter = QtGui.QPainter(result) 257 | painter.drawPixmap(QtCore.QPoint(), p1) 258 | painter.setCompositionMode(QtGui.QPainter.CompositionMode_SourceOver) 259 | painter.drawPixmap(result.rect(), p2, p2.rect()) 260 | painter.end() 261 | 262 | return result 263 | -------------------------------------------------------------------------------- /dlc/README.md: -------------------------------------------------------------------------------- 1 | ## How to get example frames with DLC labeles painted on top 2 | Avoiding to download lengthy videos, with `stream_dlc_labeled_frames.py` one can stream 5 example frames randomly picked across a whole video, specified by eid and video type ('body', 'left' or 'right'). 3 | 4 | To use it, start an ipython session, then type `run /path/to/stream_dlc_labeled_frames.py` to load the script. Then type `stream_save_labeled_frames(eid, video_type)` after setting the path where the labelled example frames should be saved as png images locally, `save_images_folder`. 5 | 6 | ## How to make DLC-labeled video 7 | With the script DLC_labeled_video.py one can make DLC-labeled videos. The script downloads a specific IBL video ('body', 'left' or 'right') for some session (eid). It also downloads the wheel info and will print the wheel angle onto each frame. 8 | To use it, start an ipython session, then type `run /path/to/DLC_labeled_video.py` to load the script. Next type 9 | 10 | `Viewer(eid, video_type, trial_range, save_video=True, eye_zoom=False)` 11 | 12 | This will display and save a labeled video for a particular trial range. E.g. `Viewer('3663d82b-f197-4e8b-b299-7b803a155b84', 'left', [5,7])` will display and save a DLC labeled video for the left cam video of session with eid '3663d82b-f197-4e8b-b299-7b803a155b84' and range from trial 5 to trial 7. There's further the option to show a zoom of the pupil only. 13 | 14 | See `Example_DLC_access.ipynb` for more intructions and an example how to load DLC results for other potential applications. 15 | 16 | ## Wheel and DLC live viewer 17 | Plotting the wheel data along side the video and DLC can be done withe `wheel_dlc_viewer`. The viewer loops over the 18 | video frames for a given trial and shows the wheel position plotted against time, indicating the 19 | detected wheel movements. __NB__: Besides the IBL enviroment dependencies, this module also 20 | requires cv2. 21 | 22 | ### Running from command line 23 | You can run the viewer from the terminal. In Windows, running from the Anaconda prompt within 24 | iblenv ensures the paths are correctly set (`conda activate iblenv`). 25 | 26 | The below code will show the wheel data for a particular session, with the tongue and paw DLC 27 | features overlaid: 28 | 29 | ```python 30 | python wheel_dlc_viewer.py --eid 77224050-7848-4680-ad3c-109d3bcd562c --dlc tongue,paws 31 | ``` 32 | Pressing the space bar will toggle playing of the video and the left and right arrows allow you 33 | to step frame by frame. Key bindings also allow you to move between trials and toggle the legend. 34 | 35 | When called without the `--eid` flag will cause a random session to be selected. You can find a 36 | full list of input arguments and key bindings by calling the function with the `--help` flag: 37 | ```python 38 | python wheel_dlc_viewer.py -h 39 | ``` 40 | 41 | ### Running from within Python 42 | You can also import the Viewer from within Python... 43 | 44 | Example 1 - inspect trial 100 of a given session, looking at the right camera 45 | ```python 46 | from wheel_dlc_viewer import Viewer 47 | eid = '77224050-7848-4680-ad3c-109d3bcd562c' 48 | v = Viewer(eid=eid, trial=100, camera='right') 49 | ``` 50 | 51 | Example 2 - pick a random session to inspect, showing all DLC 52 | ```python 53 | from wheel_dlc_viewer import Viewer 54 | Viewer(dlc_features='all') 55 | ``` 56 | 57 | For more details, see the docstring for the Viewer. 58 | -------------------------------------------------------------------------------- /dlc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/dlc/__init__.py -------------------------------------------------------------------------------- /dlc/get_dlc_traces.py: -------------------------------------------------------------------------------- 1 | import alf.io 2 | import numpy as np 3 | from oneibl.one import ONE 4 | 5 | def get_DLC(eid,video_type): 6 | '''load dlc traces 7 | load dlc traces for a given session and 8 | video type. 9 | 10 | :param eid: A session eid 11 | :param video_type: string in 'left', 'right', body' 12 | :return: array of times and dict with dlc points 13 | as keys and x,y coordinates as values, 14 | for each frame id 15 | ''' 16 | one = ONE() 17 | D = one.load(eid, dataset_types = ['camera.dlc', 'camera.times']) 18 | alf_path = one.path_from_eid(eid) / 'alf' 19 | cam0 = alf.io.load_object( 20 | alf_path, 21 | '%sCamera' % 22 | video_type, 23 | namespace='ibl') 24 | Times = cam0['times'] 25 | cam = cam0['dlc'] 26 | points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()]) 27 | XYs = {} 28 | for point in points: 29 | x = np.ma.masked_where( 30 | cam[point + '_likelihood'] < 0.9, cam[point + '_x']) 31 | x = x.filled(np.nan) 32 | y = np.ma.masked_where( 33 | cam[point + '_likelihood'] < 0.9, cam[point + '_y']) 34 | y = y.filled(np.nan) 35 | XYs[point] = np.array([x, y]) 36 | 37 | return Times, XYs 38 | -------------------------------------------------------------------------------- /dlc/stream_dlc_labeled_frames.py: -------------------------------------------------------------------------------- 1 | from oneibl.one import ONE 2 | from ibllib.io.video import get_video_meta, get_video_frames_preload 3 | import numpy as np 4 | from pathlib import Path 5 | import cv2 6 | import alf.io 7 | import matplotlib 8 | from random import sample 9 | import time 10 | 11 | 12 | def stream_save_labeled_frames(eid, video_type): 13 | 14 | startTime = time.time() 15 | ''' 16 | For a given eid and camera type, stream 17 | sample frames, print DLC labels on them 18 | and save 19 | ''' 20 | 21 | # eid = '5522ac4b-0e41-4c53-836a-aaa17e82b9eb' 22 | # video_type = 'left' 23 | 24 | n_frames = 5 # sample 5 random frames 25 | 26 | save_images_folder = '/home/mic/DLC_QC/example_frames/' 27 | one = ONE() 28 | info = '_'.join(np.array(str(one.path_from_eid(eid)).split('/'))[[5,7,8]]) 29 | print(info, video_type) 30 | 31 | r = one.list(eid, 'dataset_types') 32 | 33 | dtypes_DLC = ['_ibl_rightCamera.times.npy', 34 | '_ibl_leftCamera.times.npy', 35 | '_ibl_bodyCamera.times.npy', 36 | '_iblrig_leftCamera.raw.mp4', 37 | '_iblrig_rightCamera.raw.mp4', 38 | '_iblrig_bodyCamera.raw.mp4', 39 | '_ibl_leftCamera.dlc.pqt', 40 | '_ibl_rightCamera.dlc.pqt', 41 | '_ibl_bodyCamera.dlc.pqt'] 42 | 43 | dtype_names = [x['name'] for x in r] 44 | 45 | assert all([i in dtype_names for i in dtypes_DLC] 46 | ), 'For this eid, not all data available' 47 | 48 | 49 | D = one.load(eid, 50 | dataset_types=['camera.times', 'camera.dlc'], 51 | dclass_output=True) 52 | alf_path = Path(D.local_path[0]).parent.parent / 'alf' 53 | 54 | cam0 = alf.io.load_object( 55 | alf_path, 56 | '%sCamera' % 57 | video_type, 58 | namespace='ibl') 59 | 60 | Times = cam0['times'] 61 | 62 | cam = cam0['dlc'] 63 | points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()]) 64 | 65 | XYs = {} 66 | for point in points: 67 | x = np.ma.masked_where( 68 | cam[point + '_likelihood'] < 0.9, cam[point + '_x']) 69 | x = x.filled(np.nan) 70 | y = np.ma.masked_where( 71 | cam[point + '_likelihood'] < 0.9, cam[point + '_y']) 72 | y = y.filled(np.nan) 73 | XYs[point] = np.array( 74 | [x, y]) 75 | 76 | 77 | 78 | if video_type != 'body': 79 | d = list(points) 80 | d.remove('tube_top') 81 | d.remove('tube_bottom') 82 | points = np.array(d) 83 | 84 | # stream frames 85 | recs = [x for x in r if f'{video_type}Camera.raw.mp4' in x['name']][0]['file_records'] 86 | video_path = [x['data_url'] for x in recs if x['data_url'] is not None][0] 87 | vid_meta = get_video_meta(video_path) 88 | 89 | frame_idx = sample(range(vid_meta['length']), n_frames) 90 | print('frame indices:', frame_idx) 91 | frames = get_video_frames_preload(video_path, 92 | frame_idx, 93 | mask=np.s_[:, :, 0]) 94 | size = [vid_meta['width'], vid_meta['height']] 95 | #return XYs, frames 96 | 97 | x0 = 0 98 | x1 = size[0] 99 | y0 = 0 100 | y1 = size[1] 101 | if video_type == 'left': 102 | dot_s = 10 # [px] for painting DLC dots 103 | else: 104 | dot_s = 5 105 | 106 | # writing stuff on frames 107 | font = cv2.FONT_HERSHEY_SIMPLEX 108 | 109 | if video_type == 'left': 110 | bottomLeftCornerOfText = (20, 1000) 111 | fontScale = 4 112 | else: 113 | bottomLeftCornerOfText = (10, 500) 114 | fontScale = 2 115 | 116 | lineType = 2 117 | 118 | # assign a color to each DLC point (now: all points red) 119 | cmap = matplotlib.cm.get_cmap('Set1') 120 | CR = np.arange(len(points)) / len(points) 121 | 122 | block = np.ones((2 * dot_s, 2 * dot_s, 3)) 123 | 124 | k = 0 125 | for frame in frames: 126 | 127 | gray = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) 128 | 129 | 130 | # print session info 131 | fontColor = (255, 255, 255) 132 | cv2.putText(gray, 133 | info, 134 | bottomLeftCornerOfText, 135 | font, 136 | fontScale / 4, 137 | fontColor, 138 | lineType) 139 | 140 | # print time 141 | Time = round(Times[frame_idx[k]], 3) 142 | a, b = bottomLeftCornerOfText 143 | bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b) 144 | 145 | 146 | 147 | a, b = bottomLeftCornerOfText 148 | bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b) 149 | cv2.putText(gray, 150 | ' time: ' + str(Time), 151 | bottomLeftCornerOfText0, 152 | font, 153 | fontScale / 2, 154 | fontColor, 155 | lineType) 156 | 157 | 158 | 159 | # print DLC dots 160 | ll = 0 161 | for point in points: 162 | 163 | # Put point color legend 164 | fontColor = (np.array([cmap(CR[ll])]) * 255)[0][:3] 165 | a, b = bottomLeftCornerOfText 166 | if video_type == 'right': 167 | bottomLeftCornerOfText2 = (a, a * 2 * (1 + ll)) 168 | else: 169 | bottomLeftCornerOfText2 = (b, a * 2 * (1 + ll)) 170 | fontScale2 = fontScale / 4 171 | cv2.putText(gray, point, 172 | bottomLeftCornerOfText2, 173 | font, 174 | fontScale2, 175 | fontColor, 176 | lineType) 177 | 178 | X0 = XYs[point][0][frame_idx[k]] 179 | Y0 = XYs[point][1][frame_idx[k]] 180 | 181 | X = Y0 182 | Y = X0 183 | 184 | #print(point,X,Y) 185 | if not np.isnan(X) and not np.isnan(Y): 186 | try: 187 | col = (np.array([cmap(CR[ll])]) * 255)[0][:3] 188 | # col = np.array([0, 0, 255]) # all points red 189 | X = X.astype(int) 190 | Y = Y.astype(int) 191 | 192 | uu = block * col 193 | gray[X - dot_s:X + dot_s, Y - dot_s:Y + dot_s] = uu 194 | 195 | except Exception as e: 196 | print('frame', frame_idx[k]) 197 | print(e) 198 | ll += 1 199 | 200 | gray = gray[y0:y1, x0:x1] 201 | # cv2.imshow('frame', gray) 202 | cv2.imwrite(f'{save_images_folder}{eid}_frame_{frame_idx[k]}.png', 203 | gray) 204 | cv2.waitKey(1) 205 | k += 1 206 | 207 | print(f'{n_frames} frames done in', np.round(time.time() - startTime)) 208 | -------------------------------------------------------------------------------- /ephysfeatures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/ephysfeatures/__init__.py -------------------------------------------------------------------------------- /ephysfeatures/prepare_features.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import one.alf.io as alfio 4 | import pandas as pd 5 | from one.api import ONE 6 | from brainbox.processing import bincount2D 7 | one = ONE() 8 | 9 | root_path = Path(r'C:\Users\Mayo\Downloads\FlatIron') 10 | files = list(root_path.rglob('electrodeSites.mlapdv.npy')) 11 | 12 | all_df_chns = [] 13 | 14 | for file in files: 15 | try: 16 | session_path = Path(*file.parts[:10]) 17 | ref = one.path2ref(session_path) 18 | eid = one.path2eid(session_path) 19 | probe = file.parts[11] 20 | 21 | lfp = alfio.load_object(session_path.joinpath('raw_ephys_data', probe), 'ephysSpectralDensityLF', namespace='iblqc') 22 | mean_lfp = 10 * np.log10(np.mean(lfp.power[:,:-1], axis=0)) 23 | 24 | try: 25 | ap = alfio.load_object(session_path.joinpath('raw_ephys_data', probe), 'ephysTimeRmsAP', namespace='iblqc') 26 | mean_ap = np.mean(ap.rms[:, :384], axis=0) 27 | except Exception as err: 28 | mean_ap = 50 * np.ones((384)) 29 | 30 | try: 31 | spikes = alfio.load_object(session_path.joinpath(f'alf/{probe}/pykilosort'), 'spikes') 32 | kp_idx = ~np.isnan(spikes['depths']) 33 | T_BIN = np.max(spikes['times']) 34 | D_BIN = 10 35 | chn_min = 10 36 | chn_max = 3840 37 | nspikes, times, depths = bincount2D(spikes['times'][kp_idx], 38 | spikes['depths'][kp_idx], 39 | T_BIN, D_BIN, 40 | ylim=[chn_min, chn_max]) 41 | 42 | amp, times, depths = bincount2D(spikes['amps'][kp_idx], 43 | spikes['depths'][kp_idx], 44 | T_BIN, D_BIN, ylim=[chn_min, chn_max], 45 | weights=spikes['amps'][kp_idx]) 46 | mean_fr = nspikes[:, 0] / T_BIN 47 | mean_amp = np.divide(amp[:, 0], nspikes[:, 0]) * 1e6 48 | mean_amp[np.isnan(mean_amp)] = 0 49 | remove_bins = np.where(nspikes[:, 0] < 50)[0] 50 | mean_amp[remove_bins] = 0 51 | except Exception as err: 52 | mean_fr = np.ones((384)) 53 | mean_amp = np.ones((384)) 54 | depths = np.ones((384)) * 20 55 | 56 | 57 | 58 | channels = alfio.load_object(file.parent, 'electrodeSites') 59 | data_chns = {} 60 | data_chns['x'] = channels['mlapdv'][:, 0] 61 | data_chns['y'] = channels['mlapdv'][:, 1] 62 | data_chns['z'] = channels['mlapdv'][:, 2] 63 | data_chns['axial_um'] = channels['localCoordinates'][:, 0] 64 | data_chns['lateral_um'] = channels['localCoordinates'][:, 1] 65 | data_chns['lfp'] = mean_lfp 66 | data_chns['ap'] = mean_ap 67 | data_chns['depth_line'] = depths 68 | data_chns['fr'] = mean_fr 69 | data_chns['amp'] = mean_amp 70 | data_chns['region_id'] = channels['brainLocationIds_ccf_2017'] 71 | 72 | df_chns = pd.DataFrame.from_dict(data_chns) 73 | df_chns['subject'] = ref['subject'] 74 | df_chns['date'] = str(ref['date']) 75 | df_chns['probe'] = probe 76 | df_chns['pid'] = eid + probe 77 | 78 | 79 | all_df_chns.append(df_chns) 80 | except Exception as err: 81 | print(err) 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /histology/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/histology/__init__.py -------------------------------------------------------------------------------- /histology/atlas_mpl.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import matplotlib 3 | import numpy as np 4 | from PyQt5 import QtCore, QtWidgets, uic 5 | 6 | from qt_helpers import qt 7 | from qt_helpers.qt_matplotlib import BaseMplCanvas 8 | import iblatlas.atlas as atlas 9 | 10 | # Make sure that we are using QT5 11 | matplotlib.use('Qt5Agg') 12 | 13 | 14 | class Model: 15 | """ 16 | Container for Data and variables of the application 17 | """ 18 | brain_atlas: atlas.BrainAtlas 19 | ap_um: float 20 | 21 | def __init__(self, brain_atlas=None, ap_um=0): 22 | self.ap_um = ap_um 23 | # load the brain atlas 24 | if brain_atlas is None: 25 | self.brain_atlas = atlas.AllenAtlas(res_um=25) 26 | else: 27 | self.brain_atlas = brain_atlas 28 | 29 | 30 | class MyStaticMplCanvas(BaseMplCanvas): 31 | 32 | def __init__(self, *args, **kwargs): 33 | super(MyStaticMplCanvas, self).__init__(*args, **kwargs) 34 | self.mpl_connect("motion_notify_event", self.on_move) 35 | 36 | def on_move(self, event): 37 | mw = qt.get_main_window() 38 | ap = mw._model.ap_um 39 | xlab = '' if event.xdata is None else '{: 6.0f}'.format(event.xdata) 40 | ylab = '' if event.ydata is None else '{: 6.0f}'.format(event.ydata) 41 | mw.label_ml.setText(xlab) 42 | mw.label_dv.setText(ylab) 43 | mw.label_ap.setText('{: 6.0f}'.format(ap)) 44 | if event.xdata is None or event.ydata is None: 45 | return 46 | # if whithin the bounds of the Atlas, label with the current strucure hovered onto 47 | xyz = np.array([[event.xdata, ap, event.ydata]]) / 1e6 48 | regions = mw._model.brain_atlas.regions 49 | id_label = mw._model.brain_atlas.get_labels(xyz) 50 | il = np.where(regions.id == id_label)[0] 51 | mw.label_acronym.setText(regions.acronym[il][0]) 52 | mw.label_structure.setText(regions.name[il][0]) 53 | mw.label_id.setText(str(id_label)) 54 | 55 | def update_slice(self, volume='image'): 56 | mw = qt.get_main_window() 57 | ap_um = mw._model.ap_um 58 | im = mw._model.brain_atlas.slice(ap_um / 1e6, axis=1, volume=volume) 59 | im = np.swapaxes(im, 0, 1) 60 | self.axes.images[0].set_data(im) 61 | self.draw() 62 | 63 | 64 | class AtlasViewer(QtWidgets.QMainWindow): 65 | def __init__(self, brain_atlas=None, ap_um=0): 66 | # init the figure 67 | super(AtlasViewer, self).__init__() 68 | uic.loadUi(str(Path(__file__).parent.joinpath('atlas_mpl.ui')), self) 69 | self.setAttribute(QtCore.Qt.WA_DeleteOnClose) 70 | self.file_menu = QtWidgets.QMenu('&File', self) 71 | self.file_menu.addAction('&Quit', self.fileQuit, 72 | QtCore.Qt.CTRL + QtCore.Qt.Key_Q) 73 | self.menuBar().addMenu(self.file_menu) 74 | self.pb_toggle.clicked.connect(self.update_slice) 75 | 76 | # init model 77 | self._model = Model(brain_atlas=brain_atlas, ap_um=ap_um) 78 | # display the coronal slice in the mpl widget 79 | self._model.brain_atlas.plot_cslice(ap_coordinate=ap_um / 1e6, 80 | ax=self.mpl_widget.axes) 81 | 82 | def fileQuit(self): 83 | self.close() 84 | 85 | def closeEvent(self, ce): 86 | self.fileQuit() 87 | 88 | def update_slice(self, ce): 89 | volume = self.pb_toggle.text() 90 | if volume == 'Annotation': 91 | self.pb_toggle.setText('Image') 92 | elif volume == 'Image': 93 | self.pb_toggle.setText('Annotation') 94 | self.mpl_widget.update_slice(volume.lower()) 95 | 96 | 97 | def viewatlas(brain_atlas=None, title=None, ap_um=0): 98 | qt.create_app() 99 | av = AtlasViewer(brain_atlas, ap_um=ap_um) 100 | av.setWindowTitle(title) 101 | av.show() 102 | return av, av.mpl_widget.axes 103 | 104 | 105 | if __name__ == "__main__": 106 | w = viewatlas(title="Allen Atlas - IBL") 107 | qt.run_app() 108 | -------------------------------------------------------------------------------- /histology/atlas_mpl.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | AtlasViewer 4 | 5 | 6 | 7 | 0 8 | 0 9 | 986 10 | 860 11 | 12 | 13 | 14 | 15 | 0 16 | 0 17 | 18 | 19 | 20 | MainWindow 21 | 22 | 23 | background-color: rgb(255, 255, 255); 24 | 25 | 26 | 27 | 28 | 0 29 | 0 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 16777215 38 | 20 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 16777215 51 | 20 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 16777215 64 | 20 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 16777215 77 | 20 78 | 79 | 80 | 81 | dv (um) 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 16777215 90 | 20 91 | 92 | 93 | 94 | structure 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 16777215 103 | 20 104 | 105 | 106 | 107 | ap (um) 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 16777215 116 | 20 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 16777215 129 | 20 130 | 131 | 132 | 133 | ml (um) 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 16777215 142 | 20 143 | 144 | 145 | 146 | atlas id 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 16777215 155 | 20 156 | 157 | 158 | 159 | acronym 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 16777215 168 | 20 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 16777215 181 | 20 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 188 194 | 0 195 | 196 | 197 | 198 | 199 | 10 200 | 16777215 201 | 202 | 203 | 204 | Annotation 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 0 213 | 316 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 0 224 | 0 225 | 986 226 | 22 227 | 228 | 229 | 230 | 231 | 232 | 233 | toolBar 234 | 235 | 236 | TopToolBarArea 237 | 238 | 239 | false 240 | 241 | 242 | 243 | 244 | 245 | MyStaticMplCanvas 246 | QWidget 247 |
iblapps.histology.atlas_mpl
248 | 1 249 |
250 |
251 | 252 | 253 | 254 | pb_toggle 255 | released() 256 | mpl_widget 257 | update() 258 | 259 | 260 | 956 261 | 255 262 | 263 | 264 | 524 265 | 563 266 | 267 | 268 | 269 | 270 |
271 | -------------------------------------------------------------------------------- /histology/transform-tracks.py: -------------------------------------------------------------------------------- 1 | """Re-Scale and Transform Traced Electrode Tracks 2 | 3 | This script will convert traced tracks from full resolution stacks, into 4 | the downsampled stack and finally transform them into Standard (ARA) 5 | space. 6 | 7 | 8 | DEPENDENCIES 9 | 10 | For this script to run, elastix must be installed. See instruction at 11 | elastix.isi.uu.nl/download.php. 12 | 13 | Check elastix is installed by running at command line: 14 | 15 | $ elastix --version 16 | elastix version: 4.800 17 | 18 | 19 | GENERATING TRACED TRACKS 20 | 21 | The tracks must be generated using Lasagna and the add_line_plugin. 22 | 23 | 24 | The resulting .csv files must be saved into a directory titled 25 | sparsedata/ inside the Sample ROOT Directory: 26 | 27 | - [KS006]: 28 | - sparsedata: 29 | - track01.csv 30 | - track02.csv 31 | 32 | NB: [KS006] can be any sample ID. 33 | 34 | The Sample ROOT Directory must also have the following structure and 35 | files: 36 | 37 | - [KS006]: 38 | - downsampledStacks_25: 39 | - dsKS006_190508_112940_25_25_GR.txt 40 | - ARA2sample: 41 | - TransformParameters.0.txt 42 | - TransformParameters.1.txt 43 | 44 | The script uses information in dsKS006_190508_112940_25_25_GR.txt, 45 | TransformParameters.0.txt and TransformParameters.1.txt to perform the 46 | transformation. 47 | 48 | The output of the script is a series of new csv files, with the track 49 | coordinates now in the registered ARA space. 50 | 51 | - [KS006]: 52 | - downsampledStacks_25: 53 | - rescaled-transformed-sparsedata: 54 | - track01.csv 55 | - track02.csv 56 | 57 | 58 | """ 59 | 60 | import sys 61 | import os 62 | import re 63 | import glob 64 | import csv 65 | import subprocess 66 | from os import listdir 67 | from os.path import isfile, join 68 | from pathlib import Path 69 | import platform 70 | 71 | #print("Python version") 72 | #print (sys.version) 73 | #print("Version info.") 74 | #print (sys.version_info) 75 | 76 | #print( os.getcwd() ) 77 | 78 | xy = 0.0 # num to hold x/y resolution 79 | z = 0.0 # num to hold z resolution 80 | 81 | 82 | rescaleFile = glob.glob("downsampledStacks_25" + os.sep + "ds*GR.txt")[0] 83 | 84 | print(rescaleFile) 85 | 86 | 87 | with open (rescaleFile, 'rt') as myfile: 88 | 89 | for line in myfile: 90 | 91 | #print(line) 92 | 93 | if line.lower().find("x/y:") != -1: 94 | for t in line.split(): 95 | try: 96 | xy = float(t) 97 | except ValueError: 98 | pass 99 | 100 | if line.lower().find("z:") != -1: 101 | for t in line.split(): 102 | try: 103 | z = float(t) 104 | except ValueError: 105 | pass 106 | 107 | #print(xy) 108 | 109 | #print(z) 110 | 111 | 112 | # NEXT - change the path in TransformParameters.1.txt file to point to correct ABSOLUTE path to TransformParameters.0.txt file: 113 | 114 | tp = os.getcwd() + os.sep + "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.0.txt" 115 | 116 | #print(tp) 117 | 118 | 119 | #tp1 = "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt" 120 | 121 | #for line in fileinput.input(tp1, inplace = 1): 122 | # print line.replace("foo", "bar"), 123 | 124 | fp = open("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt","r+") 125 | 126 | fg = open("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "new_TransformParameters.1.txt","w") 127 | 128 | for line in fp: 129 | 130 | if line.find("(InitialTransformParametersFileName") != -1: 131 | new_line = line.replace(line,'(InitialTransformParametersFileName '+tp+')\n') 132 | fg.write(new_line) 133 | 134 | else: 135 | fg.write(line) 136 | 137 | fg.close() 138 | 139 | fp.close() 140 | 141 | os.remove("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt") 142 | os.rename("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "new_TransformParameters.1.txt","downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt") 143 | 144 | 145 | # have extracted the resolution change - now open each CSV file of track coords, and transform them: 146 | 147 | # Read the CSV files: 148 | 149 | # get list of files: 150 | csvFiles = [f for f in listdir("sparsedata" + os.sep) if isfile(join("sparsedata" + os.sep, f))] 151 | 152 | for csvFilePath in csvFiles: 153 | 154 | print(csvFilePath) 155 | # get the total number of lines in the file first: 156 | csvfile = open("sparsedata" + os.sep + csvFilePath) 157 | linenum = len(csvfile.readlines()) 158 | csvfile.close() 159 | 160 | csvFilePathNoExt = csvFilePath[:-4] 161 | 162 | 163 | with open( ("sparsedata" + os.sep + csvFilePath), newline='') as csvfile: 164 | # write to rescaled-file: 165 | csvOut = open( ('downsampledStacks_25' + os.sep + 'rescaled-' + csvFilePath), 'w') 166 | pts = open('downsampledStacks_25' + os.sep + 'rescaled-'+csvFilePathNoExt+'.txt', 'w') 167 | pts.write("point\n") 168 | pts.write("%s\n" % linenum ) 169 | spamreader = csv.reader(csvfile, delimiter=',', quotechar='|') 170 | for row in spamreader: 171 | # NB: CSV file is laid out ZXY! 172 | row[0] = round( (float(row[0])/ z), 6) 173 | 174 | row[1] = round( (float(row[1])/ xy), 6) 175 | 176 | row[2] = round( (float(row[2])/ xy), 6) 177 | 178 | # re-write as PTS: XYZ! 179 | pts.write("%s " % row[1]) 180 | pts.write("%s " % row[2]) 181 | pts.write("%s\n" % row[0]) 182 | 183 | # also re-write as csv: ZXY! 184 | csvOut.write("%s," % row[0]) 185 | csvOut.write("%s," % row[1]) 186 | csvOut.write("%s\n" % row[2]) 187 | 188 | pts.close() 189 | csvOut.close() 190 | 191 | 192 | # next, use system call to convert the points into ARA space: 193 | # Use the INVERSE of sample2ARA -> ARA2sample! 194 | 195 | pts = "downsampledStacks_25" + os.sep + "rescaled-"+csvFilePathNoExt+'.txt' 196 | out = "downsampledStacks_25" + os.sep + "rescaled-transformed-sparsedata" + os.sep 197 | # NOTE - need to use TransformParameters.1.txt - which points to TP0! 198 | tp = "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt" 199 | 200 | # make output DIR: 201 | if os.path.isdir(out)==False: 202 | os.mkdir(out) 203 | 204 | cmd = "transformix -def " + pts + " -out " + out + " -tp " + tp 205 | 206 | print(cmd) 207 | 208 | returned_value = subprocess.call(cmd, shell=True) # returns the exit code in unix 209 | print('returned value:', returned_value) 210 | 211 | # finally, want to convert outputpoints.txt into the ORIGINAL CSV format: 212 | outputPts = open(out+"outputpoints.txt", "r") 213 | 214 | outPtsCsv = open( (r"" + out + csvFilePathNoExt + "_transformed.csv"), 'w') 215 | 216 | for aline in outputPts: 217 | values = aline.split() 218 | # XYZ: 219 | # print("XYZ: ", values[22], values[23], values[24] ) 220 | # NB: CSV file is laid out ZXY! 221 | outPtsCsv.write("%s," % values[24]) 222 | outPtsCsv.write("%s," % values[22]) 223 | outPtsCsv.write("%s\n" % values[23]) 224 | 225 | outputPts.close() 226 | 227 | # remove the temp files: 228 | os.remove('downsampledStacks_25' + os.sep + 'rescaled-' + csvFilePath) 229 | os.remove('downsampledStacks_25' + os.sep + 'rescaled-'+csvFilePathNoExt+'.txt') 230 | os.remove(out+"outputpoints.txt") 231 | os.remove(out+"transformix.log") 232 | 233 | 234 | -------------------------------------------------------------------------------- /launch_phy/README.md: -------------------------------------------------------------------------------- 1 | ## Launch Phy with an eid and probe_name 2 | 3 | To launch Phy, first follow [these instructions](https://github.com//int-brain-lab/iblenv) for setting up a unified IBL conda environment. 4 | 5 | Then in your terminal, make sure you are on the 'develop' branch of this repository, activate your unified conda environment, and run either: 6 | 7 | `python -s subject -d date -n session_no -p probe_name` 8 | e.g `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -s KS022 -d 2019-12-10 -n 1 -p probe00` 9 | 10 | or: 11 | 12 | `python -s subject -e eid -p probe_name` 13 | e.g. `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00` 14 | 15 | 16 | or: 17 | `python phy_launcher.py -pid 2aea57ac-d3d0-4a09-b6fa-0aa9d58d1e11` 18 | 19 | Quality metrics will be computed within phy and displayed automatically. 20 | Description of metrics can be found here: https://docs.google.com/document/d/1ba_krsfm4epiAd0zbQ8hdvDN908P9VZOpTxkkH3P_ZY/edit# 21 | 22 | ## Manual Curation 23 | Manual curation comprises three different steps 24 | 25 | 1) **Labelling clusters** (assigning each cluster with a label Good, Mua or Noise) 26 | * Select a cluster within Cluster View and click on `Edit -> Move best to` and assign your chosen label. 27 | * Alternatively you can use the shortcuts alt+G, alt+M, alt+N to label the cluster as good, mua or noise respectively 28 | 29 | 2) **Additional notes associated with clusters** (extra information about clusters, for example if it looks like artifact or drift) 30 | * Select a cluster within Cluster View 31 | * Ensure snippet mode is enabled by clicking `File-> Enable Snippet Mode` 32 | * Type `:l notes your_note` and then hit enter (when typing you should see the text appear in the lower left 33 | hand corner of the main window) 34 | * A new column, with the heading **notes** should have been created in the Cluster View window 35 | 36 | Make sure you hit the save button frequently during manual curation so your results are saved and can be 37 | recovered in case Phy freezes or crashes! 38 | 39 | 40 | ## Upload manual labels to datajoint 41 | 42 | Once you have completed manual curation, the labels that you assigned clusters and any additional notes that you 43 | added can be uploaded and stored in a datajoint table by running either: 44 | 45 | `python -s subject -d date -n session_no -p probe_name` 46 | e.g `python int-brain-lab\iblapps\launch_phy\populate_cluster_table.py -s KS022 -d 2019-12-10 -n 1 -p probe00` 47 | 48 | or: 49 | 50 | `python -s subject -e eid -p probe_name` 51 | e.g. `python int-brain-lab\iblapps\launch_phy\populate_cluster_table.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00` 52 | 53 | -------------------------------------------------------------------------------- /launch_phy/cluster_table.py: -------------------------------------------------------------------------------- 1 | import datajoint as dj 2 | from ibl_pipeline import reference 3 | 4 | schema = dj.schema('group_shared_ephys') 5 | dj.config["enable_python_native_blobs"] = True 6 | 7 | 8 | @schema 9 | class ClusterLabel(dj.Imported): 10 | definition = """ 11 | cluster_uuid: uuid # uuid of cluster 12 | -> reference.LabMember #user name 13 | --- 14 | label_time: datetime # date on which labelling was done 15 | cluster_label=null: varchar(255) # user assigned label 16 | cluster_note=null: varchar(255) # user note about cluster 17 | """ 18 | 19 | 20 | @schema 21 | class MergedClusters(dj.Imported): 22 | definition = """ 23 | cluster_uuid: uuid # uuid of merged cluster 24 | --- 25 | merged_uuid: longblob # array of uuids of original clusters that form merged cluster 26 | """ 27 | -------------------------------------------------------------------------------- /launch_phy/phy_launcher.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import pandas as pd 5 | 6 | from phy.apps.template import TemplateController, template_gui 7 | from phy.gui.qt import create_app, run_app 8 | from phylib import add_default_handler 9 | from one.api import ONE 10 | from brainbox.metrics.single_units import quick_unit_metrics 11 | from pathlib import Path 12 | from brainbox.io.one import SpikeSortingLoader 13 | 14 | 15 | def launch_phy(probe_name=None, eid=None, pid=None, subj=None, date=None, sess_no=None, one=None): 16 | """ 17 | Launch phy given an eid and probe name. 18 | """ 19 | 20 | # This is a first draft, no error handling and a draft dataset list. 21 | 22 | # Load data from probe # 23 | # -------------------- # 24 | 25 | from one.api import ONE 26 | from iblatlas.atlas import AllenAtlas 27 | from brainbox.io.one import SpikeSortingLoader 28 | import spikeglx 29 | one = one or ONE() 30 | ba = AllenAtlas() 31 | 32 | datasets = [ 33 | 'spikes.times', 34 | 'spikes.clusters', 35 | 'spikes.amps', 36 | 'spikes.templates', 37 | 'spikes.samples', 38 | 'spikes.depths', 39 | 'clusters.uuids', 40 | 'clusters.metrics', 41 | 'clusters.waveforms', 42 | 'clusters.waveformsChannels', 43 | 'clusters.depths', 44 | 'clusters.amps', 45 | 'clusters.channels'] 46 | 47 | if pid is None: 48 | ssl = SpikeSortingLoader(eid=eid, pname=probe_name, one=one, atlas=ba) 49 | else: 50 | ssl = SpikeSortingLoader(pid=pid, one=one, atlas=ba) 51 | ssl.download_spike_sorting(dataset_types=datasets) 52 | ssl.download_spike_sorting_object('templates') 53 | ssl.download_spike_sorting_object('spikes_subset') 54 | 55 | alf_dir = ssl.session_path.joinpath(ssl.collection) 56 | 57 | if not alf_dir.joinpath('clusters.metrics.pqt').exists(): 58 | spikes, clusters, channels = ssl.load_spike_sorting() 59 | ssl.merge_clusters(spikes, clusters, channels, cache_dir=alf_dir) 60 | 61 | raw_file = next(ssl.session_path.joinpath('raw_ephys_folder', ssl.pname).glob('*.ap.*bin'), None) 62 | 63 | if raw_file is not None: 64 | sr = spikeglx.Reader(raw_file) 65 | sample_rate = sr.fs 66 | n_channel_dat = sr.nc - sr.nsync 67 | else: 68 | sample_rate = 30000 69 | n_channel_dat = 384 70 | 71 | # Launch phy # 72 | # -------------------- # 73 | add_default_handler('DEBUG', logging.getLogger("phy")) 74 | add_default_handler('DEBUG', logging.getLogger("phylib")) 75 | create_app() 76 | controller = TemplateController(dat_path=raw_file, dir_path=alf_dir, dtype=np.int16, 77 | n_channels_dat=n_channel_dat, sample_rate=sample_rate, 78 | plugins=['IBLMetricsPlugin'], 79 | plugin_dirs=[Path(__file__).resolve().parent / 'plugins']) 80 | gui = controller.create_gui() 81 | gui.show() 82 | run_app() 83 | gui.close() 84 | controller.model.close() 85 | 86 | 87 | if __name__ == '__main__': 88 | """ 89 | `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00` 90 | `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -pid c07d13ed-e387-4457-8e33-1d16aed3fd92` 91 | """ 92 | from argparse import ArgumentParser 93 | import numpy as np 94 | 95 | parser = ArgumentParser() 96 | parser.add_argument('-s', '--subject', default=False, required=False, 97 | help='Subject Name') 98 | parser.add_argument('-d', '--date', default=False, required=False, 99 | help='Date of session YYYY-MM-DD') 100 | parser.add_argument('-n', '--session_no', default=1, required=False, 101 | help='Session Number', type=int) 102 | parser.add_argument('-e', '--eid', default=False, required=False, 103 | help='Session eid') 104 | parser.add_argument('-p', '--probe_label', default=False, required=False, 105 | help='Probe Label') 106 | parser.add_argument('-pid', '--pid', default=False, required=False, 107 | help='Probe ID') 108 | 109 | args = parser.parse_args() 110 | if args.eid: 111 | launch_phy(probe_name=str(args.probe_label), eid=str(args.eid)) 112 | elif args.pid: 113 | launch_phy(pid=str(args.pid)) 114 | else: 115 | if not np.all(np.array([args.subject, args.date, args.session_no], 116 | dtype=object)): 117 | print('Must give Subject, Date and Session number') 118 | else: 119 | launch_phy(probe_name=str(args.probe_label), subj=str(args.subject), 120 | date=str(args.date), sess_no=args.session_no) 121 | # launch_phy('probe00', subj='KS022', 122 | # date='2019-12-10', sess_no=1) 123 | -------------------------------------------------------------------------------- /launch_phy/plugins/phy_plugin.py: -------------------------------------------------------------------------------- 1 | # import from plugins/cluster_metrics.py 2 | """Show how to add a custom cluster metrics.""" 3 | 4 | import logging 5 | import numpy as np 6 | from phy import IPlugin 7 | import pandas as pd 8 | from pathlib import Path 9 | from brainbox.metrics.single_units import quick_unit_metrics 10 | 11 | 12 | class IBLMetricsPlugin(IPlugin): 13 | def attach_to_controller(self, controller): 14 | """Note that this function is called at initialization time, *before* the supervisor is 15 | created. The `controller.cluster_metrics` items are then passed to the supervisor when 16 | constructing it.""" 17 | 18 | clusters_file = Path(controller.dir_path.joinpath('clusters.metrics.pqt')) 19 | if clusters_file.exists(): 20 | self.metrics = pd.read_parquet(clusters_file) 21 | else: 22 | self.metrics = None 23 | return 24 | 25 | def amplitudes(cluster_id): 26 | amps = controller.get_cluster_amplitude(cluster_id) 27 | return amps * 1e6 28 | 29 | def amp_max(cluster_id): 30 | return self.metrics['amp_max'].iloc[cluster_id] * 1e6 31 | 32 | def amp_min(cluster_id): 33 | return self.metrics['amp_min'].iloc[cluster_id] * 1e6 34 | 35 | def amp_median(cluster_id): 36 | return self.metrics['amp_median'].iloc[cluster_id] * 1e6 37 | 38 | def amp_std_dB(cluster_id): 39 | return self.metrics['amp_std_dB'].iloc[cluster_id] 40 | 41 | def contamination(cluster_id): 42 | return self.metrics['contamination'].iloc[cluster_id] 43 | 44 | def contamination_alt(cluster_id): 45 | return self.metrics['contamination_alt'].iloc[cluster_id] 46 | 47 | def drift(cluster_id): 48 | return self.metrics['drift'].iloc[cluster_id] 49 | 50 | def missed_spikes_est(cluster_id): 51 | return self.metrics['missed_spikes_est'].iloc[cluster_id] 52 | 53 | def noise_cutoff(cluster_id): 54 | return self.metrics['noise_cutoff'].iloc[cluster_id] 55 | 56 | def presence_ratio(cluster_id): 57 | return self.metrics['presence_ratio'].iloc[cluster_id] 58 | 59 | def presence_ratio_std(cluster_id): 60 | return self.metrics['presence_ratio_std'].iloc[cluster_id] 61 | 62 | def slidingRP_viol(cluster_id): 63 | return self.metrics['slidingRP_viol'].iloc[cluster_id] 64 | 65 | def spike_count(cluster_id): 66 | return self.metrics['spike_count'].iloc[cluster_id] 67 | 68 | def firing_rate(cluster_id): 69 | return self.metrics['firing_rate'].iloc[cluster_id] 70 | 71 | def label(cluster_id): 72 | return self.metrics['label'].iloc[cluster_id] 73 | 74 | def ks2_label(cluster_id): 75 | if 'ks2_label' in self.metrics.columns: 76 | return self.metrics['ks2_label'].iloc[cluster_id] 77 | else: 78 | return 'nan' 79 | 80 | controller.cluster_metrics['amplitudes'] = controller.context.memcache(amplitudes) 81 | controller.cluster_metrics['amp_max'] = controller.context.memcache(amp_max) 82 | controller.cluster_metrics['amp_min'] = controller.context.memcache(amp_min) 83 | controller.cluster_metrics['amp_median'] = controller.context.memcache(amp_median) 84 | controller.cluster_metrics['amp_std_dB'] = controller.context.memcache(amp_std_dB) 85 | controller.cluster_metrics['contamination'] = controller.context.memcache(contamination) 86 | controller.cluster_metrics['contamination_alt'] = controller.context.memcache(contamination_alt) 87 | controller.cluster_metrics['drift'] = controller.context.memcache(drift) 88 | controller.cluster_metrics['missed_spikes_est'] = controller.context.memcache(missed_spikes_est) 89 | controller.cluster_metrics['noise_cutoff'] = controller.context.memcache(noise_cutoff) 90 | controller.cluster_metrics['presence_ratio'] = controller.context.memcache(presence_ratio) 91 | controller.cluster_metrics['presence_ratio_std'] = controller.context.memcache(presence_ratio_std) 92 | controller.cluster_metrics['slidingRP_viol'] = controller.context.memcache(slidingRP_viol) 93 | controller.cluster_metrics['spike_count'] = controller.context.memcache(spike_count) 94 | controller.cluster_metrics['firing_rate'] = controller.context.memcache(firing_rate) 95 | controller.cluster_metrics['label'] = controller.context.memcache(label) 96 | controller.cluster_metrics['ks2_label'] = controller.context.memcache(ks2_label) 97 | -------------------------------------------------------------------------------- /launch_phy/populate_cluster_table.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import uuid 4 | from datetime import datetime 5 | 6 | from tqdm import tqdm 7 | from one import alf 8 | import pandas as pd 9 | from one.api import ONE 10 | from one import params 11 | 12 | from launch_phy import cluster_table 13 | 14 | 15 | def populate_dj_with_phy(probe_label, eid=None, subj=None, date=None, 16 | sess_no=None, one=None): 17 | if one is None: 18 | one = ONE() 19 | 20 | if eid is None: 21 | eid = one.search(subject=subj, date=date, number=sess_no)[0] 22 | 23 | sess_path = one.eid2path(eid) 24 | 25 | cols = one.list_collections(eid) 26 | if f'alf/{probe_label}/pykilosort' in cols: 27 | collection = f'alf/{probe_label}/pykilosort' 28 | else: 29 | collection = f'alf/{probe_label}' 30 | 31 | 32 | alf_path = sess_path.joinpath(collection) 33 | 34 | cluster_path = Path(alf_path, 'spikes.clusters.npy') 35 | template_path = Path(alf_path, 'spikes.templates.npy') 36 | 37 | # Compare spikes.clusters with spikes.templates to find which clusters have been merged 38 | phy_clusters = np.load(cluster_path) 39 | id_phy = np.unique(phy_clusters) 40 | orig_clusters = np.load(template_path) 41 | id_orig = np.unique(orig_clusters) 42 | 43 | uuid_list = alf.io.load_file_content(alf_path.joinpath('clusters.uuids.csv')) 44 | 45 | # First deal with merged clusters and make sure they have cluster uuids assigned 46 | # Find the original cluster ids that have been merged into a new cluster 47 | merged_idx = np.setdiff1d(id_orig, id_phy) 48 | 49 | # See if any clusters have been merged, if not skip to the next bit 50 | if np.any(merged_idx): 51 | # Make association between original cluster and new cluster id and save in dict 52 | merge_list = {} 53 | for m in merged_idx: 54 | idx = phy_clusters[np.where(orig_clusters == m)[0][0]] 55 | if idx in merge_list: 56 | merge_list[idx].append(m) 57 | else: 58 | merge_list[idx] = [m] 59 | 60 | # Create a dataframe from the dict 61 | merge_clust = pd.DataFrame(columns={'cluster_idx', 'merged_uuid', 'merged_id'}) 62 | for key, value in merge_list.items(): 63 | value_uuid = uuid_list['uuids'][value] 64 | merge_clust = merge_clust.append({'cluster_idx': key, 'merged_uuid': tuple(value_uuid), 65 | 'merged_idx': tuple(value)}, 66 | ignore_index=True) 67 | 68 | # Get the dj table that has previously stored merged clusters and store in frame 69 | merge = cluster_table.MergedClusters() 70 | merge_dj = pd.DataFrame(columns={'cluster_uuid', 'merged_uuid'}) 71 | merge_dj['cluster_uuid'] = merge.fetch('cluster_uuid').astype(str) 72 | merge_dj['merged_uuid'] = tuple(map(tuple, merge.fetch('merged_uuid'))) 73 | 74 | # Merge the two dataframe to see if any merge combinations already have a cluster_uuid 75 | merge_comb = pd.merge(merge_dj, merge_clust, on=['merged_uuid'], how='outer') 76 | 77 | # Find the merged clusters that do not have a uuid assigned 78 | no_uuid = np.where(pd.isnull(merge_comb['cluster_uuid']))[0] 79 | 80 | # Assign new uuid to new merge pairs and add to the merge table 81 | for nid in no_uuid: 82 | new_uuid = str(uuid.uuid4()) 83 | merge_comb['cluster_uuid'].iloc[nid] = new_uuid 84 | merge.insert1( 85 | dict(cluster_uuid=new_uuid, merged_uuid=merge_comb['merged_uuid'].iloc[nid]), 86 | allow_direct_insert=True) 87 | 88 | # Add all the uuids to the cluster_uuid frame with index according to cluster id from phy 89 | for idx, c_uuid in zip(merge_comb['cluster_idx'].values, 90 | merge_comb['cluster_uuid'].values): 91 | uuid_list.loc[idx] = c_uuid 92 | 93 | csv_path = Path(alf_path, 'merge_info.csv') 94 | merge_comb = merge_comb.reindex(columns=['cluster_idx', 'cluster_uuid', 'merged_idx', 95 | 'merged_uuid']) 96 | 97 | try: 98 | merge_comb.to_csv(csv_path, index=False) 99 | except Exception as err: 100 | print(err) 101 | print('Close merge_info.csv file and then relaunch script') 102 | sys.exit(1) 103 | else: 104 | print('No merges detected, continuing...') 105 | 106 | # Now populate datajoint with cluster labels 107 | user = params.get().ALYX_LOGIN 108 | current_date = datetime.now().replace(microsecond=0) 109 | 110 | try: 111 | cluster_group = alf.io.load_file_content(alf_path.joinpath('cluster_group.tsv')) 112 | except Exception as err: 113 | print(err) 114 | print('Could not find cluster group file output from phy') 115 | sys.exit(1) 116 | 117 | try: 118 | cluster_notes = alf.io.load_file_content(alf_path.joinpath('cluster_notes.tsv')) 119 | cluster_info = pd.merge(cluster_group, cluster_notes, on=['cluster_id'], how='outer') 120 | except Exception as err: 121 | cluster_info = cluster_group 122 | cluster_info['notes'] = None 123 | 124 | cluster_info = cluster_info.where(cluster_info.notnull(), None) 125 | cluster_info['cluster_uuid'] = uuid_list['uuids'][cluster_info['cluster_id']].values 126 | 127 | # dj table that holds data 128 | cluster = cluster_table.ClusterLabel() 129 | 130 | # Find clusters that have already been labelled by user 131 | old_clust = cluster & cluster_info & {'user_name': user} 132 | 133 | dj_clust = pd.DataFrame() 134 | dj_clust['cluster_uuid'] = (old_clust.fetch('cluster_uuid')).astype(str) 135 | dj_clust['cluster_label'] = old_clust.fetch('cluster_label') 136 | 137 | # First find the new clusters to insert into datajoint 138 | idx_new = np.where(np.isin(cluster_info['cluster_uuid'], 139 | dj_clust['cluster_uuid'], invert=True))[0] 140 | cluster_uuid = cluster_info['cluster_uuid'][idx_new].values 141 | cluster_label = cluster_info['group'][idx_new].values 142 | cluster_note = cluster_info['notes'][idx_new].values 143 | 144 | if idx_new.size != 0: 145 | print('Populating dj with ' + str(idx_new.size) + ' new labels') 146 | else: 147 | print('No new labels to add') 148 | for iClust, iLabel, iNote in zip(tqdm(cluster_uuid), cluster_label, cluster_note): 149 | cluster.insert1(dict(cluster_uuid=iClust, user_name=user, 150 | label_time=current_date, 151 | cluster_label=iLabel, 152 | cluster_note=iNote), 153 | allow_direct_insert=True) 154 | 155 | # Next look through clusters already on datajoint and check if any labels have 156 | # been changed 157 | comp_clust = pd.merge(cluster_info, dj_clust, on='cluster_uuid') 158 | idx_change = np.where(comp_clust['group'] != comp_clust['cluster_label'])[0] 159 | 160 | cluster_uuid = comp_clust['cluster_uuid'][idx_change].values 161 | cluster_label = comp_clust['group'][idx_change].values 162 | cluster_note = comp_clust['notes'][idx_change].values 163 | 164 | # Populate table 165 | if idx_change.size != 0: 166 | print('Replacing label of ' + str(idx_change.size) + ' clusters') 167 | else: 168 | print('No labels to change') 169 | for iClust, iLabel, iNote in zip(tqdm(cluster_uuid), cluster_label, cluster_note): 170 | prev_clust = cluster & {'user_name': user} & {'cluster_uuid': iClust} 171 | cluster.insert1(dict(*prev_clust.proj(), 172 | label_time=current_date, 173 | cluster_label=iLabel, 174 | cluster_note=iNote), 175 | allow_direct_insert=True, replace=True) 176 | 177 | print('Upload to datajoint complete') 178 | 179 | 180 | if __name__ == '__main__': 181 | from argparse import ArgumentParser 182 | import numpy as np 183 | 184 | parser = ArgumentParser() 185 | parser.add_argument('-s', '--subject', default=False, required=False, 186 | help='Subject Name') 187 | parser.add_argument('-d', '--date', default=False, required=False, 188 | help='Date of session YYYY-MM-DD') 189 | parser.add_argument('-n', '--session_no', default=1, required=False, 190 | help='Session Number', type=int) 191 | parser.add_argument('-e', '--eid', default=False, required=False, 192 | help='Session eid') 193 | parser.add_argument('-p', '--probe_label', default=False, required=True, 194 | help='Probe Label') 195 | args = parser.parse_args() 196 | 197 | if args.eid: 198 | populate_dj_with_phy(str(args.probe_label), eid=str(args.eid)) 199 | else: 200 | if not np.all(np.array([args.subject, args.date, args.session_no], 201 | dtype=object)): 202 | print('Must give Subject, Date and Session number') 203 | else: 204 | populate_dj_with_phy(str(args.probe_label), subj=str(args.subject), 205 | date=str(args.date), sess_no=args.session_no) 206 | 207 | #populate_dj_with_phy('probe00', subj='KS022', date='2019-12-10', sess_no=1) 208 | -------------------------------------------------------------------------------- /needles2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/needles2/__init__.py -------------------------------------------------------------------------------- /needles2/layerUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Form 4 | 5 | 6 | 7 | 0 8 | 0 9 | 400 10 | 82 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 20 20 | 20 21 | 341 22 | 41 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /needles2/mainUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 1501 10 | 884 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 0 23 | 0 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | - 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 0 88 | 0 89 | 1501 90 | 21 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | PlotWidget 99 | QGraphicsView 100 |
pyqtgraph
101 |
102 |
103 | 104 | 105 |
106 | -------------------------------------------------------------------------------- /needles2/needles_viewer.py: -------------------------------------------------------------------------------- 1 | from iblviewer.atlas_controller import AtlasController 2 | import vedo 3 | from iblviewer.atlas_model import AtlasModel, AtlasUIModel, CameraModel 4 | from iblviewer.slicer_model import SlicerModel 5 | 6 | from iblviewer.atlas_view import AtlasView 7 | from iblviewer.volume_view import VolumeView 8 | from iblviewer.slicer_view import SlicerView 9 | import iblviewer.utils as utils 10 | 11 | from ipyvtk_simple.viewer import ViewInteractiveWidget 12 | 13 | 14 | class NeedlesViewer(AtlasController): 15 | def __init__(self): 16 | super(NeedlesViewer, self).__init__() 17 | 18 | def initialize(self, plot, resolution=25, mapping='Allen', volume_mode=None, num_windows=1, render=False): 19 | vedo.settings.allowInteraction = False 20 | self.plot = plot 21 | self.plot_window_id = 0 22 | self.model = AtlasModel() 23 | self.model.initialize(resolution) 24 | self.model.load_allen_volume(mapping, volume_mode) 25 | self.model.initialize_slicers() 26 | 27 | self.view = AtlasView(self.plot, self.model) 28 | self.view.initialize() 29 | self.view.volume = VolumeView(self.plot, self.model.volume, self.model) 30 | 31 | #pn = SlicerModel.NAME_XYZ_POSITIVE 32 | nn = SlicerModel.NAME_XYZ_NEGATIVE 33 | 34 | #pxs_model = self.model.find_model(pn[0], self.model.slicers) 35 | #self.px_slicer = SlicerView(self.plot, self.view.volume, pxs_model, self.model) 36 | #pys_model = self.model.find_model(pn[1], self.model.slicers) 37 | #self.py_slicer = SlicerView(self.plot, self.view.volume, pys_model, self.model) 38 | #pzs_model = self.model.find_model(pn[2], self.model.slicers) 39 | #self.pz_slicer = SlicerView(self.plot, self.view.volume, pzs_model, self.model) 40 | 41 | nxs_model = self.model.find_model(nn[0], self.model.slicers) 42 | self.nx_slicer = SlicerView(self.plot, self.view.volume, nxs_model, self.model) 43 | nys_model = self.model.find_model(nn[1], self.model.slicers) 44 | self.ny_slicer = SlicerView(self.plot, self.view.volume, nys_model, self.model) 45 | nzs_model = self.model.find_model(nn[2], self.model.slicers) 46 | self.nz_slicer = SlicerView(self.plot, self.view.volume, nzs_model, self.model) 47 | 48 | self.slicers = [self.nx_slicer, self.ny_slicer, self.nz_slicer] 49 | 50 | vedo.settings.defaultFont = self.model.ui.font 51 | self.initialize_embed_ui(slicer_target=self.view.volume) 52 | 53 | self.plot.show(interactive=False) 54 | self.handle_transfer_function_update() 55 | # By default, the atlas volume is our target 56 | self.model.camera.target = self.view.volume.actor 57 | # We start with a sagittal view 58 | self.set_left_view() 59 | 60 | self.render() 61 | 62 | def initialize_embed_ui(self, slicer_target=None): 63 | s_kw = self.model.ui.slider_config 64 | d = self.view.volume.model.dimensions 65 | if d is None: 66 | return 67 | 68 | #self.add_slider('px', self.update_px_slicer, 0, int(d[0]), 0, pos=(0.05, 0.065, 0.12), 69 | # title='+X', **s_kw) 70 | #self.add_slider('py', self.update_py_slicer, 0, int(d[1]), 0, pos=(0.2, 0.065, 0.12), 71 | # title='+Y', **s_kw) 72 | #self.add_slider('pz', self.update_pz_slicer, 0, int(d[2]), 0, pos=(0.35, 0.065, 0.12), 73 | # title='+Z', **s_kw) 74 | 75 | def update_slicer(self, slicer_view, value): 76 | """ 77 | Update a given slicer with the given value 78 | :param slicer_view: SlicerView instance 79 | :param value: Value 80 | """ 81 | volume = self.view.volume 82 | model = slicer_view.model 83 | model.set_value(value) 84 | model.clipping_planes = volume.get_clipping_planes(model.axis) 85 | slicer_view.update(add_to_scene=self.model.slices_visible) 86 | 87 | 88 | -------------------------------------------------------------------------------- /needles2/regionUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Form 4 | 5 | 6 | 7 | 0 8 | 0 9 | 209 10 | 210 11 | 12 | 13 | 14 | 15 | 0 16 | 0 17 | 18 | 19 | 20 | Form 21 | 22 | 23 | 24 | 25 | 10 26 | 10 27 | 291 28 | 210 29 | 30 | 31 | 32 | 33 | 0 34 | 0 35 | 36 | 37 | 38 | QFrame::StyledPanel 39 | 40 | 41 | QFrame::Raised 42 | 43 | 44 | 45 | 46 | 10 47 | 10 48 | 100 49 | 30 50 | 51 | 52 | 53 | 54 | 75 55 | true 56 | 57 | 58 | 59 | Selected Region: 60 | 61 | 62 | 63 | 64 | 65 | 10 66 | 40 67 | 171 68 | 20 69 | 70 | 71 | 72 | QComboBox::AdjustToContentsOnFirstShow 73 | 74 | 75 | 0 76 | 77 | 78 | 79 | 80 | 81 | 10 82 | 70 83 | 90 84 | 30 85 | 86 | 87 | 88 | 89 | 75 90 | true 91 | 92 | 93 | 94 | Current Region: 95 | 96 | 97 | 98 | 99 | 100 | 110 101 | 70 102 | 80 103 | 30 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 10 114 | 100 115 | 30 116 | 30 117 | 118 | 119 | 120 | 121 | 75 122 | true 123 | 124 | 125 | 126 | ML: 127 | 128 | 129 | 130 | 131 | 132 | 50 133 | 100 134 | 60 135 | 30 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 10 146 | 130 147 | 30 148 | 30 149 | 150 | 151 | 152 | 153 | 75 154 | true 155 | 156 | 157 | 158 | AP: 159 | 160 | 161 | 162 | 163 | 164 | 50 165 | 130 166 | 60 167 | 30 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 10 178 | 160 179 | 50 180 | 30 181 | 182 | 183 | 184 | 185 | 75 186 | true 187 | 188 | 189 | 190 | DV: 191 | 192 | 193 | 194 | 195 | 196 | 50 197 | 160 198 | 50 199 | 30 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /needles2/tableUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Form 4 | 5 | 6 | 7 | 0 8 | 0 9 | 400 10 | 300 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 10 20 | 10 21 | 321 22 | 271 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /qt_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/qt_helpers/__init__.py -------------------------------------------------------------------------------- /qt_helpers/qt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from functools import wraps 4 | 5 | from PyQt5 import QtWidgets 6 | 7 | _logger = logging.getLogger('ibllib') 8 | 9 | 10 | def get_main_window(): 11 | """ Get the Main window of a QT application""" 12 | app = QtWidgets.QApplication.instance() 13 | return [w for w in app.topLevelWidgets() if isinstance(w, QtWidgets.QMainWindow)][0] 14 | 15 | 16 | def create_app(): 17 | """Create a Qt application.""" 18 | global QT_APP 19 | QT_APP = QtWidgets.QApplication.instance() 20 | if QT_APP is None: # pragma: no cover 21 | QT_APP = QtWidgets.QApplication(sys.argv) 22 | return QT_APP 23 | 24 | 25 | def require_qt(func): 26 | """Function decorator to specify that a function requires a Qt application. 27 | Use this decorator to specify that a function needs a running 28 | Qt application before it can run. An error is raised if that is not 29 | the case. 30 | """ 31 | @wraps(func) 32 | def wrapped(*args, **kwargs): 33 | if not QtWidgets.QApplication.instance(): # pragma: no cover 34 | _logger.warning("Creating a Qt application.") 35 | create_app() 36 | return func(*args, **kwargs) 37 | return wrapped 38 | 39 | 40 | @require_qt 41 | def run_app(): # pragma: no cover 42 | """Run the Qt application.""" 43 | global QT_APP 44 | return QT_APP.exit(QT_APP.exec_()) 45 | -------------------------------------------------------------------------------- /qt_helpers/qt_matplotlib.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | # Make sure that we are using QT5 3 | matplotlib.use('Qt5Agg') 4 | from PyQt5 import QtCore, QtWidgets, uic 5 | 6 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 7 | from matplotlib.figure import Figure 8 | 9 | 10 | class BaseMplCanvas(FigureCanvas): 11 | """ 12 | ui 13 | model 14 | """ 15 | def __init__(self, parent=None, ui=None, model=None, width=5, height=4, dpi=100): 16 | fig = Figure(figsize=(width, height), dpi=dpi) 17 | self.axes = fig.add_subplot(111) 18 | self._ui = ui 19 | self._model = model 20 | self.compute_initial_figure() 21 | 22 | FigureCanvas.__init__(self, fig) 23 | self.setParent(parent) 24 | 25 | FigureCanvas.setSizePolicy(self, 26 | QtWidgets.QSizePolicy.Expanding, 27 | QtWidgets.QSizePolicy.Expanding) 28 | FigureCanvas.updateGeometry(self) 29 | self.setCursor(QtCore.Qt.CrossCursor) 30 | 31 | def compute_initial_figure(self): 32 | pass 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easyqc 2 | ibllib 3 | pyqtgraph 4 | simpleITK 5 | pyqt5 -------------------------------------------------------------------------------- /run_tests: -------------------------------------------------------------------------------- 1 | python -m unittest discover 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt') as f: 4 | require = [x.strip() for x in f.readlines() if not x.startswith('git+')] 5 | 6 | setup( 7 | name='iblapps', 8 | version='0.1', 9 | python_requires='>=3.10', 10 | packages=find_packages(), 11 | include_package_data=True, 12 | install_requires=require, 13 | entry_points={ 14 | 'console_scripts': [ 15 | 'atlas=atlasview.atlasview:main', 16 | 'ephys-align=atlaselectrophysiology.ephys_atlas_gui:launch_offline', 17 | ] 18 | }, 19 | ) 20 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/data_alignmentqc_gui.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/tests/fixtures/data_alignmentqc_gui.npz -------------------------------------------------------------------------------- /tests/test_task_qc_viewer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | 6 | class TestTaskQCViewer(unittest.TestCase): 7 | def setUp(self): 8 | iblapps_path = Path(__file__).parts[:-2] 9 | self.task_qc_path = Path(*iblapps_path) / 'task_qc_viewer' / 'task_qc.py' 10 | self.UUID = '1211f4af-d3e4-4c4e-9d0b-75a0bc2bf1f0' 11 | 12 | def test_build(self): 13 | os.system(f'python {self.task_qc_path} {self.UUID}') 14 | 15 | def test_build_fail(self): 16 | os.system(f'python {self.task_qc_path} "BAD_ID"') 17 | -------------------------------------------------------------------------------- /viewspikes/README.md: -------------------------------------------------------------------------------- 1 | # View Spikes 2 | Electrophysiology display tools for IBL insertions. 3 | 4 | ## Run instructions 5 | Look at the examples file at the root of the package. 6 | 7 | ## Install Instructions 8 | 9 | Pre-requisites: 10 | - IBL environment installed https://github.com/int-brain-lab/iblenv 11 | - ONE parameters setup to download data 12 | 13 | ``` 14 | conda activate iblenv 15 | pip install viewephys 16 | pip install -U pyqtgraph 17 | ``` 18 | 19 | 20 | ## Roadmap 21 | - multi-probe displays for sessions with several probes 22 | - make sure NP2.0 4 shanks is supported 23 | - display LFP 24 | - speed up the raster plots in pyqtgraph 25 | -------------------------------------------------------------------------------- /viewspikes/datoviz.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | 5 | import datoviz as dviz 6 | 7 | # ------------------------------------------------------------------------------------------------- 8 | # Raster viewer 9 | # ------------------------------------------------------------------------------------------------- 10 | 11 | 12 | class RasterView: 13 | def __init__(self): 14 | self.canvas = dviz.canvas(show_fps=True) 15 | self.panel = self.canvas.panel(controller='axes') 16 | self.visual = self.panel.visual('point') 17 | self.pvars = {'ms': 2., 'alpha': .03} 18 | self.gui = self.canvas.gui('XY') 19 | self.gui.control("label", "Coords", value="(0, 0)") 20 | 21 | def set_spikes(self, spikes): 22 | pos = np.c_[spikes.times, spikes.depths, np.zeros_like(spikes.times)] 23 | color = dviz.colormap(20 * np.log10(spikes.amps), cmap='cividis', alpha=self.pvars['alpha']) 24 | self.visual.data('pos', pos) 25 | self.visual.data('color', color) 26 | self.visual.data('ms', np.array([self.pvars['ms']])) 27 | 28 | 29 | class RasterController: 30 | _time_select_cb = None 31 | 32 | def __init__(self, model, view): 33 | self.m = model 34 | self.v = view 35 | self.v.canvas.connect(self.on_mouse_move) 36 | self.v.canvas.connect(self.on_key_press) 37 | self.redraw() 38 | 39 | def redraw(self): 40 | print('redraw', self.v.pvars) 41 | self.v.set_spikes(self.m.spikes) 42 | 43 | def on_mouse_move(self, x, y, modifiers=()): 44 | p = self.v.canvas.panel_at(x, y) 45 | if not p: 46 | return 47 | # Then, we transform into the data coordinate system 48 | # Supported coordinate systems: 49 | # target_cds='data' / 'scene' / 'vulkan' / 'framebuffer' / 'window' 50 | xd, yd = p.pick(x, y) 51 | self.v.gui.set_value("Coords", f"({xd:0.2f}, {yd:0.2f})") 52 | 53 | def on_key_press(self, key, modifiers=()): 54 | print(key, modifiers) 55 | if key == 'a' and modifiers == ('control',): 56 | self.v.pvars['alpha'] = np.minimum(self.v.pvars['alpha'] + 0.1, 1.) 57 | elif key == 'z' and modifiers == ('control',): 58 | self.v.pvars['alpha'] = np.maximum(self.v.pvars['alpha'] - 0.1, 0.) 59 | elif key == 'page_up': 60 | self.v.pvars['ms'] = np.minimum(self.v.pvars['ms'] * 1.1, 20) 61 | elif key == 'page_down': 62 | self.v.pvars['ms'] = np.maximum(self.v.pvars['ms'] / 1.1, 1) 63 | else: 64 | return 65 | self.redraw() 66 | 67 | 68 | @dataclass 69 | class RasterModel: 70 | spikes: dict 71 | 72 | 73 | def raster(spikes): 74 | rm = RasterController(RasterModel(spikes), RasterView()) 75 | dviz.run() 76 | -------------------------------------------------------------------------------- /viewspikes/example_view_ephys_session.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from one.api import ONE 3 | from brainbox.io.one import EphysSessionLoader, SpikeSortingLoader 4 | from iblapps.viewspikes.gui import view_raster 5 | 6 | PATH_CACHE = Path("/datadisk/Data/NAWG/01_lick_artefacts/openalyx") 7 | 8 | one = ONE(base_url="https://openalyx.internationalbrainlab.org", cache_dir=PATH_CACHE) 9 | 10 | pid = '5135e93f-2f1f-4301-9532-b5ad62548c49' 11 | eid, pname = one.pid2eid(pid) 12 | 13 | 14 | self = view_raster(pid=pid, one=one, stream=False) 15 | 16 | -------------------------------------------------------------------------------- /viewspikes/gui.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import scipy.signal 5 | from PyQt5 import QtWidgets, QtCore, QtGui, uic 6 | import pyqtgraph as pg 7 | 8 | from iblutil.numerical import bincount2D 9 | from viewephys.gui import viewephys 10 | import ibldsp 11 | from brainbox.io.one import EphysSessionLoader, SpikeSortingLoader 12 | from iblatlas.atlas import BrainRegions 13 | 14 | 15 | regions = BrainRegions() 16 | T_BIN = .007 # time bin size in secs 17 | D_BIN = 20 # depth bin size in um 18 | 19 | SNS_PALETTE = [(0.12156862745098039, 0.4666666666666667, 0.7058823529411765), 20 | (1.0, 0.4980392156862745, 0.054901960784313725), 21 | (0.17254901960784313, 0.6274509803921569, 0.17254901960784313), 22 | (0.8392156862745098, 0.15294117647058825, 0.1568627450980392), 23 | (0.5803921568627451, 0.403921568627451, 0.7411764705882353), 24 | (0.5490196078431373, 0.33725490196078434, 0.29411764705882354), 25 | (0.8901960784313725, 0.4666666666666667, 0.7607843137254902), 26 | (0.4980392156862745, 0.4980392156862745, 0.4980392156862745), 27 | (0.7372549019607844, 0.7411764705882353, 0.13333333333333333), 28 | (0.09019607843137255, 0.7450980392156863, 0.8117647058823529)] 29 | 30 | YMIN, YMAX = (-1, 4000) 31 | 32 | 33 | def get_trial_events_to_display(trials): 34 | errors = trials['feedback_times'][trials['feedbackType'] == -1].values 35 | errors = np.sort(np.r_[errors, errors + .5]) 36 | gocue = trials['goCue_times'].values 37 | gocue = np.sort(np.r_[gocue, gocue + .11]) 38 | trial_events = dict( 39 | goCue_times=gocue, 40 | error_times=errors, 41 | reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values) 42 | return trial_events 43 | 44 | 45 | def view_raster(pid, one, stream=True): 46 | from qt import create_app 47 | app = create_app() 48 | ssl = SpikeSortingLoader(one=one, pid=pid) 49 | sl = EphysSessionLoader(one=one, eid=ssl.eid) 50 | sl.load_trials() 51 | spikes, clusters, channels = ssl.load_spike_sorting(dataset_types=['spikes.samples']) 52 | clusters = ssl.merge_clusters(spikes, clusters, channels) 53 | return RasterView(ssl, spikes, clusters, channels, trials=sl.trials, stream=stream) 54 | 55 | 56 | class RasterView(QtWidgets.QMainWindow): 57 | plotItem_raster: pg.PlotWidget = None 58 | 59 | def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=True, *args, **kwargs): 60 | self.ssl = ssl 61 | self.spikes = spikes 62 | self.clusters = clusters 63 | self.channels = channels 64 | self.trials = trials 65 | # self.sr_lf = ssl.raw_electrophysiology(band='lf', stream=False) that would be cool to also have the LFP 66 | self.sr = ssl.raw_electrophysiology(band='ap', stream=stream) 67 | self.eqcs = [] 68 | super(RasterView, self).__init__(*args, **kwargs) 69 | # wave by Diana Militano from the Noun Projectp 70 | uic.loadUi(Path(__file__).parent.joinpath('raster.ui'), self) 71 | background_color = self.palette().color(self.backgroundRole()) 72 | self.plotItem_raster.setAspectLocked(False) 73 | self.imageItem_raster = pg.ImageItem() 74 | self.plotItem_raster.setBackground(background_color) 75 | self.plotItem_raster.addItem(self.imageItem_raster) 76 | self.viewBox_raster = self.plotItem_raster.getPlotItem().getViewBox() 77 | s = self.viewBox_raster.scene() 78 | # vb.scene().sigMouseMoved.connect(self.mouseMoveEvent) 79 | s.sigMouseClicked.connect(self.mouseClick) 80 | ################################################### set image 81 | iok = ~np.isnan(spikes.depths) 82 | self.raster, self.rtimes, self.depths = bincount2D( 83 | spikes.times[iok], spikes.depths[iok], T_BIN, D_BIN) 84 | self.imageItem_raster.setImage(self.raster.T) 85 | transform = [T_BIN, 0., 0., 0., D_BIN, 0., - .5, - .5, 1.] 86 | self.transform = np.array(transform).reshape((3, 3)).T 87 | self.imageItem_raster.setTransform(QtGui.QTransform(*transform)) 88 | self.plotItem_raster.setLimits(xMin=0, xMax=self.rtimes[-1], yMin=0, yMax=self.depths[-1]) 89 | # set colormap 90 | cm = pg.colormap.get('Greys', source='matplotlib') # prepare a linear color map 91 | bar = pg.ColorBarItem(values=(0, .5), colorMap=cm) # prepare interactive color bar 92 | # Have ColorBarItem control colors of img and appear in 'plot': 93 | bar.setImageItem(self.imageItem_raster) 94 | ################################################## plot location 95 | # self.view.layers[label] = {'layer': new_scatter, 'type': 'scatter'} 96 | self.line_eqc = pg.PlotCurveItem() 97 | self.plotItem_raster.addItem(self.line_eqc) 98 | ################################################## plot trials 99 | if self.trials is not None: 100 | trial_times = get_trial_events_to_display(trials) 101 | self.trial_lines = {} 102 | for i, k in enumerate(trial_times): 103 | self.trial_lines[k] = pg.PlotCurveItem() 104 | self.plotItem_raster.addItem(self.trial_lines[k]) 105 | x = np.tile(trial_times[k][:, np.newaxis], (1, 2)).flatten() 106 | y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(trial_times[k].shape[0] / 2 + 1))[:trial_times[k].shape[0] * 2] 107 | self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 255, width=2)) 108 | self.show() 109 | 110 | def mouseClick(self, event): 111 | """Draws a line on the raster and display in EasyQC""" 112 | if not event.double(): 113 | return 114 | qxy = self.imageItem_raster.mapFromScene(event.scenePos()) 115 | x = qxy.x() 116 | self.show_ephys(t0=self.rtimes[int(x - 1)]) 117 | ymax = np.max(self.depths) + 50 118 | self.line_eqc.setData(x=x + np.array([-.5, -.5, .5, .5]), 119 | y=np.array([0, ymax, ymax, 0]), 120 | pen=pg.mkPen((0, 255, 0))) 121 | 122 | 123 | def keyPressEvent(self, e): 124 | """ 125 | page-up / ctrl + a : gain up 126 | page-down / ctrl + z : gain down 127 | :param e: 128 | """ 129 | k, m = (e.key(), e.modifiers()) 130 | # page up / ctrl + a 131 | if k == QtCore.Qt.Key_PageUp or ( 132 | m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_A): 133 | self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] / 1.4]) 134 | # page down / ctrl + z 135 | elif k == QtCore.Qt.Key_PageDown or ( 136 | m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_Z): 137 | self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] * 1.4]) 138 | 139 | def show_ephys(self, t0, tlen=1.8): 140 | """ 141 | :param t0: behaviour time in seconds at which to start the view 142 | :param tlen: 143 | :return: 144 | """ 145 | print(t0) 146 | s0 = int(self.ssl.samples2times(t0, direction='reverse')) 147 | s1 = s0 + int(self.sr.fs * tlen) 148 | raw = self.sr[s0:s1, : - self.sr.nsync].T 149 | 150 | butter_kwargs = {'N': 3, 'Wn': 300 / self.sr.fs * 2, 'btype': 'highpass'} 151 | 152 | sos = scipy.signal.butter(**butter_kwargs, output='sos') 153 | butt = scipy.signal.sosfiltfilt(sos, raw) 154 | destripe = ibldsp.voltage.destripe(raw, fs=self.sr.fs) 155 | 156 | self.eqc_raw = viewephys(butt, self.sr.fs, channels=self.channels, br=regions, title='butt', t0=t0, t_scalar=1) 157 | self.eqc_des = viewephys(destripe, self.sr.fs, channels=self.channels, br=regions, title='destripe', t0=t0, t_scalar=1) 158 | 159 | eqc_xrange = [t0 + tlen / 2 - 0.01, t0 + tlen / 2 + 0.01] 160 | self.eqc_des.viewBox_seismic.setXRange(*eqc_xrange) 161 | self.eqc_raw.viewBox_seismic.setXRange(*eqc_xrange) 162 | 163 | # we slice the spikes using the samples according to ephys time, but display in session times 164 | slice_spikes = slice(np.searchsorted(self.spikes['samples'], s0), np.searchsorted(self.spikes['samples'], s1)) 165 | t = self.spikes['times'][slice_spikes] 166 | ic = self.spikes.clusters[slice_spikes] 167 | 168 | iok = self.clusters['label'][ic] == 1 169 | 170 | for eqc in [self.eqc_des, self.eqc_raw]: 171 | eqc.ctrl.add_scatter(t[~iok], self.clusters.channels[ic[~iok]], (255, 0, 0, 100), label='bad units') 172 | eqc.ctrl.add_scatter(t[iok], self.clusters.channels[ic[iok]], rgb=(0, 255, 0, 100), label='good units') 173 | 174 | if self.trials is not None: 175 | trial_events = get_trial_events_to_display(self.trials) 176 | for i, k in enumerate(trial_events): 177 | ie = np.logical_and(trial_events[k] >= t0, trial_events[k] <= (t0 + tlen)) 178 | if np.sum(ie) == 0: 179 | continue 180 | te = trial_events[k][ie] 181 | x = np.tile(te[:, np.newaxis], (1, 2)).flatten() 182 | y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(te.shape[0] / 2 + 1))[:te.shape[0] * 2] 183 | for eqc in [self.eqc_des, self.eqc_raw]: 184 | eqc.ctrl.add_curve(x, y, rgb=(np.array(SNS_PALETTE[i]) * 255).astype(int), label=k) 185 | print(te) 186 | 187 | -------------------------------------------------------------------------------- /viewspikes/load_ma_data.py: -------------------------------------------------------------------------------- 1 | from one.api import ONE 2 | import one.alf.io as alfio 3 | from iblutil.util import Bunch 4 | from qt_helpers import qt 5 | import numpy as np 6 | import pyqtgraph as pg 7 | 8 | 9 | import atlaselectrophysiology.ephys_atlas_gui as alignment_window 10 | import data_exploration_gui.gui_main as trial_window 11 | 12 | 13 | # some extra controls 14 | 15 | class AlignmentWindow(alignment_window.MainWindow): 16 | def __init__(self, offline=False, probe_id=None, one=None): 17 | super(AlignmentWindow, self).__init__(probe_id=probe_id, one=one) 18 | self.trial_gui = None 19 | 20 | 21 | def cluster_clicked(self, item, point): 22 | clust = super().cluster_clicked(item, point) 23 | print(clust) 24 | self.trial_gui.on_cluster_chosen(clust) 25 | 26 | def add_trials_to_raster(self, trial_key='feedback_times'): 27 | self.selected_trials = self.trial_gui.data.trials[trial_key] 28 | x, y = self.vertical_lines(self.selected_trials, 0, 3840) 29 | trial_curve = pg.PlotCurveItem() 30 | trial_curve.setData(x=x, y=y, pen=self.rpen_dot, connect='finite') 31 | trial_curve.setClickable(True) 32 | self.fig_img.addItem(trial_curve) 33 | self.fig_img.scene().sigMouseClicked.connect(self.on_mouse_clicked) 34 | trial_curve.sigClicked.connect(self.trial_line_clicked) 35 | 36 | 37 | def vertical_lines(self, x, ymin, ymax): 38 | 39 | x = np.tile(x, (3, 1)) 40 | x[2, :] = np.nan 41 | y = np.zeros_like(x) 42 | y[0, :] = ymin 43 | y[1, :] = ymax 44 | y[2, :] = np.nan 45 | 46 | return x.T.flatten(), y.T.flatten() 47 | 48 | def trial_line_clicked(self, ev): 49 | self.clicked = ev 50 | 51 | def on_mouse_clicked(self, event): 52 | if not event.double() and type(self.clicked) == pg.PlotCurveItem: 53 | self.pos = self.data_plot.mapFromScene(event.scenePos()) 54 | x = self.pos.x() * self.x_scale 55 | trial_id = np.argmin(np.abs(self.selected_trials - x)) 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | # highlight the trial in the trial gui 65 | 66 | 67 | 68 | self.clicked = None 69 | 70 | 71 | 72 | 73 | class TrialWindow(trial_window.MainWindow): 74 | def __init__(self): 75 | super(TrialWindow, self).__init__() 76 | self.alignment_gui = None 77 | self.scat = None 78 | 79 | def on_scatter_plot_clicked(self, scatter, point): 80 | super().on_scatter_plot_clicked(scatter, point) 81 | self.add_clust_scatter() 82 | 83 | def on_cluster_list_clicked(self): 84 | super().on_cluster_list_clicked() 85 | self.add_clust_scatter() 86 | 87 | def on_next_cluster_clicked(self): 88 | super().on_next_cluster_clicked() 89 | self.add_clust_scatter() 90 | 91 | def on_previous_cluster_clicked(self): 92 | super().on_previous_cluster_clicked() 93 | self.add_clust_scatter() 94 | 95 | def add_clust_scatter(self): 96 | if not self.scat: 97 | self.scat = pg.ScatterPlotItem() 98 | self.alignment_gui.fig_img.addItem(self.scat) 99 | 100 | self.scat.setData(self.data.spikes.times[self.data.clus_idx], 101 | self.data.spikes.depths[self.data.clus_idx], brush='r') 102 | 103 | # need to get out spikes.times and spikes.depths 104 | 105 | 106 | 107 | 108 | 109 | 110 | def load_data(eid, probe, one=None): 111 | one = one or ONE() 112 | session_path = one.path_from_eid(eid).joinpath('alf') 113 | probe_path = session_path.joinpath(probe) 114 | data = Bunch() 115 | data['trials'] = alfio.load_object(session_path, 'trials', namespace='ibl') 116 | data['spikes'] = alfio.load_object(probe_path, 'spikes') 117 | data['clusters'] = alfio.load_object(probe_path, 'clusters') 118 | 119 | return data 120 | 121 | def viewer(probe_id=None, one=None): 122 | """ 123 | """ 124 | probe = one.alyx.rest('insertions', 'list', id=probe_id)[0] 125 | data = load_data(probe['session'], probe['name'], one=one) 126 | qt.create_app() 127 | av = AlignmentWindow(probe_id=probe_id, one=one) 128 | bv = TrialWindow() 129 | bv.on_data_given(data) 130 | av.trial_gui = bv 131 | bv.alignment_gui = av 132 | 133 | av.show() 134 | bv.show() 135 | return av, bv 136 | 137 | -------------------------------------------------------------------------------- /viewspikes/plots.py: -------------------------------------------------------------------------------- 1 | # import modules 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import scipy.signal 5 | import pyqtgraph as pg 6 | 7 | import iblatlas.atlas as atlas 8 | from neuropixel import SITES_COORDINATES 9 | from ibllib.pipes.ephys_alignment import EphysAlignment 10 | from ibllib.plots import wiggle, color_cycle 11 | 12 | brain_atlas = atlas.AllenAtlas() 13 | # Instantiate brain atlas and one 14 | 15 | 16 | def show_psd(data, fs, ax=None): 17 | psd = np.zeros((data.shape[0], 129)) 18 | for tr in np.arange(data.shape[0]): 19 | f, psd[tr, :] = scipy.signal.welch(data[tr, :], fs=fs) 20 | 21 | if ax is None: 22 | fig, ax = plt.subplots() 23 | ax.plot(f, 10 * np.log10(psd.T), color='gray', alpha=0.1) 24 | ax.plot(f, 10 * np.log10(np.mean(psd, axis=0).T), color='red') 25 | ax.set_xlabel('Frequency (Hz)') 26 | ax.set_ylabel('PSD (dB rel V/Hz)') 27 | ax.set_ylim(-150, -110) 28 | ax.set_xlim(0, fs / 2) 29 | plt.show() 30 | 31 | 32 | def plot_insertion(pid, one=None): 33 | # Find eid of interest 34 | assert one 35 | insertion = one.alyx.rest('insertions', 'list', id=pid)[0] 36 | probe_label = insertion['name'] 37 | eid = insertion['session'] 38 | 39 | # Load in channels.localCoordinates dataset type 40 | chn_coords = one.load(eid, dataset_types=['channels.localCoordinates'])[0] 41 | depths = chn_coords[:, 1] 42 | 43 | # Find the ephys aligned trajectory for eid probe combination 44 | trajs = one.alyx.rest('trajectories', 'list', session=eid, probe=probe_label) 45 | #provenance=, 46 | traj_aligned = next(filter(lambda x: x['provenance'] == 'Ephys aligned histology track', trajs), None) 47 | if traj_aligned is None: 48 | raise NotImplementedError(f"Plots only aligned insertions so far - TODO") 49 | else: 50 | plot_alignment(insertion, traj_aligned, -1) 51 | # Extract all alignments from the json field of object 52 | # Load in the initial user xyz_picks obtained from track traccing 53 | 54 | 55 | def plot_alignment(insertion, traj, ind=None): 56 | depths = SITES_COORDINATES[:, 1] 57 | xyz_picks = np.array(insertion['json']['xyz_picks']) / 1e6 58 | 59 | alignments = traj['json'].copy() 60 | k = list(alignments.keys())[-1] # if only I had a Walrus available ! 61 | alignments = {k: alignments[k]} 62 | # Create a figure and arrange using gridspec 63 | widths = [1, 2.5] 64 | heights = [1] * len(alignments) 65 | gs_kw = dict(width_ratios=widths, height_ratios=heights) 66 | fig, axis = plt.subplots(len(alignments), 2, constrained_layout=True, 67 | gridspec_kw=gs_kw, figsize=(8, 9)) 68 | 69 | # Iterate over all alignments for trajectory 70 | # 1. Plot brain regions that channel pass through 71 | # 2. Plot coronal slice along trajectory with location of channels shown as red points 72 | # 3. Save results for each alignment into a dict - channels 73 | channels = {} 74 | for iK, key in enumerate(alignments): 75 | 76 | # Location of reference lines used for alignmnet 77 | feature = np.array(alignments[key][0]) 78 | track = np.array(alignments[key][1]) 79 | chn_coords = SITES_COORDINATES 80 | # Instantiate EphysAlignment object 81 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature) 82 | 83 | # Find xyz location of all channels 84 | xyz_channels = ephysalign.get_channel_locations(feature, track) 85 | # Find brain region that each channel is located in 86 | brain_regions = ephysalign.get_brain_locations(xyz_channels) 87 | # Add extra keys to store all useful information as one bunch object 88 | brain_regions['xyz'] = xyz_channels 89 | brain_regions['lateral'] = chn_coords[:, 0] 90 | brain_regions['axial'] = chn_coords[:, 1] 91 | 92 | # Store brain regions result in channels dict with same key as in alignment 93 | channel_info = {key: brain_regions} 94 | channels.update(channel_info) 95 | 96 | # For plotting -> extract the boundaries of the brain regions, as well as CCF label and colour 97 | region, region_label, region_colour, _ = ephysalign.get_histology_regions(xyz_channels, depths) 98 | 99 | # Make plot that shows the brain regions that channels pass through 100 | ax_regions = fig.axes[iK * 2] 101 | for reg, col in zip(region, region_colour): 102 | height = np.abs(reg[1] - reg[0]) 103 | bottom = reg[0] 104 | color = col / 255 105 | ax_regions.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w') 106 | ax_regions.set_yticks(region_label[:, 0].astype(int)) 107 | ax_regions.yaxis.set_tick_params(labelsize=8) 108 | ax_regions.get_xaxis().set_visible(False) 109 | ax_regions.set_yticklabels(region_label[:, 1]) 110 | ax_regions.spines['right'].set_visible(False) 111 | ax_regions.spines['top'].set_visible(False) 112 | ax_regions.spines['bottom'].set_visible(False) 113 | ax_regions.hlines([0, 3840], *ax_regions.get_xlim(), linestyles='dashed', linewidth=3, 114 | colors='k') 115 | # ax_regions.plot(np.ones(channel_depths_track.shape), channel_depths_track, '*r') 116 | 117 | # Make plot that shows coronal slice that trajectory passes through with location of channels 118 | # shown in red 119 | ax_slice = fig.axes[iK * 2 + 1] 120 | brain_atlas.plot_tilted_slice(xyz_channels, axis=1, ax=ax_slice) 121 | ax_slice.plot(xyz_channels[:, 0] * 1e6, xyz_channels[:, 2] * 1e6, 'r*') 122 | ax_slice.title.set_text(insertion['id'] + '\n' + str(key)) 123 | 124 | # Make sure the plot displays 125 | plt.show() 126 | 127 | 128 | def overlay_spikes(self, spikes, clusters, channels, rgb=None, label='default', symbol='x', size=8): 129 | first_sample = self.ctrl.model.t0 / self.ctrl.model.si 130 | last_sample = first_sample + self.ctrl.model.ns 131 | 132 | ifirst, ilast = np.searchsorted(spikes['samples'], [first_sample, last_sample]) 133 | tspi = spikes['samples'][ifirst:ilast].astype(np.float64) * self.ctrl.model.si 134 | xspi = channels['rawInd'][clusters['channels'][spikes['clusters'][ifirst:ilast]]] 135 | 136 | n_side_by_side = int(self.ctrl.model.ntr / 384) 137 | print(n_side_by_side) 138 | if n_side_by_side > 1: 139 | addx = (np.zeros([1, xspi.size]) + np.array([self.ctrl.model.ntr * r for r in range(n_side_by_side)])[:, np.newaxis]).flatten() 140 | xspi = np.tile(xspi, n_side_by_side) + addx 141 | yspi = np.tile(tspi, n_side_by_side) 142 | if self.ctrl.model.taxis == 1: 143 | self.ctrl.add_scatter(xspi, tspi, label=label) 144 | else: 145 | self.ctrl.add_scatter(tspi, xspi, label=label) 146 | sc = self.layers[label]['layer'] 147 | sc.setSize(size) 148 | sc.setSymbol(symbol) 149 | # sc.setPen(pg.mkPen((0, 255, 0, 155), width=1)) 150 | if rgb is None: 151 | rgbs = [list((rgb * 255).astype(np.uint8)) for rgb in color_cycle(spikes['clusters'][ifirst:ilast])] 152 | sc.setBrush([pg.mkBrush(rgb) for rgb in rgbs]) 153 | sc.setPen([pg.mkPen(rgb) for rgb in rgbs]) 154 | else: 155 | sc.setBrush(pg.mkBrush(rgb)) 156 | sc.setPen(pg.mkPen(rgb)) 157 | return sc, tspi, xspi 158 | 159 | # sc.setData(x=xspi, y=tspi, brush=pg.mkBrush((255, 0, 0))) 160 | def callback(sc, points, evt): 161 | NTR = 12 162 | NS = 128 163 | qxy = self.imageItem_seismic.mapFromScene(evt.scenePos()) 164 | # tr, _, _, s, _ = self.ctrl.cursor2timetraceamp(qxy) 165 | itr, itime = self.ctrl.cursor2ind(qxy) 166 | print(itr, itime) 167 | h = self.ctrl.model.header 168 | x = self.ctrl.model.data 169 | trsel = np.arange(np.max([0, itr - NTR]), np.min([self.ctrl.model.ntr, itr + NTR])) 170 | ordre = np.lexsort((h['x'][trsel], h['y'][trsel])) 171 | trsel = trsel[ordre] 172 | w = x[trsel, int(itime) - NS:int(itime) + NS] 173 | wiggle(-w.T * 10000, fs=1 / self.ctrl.model.si, t0=(itime - NS) * self.ctrl.model.si) 174 | # hw = {k:h[k][trsel] for k in h} 175 | 176 | # s.sigMouseClicked.disconnect() 177 | sc.sigClicked.connect(callback) 178 | # self.ctrl.remove_all_layers() 179 | -------------------------------------------------------------------------------- /viewspikes/raster.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 999 10 | 656 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | Qt::Horizontal 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 0 38 | 0 39 | 999 40 | 22 41 | 42 | 43 | 44 | 45 | File 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | open... 55 | 56 | 57 | 58 | 59 | 60 | PlotWidget 61 | QGraphicsView 62 |
pyqtgraph
63 |
64 |
65 | 66 | 67 |
68 | --------------------------------------------------------------------------------