├── .flake8
├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── atlaselectrophysiology
├── AdaptedAxisItem.py
├── ColorBar.py
├── README.md
├── __init__.py
├── alignment_with_easyqc.py
├── compare_alignments.py
├── create_overview_plots.py
├── ephys_atlas_gui.py
├── ephys_atlas_image.png
├── ephys_gui_setup.py
├── example_code.py
├── extract_files.py
├── get_scale_factor.py
├── load_data.py
├── load_data_local.py
├── load_histology.py
├── plot_data.py
├── qc_table.py
├── rendering.py
├── root.obj
├── sandbox.py
└── subject_scaling.py
├── atlasview
├── __init__.py
├── atlasview.py
├── channels_test.npy
├── launcher.py
├── region_values.npy
├── sliceview.ui
└── topview.ui
├── data_exploration_gui
├── README.md
├── cluster.py
├── cluster_class.py
├── data_class.py
├── data_explore_gui.py
├── data_model.py
├── filter.py
├── filter_class.py
├── gui_main.py
├── load_data.py
├── misc.py
├── misc_class.py
├── plot.py
├── plot_class.py
├── scatter.py
├── scatter_class.py
└── utils.py
├── dlc
├── DLC_labeled_video.py
├── Example_DLC_access.ipynb
├── README.md
├── __init__.py
├── get_dlc_traces.py
├── overview_plot_dlc.py
├── stream_dlc_labeled_frames.py
└── wheel_dlc_viewer.py
├── ephysfeatures
├── __init__.py
├── features_across_region.py
├── prepare_features.py
└── qrangeslider.py
├── histology
├── __init__.py
├── atlas_mpl.py
├── atlas_mpl.ui
└── transform-tracks.py
├── launch_phy
├── README.md
├── cluster_table.py
├── defined_metrics.py
├── metrics.py
├── phy_launcher.py
├── plugins
│ └── phy_plugin.py
└── populate_cluster_table.py
├── mesoscope_gui
└── gui.py
├── needles2
├── __init__.py
├── coverageUI.ui
├── layerUI.ui
├── mainUI.ui
├── needles_viewer.py
├── probe_model.py
├── regionUI.ui
├── run_needles2.py
├── spike_features.py
└── tableUI.ui
├── qt_helpers
├── __init__.py
├── qt.py
└── qt_matplotlib.py
├── requirements.txt
├── run_tests
├── setup.py
├── tests
├── __init__.py
├── fixtures
│ └── data_alignmentqc_gui.npz
├── test_alignment_qc_gui.py
└── test_task_qc_viewer.py
└── viewspikes
├── README.md
├── datoviz.py
├── example_view_ephys_session.py
├── gui.py
├── load_ma_data.py
├── plots.py
└── raster.ui
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 99
3 | ignore = W504
4 | exclude =
5 | .git,
6 | __pycache__,
7 | __init__.py,
8 | data_exploration_gui/
9 | launch_phy/
10 | histology/
11 | qt_matplotlib.py
12 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ master, develop ]
6 | pull_request:
7 | branches: [ master, develop ]
8 |
9 | jobs:
10 | build:
11 | name: build (${{ matrix.python-version }}, ${{ matrix.os }})
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | max-parallel: 3
15 | matrix:
16 | os: ["ubuntu-latest", "macos-latest", "windows-latest"]
17 | python-version: ["3.12"]
18 | steps:
19 | - uses: actions/checkout@v3
20 |
21 | - uses: conda-incubator/setup-miniconda@v2.0.0
22 | with:
23 | auto-update-conda: true
24 | python-version: ${{ matrix.python-version }}
25 |
26 | - name: Install requirements
27 | shell: bash -l {0}
28 | run: |
29 | conda activate test
30 | pip install --requirement requirements.txt
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # IBL apps stuff
2 | dlc/*.mp4
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # VSCode
135 | .vscode/
136 |
137 | scratch/
138 |
139 | # pycharm
140 | .idea/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 International Brain Laboratory
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # iblapps
2 | IBL related applications that rely on unsupported libraries such as pyqt5. See package README files for more details.
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/__init__.py
--------------------------------------------------------------------------------
/atlaselectrophysiology/ColorBar.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtGui
2 | import pyqtgraph as pg
3 | import matplotlib
4 | import numpy as np
5 | from pyqtgraph.functions import makeARGB
6 |
7 |
8 | class ColorBar(pg.GraphicsWidget):
9 |
10 | def __init__(self, cmap_name, cbin=256, parent=None, data=None):
11 | pg.GraphicsWidget.__init__(self)
12 |
13 | # Create colour map from matplotlib colourmap name
14 | self.cmap_name = cmap_name
15 | cmap = matplotlib.cm.get_cmap(self.cmap_name)
16 | if type(cmap) == matplotlib.colors.LinearSegmentedColormap:
17 | cbins = np.linspace(0.0, 1.0, cbin)
18 | colors = (cmap(cbins)[np.newaxis, :, :3][0]).tolist()
19 | else:
20 | colors = cmap.colors
21 | colors = [(np.array(c) * 255).astype(int).tolist() + [255.] for c in colors]
22 | positions = np.linspace(0, 1, len(colors))
23 | self.map = pg.ColorMap(positions, colors)
24 | self.lut = self.map.getLookupTable()
25 | self.grad = self.map.getGradient()
26 |
27 | def getBrush(self, data, levels=None):
28 | if levels is None:
29 | levels = [np.min(data), np.max(data)]
30 | brush_rgb, _ = makeARGB(data[:, np.newaxis], levels=levels, lut=self.lut, useRGBA=True)
31 | brush = [QtGui.QColor(*col) for col in np.squeeze(brush_rgb)]
32 | return brush
33 |
34 | def getColourMap(self):
35 | return self.lut
36 |
37 | def makeColourBar(self, width, height, fig, min=0, max=1, label='', lim=False):
38 | self.cbar = HorizontalBar(width, height, self.grad)
39 | ax = fig.getAxis('top')
40 | ax.setPen('k')
41 | ax.setTextPen('k')
42 | ax.setStyle(stopAxisAtTick=((True, True)))
43 | # labelStyle = {'font-size': '8pt'}
44 | ax.setLabel(label)
45 | ax.setHeight(30)
46 | if lim:
47 | ax.setTicks([[(0, str(np.around(min, 2))), (width, str(np.around(max, 2)))]])
48 | else:
49 | ax.setTicks([[(0, str(np.around(min, 2))), (width / 2,
50 | str(np.around(min + (max - min) / 2, 2))),
51 | (width, str(np.around(max, 2)))],
52 | [(width / 4, str(np.around(min + (max - min) / 4, 2))),
53 | (3 * width / 4, str(np.around(min + (3 * (max - min) / 4), 2)))]])
54 | fig.setXRange(0, width)
55 | fig.setYRange(0, height)
56 |
57 | return self.cbar
58 |
59 |
60 | class HorizontalBar(pg.GraphicsWidget):
61 | def __init__(self, width, height, grad):
62 | pg.GraphicsWidget.__init__(self)
63 | self.width = width
64 | self.height = height
65 | self.grad = grad
66 | QtGui.QPainter()
67 |
68 | def paint(self, p, *args):
69 | p.setPen(QtCore.Qt.NoPen)
70 | self.grad.setStart(0, self.height / 2)
71 | self.grad.setFinalStop(self.width, self.height / 2)
72 | p.setBrush(pg.QtGui.QBrush(self.grad))
73 | p.drawRect(QtCore.QRectF(0, 0, self.width, self.height))
74 |
75 |
76 | class VerticalBar(pg.GraphicsWidget):
77 | def __init__(self, width, height, grad):
78 | pg.GraphicsWidget.__init__(self)
79 | self.width = width
80 | self.height = height
81 | self.grad = grad
82 | QtGui.QPainter()
83 |
84 | def paint(self, p, *args):
85 | p.setPen(QtCore.Qt.NoPen)
86 | self.grad.setStart(self.width / 2, self.height)
87 | self.grad.setFinalStop(self.width / 2, 0)
88 | p.setBrush(pg.QtGui.QBrush(self.grad))
89 | p.drawRect(QtCore.QRectF(0, 0, self.width, self.height))
90 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/README.md:
--------------------------------------------------------------------------------
1 | # Ephys Atlas GUI
2 |
3 | GUI to allow user to align electrophysiology data with histology data. Please refer to this wiki page for information on installation and usage https://github.com/int-brain-lab/iblapps/wiki
4 |
5 |
6 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlaselectrophysiology/__init__.py
--------------------------------------------------------------------------------
/atlaselectrophysiology/compare_alignments.py:
--------------------------------------------------------------------------------
1 | # import modules
2 | from one.api import ONE
3 | from ibllib.pipes.ephys_alignment import EphysAlignment
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import matplotlib
7 | import iblatlas.atlas as atlas
8 | from pathlib import Path
9 | # Instantiate brain atlas and one
10 | brain_atlas = atlas.AllenAtlas(25)
11 | one = ONE()
12 |
13 | fig_path = Path('C:/Users/Mayo/Documents/PYTHON/alignment_figures/scale_factor')
14 | # Find eid of interest
15 | aligned_sess = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track',
16 | django='probe_insertion__session__project__name__icontains,'
17 | 'ibl_neuropixel_brainwide_01,'
18 | 'probe_insertion__session__qc__lt,30')
19 | eids = np.array([s['session']['id'] for s in aligned_sess])
20 | probes = np.array([s['probe_name'] for s in aligned_sess])
21 |
22 | json = [s['json'] for s in aligned_sess]
23 | idx_none = [i for i, val in enumerate(json) if val is None]
24 | json_val = np.delete(json, idx_none)
25 | keys = [list(s.keys()) for s in json_val]
26 |
27 | eids = np.delete(eids, idx_none)
28 | probes = np.delete(probes, idx_none)
29 |
30 | # Find index of json fields with 2 or more keys
31 | len_key = [len(s) for s in keys]
32 | idx_several = [i for i, val in enumerate(keys) if len(val) >= 2]
33 | eid_several = eids[idx_several]
34 | probe_several = probes[idx_several]
35 |
36 |
37 | # subject = 'KS023'
38 | # date = '2019-12-10'
39 | # sess_no = 1
40 | # probe_label = 'probe00'
41 | # eid = one.search(subject=subject, date=date, number=sess_no)[0]
42 | # eid = 'e2448a52-2c22-4ecc-bd48-632789147d9c'
43 |
44 | small_scaling = [["CSHL045", "2020-02-25", "probe00", "2020-06-12T10:40:22_noam.roth"],
45 | ["CSHL045", "2020-02-25", "probe01", "2020-09-12T23:43:03_petrina.lau"],
46 | ["CSHL047", "2020-01-20", "probe00", "2020-09-28T08:18:19_noam.roth"],
47 | ["CSHL047", "2020-01-27", "probe00", "2020-09-13T15:51:21_petrina.lau"],
48 | ["CSHL049", "2020-01-08", "probe00", "2020-09-14T15:44:56_nate"],
49 | ["CSHL051", "2020-02-05", "probe00", "2020-06-12T14:05:49_guido"],
50 | ["CSHL055", "2020-02-18", "probe00", "2020-08-13T12:07:08_jeanpaul"],
51 | ["CSH_ZAD_001", "2020-01-13", "probe00", "2020-09-22T17:23:50_petrina.lau"],
52 | ["KS014", "2019-12-03", "probe01", "2020-06-17T19:42:02_Karolina_Socha"],
53 | ["KS016", "2019-12-05", "probe01", "2020-06-18T10:26:54_Karolina_Socha"],
54 | ["KS020", "2020-02-06", "probe00", "2020-09-13T15:08:15_petrina.lau"],
55 | ["NYU-11", "2020-02-21", "probe01", "2020-09-13T11:19:59_petrina.lau"],
56 | ["SWC_014", "2019-12-15", "probe01", "2020-07-26T22:24:39_noam.roth"],
57 | ["ZM_2240", "2020-01-23", "probe00", "2020-06-05T14:57:46_guido"]]
58 |
59 | big_scaling = [["CSHL045", "2020-02-25", "probe00", "2020-06-12T10:40:22_noam.roth"],
60 | ["CSHL045", "2020-02-25", "probe01", "2020-09-12T23:43:03_petrina.lau"],
61 | ["CSH_ZAD_001", "2020-01-13", "probe00", "2020-09-22T17:23:50_petrina.lau"],
62 | ["KS014", "2019-12-03", "probe00", "2020-06-17T15:15:01_Karolina_Socha"],
63 | ["KS014", "2019-12-03", "probe01", "2020-06-17T19:42:02_Karolina_Socha"],
64 | ["KS014", "2019-12-04", "probe00", "2020-09-12T16:39:14_petrina.lau"],
65 | ["KS014", "2019-12-06", "probe00", "2020-06-17T13:40:00_Karolina_Socha"],
66 | ["KS014", "2019-12-07", "probe00", "2020-06-17T16:21:35_Karolina_Socha"],
67 | ["KS016", "2019-12-04", "probe00", "2020-08-13T14:02:44_jeanpaul"],
68 | ["KS016", "2019-12-05", "probe00", "2020-06-18T10:49:53_Karolina_Socha"],
69 | ["KS023", "2019-12-07", "probe00", "2020-09-09T14:21:35_nate"],
70 | ["KS023", "2019-12-07", "probe01", "2020-06-18T15:50:55_Karolina_Socha"],
71 | ["KS023", "2019-12-10", "probe01", "2020-06-12T13:59:02_guido"],
72 | ["KS023", "2019-12-11", "probe01", "2020-06-18T16:04:18_Karolina_Socha"],
73 | ["NYU-11", "2020-02-21", "probe01", "2020-09-13T11:19:59_petrina.lau"],
74 | ["NYU-12", "2020-01-22", "probe00", "2020-09-13T19:53:43_petrina.lau"],
75 | ["SWC_014", "2019-12-12", "probe00", "2020-07-27T11:27:16_noam.roth"],
76 | ["SWC_038", "2020-08-01", "probe01", "2020-08-31T12:32:05_nate"],
77 | ["ibl_witten_14", "2019-12-11", "probe00", "2020-06-14T15:33:45_noam.roth"]]
78 |
79 | for sess in small_scaling:
80 | eid = one.search(subject=sess[0], date=sess[1])[0]
81 | probe_label = sess[2]
82 | # for eid, probe_label in zip(eid_several, probe_several):
83 | trajectory = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track',
84 | session=eid, probe=probe_label)
85 |
86 | subject = trajectory[0]['session']['subject']
87 | date = trajectory[0]['session']['start_time'][0:10]
88 | chn_coords = one.load(eid, dataset_types=['channels.localCoordinates'])[0]
89 | depths = chn_coords[:, 1]
90 |
91 | insertion = one.alyx.rest('insertions', 'list', session=eid, name=probe_label)
92 | xyz_picks = np.array(insertion[0]['json']['xyz_picks']) / 1e6
93 |
94 | alignments = trajectory[0]['json']
95 |
96 | def plot_regions(region, label, colour, ax):
97 | for reg, col in zip(region, colour):
98 | height = np.abs(reg[1] - reg[0])
99 | color = col / 255
100 | ax.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
101 |
102 | ax.set_yticks(label[:, 0].astype(int))
103 | ax.set_yticklabels(label[:, 1])
104 | ax.yaxis.set_tick_params(labelsize=10)
105 | ax.tick_params(axis="y", direction="in", pad=-50)
106 | ax.set_ylim([20, 3840])
107 | ax.get_xaxis().set_visible(False)
108 |
109 | def plot_scaling(region, scale, mapper, ax):
110 | for reg, col in zip(region_scaled, scale_factor):
111 | height = np.abs(reg[1] - reg[0])
112 | color = np.array(mapper.to_rgba(col, bytes=True)) / 255
113 | ax.bar(x=1.1, height=height, width=0.2, color=color, bottom=reg[0], edgecolor='w')
114 |
115 | sec_ax = ax.secondary_yaxis('right')
116 | sec_ax.set_yticks(np.mean(region, axis=1))
117 | sec_ax.set_yticklabels(np.around(scale, 2))
118 | sec_ax.tick_params(axis="y", direction="in")
119 | sec_ax.set_ylim([20, 3840])
120 |
121 | fig, ax = plt.subplots(1, len(alignments) + 1, figsize=(15, 15))
122 | ephysalign = EphysAlignment(xyz_picks, depths, brain_atlas=brain_atlas)
123 | feature, track, _ = ephysalign.get_track_and_feature()
124 | channels_orig = ephysalign.get_channel_locations(feature, track)
125 | region, region_label = ephysalign.scale_histology_regions(feature, track)
126 | region_scaled, scale_factor = ephysalign.get_scale_factor(region)
127 | region_colour = ephysalign.region_colour
128 |
129 | norm = matplotlib.colors.Normalize(vmin=0.5, vmax=1.5, clip=True)
130 | mapper = matplotlib.cm.ScalarMappable(norm=norm, cmap=matplotlib.cm.seismic)
131 |
132 | ax_i = fig.axes[0]
133 | plot_regions(region, region_label, region_colour, ax_i)
134 | plot_scaling(region_scaled, scale_factor, mapper, ax_i)
135 | ax_i.set_title('Original')
136 |
137 | for iK, key in enumerate(alignments):
138 | # Location of reference lines used for alignmnet
139 | feature = np.array(alignments[key][0])
140 | track = np.array(alignments[key][1])
141 | user = key[20:]
142 | # Instantiate EphysAlignment object
143 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature,
144 | brain_atlas=brain_atlas)
145 |
146 | channels = ephysalign.get_channel_locations(feature, track)
147 | avg_dist = np.mean(np.sqrt(np.sum((channels - channels_orig) ** 2, axis=1)), axis=0)
148 | region, region_label = ephysalign.scale_histology_regions(feature, track)
149 | region_scaled, scale_factor = ephysalign.get_scale_factor(region)
150 |
151 | ax_i = fig.axes[iK + 1]
152 | plot_regions(region, region_label, region_colour, ax_i)
153 | plot_scaling(region_scaled, scale_factor, mapper, ax_i)
154 | ax_i.set_title(user + '\n Avg dist = ' + str(np.around(avg_dist * 1e6, 2)))
155 |
156 | fig.suptitle(subject + '_' + str(date) + '_' + probe_label, fontsize=16)
157 | plt.show()
158 | fig.savefig(fig_path.joinpath(subject + '_' + str(date) + '_' + probe_label + '.png'), dpi=100)
159 | plt.close(fig)
160 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/create_overview_plots.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from pathlib import Path
3 | import glob
4 |
5 |
6 | def make_overview_plot(folder, sess_info, save_folder=None):
7 |
8 | image_folder = folder
9 | image_info = sess_info
10 | if not save_folder:
11 | save_folder = image_folder
12 |
13 | def load_image(image_name, ax):
14 | with image_name as ifile:
15 | image = plt.imread(ifile)
16 |
17 | ax.spines['right'].set_visible(False)
18 | ax.spines['top'].set_visible(False)
19 | ax.spines['left'].set_visible(False)
20 | ax.spines['bottom'].set_visible(False)
21 | ax.set_axis_off()
22 | ax.set_aspect('equal')
23 | ax.imshow(image)
24 | return image
25 |
26 | fig = plt.figure(constrained_layout=True, figsize=(18, 9))
27 | gs = fig.add_gridspec(3, 18)
28 | gs.update(wspace=0.025, hspace=0.05)
29 |
30 | ignore_img_plots = ['leftGabor', 'rightGabor', 'noiseOn', 'valveOn', 'toneOn']
31 | img_row_order = [0, 0, 0, 0, 0, 0, 1, 1, 1]
32 | img_column_order = [0, 3, 6, 9, 12, 15, 0, 3, 6]
33 | img_idx = [0, 5, 4, 6, 7, 8, 1, 2, 3]
34 | img_files = glob.glob(str(image_folder.joinpath(image_info + 'img_*.png')))
35 | img_files = [img for img in img_files if not any([ig in img for ig in ignore_img_plots])]
36 | img_files_sort = [img_files[idx] for idx in img_idx]
37 |
38 | for iF, file in enumerate(img_files_sort):
39 | ax = fig.add_subplot(gs[img_row_order[iF], img_column_order[iF]:img_column_order[iF] + 3])
40 | load_image(Path(file), ax)
41 |
42 | ignore_probe_plots = ['RF Map']
43 | probe_row_order = [1, 1, 1, 1, 1, 1, 2, 2, 2]
44 | probe_column_order = [9, 10, 11, 12, 13, 14, 12, 13, 14]
45 | probe_idx = [0, 3, 1, 2, 4, 5, 6]
46 | probe_files = glob.glob(str(image_folder.joinpath(image_info + 'probe_*.png')))
47 | probe_files = [probe for probe in probe_files if not any([pr in probe for pr in
48 | ignore_probe_plots])]
49 | probe_files_sort = [probe_files[idx] for idx in probe_idx]
50 | line_files = glob.glob(str(image_folder.joinpath(image_info + 'line_*.png')))
51 |
52 | for iF, file in enumerate(probe_files_sort + line_files):
53 | ax = fig.add_subplot(gs[probe_row_order[iF], probe_column_order[iF]])
54 | load_image(Path(file), ax)
55 |
56 | slice_files = glob.glob(str(image_folder.joinpath(image_info + 'slice_*.png')))
57 | slice_row_order = [2, 2, 2, 2]
58 | slice_idx = [0, 1, 2, 3]
59 | slice_column_order = [0, 3, 6, 9]
60 | slice_files_sort = [slice_files[idx] for idx in slice_idx]
61 |
62 | for iF, file in enumerate(slice_files_sort):
63 | ax = fig.add_subplot(gs[slice_row_order[iF],
64 | slice_column_order[iF]:slice_column_order[iF] + 3])
65 | load_image(Path(file), ax)
66 |
67 | slice_files = glob.glob(str(image_folder.joinpath(image_info + 'slice_zoom*.png')))
68 | slice_row_order = [2, 2, 2, 2]
69 | slice_idx = [0, 1, 2, 3]
70 | slice_column_order = [2, 5, 8, 11]
71 | slice_files_sort = [slice_files[idx] for idx in slice_idx]
72 |
73 | for iF, file in enumerate(slice_files_sort):
74 | ax = fig.add_subplot(gs[slice_row_order[iF], slice_column_order[iF]])
75 | load_image(Path(file), ax)
76 |
77 | hist_files = glob.glob(str(image_folder.joinpath(image_info + 'hist*.png')))
78 | for iF, file in enumerate(hist_files):
79 | ax = fig.add_subplot(gs[1:3, 15:18])
80 | load_image(Path(file), ax)
81 |
82 | ax.text(0.5, 0, image_info[:-1], va="center", ha="center", transform=ax.transAxes)
83 | plt.savefig(save_folder.joinpath(image_info + "overview.png"),
84 | bbox_inches='tight', pad_inches=0)
85 | # plt.close()
86 | # plt.show()
87 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/ephys_atlas_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlaselectrophysiology/ephys_atlas_image.png
--------------------------------------------------------------------------------
/atlaselectrophysiology/example_code.py:
--------------------------------------------------------------------------------
1 | from one.api import ONE
2 | from atlaselectrophysiology.alignment_with_easyqc import viewer
3 |
4 | one = ONE()
5 | probe_id = 'ce397420-3cd2-4a55-8fd1-5e28321981f4'
6 |
7 |
8 | av = viewer(probe_id, one=one)
9 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/extract_files.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import logging
3 |
4 | from tqdm import tqdm
5 | import spikeglx
6 | import numpy as np
7 | from ibldsp import fourier, utils
8 | from scipy import signal
9 | import one.alf.io as alfio
10 | import ibllib.ephys.ephysqc as ephysqc
11 |
12 | from phylib.io import alf
13 |
14 | _logger = logging.getLogger('ibllib')
15 |
16 |
17 | RMS_WIN_LENGTH_SECS = 3
18 | WELCH_WIN_LENGTH_SAMPLES = 1024
19 |
20 |
21 | def rmsmap(fbin, spectra=True):
22 | """
23 | Computes RMS map in time domain and spectra for each channel of Neuropixel probe
24 |
25 | :param fbin: binary file in spike glx format (will look for attached metatdata)
26 | :type fbin: str or pathlib.Path
27 | :param spectra: whether to compute the power spectrum (only need for lfp data)
28 | :type: bool
29 | :return: a dictionary with amplitudes in channeltime space, channelfrequency space, time
30 | and frequency scales
31 | """
32 | if not isinstance(fbin, spikeglx.Reader):
33 | sglx = spikeglx.Reader(fbin)
34 | sglx.open()
35 | rms_win_length_samples = 2 ** np.ceil(np.log2(sglx.fs * RMS_WIN_LENGTH_SECS))
36 | # the window generator will generates window indices
37 | wingen = utils.WindowGenerator(ns=sglx.ns, nswin=rms_win_length_samples, overlap=0)
38 | # pre-allocate output dictionary of numpy arrays
39 | win = {'TRMS': np.zeros((wingen.nwin, sglx.nc)),
40 | 'nsamples': np.zeros((wingen.nwin,)),
41 | 'fscale': fourier.fscale(WELCH_WIN_LENGTH_SAMPLES, 1 / sglx.fs, one_sided=True),
42 | 'tscale': wingen.tscale(fs=sglx.fs)}
43 | win['spectral_density'] = np.zeros((len(win['fscale']), sglx.nc))
44 | # loop through the whole session
45 | with tqdm(total=wingen.nwin) as pbar:
46 | for first, last in wingen.firstlast:
47 | D = sglx.read_samples(first_sample=first, last_sample=last)[0].transpose()
48 | # remove low frequency noise below 1 Hz
49 | D = fourier.hp(D, 1 / sglx.fs, [0, 1])
50 | iw = wingen.iw
51 | win['TRMS'][iw, :] = utils.rms(D)
52 | win['nsamples'][iw] = D.shape[1]
53 | if spectra:
54 | # the last window may be smaller than what is needed for welch
55 | if last - first < WELCH_WIN_LENGTH_SAMPLES:
56 | continue
57 | # compute a smoothed spectrum using welch method
58 | _, w = signal.welch(
59 | D, fs=sglx.fs, window='hann', nperseg=WELCH_WIN_LENGTH_SAMPLES,
60 | detrend='constant', return_onesided=True, scaling='density', axis=-1
61 | )
62 | win['spectral_density'] += w.T
63 | # print at least every 20 windows
64 | if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0:
65 | pbar.update(iw)
66 |
67 | sglx.close()
68 | return win
69 |
70 |
71 | def extract_rmsmap(fbin, out_folder=None, spectra=True):
72 | """
73 | Wrapper for rmsmap that outputs _ibl_ephysRmsMap and _ibl_ephysSpectra ALF files
74 |
75 | :param fbin: binary file in spike glx format (will look for attached metatdata)
76 | :param out_folder: folder in which to store output ALF files. Default uses the folder in which
77 | the `fbin` file lives.
78 | :param spectra: whether to compute the power spectrum (only need for lfp data)
79 | :type: bool
80 | :return: None
81 | """
82 | _logger.info(f"Computing QC for {fbin}")
83 | sglx = spikeglx.Reader(fbin)
84 | # check if output ALF files exist already:
85 | if out_folder is None:
86 | out_folder = Path(fbin).parent
87 | else:
88 | out_folder = Path(out_folder)
89 | alf_object_time = f'ephysTimeRms{sglx.type.upper()}'
90 | alf_object_freq = f'ephysSpectralDensity{sglx.type.upper()}'
91 |
92 | # crunch numbers
93 | rms = rmsmap(fbin, spectra=spectra)
94 | # output ALF files, single precision with the optional label as suffix before extension
95 | if not out_folder.exists():
96 | out_folder.mkdir()
97 | tdict = {'rms': rms['TRMS'].astype(np.single), 'timestamps': rms['tscale'].astype(np.single)}
98 | alfio.save_object_npy(out_folder, object=alf_object_time, dico=tdict, namespace='iblqc')
99 | if spectra:
100 | fdict = {'power': rms['spectral_density'].astype(np.single),
101 | 'freqs': rms['fscale'].astype(np.single)}
102 | alfio.save_object_npy(
103 | out_folder, object=alf_object_freq, dico=fdict, namespace='iblqc')
104 |
105 |
106 | def _sample2v(ap_file):
107 | """
108 | Convert raw ephys data to Volts
109 | """
110 | md = spikeglx.read_meta_data(ap_file.with_suffix('.meta'))
111 | s2v = spikeglx._conversion_sample2v_from_meta(md)
112 | return s2v['ap'][0]
113 |
114 |
115 | def ks2_to_alf(ks_path, bin_path, out_path, bin_file=None, ampfactor=1, label=None, force=True):
116 | """
117 | Convert Kilosort 2 output to ALF dataset for single probe data
118 | :param ks_path:
119 | :param bin_path: path of raw data
120 | :param out_path:
121 | :return:
122 | """
123 | m = ephysqc.phy_model_from_ks2_path(ks2_path=ks_path, bin_path=bin_path, bin_file=bin_file)
124 | ac = alf.EphysAlfCreator(m)
125 | ac.convert(out_path, label=label, force=force, ampfactor=ampfactor)
126 |
127 | # set depths to spike_depths to catch cases where it can't be computed from pc features (e.g in case of KS3)
128 | m.depths = np.load(out_path.joinpath('spikes.depths.npy'))
129 | ephysqc.spike_sorting_metrics_ks2(ks_path, m, save=True, save_path=out_path)
130 |
131 |
132 | def extract_data(ks_path, ephys_path, out_path):
133 | efiles = spikeglx.glob_ephys_files(ephys_path)
134 |
135 | for efile in efiles:
136 | if efile.get('ap') and efile.ap.exists():
137 | ks2_to_alf(ks_path, ephys_path, out_path, bin_file=efile.ap,
138 | ampfactor=_sample2v(efile.ap), label=None, force=True)
139 |
140 | extract_rmsmap(efile.ap, out_folder=out_path, spectra=False)
141 | if efile.get('lf') and efile.lf.exists():
142 | extract_rmsmap(efile.lf, out_folder=out_path)
143 |
144 |
145 | # if __name__ == '__main__':
146 | #
147 | # ephys_path = Path('C:/Users/Mayo/Downloads/raw_ephys_data')
148 | # ks_path = Path('C:/Users/Mayo/Downloads/KS2')
149 | # out_path = Path('C:/Users/Mayo/Downloads/alf')
150 | # extract_data(ks_path, ephys_path, out_path)
151 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/get_scale_factor.py:
--------------------------------------------------------------------------------
1 | '''
2 | Extract channel locations from reference points of previous alignments saved in json field of
3 | trajectory object
4 | Create plot showing histology regions which channels pass through as well as coronal slice with
5 | channel locations shown
6 | '''
7 |
8 | # import modules
9 | from oneibl.one import ONE
10 | from ibllib.pipes.ephys_alignment import EphysAlignment
11 | import numpy as np
12 | import pandas as pd
13 | import matplotlib.pyplot as plt
14 | import seaborn as sns
15 | import iblatlas.atlas as atlas
16 |
17 |
18 | # Instantiate brain atlas and one
19 | brain_atlas = atlas.AllenAtlas(25)
20 | one = ONE()
21 |
22 | # Find eid of interest
23 | subject = 'KS022'
24 |
25 | # Find the ephys aligned trajectory for eid probe combination
26 | trajectory = one.alyx.rest('trajectories', 'list', provenance='Ephys aligned histology track',
27 | subject=subject)
28 |
29 | # Load in channels.localCoordinates dataset type
30 | chn_coords = one.load(trajectory[0]['session']['id'],
31 | dataset_types=['channels.localCoordinates'])[0]
32 | depths = chn_coords[:, 1]
33 |
34 | subject_summary = pd.DataFrame(columns={'Session', 'User', 'Scale Factor', 'Avg Scale Factor'})
35 | sf = []
36 | sess = []
37 | user = []
38 | for traj in trajectory:
39 | alignments = traj['json']
40 |
41 | insertion = one.alyx.rest('insertions', 'list', session=traj['session']['id'],
42 | name=traj['probe_name'])
43 | xyz_picks = np.array(insertion[0]['json']['xyz_picks']) / 1e6
44 |
45 | session_info = traj['session']['start_time'][:10] + '_' + traj['probe_name']
46 |
47 | for iK, key in enumerate(alignments):
48 | # Location of reference lines used for alignmnet
49 | feature = np.array(alignments[key][0])
50 | track = np.array(alignments[key][1])
51 | user = key[:19]
52 | # Instantiate EphysAlignment object
53 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature)
54 | region_scaled, _ = ephysalign.scale_histology_regions(feature, track)
55 | _, scale_factor = ephysalign.get_scale_factor(region_scaled)
56 |
57 | if np.all(np.round(np.diff(scale_factor), 3) == 0):
58 | # Case where there is no scaling but just an offset
59 | scale_factor = np.array([1])
60 | avg_sf = 1
61 | else:
62 | if feature.size > 4:
63 | # Case where 3 or more reference lines have been placed so take gradient of
64 | # linear fit to represent average scaling factor
65 | avg_sf = scale_factor[0]
66 | else:
67 | # Case where 2 reference lines have been used. Only have local scaling between
68 | # two reference lines, everywhere else scaling is 1. Use the local scaling as the
69 | # average scaling factor
70 | avg_sf = np.mean(scale_factor[1:-1])
71 |
72 | for iS, sf in enumerate(scale_factor):
73 | if iS == 0:
74 | subject_summary = subject_summary.append(
75 | {'Session': session_info, 'User': user, 'Scale Factor': sf,
76 | 'Avg Scale Factor': avg_sf}, ignore_index=True)
77 | else:
78 | subject_summary = subject_summary.append(
79 | {'Session': session_info, 'User': user, 'Scale Factor': sf,
80 | 'Avg Scale Factor': np.NaN}, ignore_index=True)
81 |
82 | fig, ax = plt.subplots(figsize=(10, 8))
83 | sns.swarmplot(x='Session', y='Scale Factor', hue='User', data=subject_summary, ax=ax)
84 | sns.swarmplot(x='Session', y='Avg Scale Factor', hue='User', size=8, linewidth=1,
85 | data=subject_summary, ax=ax)
86 | # ensures value in legend isn't repeated
87 | handles, labels = ax.get_legend_handles_labels()
88 | by_label = dict(zip(labels, handles))
89 | ax.legend(by_label.values(), by_label.keys())
90 |
91 | plt.show()
92 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/load_histology.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 | import requests
4 | import re
5 | from one import params
6 | from one.webclient import http_download_file
7 | import SimpleITK as sitk
8 |
9 |
10 | def download_histology_data(subject, lab):
11 |
12 | if lab == 'hoferlab':
13 | lab_temp = 'mrsicflogellab'
14 | else:
15 | lab_temp = lab
16 |
17 | par = params.get()
18 |
19 | try:
20 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject,
21 | 'downsampledStacks_25', 'sample2ARA')
22 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts))
23 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD))
24 | r.raise_for_status()
25 | except Exception as err:
26 | print(err)
27 | try:
28 | subject_rem = subject.replace("_", "")
29 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject_rem,
30 | 'downsampledStacks_25', 'sample2ARA')
31 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts))
32 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD))
33 | r.raise_for_status()
34 | except Exception as err:
35 | if lab_temp == 'churchlandlab_ucla':
36 | try:
37 | lab_temp = 'churchlandlab'
38 | FLAT_IRON_HIST_REL_PATH = Path('histology', lab_temp, subject,
39 | 'downsampledStacks_25', 'sample2ARA')
40 | baseurl = (par.HTTP_DATA_SERVER + '/' + '/'.join(FLAT_IRON_HIST_REL_PATH.parts))
41 | r = requests.get(baseurl, auth=(par.HTTP_DATA_SERVER_LOGIN, par.HTTP_DATA_SERVER_PWD))
42 | r.raise_for_status()
43 | except Exception as err:
44 | print(err)
45 | path_to_nrrd = None
46 | return path_to_nrrd
47 | else:
48 | print(err)
49 | path_to_nrrd = None
50 | return path_to_nrrd
51 |
52 |
53 | tif_files = []
54 | for line in r.text.splitlines():
55 | result = re.findall('href="(.*).tif"', line)
56 | if result:
57 | tif_files.append(result[0] + '.tif')
58 |
59 | CACHE_DIR = params.get_cache_dir().joinpath(lab, 'Subjects', subject, 'histology')
60 | CACHE_DIR.mkdir(exist_ok=True, parents=True)
61 | path_to_files = []
62 | for file in tif_files:
63 | path_to_image = Path(CACHE_DIR, file)
64 | if not path_to_image.exists():
65 | url = (baseurl + '/' + file)
66 | http_download_file(url, target_dir=CACHE_DIR,
67 | username=par.HTTP_DATA_SERVER_LOGIN,
68 | password=par.HTTP_DATA_SERVER_PWD)
69 |
70 | path_to_nrrd = tif2nrrd(path_to_image)
71 | path_to_files.append(path_to_nrrd)
72 |
73 | if len(path_to_files) > 3:
74 | path_to_files = path_to_files[1:3]
75 |
76 | return path_to_files
77 |
78 |
79 | def tif2nrrd(path_to_image):
80 | path_to_nrrd = Path(path_to_image).with_suffix('.nrrd')
81 | if not path_to_nrrd.exists():
82 | reader = sitk.ImageFileReader()
83 | reader.SetImageIO("TIFFImageIO")
84 | reader.SetFileName(str(path_to_image))
85 | img = reader.Execute()
86 |
87 | new_img = sitk.PermuteAxes(img, [2, 1, 0])
88 | new_img = sitk.Flip(new_img, [True, False, False])
89 | new_img.SetSpacing([1, 1, 1])
90 | writer = sitk.ImageFileWriter()
91 | writer.SetImageIO("NrrdImageIO")
92 | writer.SetFileName(str(path_to_nrrd))
93 | writer.Execute(new_img)
94 |
95 | return path_to_nrrd
96 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/qc_table.py:
--------------------------------------------------------------------------------
1 | import datajoint as dj
2 |
3 | schema = dj.schema('group_shared_ephys')
4 |
5 |
6 | @schema
7 | class EphysQC(dj.Imported):
8 | definition = """
9 | probe_insertion_uuid: uuid # probe insertion uuid
10 | -> reference.LabMember # user name
11 | ---
12 | alignment_qc=null: enum('high', 'medium', 'low') # confidence in alignment
13 | ephys_qc=null: enum('pass', 'critical', 'warning') # quality of ephys
14 | ephys_qc_description=null: varchar(255) #Description for ephys_qc
15 | """
16 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/rendering.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import cv2
5 | import vtk
6 | from matplotlib import pyplot as plt # noqa
7 | import mayavi.mlab as mlab
8 |
9 |
10 | def add_mesh(fig, obj_file, color=(1., 1., 1.), opacity=0.4):
11 | """
12 | Adds a mesh object from an *.obj file to the mayavi figure
13 | :param fig: mayavi figure
14 | :param obj_file: full path to a local *.obj file
15 | :param color: rgb tuple of floats between 0 and 1
16 | :param opacity: float between 0 and 1
17 | :return: vtk actor
18 | """
19 | reader = vtk.vtkOBJReader()
20 | reader.SetFileName(str(obj_file))
21 | reader.Update()
22 | mapper = vtk.vtkPolyDataMapper()
23 | mapper.SetInputConnection(reader.GetOutputPort())
24 | actor = vtk.vtkActor()
25 | actor.SetMapper(mapper)
26 | actor.GetProperty().SetOpacity(opacity)
27 | actor.GetProperty().SetColor(color)
28 | fig.scene.add_actor(actor)
29 | fig.scene.render()
30 | return mapper, actor
31 |
32 |
33 | def figure(grid=False, **kwargs):
34 | """
35 | Creates a mayavi figure with the brain atlas mesh
36 | :return: mayavi figure
37 | """
38 | fig = mlab.figure(bgcolor=(1, 1, 1), **kwargs)
39 | # engine = mlab.get_engine() # Returns the running mayavi engine.
40 | obj_file = Path(__file__).parent.joinpath("root.obj")
41 | mapper, actor = add_mesh(fig, obj_file)
42 |
43 | if grid:
44 | # https://vtk.org/Wiki/VTK/Examples/Python/Visualization/CubeAxesActor
45 | cubeAxesActor = vtk.vtkCubeAxesActor()
46 | cubeAxesActor.SetMapper(mapper)
47 | cubeAxesActor.SetBounds(mapper.GetBounds())
48 | cubeAxesActor.SetCamera(fig.scene.renderer._vtk_obj.GetActiveCamera())
49 | cubeAxesActor.SetXTitle("AP (um)")
50 | cubeAxesActor.SetYTitle("DV (um)")
51 | cubeAxesActor.SetZTitle("ML (um)")
52 | cubeAxesActor.GetTitleTextProperty(0).SetColor(1.0, 0.0, 0.0)
53 | cubeAxesActor.GetLabelTextProperty(0).SetColor(1.0, 0.0, 0.0)
54 | cubeAxesActor.GetTitleTextProperty(1).SetColor(0.0, 1.0, 0.0)
55 | cubeAxesActor.GetLabelTextProperty(1).SetColor(0.0, 1.0, 0.0)
56 | cubeAxesActor.GetTitleTextProperty(2).SetColor(0.0, 0.0, 1.0)
57 | cubeAxesActor.GetLabelTextProperty(2).SetColor(0.0, 0.0, 1.0)
58 | cubeAxesActor.DrawXGridlinesOn()
59 | cubeAxesActor.DrawYGridlinesOn()
60 | cubeAxesActor.DrawZGridlinesOn()
61 | fig.scene.add_actor(cubeAxesActor)
62 |
63 | mlab.view(azimuth=180, elevation=0)
64 | mlab.view(azimuth=210, elevation=210, reset_roll=False)
65 |
66 | return fig
67 |
68 |
69 | def rotating_video(output_file, mfig, fps=12, secs=6):
70 | # ffmpeg -i input.webm -pix_fmt rgb24 output.gif
71 | # ffmpeg -i certification.webm -vf scale=640:-1 -r 10 -f image2pipe -vcodec ppm - | convert -delay 5 -loop 0 - certification.gif # noqa
72 |
73 | file_video = Path(output_file)
74 | if file_video.suffix == '.avi':
75 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
76 | elif file_video.suffix == '.webm':
77 | fourcc = cv2.VideoWriter_fourcc(*'VP80')
78 | else:
79 | NotImplementedError(f"Extension {file_video.suffix} not supported")
80 |
81 | mlab.view(azimuth=180, elevation=0)
82 | mfig.scene.render()
83 | mfig.scene._lift()
84 | frame = mlab.screenshot(figure=mfig, mode='rgb', antialiased=True)
85 | w, h, _ = frame.shape
86 | video = cv2.VideoWriter(str(file_video), fourcc, float(fps), (h, w))
87 |
88 | # import time
89 | for e in np.linspace(-180, 180, secs * fps):
90 | # frame = np.random.randint(0, 256, (w, h, 3), dtype=np.uint8)
91 | mlab.view(azimuth=0, elevation=e, reset_roll=False)
92 | mfig.scene.render()
93 | frame = mlab.screenshot(figure=mfig, mode='rgb', antialiased=True)
94 | print(e, (h, w), frame.shape)
95 | video.write(np.flip(frame, axis=2)) # bgr instead of rgb...
96 | # time.sleep(0.05)
97 |
98 | video.release()
99 | cv2.destroyAllWindows()
100 |
--------------------------------------------------------------------------------
/atlaselectrophysiology/sandbox.py:
--------------------------------------------------------------------------------
1 | from brainbox.io.one import load_spike_sorting, load_channel_locations
2 | from oneibl.one import ONE
3 |
4 | one = ONE(base_url="https://dev.alyx.internationalbrainlab.org")
5 | eids = one.search(subject='ZM_2407', task_protocol='ephys')
6 |
7 | channels = load_channel_locations(eids[0], one=one)
8 | spikes, clusters = load_spike_sorting(eids[0], one=one)
9 |
--------------------------------------------------------------------------------
/atlasview/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/__init__.py
--------------------------------------------------------------------------------
/atlasview/channels_test.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/channels_test.npy
--------------------------------------------------------------------------------
/atlasview/launcher.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | from atlasview import atlasview # mouais il va falloir changer ça
4 | av = atlasview.view() # need to have an output argument here or the garbage collector will clean
5 | # it up and boom
6 |
7 | """ Roadmap
8 | - swap volumes combox (label RGB option / density)
9 | - overlay brain regions with transparencye
10 | - overlay volumes (for example coverage with transparency)
11 | - overlay plots: probes channels
12 | - tilted slices
13 | - coordinate swaps: add Allen / Needles / Voxel options
14 | - should we add horizontal slices ?
15 | """
16 |
17 | # add brain regions feature:
18 | reg_values = np.load(Path(atlasview.__file__).parent.joinpath('region_values.npy'))
19 | av.add_regions_feature(reg_values, 'Blues', opacity=0.7)
20 |
21 | # add scatter feature:
22 | chans = np.load(
23 | Path(atlasview.__file__).parent.joinpath('channels_test.npy'), allow_pickle=True)
24 | av.add_scatter_feature(chans)
25 |
--------------------------------------------------------------------------------
/atlasview/region_values.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/atlasview/region_values.npy
--------------------------------------------------------------------------------
/atlasview/sliceview.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | SliceWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 1101
10 | 737
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 | -
18 |
19 |
20 | -
21 |
22 |
-
23 |
24 |
25 |
26 | 100
27 | 0
28 |
29 |
30 |
31 |
32 | 120
33 | 16777215
34 |
35 |
36 |
37 | QFrame::StyledPanel
38 |
39 |
40 | QFrame::Raised
41 |
42 |
43 |
-
44 |
45 |
46 |
47 | 70
48 | 0
49 |
50 |
51 |
52 |
53 | 80
54 | 20
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 | -
63 |
64 |
65 |
66 | 15
67 | 0
68 |
69 |
70 |
71 |
72 | 10
73 | 20
74 |
75 |
76 |
77 | y
78 |
79 |
80 |
81 | -
82 |
83 |
84 |
85 | 15
86 | 0
87 |
88 |
89 |
90 |
91 | 10
92 | 20
93 |
94 |
95 |
96 | ix
97 |
98 |
99 |
100 | -
101 |
102 |
103 |
104 | 15
105 | 0
106 |
107 |
108 |
109 |
110 | 10
111 | 20
112 |
113 |
114 |
115 | iy
116 |
117 |
118 |
119 | -
120 |
121 |
122 |
123 | 70
124 | 0
125 |
126 |
127 |
128 |
129 | 80
130 | 20
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 | -
139 |
140 |
141 |
142 | 15
143 | 0
144 |
145 |
146 |
147 |
148 | 10
149 | 20
150 |
151 |
152 |
153 | x
154 |
155 |
156 |
157 | -
158 |
159 |
160 |
161 | 70
162 | 0
163 |
164 |
165 |
166 |
167 | 80
168 | 20
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 | -
177 |
178 |
179 |
180 | 70
181 | 0
182 |
183 |
184 |
185 |
186 | 80
187 | 20
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 | -
196 |
197 |
198 | Qt::Vertical
199 |
200 |
201 |
202 | 20
203 | 40
204 |
205 |
206 |
207 |
208 | -
209 |
210 |
211 |
212 |
213 |
214 |
215 | -
216 |
217 |
218 | v
219 |
220 |
221 |
222 | -
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 | -
233 |
234 |
235 | Qt::Vertical
236 |
237 |
238 | QSizePolicy::Fixed
239 |
240 |
241 |
242 | 20
243 | 5
244 |
245 |
246 |
247 |
248 | -
249 |
250 |
251 | -
252 |
253 |
254 | Qt::Horizontal
255 |
256 |
257 | QSizePolicy::Maximum
258 |
259 |
260 |
261 | 5
262 | 17
263 |
264 |
265 |
266 |
267 | -
268 |
269 |
270 | Qt::Horizontal
271 |
272 |
273 | QSizePolicy::Fixed
274 |
275 |
276 |
277 | 5
278 | 20
279 |
280 |
281 |
282 |
283 |
284 |
285 | -
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 | 0
297 | 0
298 | 3
299 | 22
300 |
301 |
302 |
303 |
304 |
305 |
306 | PlotWidget
307 | QGraphicsView
308 |
309 |
310 |
311 |
312 |
313 |
314 |
--------------------------------------------------------------------------------
/atlasview/topview.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | TopView
4 |
5 |
6 |
7 | 0
8 | 0
9 | 435
10 | 461
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 |
18 | -
19 |
20 |
21 | -
22 |
23 |
24 |
25 | 100
26 | 0
27 |
28 |
29 |
30 | QFrame::StyledPanel
31 |
32 |
33 | QFrame::Raised
34 |
35 |
36 |
37 |
38 | 20
39 | 20
40 | 67
41 | 17
42 |
43 |
44 |
45 | Annotation
46 |
47 |
48 |
49 |
50 |
51 | 40
52 | 50
53 | 16
54 | 160
55 |
56 |
57 |
58 | 20
59 |
60 |
61 | Qt::Vertical
62 |
63 |
64 |
65 |
66 |
67 | 10
68 | 220
69 | 67
70 | 17
71 |
72 |
73 |
74 | Image
75 |
76 |
77 |
78 |
79 |
80 | 0
81 | 350
82 | 93
83 | 27
84 |
85 |
86 |
87 |
88 |
89 |
90 | 10
91 | 330
92 | 67
93 | 17
94 |
95 |
96 |
97 | Mapping
98 |
99 |
100 |
101 |
102 |
103 |
104 |
114 |
115 |
116 |
117 |
118 | PlotWidget
119 | QGraphicsView
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/data_exploration_gui/README.md:
--------------------------------------------------------------------------------
1 | # Data Exploration GUI
2 |
3 | GUI to allow user to explore ephys data from IBL task
4 |
5 | ## Setup
6 |
7 | Install ibl environment following [these instructions](https://github.com/int-brain-lab/iblenv#iblenv-installation-guide)
8 |
9 | Go to the ```data_exploration_gui``` folder
10 |
11 | ```
12 | cd iblapps/data_exploration_gui
13 | ```
14 |
15 | ### Using GUI
16 | To launch the gui you should run the following from the command line. You can specify either a probe insertion id
17 | e.g
18 | ```
19 | python data_explore_gui.py -pid 9657af01-50bd-4120-8303-416ad9e24a51
20 | ```
21 |
22 | or an eid and probe name, e.g
23 | ```
24 | python data_explore_gui.py -eid 7f6b86f9-879a-4ea2-8531-294a221af5d0 -name probe00
25 | ```
26 |
--------------------------------------------------------------------------------
/data_exploration_gui/cluster.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtGui, QtWidgets
2 | from data_exploration_gui import utils
3 |
4 |
5 | class ClusterGroup:
6 |
7 | def __init__(self):
8 |
9 | self.cluster_buttons = QtWidgets.QButtonGroup()
10 | self.cluster_group = QtWidgets.QGroupBox('Sort Clusters By:')
11 | self.cluster_layout = QtWidgets.QHBoxLayout()
12 | for i, val in enumerate(utils.SORT_CLUSTER_OPTIONS):
13 | button = QtWidgets.QRadioButton(val)
14 | if i == 0:
15 | button.setChecked(True)
16 | else:
17 | button.setChecked(False)
18 | self.cluster_buttons.addButton(button)
19 | self.cluster_layout.addWidget(button)
20 |
21 | self.cluster_group.setLayout(self.cluster_layout)
22 |
23 | self.cluster_colours = QtWidgets.QGroupBox()
24 | h_layout_col = QtWidgets.QHBoxLayout()
25 | h_layout_lab = QtWidgets.QHBoxLayout()
26 | for val in utils.UNIT_OPTIONS:
27 | img = QtWidgets.QLabel()
28 | label = QtWidgets.QLabel(val)
29 | pix = QtGui.QPixmap(40, 5)
30 | pix.fill(utils.colours[val])
31 | img.setPixmap(pix)
32 | h_layout_lab.addWidget(img)
33 | h_layout_col.addWidget(label)
34 |
35 | v_layout = QtWidgets.QVBoxLayout()
36 | v_layout.addLayout(h_layout_lab)
37 | v_layout.addLayout(h_layout_col)
38 |
39 | self.cluster_colours.setLayout(v_layout)
40 |
41 | self.cluster_list = QtWidgets.QListWidget()
42 |
43 | self.cluster_next_button = QtWidgets.QPushButton('Next')
44 | self.cluster_next_button.setFixedSize(90, 30)
45 | self.cluster_previous_button = QtWidgets.QPushButton('Previous')
46 | self.cluster_previous_button.setFixedSize(90, 30)
47 |
48 | self.cluster_list_group = QtWidgets.QGroupBox()
49 |
50 | group_layout = QtWidgets.QGridLayout()
51 | group_layout.addWidget(self.cluster_group, 0, 0, 1, 3)
52 | group_layout.addWidget(self.cluster_colours, 1, 0, 1, 3)
53 | group_layout.addWidget(self.cluster_list, 2, 0, 3, 3)
54 | group_layout.addWidget(self.cluster_previous_button, 5, 0)
55 | group_layout.addWidget(self.cluster_next_button, 5, 2)
56 | self.cluster_list_group.setLayout(group_layout)
57 |
58 | self.reset()
59 |
60 | def reset(self):
61 | self.cluster_list.clear()
62 | self.clust_colour_ks = None
63 | self.clust_colour_ibl = None
64 | self.clust_ids = None
65 | self.clust = None
66 | self.clust_prev = None
67 |
68 | def populate(self, clust_ids, clust_colour_ks, clust_colour_ibl):
69 | self.clust_ids = clust_ids
70 | self.clust_colour_ks = clust_colour_ks
71 | self.clust_colour_ibl = clust_colour_ibl
72 |
73 | for idx, val in enumerate(clust_ids):
74 | item = QtWidgets.QListWidgetItem('Cluster Number ' + str(val))
75 | icon = utils.get_icon(self.clust_colour_ibl[idx], self.clust_colour_ks[idx], 20)
76 | item.setIcon(QtGui.QIcon(icon))
77 | self.cluster_list.addItem(item)
78 | self.cluster_list.setCurrentRow(0)
79 |
80 | def update_list_icon(self):
81 | item = self.cluster_list.item(self.clust_prev)
82 | icon = utils.get_icon(self.clust_colour_ibl[self.clust_prev],
83 | self.clust_colour_ks[self.clust_prev], 12)
84 | item.setIcon(QtGui.QIcon(icon))
85 | item.setText(' Cluster Number ' + str(self.clust_ids[self.clust_prev]))
86 |
87 | def update_current_row(self, clust):
88 | self.clust = clust
89 | self.cluster_list.setCurrentRow(self.clust)
90 |
91 | def on_next_cluster_clicked(self):
92 | self.clust += 1
93 | self.cluster_list.setCurrentRow(self.clust)
94 | return self.clust
95 |
96 | def on_previous_cluster_clicked(self):
97 | self.clust -= 1
98 | self.cluster_list.setCurrentRow(self.clust)
99 | return self.clust
100 |
101 | def on_cluster_list_clicked(self):
102 | self.clust = self.cluster_list.currentRow()
103 | return self.clust
104 |
105 | def update_cluster_index(self):
106 | self.clust_prev = self.clust
107 | return self.clust_prev
108 |
109 | def initialise_cluster_index(self):
110 | self.clust = 0
111 | self.clust_prev = 0
112 |
113 | return self.clust, self.clust_prev
114 |
115 | def get_selected_sort(self):
116 | return self.cluster_buttons.checkedButton().text()
117 |
--------------------------------------------------------------------------------
/data_exploration_gui/cluster_class.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtGui, QtWidgets
2 |
3 |
4 | class ClusterGroup:
5 |
6 | def __init__(self):
7 |
8 | self.cluster_buttons = QtWidgets.QButtonGroup()
9 | self.cluster_group = QtWidgets.QGroupBox('Sort Clusters By:')
10 | self.cluster_layout = QtWidgets.QHBoxLayout()
11 | cluster_options = ['cluster no.', 'no. of spikes', 'good units']
12 | for i, val in enumerate(cluster_options):
13 | button = QtWidgets.QRadioButton(val)
14 | if i == 0:
15 | button.setChecked(True)
16 | else:
17 | button.setChecked(False)
18 | self.cluster_buttons.addButton(button, id=i)
19 | self.cluster_layout.addWidget(button)
20 |
21 | self.cluster_group.setLayout(self.cluster_layout)
22 |
23 | self.cluster_list = QtWidgets.QListWidget()
24 | self.cluster_list.SingleSelection
25 |
26 | self.cluster_next_button = QtWidgets.QPushButton('Next')
27 | self.cluster_next_button.setFixedSize(90, 30)
28 | self.cluster_previous_button = QtWidgets.QPushButton('Previous')
29 | self.cluster_previous_button.setFixedSize(90, 30)
30 |
31 | self.cluster_list_group = QtWidgets.QGroupBox()
32 | self.cluster_list_group.setFixedSize(400, 200)
33 | self.group_widget()
34 |
35 | self.clust_colour = []
36 | self.clust_ids = []
37 | self.clust = []
38 | self.clust_prev = []
39 |
40 |
41 | def group_widget(self):
42 | group_layout = QtWidgets.QGridLayout()
43 | group_layout.addWidget(self.cluster_group, 0, 0, 1, 3)
44 | group_layout.addWidget(self.cluster_list, 1, 0, 3, 3)
45 | group_layout.addWidget(self.cluster_previous_button, 4, 0)
46 | group_layout.addWidget(self.cluster_next_button, 4, 2)
47 | self.cluster_list_group.setLayout(group_layout)
48 |
49 | def reset(self):
50 | self.cluster_list.clear()
51 | #self.cluster_option1.setChecked(True)
52 | self.clust_colour = []
53 | self.clust_ids = []
54 | self.clust = []
55 | self.clust_prev = []
56 |
57 | def populate(self, clust_ids, clust_colour):
58 | self.clust_ids = clust_ids
59 | self.clust_colour = clust_colour
60 | icon = QtGui.QPixmap(20, 20)
61 | for idx, val in enumerate(clust_ids):
62 | item = QtWidgets.QListWidgetItem('Cluster Number ' + str(val))
63 | icon.fill(self.clust_colour[idx])
64 | item.setIcon(QtGui.QIcon(icon))
65 | self.cluster_list.addItem(item)
66 | self.cluster_list.setCurrentRow(0)
67 |
68 | def update_list_icon(self):
69 | item = self.cluster_list.item(self.clust_prev)
70 | icon = QtGui.QPixmap(10, 10)
71 | icon.fill(self.clust_colour[self.clust_prev])
72 | item.setIcon(QtGui.QIcon(icon))
73 | item.setText(' Cluster Number ' + str(self.clust_ids[self.clust_prev]))
74 |
75 | def update_current_row(self, clust):
76 | self.clust = clust
77 | self.cluster_list.setCurrentRow(self.clust)
78 |
79 | def on_next_cluster_clicked(self):
80 | self.clust += 1
81 | self.cluster_list.setCurrentRow(self.clust)
82 | return self.clust
83 |
84 | def on_previous_cluster_clicked(self):
85 | self.clust -= 1
86 | self.cluster_list.setCurrentRow(self.clust)
87 | return self.clust
88 |
89 | def on_cluster_list_clicked(self):
90 | self.clust = self.cluster_list.currentRow()
91 | return self.clust
92 |
93 | def update_cluster_index(self):
94 | self.clust_prev = self.clust
95 | return self.clust_prev
96 |
97 | def initialise_cluster_index(self):
98 | self.clust = 0
99 | self.clust_prev = 0
100 |
101 | return self.clust, self.clust_prev
102 |
--------------------------------------------------------------------------------
/data_exploration_gui/data_class.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtGui, QtWidgets
2 | import numpy as np
3 | import os
4 | import one.alf.io as alfio
5 | from brainbox.processing import get_units_bunch, compute_cluster_average
6 | from brainbox.population.decode import xcorr
7 | from brainbox.singlecell import calculate_peths
8 | from brainbox.io.spikeglx import extract_waveforms
9 | from pathlib import Path
10 |
11 |
12 | class DataGroup:
13 | def __init__(self):
14 |
15 | self.waveform_button = QtWidgets.QPushButton('Generate Waveform')
16 | self.waveform_list = QtWidgets.QListWidget()
17 | self.waveform_list.SingleSelection
18 | self.waveform_text = QtWidgets.QLabel('No. of spikes =')
19 |
20 | waveform_layout = QtWidgets.QGridLayout()
21 | waveform_layout.addWidget(self.waveform_button, 0, 0)
22 | waveform_layout.addWidget(self.waveform_text, 1, 0)
23 | waveform_layout.addWidget(self.waveform_list, 0, 1, 3, 1)
24 |
25 | self.waveform_group = QtWidgets.QGroupBox()
26 | self.waveform_group.setLayout(waveform_layout)
27 |
28 |
29 | #For peths and rasters
30 | self.t_before = 0.4
31 | self.t_after = 1
32 |
33 | #For autocorrolelogram
34 | self.autocorr_window = 0.1
35 | self.autocorr_bin = 0.001
36 |
37 |
38 | #For waveform (N.B in ms)
39 | self.waveform_window = 2
40 | self.CAR = False
41 |
42 | def load(self, folder_path):
43 | self.folder_path = folder_path
44 | self.find_files()
45 | self.load_data()
46 | self.compute_timescales()
47 |
48 | return self.ephys_file_path, self.gui_path
49 |
50 | def find_files(self):
51 | self.probe_path = self.folder_path
52 | self.alf_path = self.folder_path.parent
53 | dir_path = self.folder_path.parent.parent
54 | self.gui_path = os.path.join(self.probe_path, 'gui')
55 | probe = os.path.split(self.probe_path)[1]
56 | ephys_path = os.path.join(dir_path, 'raw_ephys_data', probe)
57 |
58 | self.ephys_file_path = []
59 |
60 | try:
61 | for i in os.listdir(ephys_path):
62 | if 'ap' in i and 'bin' in i:
63 | self.ephys_file_path = os.path.join(ephys_path, i)
64 | except:
65 | self.ephys_file_path = []
66 |
67 | def load_data(self):
68 | self.spikes = alfio.load_object(self.probe_path, 'spikes')
69 | self.trials = alfio.load_object(self.alf_path, 'trials')
70 | self.clusters = alfio.load_object(self.probe_path, 'clusters')
71 | self.prepare_data(self.spikes, self.clusters, self.trials)
72 | # self.ids = np.unique(self.spikes.clusters)
73 | # self.metrics = np.array(self.clusters.metrics.ks2_label[self.ids])
74 | # self.colours = np.array(self.clusters.metrics.ks2_label[self.ids])
75 | # self.colours[np.where(self.colours == 'mua')[0]] = QtGui.QColor('#fdc086')
76 | # self.colours[np.where(self.colours == 'good')[0]] = QtGui.QColor('#7fc97f')
77 | #
78 | # file_count = 0
79 | # if os.path.isdir(self.gui_path):
80 | # for i in os.listdir(self.gui_path):
81 | # if 'depth' in i:
82 | # self.depths = np.load(Path(self.gui_path + '/cluster_depths.npy'))
83 | # file_count += 1
84 | # elif 'amp' in i:
85 | # self.amps = np.load(Path(self.gui_path + '/cluster_amps.npy'))
86 | # file_count += 1
87 | # elif 'nspikes' in i:
88 | # self.nspikes = np.load(Path(self.gui_path + '/cluster_nspikes.npy'))
89 | # file_count += 1
90 | # if file_count != 3:
91 | # self.compute_depth_and_amplitudes()
92 | # else:
93 | # os.mkdir(self.gui_path)
94 | # self.compute_depth_and_amplitudes()
95 | #
96 | # self.sort_by_id = np.arange(len(self.ids))
97 | # self.sort_by_nspikes = np.argsort(self.nspikes)
98 | # self.sort_by_nspikes = self.sort_by_nspikes[::-1]
99 | # self.sort_by_good = np.append(np.where(self.metrics == 'good')[0], np.where(self.metrics == 'mua')[0])
100 | # self.n_trials = len(self.trials['contrastLeft'])
101 |
102 | def prepare_data(self, spikes, clusters, trials):
103 | self.spikes = spikes
104 | self.clusters = clusters
105 | self.trials = trials
106 | self.ids = np.unique(spikes.clusters)
107 | self.metrics = np.array(clusters.metrics.ks2_label[self.ids])
108 | self.colours = np.array(clusters.metrics.ks2_label[self.ids])
109 | self.colours[np.where(self.colours == 'mua')[0]] = QtGui.QColor('#fdc086')
110 | self.colours[np.where(self.colours == 'good')[0]] = QtGui.QColor('#7fc97f')
111 | _, self.depths, self.nspikes = compute_cluster_average(spikes.clusters, spikes.depths)
112 | _, self.amps, _ = compute_cluster_average(spikes.clusters, spikes.amps)
113 | self.amps = self.amps * 1e6
114 | self.sort_by_id = np.arange(len(self.ids))
115 | self.sort_by_nspikes = np.argsort(self.nspikes)
116 | self.sort_by_nspikes = self.sort_by_nspikes[::-1]
117 | self.sort_by_good = np.append(np.where(self.metrics == 'good')[0],
118 | np.where(self.metrics == 'mua')[0])
119 | self.n_trials = len(trials['contrastLeft'])
120 |
121 |
122 | def compute_depth_and_amplitudes(self):
123 | units_b = get_units_bunch(self.spikes)
124 | self.depths = []
125 | self.amps = []
126 | self.nspikes = []
127 | for clu in self.ids:
128 | self.depths = np.append(self.depths, np.nanmean(units_b.depths[str(clu)]))
129 | self.amps = np.append(self.amps, np.nanmean(units_b.amps[str(clu)]) * 1e6)
130 | self.nspikes = np.append(self.nspikes, len(units_b.amps[str(clu)]))
131 |
132 | np.save((self.gui_path + '/cluster_depths'), self.depths)
133 | np.save((self.gui_path + '/cluster_amps'), self.amps)
134 | np.save((self.gui_path + '/cluster_nspikes'), self.nspikes)
135 |
136 | def sort_data(self, order):
137 | self.clust_ids = self.ids[order]
138 | self.clust_amps = self.amps[order]
139 | self.clust_depths = self.depths[order]
140 | self.clust_colours = self.colours[order]
141 |
142 | return self.clust_ids, self.clust_amps, self.clust_depths, self.clust_colours
143 |
144 | def reset(self):
145 | self.waveform_list.clear()
146 |
147 | def populate(self, clust):
148 | self.reset()
149 | self.clus_idx = np.where(self.spikes.clusters == self.clust_ids[clust])[0]
150 |
151 | if len(self.clus_idx) <= 500:
152 | self.spk_intervals = [0, len(self.clus_idx)]
153 | else:
154 | self.spk_intervals = np.arange(0, len(self.clus_idx), 500)
155 | self.spk_intervals = np.append(self.spk_intervals, len(self.clus_idx))
156 |
157 | for idx in range(0, len(self.spk_intervals) - 1):
158 | item = QtWidgets.QListWidgetItem(str(self.spk_intervals[idx]) + ':' + str(self.spk_intervals[idx+1]))
159 | self.waveform_list.addItem(item)
160 |
161 | self.waveform_list.setCurrentRow(0)
162 | self.n_waveform = 0
163 |
164 | def compute_peth(self, trial_type, clust, trials_id):
165 | peths, bin = calculate_peths(self.spikes.times, self.spikes.clusters,
166 | [self.clust_ids[clust]], self.trials[trial_type][trials_id], self.t_before, self.t_after)
167 |
168 | peth_mean = peths.means[0, :]
169 | peth_std = peths.stds[0, :] / np.sqrt(len(trials_id))
170 | t_peth = peths.tscale
171 |
172 | return t_peth, peth_mean, peth_std
173 |
174 | def compute_rasters(self, trial_type, clust, trials_id):
175 | self.x = np.empty(0)
176 | self.y = np.empty(0)
177 | spk_times = self.spikes.times[self.spikes.clusters == self.clust_ids[clust]]
178 | for idx, val in enumerate(self.trials[trial_type][trials_id]):
179 | spks_to_include = np.bitwise_and(spk_times >= val - self.t_before, spk_times <= val + self.t_after)
180 | trial_spk_times = spk_times[spks_to_include]
181 | trial_spk_times_aligned = trial_spk_times - val
182 | trial_no = (np.ones(len(trial_spk_times_aligned))) * idx * 10
183 | self.x = np.append(self.x, trial_spk_times_aligned)
184 | self.y = np.append(self.y, trial_no)
185 |
186 | return self.x, self.y, self.n_trials
187 |
188 | def compute_autocorr(self, clust):
189 | self.clus_idx = np.where(self.spikes.clusters == self.clust_ids[clust])[0]
190 |
191 | x_corr = xcorr(self.spikes.times[self.clus_idx], self.spikes.clusters[self.clus_idx],
192 | self.autocorr_bin, self.autocorr_window)
193 |
194 | corr = x_corr[0, 0, :]
195 |
196 | return self.t_autocorr, corr
197 |
198 | def compute_template(self, clust):
199 | template = (self.clusters.waveforms[self.clust_ids[clust], :, 0]) * 1e6
200 | return self.t_template, template
201 |
202 | def compute_waveform(self, clust):
203 | if len(self.ephys_file_path) != 0:
204 | spk_times = self.spikes.times[self.clus_idx][self.spk_intervals[self.n_waveform]:self.spk_intervals[self.n_waveform + 1]]
205 | max_ch = self.clusters['channels'][self.clust_ids[clust]]
206 | wf = extract_waveforms(self.ephys_file_path, spk_times, max_ch, t = self.waveform_window, car = self.CAR)
207 | wf_mean = np.mean(wf[:,:,0], axis = 0)
208 | wf_std = np.std(wf[:,:,0], axis = 0)
209 |
210 | return self.t_waveform, wf_mean, wf_std
211 |
212 | def compute_timescales(self):
213 | self.t_autocorr = np.arange((self.autocorr_window/2) - self.autocorr_window, (self.autocorr_window/2)+ self.autocorr_bin, self.autocorr_bin)
214 | sr = 30000
215 | n_template = len(self.clusters.waveforms[self.ids[0], :, 0])
216 | self.t_template = 1e3 * (np.arange(n_template)) / sr
217 |
218 | n_waveform = int(sr / 1000 * (self.waveform_window))
219 | self.t_waveform = 1e3 * (np.arange(n_waveform)) / sr
220 | #self.t_waveform = 1e3 * (np.arange(n_waveform/2 - n_waveform, n_waveform/2, 1))/sr
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------
/data_exploration_gui/filter.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtWidgets, QtGui
2 | import numpy as np
3 | from data_exploration_gui import utils
4 |
5 |
6 | class FilterGroup:
7 | def __init__(self):
8 |
9 | # button to reset filters to default
10 | self.reset_button = QtWidgets.QPushButton('Reset Filters')
11 |
12 | # checkboxes for contrasts
13 | self.contrasts = utils.CONTRAST_OPTIONS
14 | self.contrast_buttons = QtWidgets.QButtonGroup()
15 | self.contrast_buttons.setExclusive(False)
16 | self.contrast_group = QtWidgets.QGroupBox('Stimulus Contrast')
17 | self.contrast_layout = QtWidgets.QVBoxLayout()
18 | for val in self.contrasts:
19 | button = QtWidgets.QCheckBox(str(val * 100))
20 | button.setCheckState(QtCore.Qt.Checked)
21 | self.contrast_buttons.addButton(button)
22 | self.contrast_layout.addWidget(button)
23 |
24 | self.contrast_group.setLayout(self.contrast_layout)
25 |
26 | # button for whether to overlay trial plots
27 | self.hold_button = QtWidgets.QCheckBox('Hold')
28 | self.hold_button.setCheckState(QtCore.Qt.Checked)
29 |
30 | # checkboxes for trial options
31 | self.trial_buttons = QtWidgets.QButtonGroup()
32 | # Just for now
33 | self.trial_group = QtWidgets.QGroupBox('Trial Options')
34 | self.trial_layout = QtWidgets.QVBoxLayout()
35 | for val in utils.TRIAL_OPTIONS:
36 | button = QtWidgets.QCheckBox(val)
37 | if val == utils.TRIAL_OPTIONS[0]:
38 | button.setCheckState(QtCore.Qt.Checked)
39 | else:
40 | button.setCheckState(QtCore.Qt.Unchecked)
41 | self.trial_buttons.addButton(button)
42 | self.trial_layout.addWidget(button)
43 |
44 | self.trial_colours = QtWidgets.QVBoxLayout()
45 | for val in utils.TRIAL_OPTIONS:
46 | img = QtWidgets.QLabel()
47 | pix = QtGui.QPixmap(40, 5)
48 | pix.fill(utils.colours[val])
49 | img.setPixmap(pix)
50 | self.trial_colours.addWidget(img)
51 |
52 | hlayout = QtWidgets.QHBoxLayout()
53 | hlayout.addLayout(self.trial_layout)
54 | hlayout.addLayout(self.trial_colours)
55 |
56 | self.trial_group.setLayout(hlayout)
57 |
58 | # radio buttons for order options
59 | self.order_buttons = QtWidgets.QButtonGroup()
60 | self.order_group = QtWidgets.QGroupBox('Order Trials By:')
61 | self.order_layout = QtWidgets.QVBoxLayout()
62 | for val in utils.ORDER_OPTIONS:
63 | button = QtWidgets.QRadioButton(val)
64 | if val == utils.ORDER_OPTIONS[0]:
65 | button.setChecked(True)
66 | else:
67 | button.setChecked(False)
68 | self.order_buttons.addButton(button)
69 | self.order_layout.addWidget(button)
70 |
71 | self.order_group.setLayout(self.order_layout)
72 |
73 | # radio buttons for sort options
74 | self.sort_buttons = QtWidgets.QButtonGroup()
75 | self.sort_group = QtWidgets.QGroupBox('Sort Trials By:')
76 | self.sort_layout = QtWidgets.QVBoxLayout()
77 | for val in utils.SORT_OPTIONS:
78 | button = QtWidgets.QRadioButton(val)
79 | if val == utils.SORT_OPTIONS[0]:
80 | button.setChecked(True)
81 | else:
82 | button.setChecked(False)
83 | self.sort_buttons.addButton(button)
84 | self.sort_layout.addWidget(button)
85 | self.sort_group.setLayout(self.sort_layout)
86 |
87 | self.event_group = QtWidgets.QGroupBox('Align to:')
88 | self.event_list = QtGui.QStandardItemModel()
89 | self.event_combobox = QtWidgets.QComboBox()
90 | self.event_combobox.setModel(self.event_list)
91 | self.event_layout = QtWidgets.QVBoxLayout()
92 | self.event_layout.addWidget(self.event_combobox)
93 | self.event_group.setLayout(self.event_layout)
94 |
95 | self.behav_group = QtWidgets.QGroupBox('Show behaviour:')
96 | self.behav_list = QtGui.QStandardItemModel()
97 | self.behav_combobox = QtWidgets.QComboBox()
98 | self.behav_combobox.setModel(self.behav_list)
99 | self.behav_layout = QtWidgets.QVBoxLayout()
100 | self.behav_layout.addWidget(self.behav_combobox)
101 | self.behav_group.setLayout(self.behav_layout)
102 |
103 | # Now group everything into one big box
104 | self.filter_options_group = QtWidgets.QGroupBox()
105 | group_layout = QtWidgets.QVBoxLayout()
106 | group_layout.addWidget(self.reset_button)
107 | group_layout.addWidget(self.contrast_group)
108 | group_layout.addWidget(self.hold_button)
109 | group_layout.addWidget(self.trial_group)
110 | group_layout.addWidget(self.order_group)
111 | group_layout.addWidget(self.sort_group)
112 | group_layout.addWidget(self.event_group)
113 | group_layout.addWidget(self.behav_group)
114 | self.filter_options_group.setLayout(group_layout)
115 |
116 | def populate(self, event_options, behav_options):
117 | self.populate_lists(event_options, self.event_list, self.event_combobox)
118 | self.populate_lists(behav_options, self.behav_list, self.behav_combobox)
119 |
120 | def populate_lists(self, data, list_name, combobox):
121 |
122 | list_name.clear()
123 | for dat in data:
124 | item = QtGui.QStandardItem(dat)
125 | item.setEditable(False)
126 | list_name.appendRow(item)
127 |
128 | # This makes sure the drop down menu is wide enough to show full length of string
129 | min_width = combobox.fontMetrics().width(max(data, key=len))
130 | min_width += combobox.view().autoScrollMargin()
131 | min_width += combobox.style().pixelMetric(QtWidgets.QStyle.PM_ScrollBarExtent)
132 | combobox.view().setMinimumWidth(min_width)
133 |
134 | # Set the default to be the first option
135 | combobox.setCurrentIndex(0)
136 |
137 | def set_selected_event(self, event):
138 | for it in range(self.event_combobox.count()):
139 | if self.event_combobox.itemText(it) == event:
140 | self.event_combobox.setCurrentIndex(it)
141 |
142 | def get_selected_event(self):
143 | return self.event_combobox.currentText()
144 |
145 | def get_selected_behaviour(self):
146 | return self.behav_combobox.currentText()
147 |
148 | def get_selected_contrasts(self):
149 | contrasts = []
150 | for button in self.contrast_buttons.buttons():
151 | if button.isChecked():
152 | contrasts.append(np.float(button.text()) / 100)
153 |
154 | return contrasts
155 |
156 | def get_selected_trials(self):
157 | trials = []
158 | for button in self.trial_buttons.buttons():
159 | if button.isChecked():
160 | trials.append(button.text())
161 | return trials
162 |
163 | def get_selected_order(self):
164 | return self.order_buttons.checkedButton().text()
165 |
166 | def get_selected_sort(self):
167 | return self.sort_buttons.checkedButton().text()
168 |
169 | def get_hold_status(self):
170 | return self.hold_button.isChecked()
171 |
172 | def get_selected_filters(self):
173 | contrasts = self.get_selected_contrasts()
174 | order = self.get_selected_order()
175 | sort = self.get_selected_sort()
176 | hold = self.get_hold_status()
177 |
178 | return contrasts, order, sort, hold
179 |
180 | def set_sorted_button(self, sort):
181 | for button in self.sort_buttons.buttons():
182 | if button.text() == sort:
183 | button.setChecked(True)
184 | else:
185 | button.setChecked(False)
186 |
187 | def reset_filters(self, contrasts=True):
188 |
189 | if contrasts:
190 | for button in self.contrast_buttons.buttons():
191 | button.setChecked(True)
192 |
193 | for button in self.trial_buttons.buttons():
194 | if button.text() == utils.TRIAL_OPTIONS[0]:
195 | #if not button.isChecked():
196 | button.setChecked(True)
197 | else:
198 | button.setChecked(False)
199 |
200 | for button in self.order_buttons.buttons():
201 | if button.text() == utils.ORDER_OPTIONS[0]:
202 | button.setChecked(True)
203 | else:
204 | button.setChecked(False)
205 |
206 | for button in self.sort_buttons.buttons():
207 | if button.text() == utils.SORT_OPTIONS[0]:
208 | button.setChecked(True)
209 | else:
210 | button.setChecked(False)
--------------------------------------------------------------------------------
/data_exploration_gui/load_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from one.api import ONE
3 | import numpy as np
4 |
5 | parser = argparse.ArgumentParser(description='Load in subject info')
6 |
7 | parser.add_argument('-s', '--subject', default=False, required=False,
8 | help='Subject Name')
9 | parser.add_argument('-d', '--date', default=False, required=False,
10 | help='Date of session YYYY-MM-DD')
11 | parser.add_argument('-n', '--session_no', default=1, required=False,
12 | help='Session Number', type=int)
13 | parser.add_argument('-e', '--eid', default=False, required=False,
14 | help='Session eid')
15 | parser.add_argument('-p', '--probe_label', default=False, required=True,
16 | help='Probe Label')
17 |
18 | args = parser.parse_args()
19 |
20 | one = ONE()
21 | if not args.eid:
22 | if not np.all(np.array([args.subject, args.date, args.session_no],
23 | dtype=object)):
24 | print('Must give Subject, Date and Session number')
25 | else:
26 | eid = one.search(subject=str(args.subject), date=str(args.date), number=args.session_no)[0]
27 | print(eid)
28 | else:
29 | eid = str(args.eid)
30 |
31 | _ = one.load_object(eid, obj='trials', collection='alf',
32 | download_only=True)
33 | _ = one.load_object(eid, obj='spikes', attribute=['times', 'clusters', 'amps', 'depths'],
34 | collection=f'alf/{str(args.probe_label)}', download_only=True)
35 | _ = one.load_object(eid, obj='clusters', collection=f'alf/{str(args.probe_label)}',
36 | download_only=True)
37 |
--------------------------------------------------------------------------------
/data_exploration_gui/misc.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtWidgets
2 | from data_exploration_gui import utils
3 |
4 | class MiscGroup:
5 | def __init__(self):
6 |
7 | self.qc_group = QtWidgets.QGroupBox()
8 | self.dlc_warning_label = QtWidgets.QLabel()
9 | self.sess_qc_group = QtWidgets.QHBoxLayout()
10 | self.clust_qc_group = QtWidgets.QHBoxLayout()
11 |
12 | self.sess_qc_labels = []
13 | self.clust_qc_labels = []
14 |
15 | for qc in utils.SESS_QC:
16 | qc_label = QtWidgets.QLabel(f'{qc}:')
17 | self.sess_qc_group.addWidget(qc_label)
18 | self.sess_qc_labels.append(qc_label)
19 |
20 | for qc in utils.CLUSTER_QC:
21 | qc_label = QtWidgets.QLabel(f'{qc}: ')
22 | self.clust_qc_group.addWidget(qc_label)
23 | self.clust_qc_labels.append(qc_label)
24 |
25 | vlayout = QtWidgets.QVBoxLayout()
26 | vlayout.addWidget(self.dlc_warning_label)
27 | vlayout.addLayout(self.sess_qc_group)
28 | vlayout.addLayout(self.clust_qc_group)
29 | self.qc_group.setLayout(vlayout)
30 |
31 | def set_sess_qc_text(self, data):
32 | for label in self.sess_qc_labels:
33 | title = label.text().split(':')[0]
34 | text = title + ': ' + str(data[title])
35 | label.setText(text)
36 |
37 | def set_clust_qc_text(self, data):
38 | for label in self.clust_qc_labels:
39 | title = label.text().split(':')[0]
40 | text = title + ': ' + str(data[title])
41 | label.setText(text)
42 |
43 | def set_dlc_label(self, aligned):
44 | if not aligned:
45 | self.dlc_warning_label.setText(utils.dlc_warning)
--------------------------------------------------------------------------------
/data_exploration_gui/misc_class.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtGui, QtWidgets
2 |
3 | import pandas as pd
4 |
5 |
6 | class MiscGroup:
7 | def __init__(self):
8 |
9 | folder_prompt = QtWidgets.QLabel('Select Folder')
10 | self.folder_line = QtWidgets.QLineEdit()
11 | self.folder_button = QtWidgets.QToolButton()
12 | self.folder_button.setText('...')
13 |
14 | folder_layout = QtWidgets.QHBoxLayout()
15 | folder_layout.addWidget(folder_prompt)
16 | folder_layout.addWidget(self.folder_line)
17 | folder_layout.addWidget(self.folder_button)
18 |
19 | self.folder_group = QtWidgets.QGroupBox()
20 | self.folder_group.setLayout(folder_layout)
21 |
22 | clust_list1_title = QtWidgets.QLabel('Go Cue Aligned Clusters')
23 | self.clust_list1 = QtWidgets.QListWidget()
24 | self.remove_clust_button1 = QtWidgets.QPushButton('Remove')
25 |
26 | clust_list2_title = QtWidgets.QLabel('Feedback Aligned Clusters')
27 | self.clust_list2 = QtWidgets.QListWidget()
28 | self.remove_clust_button2 = QtWidgets.QPushButton('Remove')
29 | self.save_clust_button = QtWidgets.QPushButton('Save Clusters')
30 |
31 | clust_interest_layout = QtWidgets.QGridLayout()
32 | clust_interest_layout.addWidget(clust_list1_title, 0, 0)
33 | clust_interest_layout.addWidget(self.remove_clust_button1, 0, 1)
34 | clust_interest_layout.addWidget(self.clust_list1, 1, 0, 1, 2)
35 | clust_interest_layout.addWidget(clust_list2_title, 2, 0)
36 | clust_interest_layout.addWidget(self.remove_clust_button2, 2, 1)
37 | clust_interest_layout.addWidget(self.clust_list2, 3, 0, 1, 2)
38 | clust_interest_layout.addWidget(self.save_clust_button, 4, 0)
39 |
40 | self.clust_interest = QtWidgets.QGroupBox()
41 | self.clust_interest.setLayout(clust_interest_layout)
42 | self.clust_interest.setFixedSize(250, 350)
43 |
44 | self.terminal = QtWidgets.QTextBrowser()
45 | self.terminal_button = QtWidgets.QPushButton('Clear')
46 | self.terminal_button.setFixedSize(70, 30)
47 | self.terminal_button.setParent(self.terminal)
48 | self.terminal_button.move(580, 0)
49 |
50 |
51 | def on_save_button_clicked(self, gui_path):
52 | save_file = QtWidgets.QFileDialog()
53 | save_file.setAcceptMode(QtWidgets.QFileDialog.AcceptSave)
54 | file = save_file.getSaveFileName(None, 'Save File', gui_path, 'CSV (*.csv)')
55 | print(file[0])
56 | if file[0]:
57 | save_clusters = pd.DataFrame(columns=['go_Cue_aligned', 'feedback_aligned'])
58 |
59 | go_cue_clusters = []
60 | for idx in range(self.clust_list1.count()):
61 | cluster_no = self.clust_list1.item(idx).text()
62 | go_cue_clusters.append(int(cluster_no[8:]))
63 |
64 | go_cue = pd.Series(go_cue_clusters)
65 | save_clusters['go_Cue_aligned'] = go_cue
66 |
67 | feedback_clusters = []
68 | for idx in range(self.clust_list2.count()):
69 | cluster_no = self.clust_list2.item(idx).text()
70 | feedback_clusters.append(int(cluster_no[8:]))
71 |
72 | feedback = pd.Series(feedback_clusters)
73 | save_clusters['feedback_aligned'] = feedback
74 |
75 | save_clusters.to_csv(file[0], index=None, header=True)
76 |
77 | return file[0]
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/data_exploration_gui/scatter.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore
2 | import pyqtgraph as pg
3 | import numpy as np
4 |
5 |
6 | class ScatterGroup:
7 |
8 | def __init__(self):
9 |
10 | self.fig_scatter = pg.PlotWidget(background='w')
11 | self.fig_scatter.setMouseTracking(True)
12 | self.fig_scatter.scene().sigMouseMoved.connect(self.on_mouse_hover)
13 | self.fig_scatter.setLabel('bottom', 'Cluster Amplitude (uV)')
14 | self.fig_scatter.setLabel('left', 'Cluster Distance From Tip (uV)')
15 |
16 | self.scatter_plot = pg.ScatterPlotItem()
17 | self.scatter_text = pg.TextItem(color='k')
18 | self.scatter_text.hide()
19 |
20 | self.fig_scatter.addItem(self.scatter_plot)
21 | self.fig_scatter.addItem(self.scatter_text)
22 |
23 | self.reset()
24 |
25 | def reset(self):
26 | self.scatter_plot.setData()
27 | self.clust_amps = None
28 | self.clust_depths = None
29 | self.clust_color_ks = None
30 | self.clust_color_ibl = None
31 | self.clust_ids = None
32 | self.point_pos = None
33 | self.point_pos_prev = None
34 |
35 | def populate(self, clust_amps, clust_depths, clust_ids, clust_color_ks, clust_color_ibl):
36 | self.clust_amps = clust_amps
37 | self.clust_depths = clust_depths
38 | self.clust_color_ks = clust_color_ks
39 | self.clust_color_ibl = clust_color_ibl
40 | self.clust_ids = clust_ids
41 | self.scatter_plot.setData(x=self.clust_amps, y=self.clust_depths, size=10)
42 | self.scatter_plot.setBrush(clust_color_ibl)
43 | self.scatter_plot.setPen(clust_color_ks)
44 | self.x_min = np.min(self.clust_amps) - 2
45 | self.x_max = np.max(self.clust_amps) + 2
46 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max)
47 | self.fig_scatter.setYRange(min=0, max=4000)
48 |
49 | def on_scatter_plot_clicked(self, point):
50 | self.point_pos = point[0].pos()
51 | clust = np.argwhere(self.scatter_plot.data['x'] == self.point_pos.x())[0][0]
52 | return clust
53 |
54 | def on_mouse_hover(self, pos):
55 | mouse_pos = self.scatter_plot.mapFromScene(pos)
56 | point = self.scatter_plot.pointsAt(mouse_pos)
57 |
58 | if len(point) != 0:
59 | point_pos = point[0].pos()
60 | clust = np.argwhere(self.scatter_plot.data['x'] == point_pos.x())[0][0]
61 | self.scatter_text.setText('Cluster no. ' + str(self.clust_ids[clust]))
62 | self.scatter_text.setPos(mouse_pos.x(), (mouse_pos.y()))
63 | self.scatter_text.show()
64 | else:
65 | self.scatter_text.hide()
66 |
67 | def update_scatter_icon(self, clust_prev):
68 | point = self.scatter_plot.pointsAt(self.point_pos)
69 | if len(point) > 1:
70 | p_diff = []
71 | for p in point:
72 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x()))
73 | idx = np.argmin(p_diff)
74 | point[idx].setBrush('k')
75 | point[idx].setPen('k')
76 | else:
77 | point[0].setBrush('k')
78 | point[0].setPen('k')
79 |
80 | point_prev = self.scatter_plot.pointsAt(self.point_pos_prev)
81 | if len(point_prev) > 1:
82 | p_diff = []
83 | for p in point_prev:
84 | p_diff.append(np.abs(p.pos().x() - self.point_pos_prev.x()))
85 | idx = np.argmin(p_diff)
86 | point_prev[idx].setPen(self.clust_color_ks[clust_prev])
87 | point_prev[idx].setBrush(self.clust_color_ibl[clust_prev])
88 | point_prev[idx].setSize(8)
89 | else:
90 | point_prev[0].setPen(self.clust_color_ks[clust_prev])
91 | point_prev[0].setBrush(self.clust_color_ibl[clust_prev])
92 | point_prev[0].setSize(8)
93 |
94 | def set_current_point(self, clust):
95 | self.point_pos.setX(self.clust_amps[clust])
96 | self.point_pos.setY(self.clust_depths[clust])
97 |
98 | def update_prev_point(self):
99 | self.point_pos_prev.setX(self.point_pos.x())
100 | self.point_pos_prev.setY(self.point_pos.y())
101 |
102 | def initialise_scatter_index(self):
103 | self.point_pos = QtCore.QPointF()
104 | self.point_pos_prev = QtCore.QPointF()
105 | self.point_pos.setX(self.clust_amps[0])
106 | self.point_pos.setY(self.clust_depths[0])
107 | self.point_pos_prev.setX(self.clust_amps[0])
108 | self.point_pos_prev.setY(self.clust_depths[0])
109 |
110 | point = self.scatter_plot.pointsAt(self.point_pos)
111 |
112 | if len(point) > 1:
113 | p_diff = []
114 | for p in point:
115 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x()))
116 | idx = np.argmin(p_diff)
117 | point[idx].setBrush('k')
118 | point[idx].setPen('k')
119 | else:
120 | point[0].setBrush('k')
121 | point[0].setPen('k')
122 |
--------------------------------------------------------------------------------
/data_exploration_gui/scatter_class.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtCore, QtGui, QtWidgets
2 | import pyqtgraph as pg
3 | import numpy as np
4 |
5 | class ScatterGroup:
6 |
7 | def __init__(self):
8 |
9 | self.fig_scatter = pg.PlotWidget(background='w')
10 | self.fig_scatter.setMouseTracking(True)
11 | self.fig_scatter.setLabel('bottom', 'Cluster Amplitude (uV)')
12 | self.fig_scatter.setLabel('left', 'Cluster Distance From Tip (uV)')
13 |
14 | self.scatter_plot = pg.ScatterPlotItem()
15 | self.scatter_text = pg.TextItem(color='k')
16 | self.scatter_text.hide()
17 |
18 | self.scatter_reset_button = QtWidgets.QPushButton('Reset Axis')
19 | self.scatter_reset_button.setFixedSize(70, 30)
20 | self.scatter_reset_button.setParent(self.fig_scatter)
21 |
22 | self.fig_scatter.addItem(self.scatter_plot)
23 | self.fig_scatter.addItem(self.scatter_text)
24 | self.fig_scatter.setFixedSize(400, 580)
25 |
26 | self.clust_amps = []
27 | self.clust_depths = []
28 | self.clust_color = []
29 | self.clust_ids = []
30 | self.point_pos = []
31 | self.point_pos_prev = []
32 |
33 | def reset(self):
34 | self.scatter_plot.setData()
35 | self.clust_amps = []
36 | self.clust_depths = []
37 | self.clust_color = []
38 | self.clust_ids = []
39 | self.point_pos = []
40 | self.point_pos_prev = []
41 |
42 | def populate(self, clust_amps, clust_depths, clust_ids, clust_color):
43 | self.clust_amps = clust_amps
44 | self.clust_depths = clust_depths
45 | self.clust_color = clust_color
46 | self.clust_ids = clust_ids
47 | self.scatter_plot.setData(x=self.clust_amps, y=self.clust_depths, size=8)
48 | self.scatter_plot.setBrush(clust_color)
49 | self.scatter_plot.setPen(QtGui.QColor(30, 50, 2))
50 | self.x_min = self.clust_amps.min() - 2
51 | self.x_max = self.clust_amps.max() - 2
52 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max)
53 | self.fig_scatter.setYRange(min=0, max=4000)
54 |
55 | def on_scatter_plot_clicked(self, point):
56 | self.point_pos = point[0].pos()
57 | clust = np.argwhere(self.scatter_plot.data['x'] == self.point_pos.x())[0][0]
58 | return clust
59 |
60 | def on_mouse_hover(self, pos):
61 | mouse_pos = self.scatter_plot.mapFromScene(pos)
62 | point = self.scatter_plot.pointsAt(mouse_pos)
63 |
64 | if len(point) != 0:
65 | point_pos = point[0].pos()
66 | clust = np.argwhere(self.scatter_plot.data['x'] == point_pos.x())[0][0]
67 | self.scatter_text.setText('Cluster no. ' + str(self.clust_ids[clust]))
68 | self.scatter_text.setPos(mouse_pos.x(), (mouse_pos.y()))
69 | self.scatter_text.show()
70 | else:
71 | self.scatter_text.hide()
72 |
73 | def on_scatter_plot_reset(self):
74 | self.fig_scatter.setXRange(min=self.x_min, max=self.x_max)
75 | self.fig_scatter.setYRange(min=0, max=4000)
76 |
77 | def update_scatter_icon(self, clust_prev):
78 | point = self.scatter_plot.pointsAt(self.point_pos)
79 | if len(point) > 1:
80 | p_diff = []
81 | for p in point:
82 | p_diff.append(np.abs(p.pos().x() - self.point_pos.x()))
83 | idx = np.argmin(p_diff)
84 | point[idx].setBrush('b')
85 | point[idx].setPen('b')
86 | else:
87 | point[0].setBrush('b')
88 | point[0].setPen('b')
89 |
90 | point_prev = self.scatter_plot.pointsAt(self.point_pos_prev)
91 | if len(point_prev) > 1:
92 | p_diff = []
93 | for p in point_prev:
94 | p_diff.append(np.abs(p.pos().x() - self.point_pos_prev.x()))
95 | idx = np.argmin(p_diff)
96 | point_prev[idx].setPen('w')
97 | point_prev[idx].setBrush(self.clust_color[clust_prev])
98 | else:
99 | point_prev[0].setPen('w')
100 | point_prev[0].setBrush(self.clust_color[clust_prev])
101 |
102 | def set_current_point(self, clust):
103 | self.point_pos.setX(self.clust_amps[clust])
104 | self.point_pos.setY(self.clust_depths[clust])
105 |
106 | return self.point_pos
107 |
108 | def update_prev_point(self):
109 | self.point_pos_prev.setX(self.point_pos.x())
110 | self.point_pos_prev.setY(self.point_pos.y())
111 |
112 | return self.point_pos_prev
113 |
114 | def initialise_scatter_index(self):
115 | self.point_pos = QtCore.QPointF()
116 | self.point_pos_prev = QtCore.QPointF()
117 | self.point_pos.setX(self.clust_amps[0])
118 | self.point_pos.setY(self.clust_depths[0])
119 | self.point_pos_prev.setX(self.clust_amps[0])
120 | self.point_pos_prev.setY(self.clust_depths[0])
121 |
122 | point = self.scatter_plot.pointsAt(self.point_pos)
123 | point[0].setBrush('b')
124 | point[0].setPen('b')
125 |
126 | return self.point_pos, self.point_pos_prev
127 |
--------------------------------------------------------------------------------
/data_exploration_gui/utils.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import QtGui, QtCore
2 | from iblutil.util import Bunch
3 | import copy
4 |
5 | colours = {'all': QtGui.QColor('#808080'),
6 | 'correct': QtGui.QColor('#1f77b4'),
7 | 'incorrect': QtGui.QColor('#d62728'),
8 | 'left': QtGui.QColor('#2ca02c'),
9 | 'right': QtGui.QColor('#bcbd22'),
10 | 'left correct': QtGui.QColor('#17becf'),
11 | 'right correct': QtGui.QColor('#9467bd'),
12 | 'left incorrect': QtGui.QColor('#8c564b'),
13 | 'right incorrect': QtGui.QColor('#ff7f0e'),
14 | 'KS good': QtGui.QColor('#1f77b4'),
15 | 'KS mua': QtGui.QColor('#fdc086'),
16 | 'IBL good': QtGui.QColor('#7fc97f'),
17 | 'IBL bad': QtGui.QColor('#d62728'),
18 | 'line': QtGui.QColor('#7732a8'),
19 | 'no metric': QtGui.QColor('#989898')}
20 |
21 |
22 | all = Bunch()
23 | idx = Bunch()
24 | idx['colours'] = []
25 | idx['text'] = []
26 | side = Bunch()
27 | side['colours'] = [colours['left'], colours['right']]
28 | side['text'] = ['left', 'right']
29 | choice = Bunch()
30 | choice['colours'] = [colours['correct'], colours['incorrect']]
31 | choice['text'] = ['correct', 'incorrect']
32 | choice_and_side = Bunch()
33 | choice_and_side['colours'] = [colours['left correct'], colours['right correct'],
34 | colours['left incorrect'], colours['right incorrect']]
35 | choice_and_side['text'] = ['left correct', 'right correct', 'left incorrect', 'right incorrect']
36 | all['idx'] = idx
37 | all['side'] = side
38 | all['choice'] = choice
39 | all['choice and side'] = choice_and_side
40 |
41 |
42 | correct = Bunch()
43 | side = Bunch()
44 | side['colours'] = [colours['left correct'], colours['right correct']]
45 | side['text'] = ['left correct', 'right correct']
46 | choice = Bunch()
47 | choice['colours'] = [colours['correct']]
48 | choice['text'] = ['correct']
49 | correct['idx'] = idx
50 | correct['side'] = side
51 | correct['choice'] = choice
52 | correct['choice and side'] = side
53 |
54 | incorrect = Bunch()
55 | side = Bunch()
56 | side['colours'] = [colours['left incorrect'], colours['right incorrect']]
57 | side['text'] = ['left incorrect', 'right incorrect']
58 | choice = Bunch()
59 | choice['colours'] = [colours['incorrect']]
60 | choice['text'] = ['incorrect']
61 | incorrect['idx'] = idx
62 | incorrect['side'] = side
63 | incorrect['choice'] = choice
64 | incorrect['choice and side'] = side
65 |
66 |
67 | left = Bunch()
68 | side = Bunch()
69 | side['colours'] = [colours['left']]
70 | side['text'] = ['left']
71 | choice = Bunch()
72 | choice['colours'] = [colours['left correct'], colours['left incorrect']]
73 | choice['text'] = ['left correct', 'left incorrect']
74 | left['idx'] = idx
75 | left['side'] = side
76 | left['choice'] = choice
77 | left['choice and side'] = choice
78 |
79 |
80 | right = Bunch()
81 | side = Bunch()
82 | side['colours'] = [colours['right']]
83 | side['text'] = ['right']
84 | choice = Bunch()
85 | choice['colours'] = [colours['right correct'], colours['right incorrect']]
86 | choice['text'] = ['right correct', 'right incorrect']
87 | right['idx'] = idx
88 | right['side'] = side
89 | right['choice'] = choice
90 | right['choice and side'] = choice
91 |
92 | left_correct = Bunch()
93 | side = Bunch()
94 | side['colours'] = [colours['left correct']]
95 | side['text'] = ['left correct']
96 | left_correct['idx'] = idx
97 | left_correct['side'] = side
98 | left_correct['choice'] = side
99 | left_correct['choice and side'] = side
100 |
101 | right_correct = Bunch()
102 | side = Bunch()
103 | side['colours'] = [colours['right correct']]
104 | side['text'] = ['right correct']
105 | right_correct['idx'] = idx
106 | right_correct['side'] = side
107 | right_correct['choice'] = side
108 | right_correct['choice and side'] = side
109 |
110 | left_incorrect = Bunch()
111 | side = Bunch()
112 | side['colours'] = [colours['left incorrect']]
113 | side['text'] = ['left incorrect']
114 | left_incorrect['idx'] = idx
115 | left_incorrect['side'] = side
116 | left_incorrect['choice'] = side
117 | left_incorrect['choice and side'] = side
118 |
119 | right_incorrect = Bunch()
120 | side = Bunch()
121 | side['colours'] = [colours['right incorrect']]
122 | side['text'] = ['right incorrect']
123 | right_incorrect['idx'] = idx
124 | right_incorrect['side'] = side
125 | right_incorrect['choice'] = side
126 | right_incorrect['choice and side'] = side
127 |
128 | RASTER_OPTIONS = Bunch()
129 | RASTER_OPTIONS['all'] = all
130 | RASTER_OPTIONS['left'] = left
131 | RASTER_OPTIONS['right'] = right
132 | RASTER_OPTIONS['correct'] = correct
133 | RASTER_OPTIONS['incorrect'] = incorrect
134 | RASTER_OPTIONS['left correct'] = left_correct
135 | RASTER_OPTIONS['right correct'] = right_correct
136 | RASTER_OPTIONS['left incorrect'] = left_incorrect
137 | RASTER_OPTIONS['right incorrect'] = right_incorrect
138 |
139 | all = Bunch()
140 | all['colour'] = copy.copy(colours['all'])
141 | all['fill'] = colours['all']
142 | all['text'] = 'all'
143 |
144 | left = Bunch()
145 | left['colour'] = copy.copy(colours['left'])
146 | left['fill'] = colours['left']
147 | left['text'] = 'left'
148 |
149 | right = Bunch()
150 | right['colour'] = copy.copy(colours['right'])
151 | right['fill'] = colours['right']
152 | right['text'] = 'right'
153 |
154 | correct = Bunch()
155 | correct['colour'] = copy.copy(colours['correct'])
156 | correct['fill'] = colours['correct']
157 | correct['text'] = 'correct'
158 |
159 | incorrect = Bunch()
160 | incorrect['colour'] = copy.copy(colours['incorrect'])
161 | incorrect['fill'] = colours['incorrect']
162 | incorrect['text'] = 'incorrect'
163 |
164 | left_correct = Bunch()
165 | left_correct['colour'] = copy.copy(colours['left correct'])
166 | left_correct['fill'] = colours['left correct']
167 | left_correct['text'] = 'left correct'
168 |
169 | right_correct = Bunch()
170 | right_correct['colour'] = copy.copy(colours['right correct'])
171 | right_correct['fill'] = colours['right correct']
172 | right_correct['text'] = 'right correct'
173 |
174 | left_incorrect = Bunch()
175 | left_incorrect['colour'] = copy.copy(colours['left incorrect'])
176 | left_incorrect['fill'] = colours['left incorrect']
177 | left_incorrect['text'] = 'left incorrect'
178 |
179 | right_incorrect = Bunch()
180 | right_incorrect['colour'] = copy.copy(colours['right incorrect'])
181 | right_incorrect['fill'] = colours['right incorrect']
182 | right_incorrect['text'] = 'right incorrect'
183 |
184 | PSTH_OPTIONS = Bunch()
185 | PSTH_OPTIONS['all'] = all
186 | PSTH_OPTIONS['left'] = left
187 | PSTH_OPTIONS['right'] = right
188 | PSTH_OPTIONS['correct'] = correct
189 | PSTH_OPTIONS['incorrect'] = incorrect
190 | PSTH_OPTIONS['left correct'] = left_correct
191 | PSTH_OPTIONS['right correct'] = right_correct
192 | PSTH_OPTIONS['left incorrect'] = left_incorrect
193 | PSTH_OPTIONS['right incorrect'] = right_incorrect
194 |
195 | MAP_SIDE_OPTIONS = Bunch()
196 | MAP_SIDE_OPTIONS['all'] = 'all'
197 | MAP_SIDE_OPTIONS['left'] = 'left'
198 | MAP_SIDE_OPTIONS['right'] = 'right'
199 | MAP_SIDE_OPTIONS['correct'] = 'all'
200 | MAP_SIDE_OPTIONS['incorrect'] = 'all'
201 | MAP_SIDE_OPTIONS['left correct'] = 'left'
202 | MAP_SIDE_OPTIONS['right correct'] = 'right'
203 | MAP_SIDE_OPTIONS['left incorrect'] = 'left'
204 | MAP_SIDE_OPTIONS['right incorrect'] = 'right'
205 |
206 | MAP_CHOICE_OPTIONS = Bunch()
207 | MAP_CHOICE_OPTIONS['all'] = 'all'
208 | MAP_CHOICE_OPTIONS['left'] = 'all'
209 | MAP_CHOICE_OPTIONS['right'] = 'all'
210 | MAP_CHOICE_OPTIONS['correct'] = 'correct'
211 | MAP_CHOICE_OPTIONS['incorrect'] = 'incorrect'
212 | MAP_CHOICE_OPTIONS['left correct'] = 'correct'
213 | MAP_CHOICE_OPTIONS['right correct'] = 'correct'
214 | MAP_CHOICE_OPTIONS['left incorrect'] = 'incorrect'
215 | MAP_CHOICE_OPTIONS['right incorrect'] = 'incorrect'
216 |
217 | MAP_SORT_OPTIONS = Bunch()
218 | MAP_SORT_OPTIONS['all'] = 'idx'
219 | MAP_SORT_OPTIONS['left'] = 'side'
220 | MAP_SORT_OPTIONS['right'] = 'side'
221 | MAP_SORT_OPTIONS['correct'] = 'choice'
222 | MAP_SORT_OPTIONS['incorrect'] = 'choice'
223 | MAP_SORT_OPTIONS['left correct'] = 'choice and side'
224 | MAP_SORT_OPTIONS['right correct'] = 'choice and side'
225 | MAP_SORT_OPTIONS['left incorrect'] = 'choice and side'
226 | MAP_SORT_OPTIONS['right incorrect'] = 'choice and side'
227 |
228 |
229 | TRIAL_OPTIONS = ['all', 'correct', 'incorrect', 'left', 'right', 'left correct', 'left incorrect',
230 | 'right correct', 'right incorrect']
231 | CONTRAST_OPTIONS = [1, 0.25, 0.125, 0.0625, 0]
232 | ORDER_OPTIONS = ['trial num', 'reaction time']
233 | SORT_OPTIONS = ['idx', 'choice', 'side', 'choice and side']
234 | UNIT_OPTIONS = ['IBL good', 'IBL bad', 'KS good', 'KS mua']
235 | SORT_CLUSTER_OPTIONS = ['ids', 'n spikes', 'IBL good', 'KS good']
236 |
237 | SESS_QC = ['task', 'behavior', 'dlcLeft', 'dlcRight', 'videoLeft', 'videoRight']
238 | CLUSTER_QC = ['noise_cutoff', 'amp_median', 'slidingRP_viol']
239 |
240 | dlc_warning = 'WARNING: dlc points and timestamps differ in length, dlc points are ' \
241 | 'not aligned correctly'
242 |
243 |
244 | def get_icon(col_outer, col_inner, pix_size):
245 |
246 | p1 = QtGui.QPixmap(pix_size, pix_size)
247 | p1.fill(col_outer)
248 | p2 = QtGui.QPixmap(pix_size, pix_size)
249 | p2.fill(QtCore.Qt.transparent)
250 | p = QtGui.QPainter(p2)
251 | p.fillRect(int(pix_size / 4), int(pix_size / 4), int(pix_size / 2),
252 | int(pix_size / 2), col_inner)
253 | p.end()
254 |
255 | result = QtGui.QPixmap(p1)
256 | painter = QtGui.QPainter(result)
257 | painter.drawPixmap(QtCore.QPoint(), p1)
258 | painter.setCompositionMode(QtGui.QPainter.CompositionMode_SourceOver)
259 | painter.drawPixmap(result.rect(), p2, p2.rect())
260 | painter.end()
261 |
262 | return result
263 |
--------------------------------------------------------------------------------
/dlc/README.md:
--------------------------------------------------------------------------------
1 | ## How to get example frames with DLC labeles painted on top
2 | Avoiding to download lengthy videos, with `stream_dlc_labeled_frames.py` one can stream 5 example frames randomly picked across a whole video, specified by eid and video type ('body', 'left' or 'right').
3 |
4 | To use it, start an ipython session, then type `run /path/to/stream_dlc_labeled_frames.py` to load the script. Then type `stream_save_labeled_frames(eid, video_type)` after setting the path where the labelled example frames should be saved as png images locally, `save_images_folder`.
5 |
6 | ## How to make DLC-labeled video
7 | With the script DLC_labeled_video.py one can make DLC-labeled videos. The script downloads a specific IBL video ('body', 'left' or 'right') for some session (eid). It also downloads the wheel info and will print the wheel angle onto each frame.
8 | To use it, start an ipython session, then type `run /path/to/DLC_labeled_video.py` to load the script. Next type
9 |
10 | `Viewer(eid, video_type, trial_range, save_video=True, eye_zoom=False)`
11 |
12 | This will display and save a labeled video for a particular trial range. E.g. `Viewer('3663d82b-f197-4e8b-b299-7b803a155b84', 'left', [5,7])` will display and save a DLC labeled video for the left cam video of session with eid '3663d82b-f197-4e8b-b299-7b803a155b84' and range from trial 5 to trial 7. There's further the option to show a zoom of the pupil only.
13 |
14 | See `Example_DLC_access.ipynb` for more intructions and an example how to load DLC results for other potential applications.
15 |
16 | ## Wheel and DLC live viewer
17 | Plotting the wheel data along side the video and DLC can be done withe `wheel_dlc_viewer`. The viewer loops over the
18 | video frames for a given trial and shows the wheel position plotted against time, indicating the
19 | detected wheel movements. __NB__: Besides the IBL enviroment dependencies, this module also
20 | requires cv2.
21 |
22 | ### Running from command line
23 | You can run the viewer from the terminal. In Windows, running from the Anaconda prompt within
24 | iblenv ensures the paths are correctly set (`conda activate iblenv`).
25 |
26 | The below code will show the wheel data for a particular session, with the tongue and paw DLC
27 | features overlaid:
28 |
29 | ```python
30 | python wheel_dlc_viewer.py --eid 77224050-7848-4680-ad3c-109d3bcd562c --dlc tongue,paws
31 | ```
32 | Pressing the space bar will toggle playing of the video and the left and right arrows allow you
33 | to step frame by frame. Key bindings also allow you to move between trials and toggle the legend.
34 |
35 | When called without the `--eid` flag will cause a random session to be selected. You can find a
36 | full list of input arguments and key bindings by calling the function with the `--help` flag:
37 | ```python
38 | python wheel_dlc_viewer.py -h
39 | ```
40 |
41 | ### Running from within Python
42 | You can also import the Viewer from within Python...
43 |
44 | Example 1 - inspect trial 100 of a given session, looking at the right camera
45 | ```python
46 | from wheel_dlc_viewer import Viewer
47 | eid = '77224050-7848-4680-ad3c-109d3bcd562c'
48 | v = Viewer(eid=eid, trial=100, camera='right')
49 | ```
50 |
51 | Example 2 - pick a random session to inspect, showing all DLC
52 | ```python
53 | from wheel_dlc_viewer import Viewer
54 | Viewer(dlc_features='all')
55 | ```
56 |
57 | For more details, see the docstring for the Viewer.
58 |
--------------------------------------------------------------------------------
/dlc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/dlc/__init__.py
--------------------------------------------------------------------------------
/dlc/get_dlc_traces.py:
--------------------------------------------------------------------------------
1 | import alf.io
2 | import numpy as np
3 | from oneibl.one import ONE
4 |
5 | def get_DLC(eid,video_type):
6 | '''load dlc traces
7 | load dlc traces for a given session and
8 | video type.
9 |
10 | :param eid: A session eid
11 | :param video_type: string in 'left', 'right', body'
12 | :return: array of times and dict with dlc points
13 | as keys and x,y coordinates as values,
14 | for each frame id
15 | '''
16 | one = ONE()
17 | D = one.load(eid, dataset_types = ['camera.dlc', 'camera.times'])
18 | alf_path = one.path_from_eid(eid) / 'alf'
19 | cam0 = alf.io.load_object(
20 | alf_path,
21 | '%sCamera' %
22 | video_type,
23 | namespace='ibl')
24 | Times = cam0['times']
25 | cam = cam0['dlc']
26 | points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()])
27 | XYs = {}
28 | for point in points:
29 | x = np.ma.masked_where(
30 | cam[point + '_likelihood'] < 0.9, cam[point + '_x'])
31 | x = x.filled(np.nan)
32 | y = np.ma.masked_where(
33 | cam[point + '_likelihood'] < 0.9, cam[point + '_y'])
34 | y = y.filled(np.nan)
35 | XYs[point] = np.array([x, y])
36 |
37 | return Times, XYs
38 |
--------------------------------------------------------------------------------
/dlc/stream_dlc_labeled_frames.py:
--------------------------------------------------------------------------------
1 | from oneibl.one import ONE
2 | from ibllib.io.video import get_video_meta, get_video_frames_preload
3 | import numpy as np
4 | from pathlib import Path
5 | import cv2
6 | import alf.io
7 | import matplotlib
8 | from random import sample
9 | import time
10 |
11 |
12 | def stream_save_labeled_frames(eid, video_type):
13 |
14 | startTime = time.time()
15 | '''
16 | For a given eid and camera type, stream
17 | sample frames, print DLC labels on them
18 | and save
19 | '''
20 |
21 | # eid = '5522ac4b-0e41-4c53-836a-aaa17e82b9eb'
22 | # video_type = 'left'
23 |
24 | n_frames = 5 # sample 5 random frames
25 |
26 | save_images_folder = '/home/mic/DLC_QC/example_frames/'
27 | one = ONE()
28 | info = '_'.join(np.array(str(one.path_from_eid(eid)).split('/'))[[5,7,8]])
29 | print(info, video_type)
30 |
31 | r = one.list(eid, 'dataset_types')
32 |
33 | dtypes_DLC = ['_ibl_rightCamera.times.npy',
34 | '_ibl_leftCamera.times.npy',
35 | '_ibl_bodyCamera.times.npy',
36 | '_iblrig_leftCamera.raw.mp4',
37 | '_iblrig_rightCamera.raw.mp4',
38 | '_iblrig_bodyCamera.raw.mp4',
39 | '_ibl_leftCamera.dlc.pqt',
40 | '_ibl_rightCamera.dlc.pqt',
41 | '_ibl_bodyCamera.dlc.pqt']
42 |
43 | dtype_names = [x['name'] for x in r]
44 |
45 | assert all([i in dtype_names for i in dtypes_DLC]
46 | ), 'For this eid, not all data available'
47 |
48 |
49 | D = one.load(eid,
50 | dataset_types=['camera.times', 'camera.dlc'],
51 | dclass_output=True)
52 | alf_path = Path(D.local_path[0]).parent.parent / 'alf'
53 |
54 | cam0 = alf.io.load_object(
55 | alf_path,
56 | '%sCamera' %
57 | video_type,
58 | namespace='ibl')
59 |
60 | Times = cam0['times']
61 |
62 | cam = cam0['dlc']
63 | points = np.unique(['_'.join(x.split('_')[:-1]) for x in cam.keys()])
64 |
65 | XYs = {}
66 | for point in points:
67 | x = np.ma.masked_where(
68 | cam[point + '_likelihood'] < 0.9, cam[point + '_x'])
69 | x = x.filled(np.nan)
70 | y = np.ma.masked_where(
71 | cam[point + '_likelihood'] < 0.9, cam[point + '_y'])
72 | y = y.filled(np.nan)
73 | XYs[point] = np.array(
74 | [x, y])
75 |
76 |
77 |
78 | if video_type != 'body':
79 | d = list(points)
80 | d.remove('tube_top')
81 | d.remove('tube_bottom')
82 | points = np.array(d)
83 |
84 | # stream frames
85 | recs = [x for x in r if f'{video_type}Camera.raw.mp4' in x['name']][0]['file_records']
86 | video_path = [x['data_url'] for x in recs if x['data_url'] is not None][0]
87 | vid_meta = get_video_meta(video_path)
88 |
89 | frame_idx = sample(range(vid_meta['length']), n_frames)
90 | print('frame indices:', frame_idx)
91 | frames = get_video_frames_preload(video_path,
92 | frame_idx,
93 | mask=np.s_[:, :, 0])
94 | size = [vid_meta['width'], vid_meta['height']]
95 | #return XYs, frames
96 |
97 | x0 = 0
98 | x1 = size[0]
99 | y0 = 0
100 | y1 = size[1]
101 | if video_type == 'left':
102 | dot_s = 10 # [px] for painting DLC dots
103 | else:
104 | dot_s = 5
105 |
106 | # writing stuff on frames
107 | font = cv2.FONT_HERSHEY_SIMPLEX
108 |
109 | if video_type == 'left':
110 | bottomLeftCornerOfText = (20, 1000)
111 | fontScale = 4
112 | else:
113 | bottomLeftCornerOfText = (10, 500)
114 | fontScale = 2
115 |
116 | lineType = 2
117 |
118 | # assign a color to each DLC point (now: all points red)
119 | cmap = matplotlib.cm.get_cmap('Set1')
120 | CR = np.arange(len(points)) / len(points)
121 |
122 | block = np.ones((2 * dot_s, 2 * dot_s, 3))
123 |
124 | k = 0
125 | for frame in frames:
126 |
127 | gray = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
128 |
129 |
130 | # print session info
131 | fontColor = (255, 255, 255)
132 | cv2.putText(gray,
133 | info,
134 | bottomLeftCornerOfText,
135 | font,
136 | fontScale / 4,
137 | fontColor,
138 | lineType)
139 |
140 | # print time
141 | Time = round(Times[frame_idx[k]], 3)
142 | a, b = bottomLeftCornerOfText
143 | bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b)
144 |
145 |
146 |
147 | a, b = bottomLeftCornerOfText
148 | bottomLeftCornerOfText0 = (int(a * 10 + b / 2), b)
149 | cv2.putText(gray,
150 | ' time: ' + str(Time),
151 | bottomLeftCornerOfText0,
152 | font,
153 | fontScale / 2,
154 | fontColor,
155 | lineType)
156 |
157 |
158 |
159 | # print DLC dots
160 | ll = 0
161 | for point in points:
162 |
163 | # Put point color legend
164 | fontColor = (np.array([cmap(CR[ll])]) * 255)[0][:3]
165 | a, b = bottomLeftCornerOfText
166 | if video_type == 'right':
167 | bottomLeftCornerOfText2 = (a, a * 2 * (1 + ll))
168 | else:
169 | bottomLeftCornerOfText2 = (b, a * 2 * (1 + ll))
170 | fontScale2 = fontScale / 4
171 | cv2.putText(gray, point,
172 | bottomLeftCornerOfText2,
173 | font,
174 | fontScale2,
175 | fontColor,
176 | lineType)
177 |
178 | X0 = XYs[point][0][frame_idx[k]]
179 | Y0 = XYs[point][1][frame_idx[k]]
180 |
181 | X = Y0
182 | Y = X0
183 |
184 | #print(point,X,Y)
185 | if not np.isnan(X) and not np.isnan(Y):
186 | try:
187 | col = (np.array([cmap(CR[ll])]) * 255)[0][:3]
188 | # col = np.array([0, 0, 255]) # all points red
189 | X = X.astype(int)
190 | Y = Y.astype(int)
191 |
192 | uu = block * col
193 | gray[X - dot_s:X + dot_s, Y - dot_s:Y + dot_s] = uu
194 |
195 | except Exception as e:
196 | print('frame', frame_idx[k])
197 | print(e)
198 | ll += 1
199 |
200 | gray = gray[y0:y1, x0:x1]
201 | # cv2.imshow('frame', gray)
202 | cv2.imwrite(f'{save_images_folder}{eid}_frame_{frame_idx[k]}.png',
203 | gray)
204 | cv2.waitKey(1)
205 | k += 1
206 |
207 | print(f'{n_frames} frames done in', np.round(time.time() - startTime))
208 |
--------------------------------------------------------------------------------
/ephysfeatures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/ephysfeatures/__init__.py
--------------------------------------------------------------------------------
/ephysfeatures/prepare_features.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | import one.alf.io as alfio
4 | import pandas as pd
5 | from one.api import ONE
6 | from brainbox.processing import bincount2D
7 | one = ONE()
8 |
9 | root_path = Path(r'C:\Users\Mayo\Downloads\FlatIron')
10 | files = list(root_path.rglob('electrodeSites.mlapdv.npy'))
11 |
12 | all_df_chns = []
13 |
14 | for file in files:
15 | try:
16 | session_path = Path(*file.parts[:10])
17 | ref = one.path2ref(session_path)
18 | eid = one.path2eid(session_path)
19 | probe = file.parts[11]
20 |
21 | lfp = alfio.load_object(session_path.joinpath('raw_ephys_data', probe), 'ephysSpectralDensityLF', namespace='iblqc')
22 | mean_lfp = 10 * np.log10(np.mean(lfp.power[:,:-1], axis=0))
23 |
24 | try:
25 | ap = alfio.load_object(session_path.joinpath('raw_ephys_data', probe), 'ephysTimeRmsAP', namespace='iblqc')
26 | mean_ap = np.mean(ap.rms[:, :384], axis=0)
27 | except Exception as err:
28 | mean_ap = 50 * np.ones((384))
29 |
30 | try:
31 | spikes = alfio.load_object(session_path.joinpath(f'alf/{probe}/pykilosort'), 'spikes')
32 | kp_idx = ~np.isnan(spikes['depths'])
33 | T_BIN = np.max(spikes['times'])
34 | D_BIN = 10
35 | chn_min = 10
36 | chn_max = 3840
37 | nspikes, times, depths = bincount2D(spikes['times'][kp_idx],
38 | spikes['depths'][kp_idx],
39 | T_BIN, D_BIN,
40 | ylim=[chn_min, chn_max])
41 |
42 | amp, times, depths = bincount2D(spikes['amps'][kp_idx],
43 | spikes['depths'][kp_idx],
44 | T_BIN, D_BIN, ylim=[chn_min, chn_max],
45 | weights=spikes['amps'][kp_idx])
46 | mean_fr = nspikes[:, 0] / T_BIN
47 | mean_amp = np.divide(amp[:, 0], nspikes[:, 0]) * 1e6
48 | mean_amp[np.isnan(mean_amp)] = 0
49 | remove_bins = np.where(nspikes[:, 0] < 50)[0]
50 | mean_amp[remove_bins] = 0
51 | except Exception as err:
52 | mean_fr = np.ones((384))
53 | mean_amp = np.ones((384))
54 | depths = np.ones((384)) * 20
55 |
56 |
57 |
58 | channels = alfio.load_object(file.parent, 'electrodeSites')
59 | data_chns = {}
60 | data_chns['x'] = channels['mlapdv'][:, 0]
61 | data_chns['y'] = channels['mlapdv'][:, 1]
62 | data_chns['z'] = channels['mlapdv'][:, 2]
63 | data_chns['axial_um'] = channels['localCoordinates'][:, 0]
64 | data_chns['lateral_um'] = channels['localCoordinates'][:, 1]
65 | data_chns['lfp'] = mean_lfp
66 | data_chns['ap'] = mean_ap
67 | data_chns['depth_line'] = depths
68 | data_chns['fr'] = mean_fr
69 | data_chns['amp'] = mean_amp
70 | data_chns['region_id'] = channels['brainLocationIds_ccf_2017']
71 |
72 | df_chns = pd.DataFrame.from_dict(data_chns)
73 | df_chns['subject'] = ref['subject']
74 | df_chns['date'] = str(ref['date'])
75 | df_chns['probe'] = probe
76 | df_chns['pid'] = eid + probe
77 |
78 |
79 | all_df_chns.append(df_chns)
80 | except Exception as err:
81 | print(err)
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/histology/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/histology/__init__.py
--------------------------------------------------------------------------------
/histology/atlas_mpl.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import matplotlib
3 | import numpy as np
4 | from PyQt5 import QtCore, QtWidgets, uic
5 |
6 | from qt_helpers import qt
7 | from qt_helpers.qt_matplotlib import BaseMplCanvas
8 | import iblatlas.atlas as atlas
9 |
10 | # Make sure that we are using QT5
11 | matplotlib.use('Qt5Agg')
12 |
13 |
14 | class Model:
15 | """
16 | Container for Data and variables of the application
17 | """
18 | brain_atlas: atlas.BrainAtlas
19 | ap_um: float
20 |
21 | def __init__(self, brain_atlas=None, ap_um=0):
22 | self.ap_um = ap_um
23 | # load the brain atlas
24 | if brain_atlas is None:
25 | self.brain_atlas = atlas.AllenAtlas(res_um=25)
26 | else:
27 | self.brain_atlas = brain_atlas
28 |
29 |
30 | class MyStaticMplCanvas(BaseMplCanvas):
31 |
32 | def __init__(self, *args, **kwargs):
33 | super(MyStaticMplCanvas, self).__init__(*args, **kwargs)
34 | self.mpl_connect("motion_notify_event", self.on_move)
35 |
36 | def on_move(self, event):
37 | mw = qt.get_main_window()
38 | ap = mw._model.ap_um
39 | xlab = '' if event.xdata is None else '{: 6.0f}'.format(event.xdata)
40 | ylab = '' if event.ydata is None else '{: 6.0f}'.format(event.ydata)
41 | mw.label_ml.setText(xlab)
42 | mw.label_dv.setText(ylab)
43 | mw.label_ap.setText('{: 6.0f}'.format(ap))
44 | if event.xdata is None or event.ydata is None:
45 | return
46 | # if whithin the bounds of the Atlas, label with the current strucure hovered onto
47 | xyz = np.array([[event.xdata, ap, event.ydata]]) / 1e6
48 | regions = mw._model.brain_atlas.regions
49 | id_label = mw._model.brain_atlas.get_labels(xyz)
50 | il = np.where(regions.id == id_label)[0]
51 | mw.label_acronym.setText(regions.acronym[il][0])
52 | mw.label_structure.setText(regions.name[il][0])
53 | mw.label_id.setText(str(id_label))
54 |
55 | def update_slice(self, volume='image'):
56 | mw = qt.get_main_window()
57 | ap_um = mw._model.ap_um
58 | im = mw._model.brain_atlas.slice(ap_um / 1e6, axis=1, volume=volume)
59 | im = np.swapaxes(im, 0, 1)
60 | self.axes.images[0].set_data(im)
61 | self.draw()
62 |
63 |
64 | class AtlasViewer(QtWidgets.QMainWindow):
65 | def __init__(self, brain_atlas=None, ap_um=0):
66 | # init the figure
67 | super(AtlasViewer, self).__init__()
68 | uic.loadUi(str(Path(__file__).parent.joinpath('atlas_mpl.ui')), self)
69 | self.setAttribute(QtCore.Qt.WA_DeleteOnClose)
70 | self.file_menu = QtWidgets.QMenu('&File', self)
71 | self.file_menu.addAction('&Quit', self.fileQuit,
72 | QtCore.Qt.CTRL + QtCore.Qt.Key_Q)
73 | self.menuBar().addMenu(self.file_menu)
74 | self.pb_toggle.clicked.connect(self.update_slice)
75 |
76 | # init model
77 | self._model = Model(brain_atlas=brain_atlas, ap_um=ap_um)
78 | # display the coronal slice in the mpl widget
79 | self._model.brain_atlas.plot_cslice(ap_coordinate=ap_um / 1e6,
80 | ax=self.mpl_widget.axes)
81 |
82 | def fileQuit(self):
83 | self.close()
84 |
85 | def closeEvent(self, ce):
86 | self.fileQuit()
87 |
88 | def update_slice(self, ce):
89 | volume = self.pb_toggle.text()
90 | if volume == 'Annotation':
91 | self.pb_toggle.setText('Image')
92 | elif volume == 'Image':
93 | self.pb_toggle.setText('Annotation')
94 | self.mpl_widget.update_slice(volume.lower())
95 |
96 |
97 | def viewatlas(brain_atlas=None, title=None, ap_um=0):
98 | qt.create_app()
99 | av = AtlasViewer(brain_atlas, ap_um=ap_um)
100 | av.setWindowTitle(title)
101 | av.show()
102 | return av, av.mpl_widget.axes
103 |
104 |
105 | if __name__ == "__main__":
106 | w = viewatlas(title="Allen Atlas - IBL")
107 | qt.run_app()
108 |
--------------------------------------------------------------------------------
/histology/atlas_mpl.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | AtlasViewer
4 |
5 |
6 |
7 | 0
8 | 0
9 | 986
10 | 860
11 |
12 |
13 |
14 |
15 | 0
16 | 0
17 |
18 |
19 |
20 | MainWindow
21 |
22 |
23 | background-color: rgb(255, 255, 255);
24 |
25 |
26 |
27 |
28 | 0
29 | 0
30 |
31 |
32 |
33 | -
34 |
35 |
36 |
37 | 16777215
38 | 20
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | -
47 |
48 |
49 |
50 | 16777215
51 | 20
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | -
60 |
61 |
62 |
63 | 16777215
64 | 20
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | -
73 |
74 |
75 |
76 | 16777215
77 | 20
78 |
79 |
80 |
81 | dv (um)
82 |
83 |
84 |
85 | -
86 |
87 |
88 |
89 | 16777215
90 | 20
91 |
92 |
93 |
94 | structure
95 |
96 |
97 |
98 | -
99 |
100 |
101 |
102 | 16777215
103 | 20
104 |
105 |
106 |
107 | ap (um)
108 |
109 |
110 |
111 | -
112 |
113 |
114 |
115 | 16777215
116 | 20
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | -
125 |
126 |
127 |
128 | 16777215
129 | 20
130 |
131 |
132 |
133 | ml (um)
134 |
135 |
136 |
137 | -
138 |
139 |
140 |
141 | 16777215
142 | 20
143 |
144 |
145 |
146 | atlas id
147 |
148 |
149 |
150 | -
151 |
152 |
153 |
154 | 16777215
155 | 20
156 |
157 |
158 |
159 | acronym
160 |
161 |
162 |
163 | -
164 |
165 |
166 |
167 | 16777215
168 | 20
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 | -
177 |
178 |
179 |
180 | 16777215
181 | 20
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 | -
190 |
191 |
192 |
193 | 188
194 | 0
195 |
196 |
197 |
198 |
199 | 10
200 | 16777215
201 |
202 |
203 |
204 | Annotation
205 |
206 |
207 |
208 | -
209 |
210 |
211 |
212 | 0
213 | 316
214 |
215 |
216 |
217 |
218 |
219 |
220 |
230 |
231 |
232 |
233 | toolBar
234 |
235 |
236 | TopToolBarArea
237 |
238 |
239 | false
240 |
241 |
242 |
243 |
244 |
245 | MyStaticMplCanvas
246 | QWidget
247 | iblapps.histology.atlas_mpl
248 | 1
249 |
250 |
251 |
252 |
253 |
254 | pb_toggle
255 | released()
256 | mpl_widget
257 | update()
258 |
259 |
260 | 956
261 | 255
262 |
263 |
264 | 524
265 | 563
266 |
267 |
268 |
269 |
270 |
271 |
--------------------------------------------------------------------------------
/histology/transform-tracks.py:
--------------------------------------------------------------------------------
1 | """Re-Scale and Transform Traced Electrode Tracks
2 |
3 | This script will convert traced tracks from full resolution stacks, into
4 | the downsampled stack and finally transform them into Standard (ARA)
5 | space.
6 |
7 |
8 | DEPENDENCIES
9 |
10 | For this script to run, elastix must be installed. See instruction at
11 | elastix.isi.uu.nl/download.php.
12 |
13 | Check elastix is installed by running at command line:
14 |
15 | $ elastix --version
16 | elastix version: 4.800
17 |
18 |
19 | GENERATING TRACED TRACKS
20 |
21 | The tracks must be generated using Lasagna and the add_line_plugin.
22 |
23 |
24 | The resulting .csv files must be saved into a directory titled
25 | sparsedata/ inside the Sample ROOT Directory:
26 |
27 | - [KS006]:
28 | - sparsedata:
29 | - track01.csv
30 | - track02.csv
31 |
32 | NB: [KS006] can be any sample ID.
33 |
34 | The Sample ROOT Directory must also have the following structure and
35 | files:
36 |
37 | - [KS006]:
38 | - downsampledStacks_25:
39 | - dsKS006_190508_112940_25_25_GR.txt
40 | - ARA2sample:
41 | - TransformParameters.0.txt
42 | - TransformParameters.1.txt
43 |
44 | The script uses information in dsKS006_190508_112940_25_25_GR.txt,
45 | TransformParameters.0.txt and TransformParameters.1.txt to perform the
46 | transformation.
47 |
48 | The output of the script is a series of new csv files, with the track
49 | coordinates now in the registered ARA space.
50 |
51 | - [KS006]:
52 | - downsampledStacks_25:
53 | - rescaled-transformed-sparsedata:
54 | - track01.csv
55 | - track02.csv
56 |
57 |
58 | """
59 |
60 | import sys
61 | import os
62 | import re
63 | import glob
64 | import csv
65 | import subprocess
66 | from os import listdir
67 | from os.path import isfile, join
68 | from pathlib import Path
69 | import platform
70 |
71 | #print("Python version")
72 | #print (sys.version)
73 | #print("Version info.")
74 | #print (sys.version_info)
75 |
76 | #print( os.getcwd() )
77 |
78 | xy = 0.0 # num to hold x/y resolution
79 | z = 0.0 # num to hold z resolution
80 |
81 |
82 | rescaleFile = glob.glob("downsampledStacks_25" + os.sep + "ds*GR.txt")[0]
83 |
84 | print(rescaleFile)
85 |
86 |
87 | with open (rescaleFile, 'rt') as myfile:
88 |
89 | for line in myfile:
90 |
91 | #print(line)
92 |
93 | if line.lower().find("x/y:") != -1:
94 | for t in line.split():
95 | try:
96 | xy = float(t)
97 | except ValueError:
98 | pass
99 |
100 | if line.lower().find("z:") != -1:
101 | for t in line.split():
102 | try:
103 | z = float(t)
104 | except ValueError:
105 | pass
106 |
107 | #print(xy)
108 |
109 | #print(z)
110 |
111 |
112 | # NEXT - change the path in TransformParameters.1.txt file to point to correct ABSOLUTE path to TransformParameters.0.txt file:
113 |
114 | tp = os.getcwd() + os.sep + "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.0.txt"
115 |
116 | #print(tp)
117 |
118 |
119 | #tp1 = "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt"
120 |
121 | #for line in fileinput.input(tp1, inplace = 1):
122 | # print line.replace("foo", "bar"),
123 |
124 | fp = open("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt","r+")
125 |
126 | fg = open("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "new_TransformParameters.1.txt","w")
127 |
128 | for line in fp:
129 |
130 | if line.find("(InitialTransformParametersFileName") != -1:
131 | new_line = line.replace(line,'(InitialTransformParametersFileName '+tp+')\n')
132 | fg.write(new_line)
133 |
134 | else:
135 | fg.write(line)
136 |
137 | fg.close()
138 |
139 | fp.close()
140 |
141 | os.remove("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt")
142 | os.rename("downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "new_TransformParameters.1.txt","downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt")
143 |
144 |
145 | # have extracted the resolution change - now open each CSV file of track coords, and transform them:
146 |
147 | # Read the CSV files:
148 |
149 | # get list of files:
150 | csvFiles = [f for f in listdir("sparsedata" + os.sep) if isfile(join("sparsedata" + os.sep, f))]
151 |
152 | for csvFilePath in csvFiles:
153 |
154 | print(csvFilePath)
155 | # get the total number of lines in the file first:
156 | csvfile = open("sparsedata" + os.sep + csvFilePath)
157 | linenum = len(csvfile.readlines())
158 | csvfile.close()
159 |
160 | csvFilePathNoExt = csvFilePath[:-4]
161 |
162 |
163 | with open( ("sparsedata" + os.sep + csvFilePath), newline='') as csvfile:
164 | # write to rescaled-file:
165 | csvOut = open( ('downsampledStacks_25' + os.sep + 'rescaled-' + csvFilePath), 'w')
166 | pts = open('downsampledStacks_25' + os.sep + 'rescaled-'+csvFilePathNoExt+'.txt', 'w')
167 | pts.write("point\n")
168 | pts.write("%s\n" % linenum )
169 | spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
170 | for row in spamreader:
171 | # NB: CSV file is laid out ZXY!
172 | row[0] = round( (float(row[0])/ z), 6)
173 |
174 | row[1] = round( (float(row[1])/ xy), 6)
175 |
176 | row[2] = round( (float(row[2])/ xy), 6)
177 |
178 | # re-write as PTS: XYZ!
179 | pts.write("%s " % row[1])
180 | pts.write("%s " % row[2])
181 | pts.write("%s\n" % row[0])
182 |
183 | # also re-write as csv: ZXY!
184 | csvOut.write("%s," % row[0])
185 | csvOut.write("%s," % row[1])
186 | csvOut.write("%s\n" % row[2])
187 |
188 | pts.close()
189 | csvOut.close()
190 |
191 |
192 | # next, use system call to convert the points into ARA space:
193 | # Use the INVERSE of sample2ARA -> ARA2sample!
194 |
195 | pts = "downsampledStacks_25" + os.sep + "rescaled-"+csvFilePathNoExt+'.txt'
196 | out = "downsampledStacks_25" + os.sep + "rescaled-transformed-sparsedata" + os.sep
197 | # NOTE - need to use TransformParameters.1.txt - which points to TP0!
198 | tp = "downsampledStacks_25" + os.sep + "ARA2sample" + os.sep + "TransformParameters.1.txt"
199 |
200 | # make output DIR:
201 | if os.path.isdir(out)==False:
202 | os.mkdir(out)
203 |
204 | cmd = "transformix -def " + pts + " -out " + out + " -tp " + tp
205 |
206 | print(cmd)
207 |
208 | returned_value = subprocess.call(cmd, shell=True) # returns the exit code in unix
209 | print('returned value:', returned_value)
210 |
211 | # finally, want to convert outputpoints.txt into the ORIGINAL CSV format:
212 | outputPts = open(out+"outputpoints.txt", "r")
213 |
214 | outPtsCsv = open( (r"" + out + csvFilePathNoExt + "_transformed.csv"), 'w')
215 |
216 | for aline in outputPts:
217 | values = aline.split()
218 | # XYZ:
219 | # print("XYZ: ", values[22], values[23], values[24] )
220 | # NB: CSV file is laid out ZXY!
221 | outPtsCsv.write("%s," % values[24])
222 | outPtsCsv.write("%s," % values[22])
223 | outPtsCsv.write("%s\n" % values[23])
224 |
225 | outputPts.close()
226 |
227 | # remove the temp files:
228 | os.remove('downsampledStacks_25' + os.sep + 'rescaled-' + csvFilePath)
229 | os.remove('downsampledStacks_25' + os.sep + 'rescaled-'+csvFilePathNoExt+'.txt')
230 | os.remove(out+"outputpoints.txt")
231 | os.remove(out+"transformix.log")
232 |
233 |
234 |
--------------------------------------------------------------------------------
/launch_phy/README.md:
--------------------------------------------------------------------------------
1 | ## Launch Phy with an eid and probe_name
2 |
3 | To launch Phy, first follow [these instructions](https://github.com//int-brain-lab/iblenv) for setting up a unified IBL conda environment.
4 |
5 | Then in your terminal, make sure you are on the 'develop' branch of this repository, activate your unified conda environment, and run either:
6 |
7 | `python -s subject -d date -n session_no -p probe_name`
8 | e.g `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -s KS022 -d 2019-12-10 -n 1 -p probe00`
9 |
10 | or:
11 |
12 | `python -s subject -e eid -p probe_name`
13 | e.g. `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00`
14 |
15 |
16 | or:
17 | `python phy_launcher.py -pid 2aea57ac-d3d0-4a09-b6fa-0aa9d58d1e11`
18 |
19 | Quality metrics will be computed within phy and displayed automatically.
20 | Description of metrics can be found here: https://docs.google.com/document/d/1ba_krsfm4epiAd0zbQ8hdvDN908P9VZOpTxkkH3P_ZY/edit#
21 |
22 | ## Manual Curation
23 | Manual curation comprises three different steps
24 |
25 | 1) **Labelling clusters** (assigning each cluster with a label Good, Mua or Noise)
26 | * Select a cluster within Cluster View and click on `Edit -> Move best to` and assign your chosen label.
27 | * Alternatively you can use the shortcuts alt+G, alt+M, alt+N to label the cluster as good, mua or noise respectively
28 |
29 | 2) **Additional notes associated with clusters** (extra information about clusters, for example if it looks like artifact or drift)
30 | * Select a cluster within Cluster View
31 | * Ensure snippet mode is enabled by clicking `File-> Enable Snippet Mode`
32 | * Type `:l notes your_note` and then hit enter (when typing you should see the text appear in the lower left
33 | hand corner of the main window)
34 | * A new column, with the heading **notes** should have been created in the Cluster View window
35 |
36 | Make sure you hit the save button frequently during manual curation so your results are saved and can be
37 | recovered in case Phy freezes or crashes!
38 |
39 |
40 | ## Upload manual labels to datajoint
41 |
42 | Once you have completed manual curation, the labels that you assigned clusters and any additional notes that you
43 | added can be uploaded and stored in a datajoint table by running either:
44 |
45 | `python -s subject -d date -n session_no -p probe_name`
46 | e.g `python int-brain-lab\iblapps\launch_phy\populate_cluster_table.py -s KS022 -d 2019-12-10 -n 1 -p probe00`
47 |
48 | or:
49 |
50 | `python -s subject -e eid -p probe_name`
51 | e.g. `python int-brain-lab\iblapps\launch_phy\populate_cluster_table.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00`
52 |
53 |
--------------------------------------------------------------------------------
/launch_phy/cluster_table.py:
--------------------------------------------------------------------------------
1 | import datajoint as dj
2 | from ibl_pipeline import reference
3 |
4 | schema = dj.schema('group_shared_ephys')
5 | dj.config["enable_python_native_blobs"] = True
6 |
7 |
8 | @schema
9 | class ClusterLabel(dj.Imported):
10 | definition = """
11 | cluster_uuid: uuid # uuid of cluster
12 | -> reference.LabMember #user name
13 | ---
14 | label_time: datetime # date on which labelling was done
15 | cluster_label=null: varchar(255) # user assigned label
16 | cluster_note=null: varchar(255) # user note about cluster
17 | """
18 |
19 |
20 | @schema
21 | class MergedClusters(dj.Imported):
22 | definition = """
23 | cluster_uuid: uuid # uuid of merged cluster
24 | ---
25 | merged_uuid: longblob # array of uuids of original clusters that form merged cluster
26 | """
27 |
--------------------------------------------------------------------------------
/launch_phy/phy_launcher.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import logging
3 | import os
4 | import pandas as pd
5 |
6 | from phy.apps.template import TemplateController, template_gui
7 | from phy.gui.qt import create_app, run_app
8 | from phylib import add_default_handler
9 | from one.api import ONE
10 | from brainbox.metrics.single_units import quick_unit_metrics
11 | from pathlib import Path
12 | from brainbox.io.one import SpikeSortingLoader
13 |
14 |
15 | def launch_phy(probe_name=None, eid=None, pid=None, subj=None, date=None, sess_no=None, one=None):
16 | """
17 | Launch phy given an eid and probe name.
18 | """
19 |
20 | # This is a first draft, no error handling and a draft dataset list.
21 |
22 | # Load data from probe #
23 | # -------------------- #
24 |
25 | from one.api import ONE
26 | from iblatlas.atlas import AllenAtlas
27 | from brainbox.io.one import SpikeSortingLoader
28 | import spikeglx
29 | one = one or ONE()
30 | ba = AllenAtlas()
31 |
32 | datasets = [
33 | 'spikes.times',
34 | 'spikes.clusters',
35 | 'spikes.amps',
36 | 'spikes.templates',
37 | 'spikes.samples',
38 | 'spikes.depths',
39 | 'clusters.uuids',
40 | 'clusters.metrics',
41 | 'clusters.waveforms',
42 | 'clusters.waveformsChannels',
43 | 'clusters.depths',
44 | 'clusters.amps',
45 | 'clusters.channels']
46 |
47 | if pid is None:
48 | ssl = SpikeSortingLoader(eid=eid, pname=probe_name, one=one, atlas=ba)
49 | else:
50 | ssl = SpikeSortingLoader(pid=pid, one=one, atlas=ba)
51 | ssl.download_spike_sorting(dataset_types=datasets)
52 | ssl.download_spike_sorting_object('templates')
53 | ssl.download_spike_sorting_object('spikes_subset')
54 |
55 | alf_dir = ssl.session_path.joinpath(ssl.collection)
56 |
57 | if not alf_dir.joinpath('clusters.metrics.pqt').exists():
58 | spikes, clusters, channels = ssl.load_spike_sorting()
59 | ssl.merge_clusters(spikes, clusters, channels, cache_dir=alf_dir)
60 |
61 | raw_file = next(ssl.session_path.joinpath('raw_ephys_folder', ssl.pname).glob('*.ap.*bin'), None)
62 |
63 | if raw_file is not None:
64 | sr = spikeglx.Reader(raw_file)
65 | sample_rate = sr.fs
66 | n_channel_dat = sr.nc - sr.nsync
67 | else:
68 | sample_rate = 30000
69 | n_channel_dat = 384
70 |
71 | # Launch phy #
72 | # -------------------- #
73 | add_default_handler('DEBUG', logging.getLogger("phy"))
74 | add_default_handler('DEBUG', logging.getLogger("phylib"))
75 | create_app()
76 | controller = TemplateController(dat_path=raw_file, dir_path=alf_dir, dtype=np.int16,
77 | n_channels_dat=n_channel_dat, sample_rate=sample_rate,
78 | plugins=['IBLMetricsPlugin'],
79 | plugin_dirs=[Path(__file__).resolve().parent / 'plugins'])
80 | gui = controller.create_gui()
81 | gui.show()
82 | run_app()
83 | gui.close()
84 | controller.model.close()
85 |
86 |
87 | if __name__ == '__main__':
88 | """
89 | `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -e a3df91c8-52a6-4afa-957b-3479a7d0897c -p probe00`
90 | `python int-brain-lab\iblapps\launch_phy\phy_launcher.py -pid c07d13ed-e387-4457-8e33-1d16aed3fd92`
91 | """
92 | from argparse import ArgumentParser
93 | import numpy as np
94 |
95 | parser = ArgumentParser()
96 | parser.add_argument('-s', '--subject', default=False, required=False,
97 | help='Subject Name')
98 | parser.add_argument('-d', '--date', default=False, required=False,
99 | help='Date of session YYYY-MM-DD')
100 | parser.add_argument('-n', '--session_no', default=1, required=False,
101 | help='Session Number', type=int)
102 | parser.add_argument('-e', '--eid', default=False, required=False,
103 | help='Session eid')
104 | parser.add_argument('-p', '--probe_label', default=False, required=False,
105 | help='Probe Label')
106 | parser.add_argument('-pid', '--pid', default=False, required=False,
107 | help='Probe ID')
108 |
109 | args = parser.parse_args()
110 | if args.eid:
111 | launch_phy(probe_name=str(args.probe_label), eid=str(args.eid))
112 | elif args.pid:
113 | launch_phy(pid=str(args.pid))
114 | else:
115 | if not np.all(np.array([args.subject, args.date, args.session_no],
116 | dtype=object)):
117 | print('Must give Subject, Date and Session number')
118 | else:
119 | launch_phy(probe_name=str(args.probe_label), subj=str(args.subject),
120 | date=str(args.date), sess_no=args.session_no)
121 | # launch_phy('probe00', subj='KS022',
122 | # date='2019-12-10', sess_no=1)
123 |
--------------------------------------------------------------------------------
/launch_phy/plugins/phy_plugin.py:
--------------------------------------------------------------------------------
1 | # import from plugins/cluster_metrics.py
2 | """Show how to add a custom cluster metrics."""
3 |
4 | import logging
5 | import numpy as np
6 | from phy import IPlugin
7 | import pandas as pd
8 | from pathlib import Path
9 | from brainbox.metrics.single_units import quick_unit_metrics
10 |
11 |
12 | class IBLMetricsPlugin(IPlugin):
13 | def attach_to_controller(self, controller):
14 | """Note that this function is called at initialization time, *before* the supervisor is
15 | created. The `controller.cluster_metrics` items are then passed to the supervisor when
16 | constructing it."""
17 |
18 | clusters_file = Path(controller.dir_path.joinpath('clusters.metrics.pqt'))
19 | if clusters_file.exists():
20 | self.metrics = pd.read_parquet(clusters_file)
21 | else:
22 | self.metrics = None
23 | return
24 |
25 | def amplitudes(cluster_id):
26 | amps = controller.get_cluster_amplitude(cluster_id)
27 | return amps * 1e6
28 |
29 | def amp_max(cluster_id):
30 | return self.metrics['amp_max'].iloc[cluster_id] * 1e6
31 |
32 | def amp_min(cluster_id):
33 | return self.metrics['amp_min'].iloc[cluster_id] * 1e6
34 |
35 | def amp_median(cluster_id):
36 | return self.metrics['amp_median'].iloc[cluster_id] * 1e6
37 |
38 | def amp_std_dB(cluster_id):
39 | return self.metrics['amp_std_dB'].iloc[cluster_id]
40 |
41 | def contamination(cluster_id):
42 | return self.metrics['contamination'].iloc[cluster_id]
43 |
44 | def contamination_alt(cluster_id):
45 | return self.metrics['contamination_alt'].iloc[cluster_id]
46 |
47 | def drift(cluster_id):
48 | return self.metrics['drift'].iloc[cluster_id]
49 |
50 | def missed_spikes_est(cluster_id):
51 | return self.metrics['missed_spikes_est'].iloc[cluster_id]
52 |
53 | def noise_cutoff(cluster_id):
54 | return self.metrics['noise_cutoff'].iloc[cluster_id]
55 |
56 | def presence_ratio(cluster_id):
57 | return self.metrics['presence_ratio'].iloc[cluster_id]
58 |
59 | def presence_ratio_std(cluster_id):
60 | return self.metrics['presence_ratio_std'].iloc[cluster_id]
61 |
62 | def slidingRP_viol(cluster_id):
63 | return self.metrics['slidingRP_viol'].iloc[cluster_id]
64 |
65 | def spike_count(cluster_id):
66 | return self.metrics['spike_count'].iloc[cluster_id]
67 |
68 | def firing_rate(cluster_id):
69 | return self.metrics['firing_rate'].iloc[cluster_id]
70 |
71 | def label(cluster_id):
72 | return self.metrics['label'].iloc[cluster_id]
73 |
74 | def ks2_label(cluster_id):
75 | if 'ks2_label' in self.metrics.columns:
76 | return self.metrics['ks2_label'].iloc[cluster_id]
77 | else:
78 | return 'nan'
79 |
80 | controller.cluster_metrics['amplitudes'] = controller.context.memcache(amplitudes)
81 | controller.cluster_metrics['amp_max'] = controller.context.memcache(amp_max)
82 | controller.cluster_metrics['amp_min'] = controller.context.memcache(amp_min)
83 | controller.cluster_metrics['amp_median'] = controller.context.memcache(amp_median)
84 | controller.cluster_metrics['amp_std_dB'] = controller.context.memcache(amp_std_dB)
85 | controller.cluster_metrics['contamination'] = controller.context.memcache(contamination)
86 | controller.cluster_metrics['contamination_alt'] = controller.context.memcache(contamination_alt)
87 | controller.cluster_metrics['drift'] = controller.context.memcache(drift)
88 | controller.cluster_metrics['missed_spikes_est'] = controller.context.memcache(missed_spikes_est)
89 | controller.cluster_metrics['noise_cutoff'] = controller.context.memcache(noise_cutoff)
90 | controller.cluster_metrics['presence_ratio'] = controller.context.memcache(presence_ratio)
91 | controller.cluster_metrics['presence_ratio_std'] = controller.context.memcache(presence_ratio_std)
92 | controller.cluster_metrics['slidingRP_viol'] = controller.context.memcache(slidingRP_viol)
93 | controller.cluster_metrics['spike_count'] = controller.context.memcache(spike_count)
94 | controller.cluster_metrics['firing_rate'] = controller.context.memcache(firing_rate)
95 | controller.cluster_metrics['label'] = controller.context.memcache(label)
96 | controller.cluster_metrics['ks2_label'] = controller.context.memcache(ks2_label)
97 |
--------------------------------------------------------------------------------
/launch_phy/populate_cluster_table.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | import uuid
4 | from datetime import datetime
5 |
6 | from tqdm import tqdm
7 | from one import alf
8 | import pandas as pd
9 | from one.api import ONE
10 | from one import params
11 |
12 | from launch_phy import cluster_table
13 |
14 |
15 | def populate_dj_with_phy(probe_label, eid=None, subj=None, date=None,
16 | sess_no=None, one=None):
17 | if one is None:
18 | one = ONE()
19 |
20 | if eid is None:
21 | eid = one.search(subject=subj, date=date, number=sess_no)[0]
22 |
23 | sess_path = one.eid2path(eid)
24 |
25 | cols = one.list_collections(eid)
26 | if f'alf/{probe_label}/pykilosort' in cols:
27 | collection = f'alf/{probe_label}/pykilosort'
28 | else:
29 | collection = f'alf/{probe_label}'
30 |
31 |
32 | alf_path = sess_path.joinpath(collection)
33 |
34 | cluster_path = Path(alf_path, 'spikes.clusters.npy')
35 | template_path = Path(alf_path, 'spikes.templates.npy')
36 |
37 | # Compare spikes.clusters with spikes.templates to find which clusters have been merged
38 | phy_clusters = np.load(cluster_path)
39 | id_phy = np.unique(phy_clusters)
40 | orig_clusters = np.load(template_path)
41 | id_orig = np.unique(orig_clusters)
42 |
43 | uuid_list = alf.io.load_file_content(alf_path.joinpath('clusters.uuids.csv'))
44 |
45 | # First deal with merged clusters and make sure they have cluster uuids assigned
46 | # Find the original cluster ids that have been merged into a new cluster
47 | merged_idx = np.setdiff1d(id_orig, id_phy)
48 |
49 | # See if any clusters have been merged, if not skip to the next bit
50 | if np.any(merged_idx):
51 | # Make association between original cluster and new cluster id and save in dict
52 | merge_list = {}
53 | for m in merged_idx:
54 | idx = phy_clusters[np.where(orig_clusters == m)[0][0]]
55 | if idx in merge_list:
56 | merge_list[idx].append(m)
57 | else:
58 | merge_list[idx] = [m]
59 |
60 | # Create a dataframe from the dict
61 | merge_clust = pd.DataFrame(columns={'cluster_idx', 'merged_uuid', 'merged_id'})
62 | for key, value in merge_list.items():
63 | value_uuid = uuid_list['uuids'][value]
64 | merge_clust = merge_clust.append({'cluster_idx': key, 'merged_uuid': tuple(value_uuid),
65 | 'merged_idx': tuple(value)},
66 | ignore_index=True)
67 |
68 | # Get the dj table that has previously stored merged clusters and store in frame
69 | merge = cluster_table.MergedClusters()
70 | merge_dj = pd.DataFrame(columns={'cluster_uuid', 'merged_uuid'})
71 | merge_dj['cluster_uuid'] = merge.fetch('cluster_uuid').astype(str)
72 | merge_dj['merged_uuid'] = tuple(map(tuple, merge.fetch('merged_uuid')))
73 |
74 | # Merge the two dataframe to see if any merge combinations already have a cluster_uuid
75 | merge_comb = pd.merge(merge_dj, merge_clust, on=['merged_uuid'], how='outer')
76 |
77 | # Find the merged clusters that do not have a uuid assigned
78 | no_uuid = np.where(pd.isnull(merge_comb['cluster_uuid']))[0]
79 |
80 | # Assign new uuid to new merge pairs and add to the merge table
81 | for nid in no_uuid:
82 | new_uuid = str(uuid.uuid4())
83 | merge_comb['cluster_uuid'].iloc[nid] = new_uuid
84 | merge.insert1(
85 | dict(cluster_uuid=new_uuid, merged_uuid=merge_comb['merged_uuid'].iloc[nid]),
86 | allow_direct_insert=True)
87 |
88 | # Add all the uuids to the cluster_uuid frame with index according to cluster id from phy
89 | for idx, c_uuid in zip(merge_comb['cluster_idx'].values,
90 | merge_comb['cluster_uuid'].values):
91 | uuid_list.loc[idx] = c_uuid
92 |
93 | csv_path = Path(alf_path, 'merge_info.csv')
94 | merge_comb = merge_comb.reindex(columns=['cluster_idx', 'cluster_uuid', 'merged_idx',
95 | 'merged_uuid'])
96 |
97 | try:
98 | merge_comb.to_csv(csv_path, index=False)
99 | except Exception as err:
100 | print(err)
101 | print('Close merge_info.csv file and then relaunch script')
102 | sys.exit(1)
103 | else:
104 | print('No merges detected, continuing...')
105 |
106 | # Now populate datajoint with cluster labels
107 | user = params.get().ALYX_LOGIN
108 | current_date = datetime.now().replace(microsecond=0)
109 |
110 | try:
111 | cluster_group = alf.io.load_file_content(alf_path.joinpath('cluster_group.tsv'))
112 | except Exception as err:
113 | print(err)
114 | print('Could not find cluster group file output from phy')
115 | sys.exit(1)
116 |
117 | try:
118 | cluster_notes = alf.io.load_file_content(alf_path.joinpath('cluster_notes.tsv'))
119 | cluster_info = pd.merge(cluster_group, cluster_notes, on=['cluster_id'], how='outer')
120 | except Exception as err:
121 | cluster_info = cluster_group
122 | cluster_info['notes'] = None
123 |
124 | cluster_info = cluster_info.where(cluster_info.notnull(), None)
125 | cluster_info['cluster_uuid'] = uuid_list['uuids'][cluster_info['cluster_id']].values
126 |
127 | # dj table that holds data
128 | cluster = cluster_table.ClusterLabel()
129 |
130 | # Find clusters that have already been labelled by user
131 | old_clust = cluster & cluster_info & {'user_name': user}
132 |
133 | dj_clust = pd.DataFrame()
134 | dj_clust['cluster_uuid'] = (old_clust.fetch('cluster_uuid')).astype(str)
135 | dj_clust['cluster_label'] = old_clust.fetch('cluster_label')
136 |
137 | # First find the new clusters to insert into datajoint
138 | idx_new = np.where(np.isin(cluster_info['cluster_uuid'],
139 | dj_clust['cluster_uuid'], invert=True))[0]
140 | cluster_uuid = cluster_info['cluster_uuid'][idx_new].values
141 | cluster_label = cluster_info['group'][idx_new].values
142 | cluster_note = cluster_info['notes'][idx_new].values
143 |
144 | if idx_new.size != 0:
145 | print('Populating dj with ' + str(idx_new.size) + ' new labels')
146 | else:
147 | print('No new labels to add')
148 | for iClust, iLabel, iNote in zip(tqdm(cluster_uuid), cluster_label, cluster_note):
149 | cluster.insert1(dict(cluster_uuid=iClust, user_name=user,
150 | label_time=current_date,
151 | cluster_label=iLabel,
152 | cluster_note=iNote),
153 | allow_direct_insert=True)
154 |
155 | # Next look through clusters already on datajoint and check if any labels have
156 | # been changed
157 | comp_clust = pd.merge(cluster_info, dj_clust, on='cluster_uuid')
158 | idx_change = np.where(comp_clust['group'] != comp_clust['cluster_label'])[0]
159 |
160 | cluster_uuid = comp_clust['cluster_uuid'][idx_change].values
161 | cluster_label = comp_clust['group'][idx_change].values
162 | cluster_note = comp_clust['notes'][idx_change].values
163 |
164 | # Populate table
165 | if idx_change.size != 0:
166 | print('Replacing label of ' + str(idx_change.size) + ' clusters')
167 | else:
168 | print('No labels to change')
169 | for iClust, iLabel, iNote in zip(tqdm(cluster_uuid), cluster_label, cluster_note):
170 | prev_clust = cluster & {'user_name': user} & {'cluster_uuid': iClust}
171 | cluster.insert1(dict(*prev_clust.proj(),
172 | label_time=current_date,
173 | cluster_label=iLabel,
174 | cluster_note=iNote),
175 | allow_direct_insert=True, replace=True)
176 |
177 | print('Upload to datajoint complete')
178 |
179 |
180 | if __name__ == '__main__':
181 | from argparse import ArgumentParser
182 | import numpy as np
183 |
184 | parser = ArgumentParser()
185 | parser.add_argument('-s', '--subject', default=False, required=False,
186 | help='Subject Name')
187 | parser.add_argument('-d', '--date', default=False, required=False,
188 | help='Date of session YYYY-MM-DD')
189 | parser.add_argument('-n', '--session_no', default=1, required=False,
190 | help='Session Number', type=int)
191 | parser.add_argument('-e', '--eid', default=False, required=False,
192 | help='Session eid')
193 | parser.add_argument('-p', '--probe_label', default=False, required=True,
194 | help='Probe Label')
195 | args = parser.parse_args()
196 |
197 | if args.eid:
198 | populate_dj_with_phy(str(args.probe_label), eid=str(args.eid))
199 | else:
200 | if not np.all(np.array([args.subject, args.date, args.session_no],
201 | dtype=object)):
202 | print('Must give Subject, Date and Session number')
203 | else:
204 | populate_dj_with_phy(str(args.probe_label), subj=str(args.subject),
205 | date=str(args.date), sess_no=args.session_no)
206 |
207 | #populate_dj_with_phy('probe00', subj='KS022', date='2019-12-10', sess_no=1)
208 |
--------------------------------------------------------------------------------
/needles2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/needles2/__init__.py
--------------------------------------------------------------------------------
/needles2/layerUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Form
4 |
5 |
6 |
7 | 0
8 | 0
9 | 400
10 | 82
11 |
12 |
13 |
14 | Form
15 |
16 |
17 |
18 |
19 | 20
20 | 20
21 | 341
22 | 41
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/needles2/mainUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | MainWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 1501
10 | 884
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 |
18 | -
19 |
20 |
21 |
22 | 0
23 | 0
24 |
25 |
26 |
27 |
-
28 |
29 |
30 | -
31 |
32 |
33 | -
34 |
35 |
36 |
37 |
38 |
39 | -
40 |
41 |
42 |
-
43 |
44 |
45 | -
46 |
47 |
48 | -
49 |
50 |
51 | -
52 |
53 |
54 |
55 |
56 |
57 | -
58 |
59 |
60 |
-
61 |
62 |
63 |
64 |
65 |
66 | -
67 |
68 |
69 |
-
70 |
71 |
72 | -
73 |
74 |
75 |
76 | -
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
94 |
95 |
96 |
97 |
98 | PlotWidget
99 | QGraphicsView
100 |
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/needles2/needles_viewer.py:
--------------------------------------------------------------------------------
1 | from iblviewer.atlas_controller import AtlasController
2 | import vedo
3 | from iblviewer.atlas_model import AtlasModel, AtlasUIModel, CameraModel
4 | from iblviewer.slicer_model import SlicerModel
5 |
6 | from iblviewer.atlas_view import AtlasView
7 | from iblviewer.volume_view import VolumeView
8 | from iblviewer.slicer_view import SlicerView
9 | import iblviewer.utils as utils
10 |
11 | from ipyvtk_simple.viewer import ViewInteractiveWidget
12 |
13 |
14 | class NeedlesViewer(AtlasController):
15 | def __init__(self):
16 | super(NeedlesViewer, self).__init__()
17 |
18 | def initialize(self, plot, resolution=25, mapping='Allen', volume_mode=None, num_windows=1, render=False):
19 | vedo.settings.allowInteraction = False
20 | self.plot = plot
21 | self.plot_window_id = 0
22 | self.model = AtlasModel()
23 | self.model.initialize(resolution)
24 | self.model.load_allen_volume(mapping, volume_mode)
25 | self.model.initialize_slicers()
26 |
27 | self.view = AtlasView(self.plot, self.model)
28 | self.view.initialize()
29 | self.view.volume = VolumeView(self.plot, self.model.volume, self.model)
30 |
31 | #pn = SlicerModel.NAME_XYZ_POSITIVE
32 | nn = SlicerModel.NAME_XYZ_NEGATIVE
33 |
34 | #pxs_model = self.model.find_model(pn[0], self.model.slicers)
35 | #self.px_slicer = SlicerView(self.plot, self.view.volume, pxs_model, self.model)
36 | #pys_model = self.model.find_model(pn[1], self.model.slicers)
37 | #self.py_slicer = SlicerView(self.plot, self.view.volume, pys_model, self.model)
38 | #pzs_model = self.model.find_model(pn[2], self.model.slicers)
39 | #self.pz_slicer = SlicerView(self.plot, self.view.volume, pzs_model, self.model)
40 |
41 | nxs_model = self.model.find_model(nn[0], self.model.slicers)
42 | self.nx_slicer = SlicerView(self.plot, self.view.volume, nxs_model, self.model)
43 | nys_model = self.model.find_model(nn[1], self.model.slicers)
44 | self.ny_slicer = SlicerView(self.plot, self.view.volume, nys_model, self.model)
45 | nzs_model = self.model.find_model(nn[2], self.model.slicers)
46 | self.nz_slicer = SlicerView(self.plot, self.view.volume, nzs_model, self.model)
47 |
48 | self.slicers = [self.nx_slicer, self.ny_slicer, self.nz_slicer]
49 |
50 | vedo.settings.defaultFont = self.model.ui.font
51 | self.initialize_embed_ui(slicer_target=self.view.volume)
52 |
53 | self.plot.show(interactive=False)
54 | self.handle_transfer_function_update()
55 | # By default, the atlas volume is our target
56 | self.model.camera.target = self.view.volume.actor
57 | # We start with a sagittal view
58 | self.set_left_view()
59 |
60 | self.render()
61 |
62 | def initialize_embed_ui(self, slicer_target=None):
63 | s_kw = self.model.ui.slider_config
64 | d = self.view.volume.model.dimensions
65 | if d is None:
66 | return
67 |
68 | #self.add_slider('px', self.update_px_slicer, 0, int(d[0]), 0, pos=(0.05, 0.065, 0.12),
69 | # title='+X', **s_kw)
70 | #self.add_slider('py', self.update_py_slicer, 0, int(d[1]), 0, pos=(0.2, 0.065, 0.12),
71 | # title='+Y', **s_kw)
72 | #self.add_slider('pz', self.update_pz_slicer, 0, int(d[2]), 0, pos=(0.35, 0.065, 0.12),
73 | # title='+Z', **s_kw)
74 |
75 | def update_slicer(self, slicer_view, value):
76 | """
77 | Update a given slicer with the given value
78 | :param slicer_view: SlicerView instance
79 | :param value: Value
80 | """
81 | volume = self.view.volume
82 | model = slicer_view.model
83 | model.set_value(value)
84 | model.clipping_planes = volume.get_clipping_planes(model.axis)
85 | slicer_view.update(add_to_scene=self.model.slices_visible)
86 |
87 |
88 |
--------------------------------------------------------------------------------
/needles2/regionUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Form
4 |
5 |
6 |
7 | 0
8 | 0
9 | 209
10 | 210
11 |
12 |
13 |
14 |
15 | 0
16 | 0
17 |
18 |
19 |
20 | Form
21 |
22 |
23 |
24 |
25 | 10
26 | 10
27 | 291
28 | 210
29 |
30 |
31 |
32 |
33 | 0
34 | 0
35 |
36 |
37 |
38 | QFrame::StyledPanel
39 |
40 |
41 | QFrame::Raised
42 |
43 |
44 |
45 |
46 | 10
47 | 10
48 | 100
49 | 30
50 |
51 |
52 |
53 |
54 | 75
55 | true
56 |
57 |
58 |
59 | Selected Region:
60 |
61 |
62 |
63 |
64 |
65 | 10
66 | 40
67 | 171
68 | 20
69 |
70 |
71 |
72 | QComboBox::AdjustToContentsOnFirstShow
73 |
74 |
75 | 0
76 |
77 |
78 |
79 |
80 |
81 | 10
82 | 70
83 | 90
84 | 30
85 |
86 |
87 |
88 |
89 | 75
90 | true
91 |
92 |
93 |
94 | Current Region:
95 |
96 |
97 |
98 |
99 |
100 | 110
101 | 70
102 | 80
103 | 30
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 | 10
114 | 100
115 | 30
116 | 30
117 |
118 |
119 |
120 |
121 | 75
122 | true
123 |
124 |
125 |
126 | ML:
127 |
128 |
129 |
130 |
131 |
132 | 50
133 | 100
134 | 60
135 | 30
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | 10
146 | 130
147 | 30
148 | 30
149 |
150 |
151 |
152 |
153 | 75
154 | true
155 |
156 |
157 |
158 | AP:
159 |
160 |
161 |
162 |
163 |
164 | 50
165 | 130
166 | 60
167 | 30
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 | 10
178 | 160
179 | 50
180 | 30
181 |
182 |
183 |
184 |
185 | 75
186 | true
187 |
188 |
189 |
190 | DV:
191 |
192 |
193 |
194 |
195 |
196 | 50
197 | 160
198 | 50
199 | 30
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/needles2/tableUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Form
4 |
5 |
6 |
7 | 0
8 | 0
9 | 400
10 | 300
11 |
12 |
13 |
14 | Form
15 |
16 |
17 |
18 |
19 | 10
20 | 10
21 | 321
22 | 271
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/qt_helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/qt_helpers/__init__.py
--------------------------------------------------------------------------------
/qt_helpers/qt.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from functools import wraps
4 |
5 | from PyQt5 import QtWidgets
6 |
7 | _logger = logging.getLogger('ibllib')
8 |
9 |
10 | def get_main_window():
11 | """ Get the Main window of a QT application"""
12 | app = QtWidgets.QApplication.instance()
13 | return [w for w in app.topLevelWidgets() if isinstance(w, QtWidgets.QMainWindow)][0]
14 |
15 |
16 | def create_app():
17 | """Create a Qt application."""
18 | global QT_APP
19 | QT_APP = QtWidgets.QApplication.instance()
20 | if QT_APP is None: # pragma: no cover
21 | QT_APP = QtWidgets.QApplication(sys.argv)
22 | return QT_APP
23 |
24 |
25 | def require_qt(func):
26 | """Function decorator to specify that a function requires a Qt application.
27 | Use this decorator to specify that a function needs a running
28 | Qt application before it can run. An error is raised if that is not
29 | the case.
30 | """
31 | @wraps(func)
32 | def wrapped(*args, **kwargs):
33 | if not QtWidgets.QApplication.instance(): # pragma: no cover
34 | _logger.warning("Creating a Qt application.")
35 | create_app()
36 | return func(*args, **kwargs)
37 | return wrapped
38 |
39 |
40 | @require_qt
41 | def run_app(): # pragma: no cover
42 | """Run the Qt application."""
43 | global QT_APP
44 | return QT_APP.exit(QT_APP.exec_())
45 |
--------------------------------------------------------------------------------
/qt_helpers/qt_matplotlib.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | # Make sure that we are using QT5
3 | matplotlib.use('Qt5Agg')
4 | from PyQt5 import QtCore, QtWidgets, uic
5 |
6 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
7 | from matplotlib.figure import Figure
8 |
9 |
10 | class BaseMplCanvas(FigureCanvas):
11 | """
12 | ui
13 | model
14 | """
15 | def __init__(self, parent=None, ui=None, model=None, width=5, height=4, dpi=100):
16 | fig = Figure(figsize=(width, height), dpi=dpi)
17 | self.axes = fig.add_subplot(111)
18 | self._ui = ui
19 | self._model = model
20 | self.compute_initial_figure()
21 |
22 | FigureCanvas.__init__(self, fig)
23 | self.setParent(parent)
24 |
25 | FigureCanvas.setSizePolicy(self,
26 | QtWidgets.QSizePolicy.Expanding,
27 | QtWidgets.QSizePolicy.Expanding)
28 | FigureCanvas.updateGeometry(self)
29 | self.setCursor(QtCore.Qt.CrossCursor)
30 |
31 | def compute_initial_figure(self):
32 | pass
33 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easyqc
2 | ibllib
3 | pyqtgraph
4 | simpleITK
5 | pyqt5
--------------------------------------------------------------------------------
/run_tests:
--------------------------------------------------------------------------------
1 | python -m unittest discover
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open('requirements.txt') as f:
4 | require = [x.strip() for x in f.readlines() if not x.startswith('git+')]
5 |
6 | setup(
7 | name='iblapps',
8 | version='0.1',
9 | python_requires='>=3.10',
10 | packages=find_packages(),
11 | include_package_data=True,
12 | install_requires=require,
13 | entry_points={
14 | 'console_scripts': [
15 | 'atlas=atlasview.atlasview:main',
16 | 'ephys-align=atlaselectrophysiology.ephys_atlas_gui:launch_offline',
17 | ]
18 | },
19 | )
20 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/fixtures/data_alignmentqc_gui.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/int-brain-lab/iblapps/59c50b4567b13f78aa67d4258f730f591f07592c/tests/fixtures/data_alignmentqc_gui.npz
--------------------------------------------------------------------------------
/tests/test_task_qc_viewer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from pathlib import Path
4 |
5 |
6 | class TestTaskQCViewer(unittest.TestCase):
7 | def setUp(self):
8 | iblapps_path = Path(__file__).parts[:-2]
9 | self.task_qc_path = Path(*iblapps_path) / 'task_qc_viewer' / 'task_qc.py'
10 | self.UUID = '1211f4af-d3e4-4c4e-9d0b-75a0bc2bf1f0'
11 |
12 | def test_build(self):
13 | os.system(f'python {self.task_qc_path} {self.UUID}')
14 |
15 | def test_build_fail(self):
16 | os.system(f'python {self.task_qc_path} "BAD_ID"')
17 |
--------------------------------------------------------------------------------
/viewspikes/README.md:
--------------------------------------------------------------------------------
1 | # View Spikes
2 | Electrophysiology display tools for IBL insertions.
3 |
4 | ## Run instructions
5 | Look at the examples file at the root of the package.
6 |
7 | ## Install Instructions
8 |
9 | Pre-requisites:
10 | - IBL environment installed https://github.com/int-brain-lab/iblenv
11 | - ONE parameters setup to download data
12 |
13 | ```
14 | conda activate iblenv
15 | pip install viewephys
16 | pip install -U pyqtgraph
17 | ```
18 |
19 |
20 | ## Roadmap
21 | - multi-probe displays for sessions with several probes
22 | - make sure NP2.0 4 shanks is supported
23 | - display LFP
24 | - speed up the raster plots in pyqtgraph
25 |
--------------------------------------------------------------------------------
/viewspikes/datoviz.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import numpy as np
4 |
5 | import datoviz as dviz
6 |
7 | # -------------------------------------------------------------------------------------------------
8 | # Raster viewer
9 | # -------------------------------------------------------------------------------------------------
10 |
11 |
12 | class RasterView:
13 | def __init__(self):
14 | self.canvas = dviz.canvas(show_fps=True)
15 | self.panel = self.canvas.panel(controller='axes')
16 | self.visual = self.panel.visual('point')
17 | self.pvars = {'ms': 2., 'alpha': .03}
18 | self.gui = self.canvas.gui('XY')
19 | self.gui.control("label", "Coords", value="(0, 0)")
20 |
21 | def set_spikes(self, spikes):
22 | pos = np.c_[spikes.times, spikes.depths, np.zeros_like(spikes.times)]
23 | color = dviz.colormap(20 * np.log10(spikes.amps), cmap='cividis', alpha=self.pvars['alpha'])
24 | self.visual.data('pos', pos)
25 | self.visual.data('color', color)
26 | self.visual.data('ms', np.array([self.pvars['ms']]))
27 |
28 |
29 | class RasterController:
30 | _time_select_cb = None
31 |
32 | def __init__(self, model, view):
33 | self.m = model
34 | self.v = view
35 | self.v.canvas.connect(self.on_mouse_move)
36 | self.v.canvas.connect(self.on_key_press)
37 | self.redraw()
38 |
39 | def redraw(self):
40 | print('redraw', self.v.pvars)
41 | self.v.set_spikes(self.m.spikes)
42 |
43 | def on_mouse_move(self, x, y, modifiers=()):
44 | p = self.v.canvas.panel_at(x, y)
45 | if not p:
46 | return
47 | # Then, we transform into the data coordinate system
48 | # Supported coordinate systems:
49 | # target_cds='data' / 'scene' / 'vulkan' / 'framebuffer' / 'window'
50 | xd, yd = p.pick(x, y)
51 | self.v.gui.set_value("Coords", f"({xd:0.2f}, {yd:0.2f})")
52 |
53 | def on_key_press(self, key, modifiers=()):
54 | print(key, modifiers)
55 | if key == 'a' and modifiers == ('control',):
56 | self.v.pvars['alpha'] = np.minimum(self.v.pvars['alpha'] + 0.1, 1.)
57 | elif key == 'z' and modifiers == ('control',):
58 | self.v.pvars['alpha'] = np.maximum(self.v.pvars['alpha'] - 0.1, 0.)
59 | elif key == 'page_up':
60 | self.v.pvars['ms'] = np.minimum(self.v.pvars['ms'] * 1.1, 20)
61 | elif key == 'page_down':
62 | self.v.pvars['ms'] = np.maximum(self.v.pvars['ms'] / 1.1, 1)
63 | else:
64 | return
65 | self.redraw()
66 |
67 |
68 | @dataclass
69 | class RasterModel:
70 | spikes: dict
71 |
72 |
73 | def raster(spikes):
74 | rm = RasterController(RasterModel(spikes), RasterView())
75 | dviz.run()
76 |
--------------------------------------------------------------------------------
/viewspikes/example_view_ephys_session.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from one.api import ONE
3 | from brainbox.io.one import EphysSessionLoader, SpikeSortingLoader
4 | from iblapps.viewspikes.gui import view_raster
5 |
6 | PATH_CACHE = Path("/datadisk/Data/NAWG/01_lick_artefacts/openalyx")
7 |
8 | one = ONE(base_url="https://openalyx.internationalbrainlab.org", cache_dir=PATH_CACHE)
9 |
10 | pid = '5135e93f-2f1f-4301-9532-b5ad62548c49'
11 | eid, pname = one.pid2eid(pid)
12 |
13 |
14 | self = view_raster(pid=pid, one=one, stream=False)
15 |
16 |
--------------------------------------------------------------------------------
/viewspikes/gui.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import scipy.signal
5 | from PyQt5 import QtWidgets, QtCore, QtGui, uic
6 | import pyqtgraph as pg
7 |
8 | from iblutil.numerical import bincount2D
9 | from viewephys.gui import viewephys
10 | import ibldsp
11 | from brainbox.io.one import EphysSessionLoader, SpikeSortingLoader
12 | from iblatlas.atlas import BrainRegions
13 |
14 |
15 | regions = BrainRegions()
16 | T_BIN = .007 # time bin size in secs
17 | D_BIN = 20 # depth bin size in um
18 |
19 | SNS_PALETTE = [(0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
20 | (1.0, 0.4980392156862745, 0.054901960784313725),
21 | (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
22 | (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
23 | (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
24 | (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
25 | (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
26 | (0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
27 | (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
28 | (0.09019607843137255, 0.7450980392156863, 0.8117647058823529)]
29 |
30 | YMIN, YMAX = (-1, 4000)
31 |
32 |
33 | def get_trial_events_to_display(trials):
34 | errors = trials['feedback_times'][trials['feedbackType'] == -1].values
35 | errors = np.sort(np.r_[errors, errors + .5])
36 | gocue = trials['goCue_times'].values
37 | gocue = np.sort(np.r_[gocue, gocue + .11])
38 | trial_events = dict(
39 | goCue_times=gocue,
40 | error_times=errors,
41 | reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values)
42 | return trial_events
43 |
44 |
45 | def view_raster(pid, one, stream=True):
46 | from qt import create_app
47 | app = create_app()
48 | ssl = SpikeSortingLoader(one=one, pid=pid)
49 | sl = EphysSessionLoader(one=one, eid=ssl.eid)
50 | sl.load_trials()
51 | spikes, clusters, channels = ssl.load_spike_sorting(dataset_types=['spikes.samples'])
52 | clusters = ssl.merge_clusters(spikes, clusters, channels)
53 | return RasterView(ssl, spikes, clusters, channels, trials=sl.trials, stream=stream)
54 |
55 |
56 | class RasterView(QtWidgets.QMainWindow):
57 | plotItem_raster: pg.PlotWidget = None
58 |
59 | def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=True, *args, **kwargs):
60 | self.ssl = ssl
61 | self.spikes = spikes
62 | self.clusters = clusters
63 | self.channels = channels
64 | self.trials = trials
65 | # self.sr_lf = ssl.raw_electrophysiology(band='lf', stream=False) that would be cool to also have the LFP
66 | self.sr = ssl.raw_electrophysiology(band='ap', stream=stream)
67 | self.eqcs = []
68 | super(RasterView, self).__init__(*args, **kwargs)
69 | # wave by Diana Militano from the Noun Projectp
70 | uic.loadUi(Path(__file__).parent.joinpath('raster.ui'), self)
71 | background_color = self.palette().color(self.backgroundRole())
72 | self.plotItem_raster.setAspectLocked(False)
73 | self.imageItem_raster = pg.ImageItem()
74 | self.plotItem_raster.setBackground(background_color)
75 | self.plotItem_raster.addItem(self.imageItem_raster)
76 | self.viewBox_raster = self.plotItem_raster.getPlotItem().getViewBox()
77 | s = self.viewBox_raster.scene()
78 | # vb.scene().sigMouseMoved.connect(self.mouseMoveEvent)
79 | s.sigMouseClicked.connect(self.mouseClick)
80 | ################################################### set image
81 | iok = ~np.isnan(spikes.depths)
82 | self.raster, self.rtimes, self.depths = bincount2D(
83 | spikes.times[iok], spikes.depths[iok], T_BIN, D_BIN)
84 | self.imageItem_raster.setImage(self.raster.T)
85 | transform = [T_BIN, 0., 0., 0., D_BIN, 0., - .5, - .5, 1.]
86 | self.transform = np.array(transform).reshape((3, 3)).T
87 | self.imageItem_raster.setTransform(QtGui.QTransform(*transform))
88 | self.plotItem_raster.setLimits(xMin=0, xMax=self.rtimes[-1], yMin=0, yMax=self.depths[-1])
89 | # set colormap
90 | cm = pg.colormap.get('Greys', source='matplotlib') # prepare a linear color map
91 | bar = pg.ColorBarItem(values=(0, .5), colorMap=cm) # prepare interactive color bar
92 | # Have ColorBarItem control colors of img and appear in 'plot':
93 | bar.setImageItem(self.imageItem_raster)
94 | ################################################## plot location
95 | # self.view.layers[label] = {'layer': new_scatter, 'type': 'scatter'}
96 | self.line_eqc = pg.PlotCurveItem()
97 | self.plotItem_raster.addItem(self.line_eqc)
98 | ################################################## plot trials
99 | if self.trials is not None:
100 | trial_times = get_trial_events_to_display(trials)
101 | self.trial_lines = {}
102 | for i, k in enumerate(trial_times):
103 | self.trial_lines[k] = pg.PlotCurveItem()
104 | self.plotItem_raster.addItem(self.trial_lines[k])
105 | x = np.tile(trial_times[k][:, np.newaxis], (1, 2)).flatten()
106 | y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(trial_times[k].shape[0] / 2 + 1))[:trial_times[k].shape[0] * 2]
107 | self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 255, width=2))
108 | self.show()
109 |
110 | def mouseClick(self, event):
111 | """Draws a line on the raster and display in EasyQC"""
112 | if not event.double():
113 | return
114 | qxy = self.imageItem_raster.mapFromScene(event.scenePos())
115 | x = qxy.x()
116 | self.show_ephys(t0=self.rtimes[int(x - 1)])
117 | ymax = np.max(self.depths) + 50
118 | self.line_eqc.setData(x=x + np.array([-.5, -.5, .5, .5]),
119 | y=np.array([0, ymax, ymax, 0]),
120 | pen=pg.mkPen((0, 255, 0)))
121 |
122 |
123 | def keyPressEvent(self, e):
124 | """
125 | page-up / ctrl + a : gain up
126 | page-down / ctrl + z : gain down
127 | :param e:
128 | """
129 | k, m = (e.key(), e.modifiers())
130 | # page up / ctrl + a
131 | if k == QtCore.Qt.Key_PageUp or (
132 | m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_A):
133 | self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] / 1.4])
134 | # page down / ctrl + z
135 | elif k == QtCore.Qt.Key_PageDown or (
136 | m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_Z):
137 | self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] * 1.4])
138 |
139 | def show_ephys(self, t0, tlen=1.8):
140 | """
141 | :param t0: behaviour time in seconds at which to start the view
142 | :param tlen:
143 | :return:
144 | """
145 | print(t0)
146 | s0 = int(self.ssl.samples2times(t0, direction='reverse'))
147 | s1 = s0 + int(self.sr.fs * tlen)
148 | raw = self.sr[s0:s1, : - self.sr.nsync].T
149 |
150 | butter_kwargs = {'N': 3, 'Wn': 300 / self.sr.fs * 2, 'btype': 'highpass'}
151 |
152 | sos = scipy.signal.butter(**butter_kwargs, output='sos')
153 | butt = scipy.signal.sosfiltfilt(sos, raw)
154 | destripe = ibldsp.voltage.destripe(raw, fs=self.sr.fs)
155 |
156 | self.eqc_raw = viewephys(butt, self.sr.fs, channels=self.channels, br=regions, title='butt', t0=t0, t_scalar=1)
157 | self.eqc_des = viewephys(destripe, self.sr.fs, channels=self.channels, br=regions, title='destripe', t0=t0, t_scalar=1)
158 |
159 | eqc_xrange = [t0 + tlen / 2 - 0.01, t0 + tlen / 2 + 0.01]
160 | self.eqc_des.viewBox_seismic.setXRange(*eqc_xrange)
161 | self.eqc_raw.viewBox_seismic.setXRange(*eqc_xrange)
162 |
163 | # we slice the spikes using the samples according to ephys time, but display in session times
164 | slice_spikes = slice(np.searchsorted(self.spikes['samples'], s0), np.searchsorted(self.spikes['samples'], s1))
165 | t = self.spikes['times'][slice_spikes]
166 | ic = self.spikes.clusters[slice_spikes]
167 |
168 | iok = self.clusters['label'][ic] == 1
169 |
170 | for eqc in [self.eqc_des, self.eqc_raw]:
171 | eqc.ctrl.add_scatter(t[~iok], self.clusters.channels[ic[~iok]], (255, 0, 0, 100), label='bad units')
172 | eqc.ctrl.add_scatter(t[iok], self.clusters.channels[ic[iok]], rgb=(0, 255, 0, 100), label='good units')
173 |
174 | if self.trials is not None:
175 | trial_events = get_trial_events_to_display(self.trials)
176 | for i, k in enumerate(trial_events):
177 | ie = np.logical_and(trial_events[k] >= t0, trial_events[k] <= (t0 + tlen))
178 | if np.sum(ie) == 0:
179 | continue
180 | te = trial_events[k][ie]
181 | x = np.tile(te[:, np.newaxis], (1, 2)).flatten()
182 | y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(te.shape[0] / 2 + 1))[:te.shape[0] * 2]
183 | for eqc in [self.eqc_des, self.eqc_raw]:
184 | eqc.ctrl.add_curve(x, y, rgb=(np.array(SNS_PALETTE[i]) * 255).astype(int), label=k)
185 | print(te)
186 |
187 |
--------------------------------------------------------------------------------
/viewspikes/load_ma_data.py:
--------------------------------------------------------------------------------
1 | from one.api import ONE
2 | import one.alf.io as alfio
3 | from iblutil.util import Bunch
4 | from qt_helpers import qt
5 | import numpy as np
6 | import pyqtgraph as pg
7 |
8 |
9 | import atlaselectrophysiology.ephys_atlas_gui as alignment_window
10 | import data_exploration_gui.gui_main as trial_window
11 |
12 |
13 | # some extra controls
14 |
15 | class AlignmentWindow(alignment_window.MainWindow):
16 | def __init__(self, offline=False, probe_id=None, one=None):
17 | super(AlignmentWindow, self).__init__(probe_id=probe_id, one=one)
18 | self.trial_gui = None
19 |
20 |
21 | def cluster_clicked(self, item, point):
22 | clust = super().cluster_clicked(item, point)
23 | print(clust)
24 | self.trial_gui.on_cluster_chosen(clust)
25 |
26 | def add_trials_to_raster(self, trial_key='feedback_times'):
27 | self.selected_trials = self.trial_gui.data.trials[trial_key]
28 | x, y = self.vertical_lines(self.selected_trials, 0, 3840)
29 | trial_curve = pg.PlotCurveItem()
30 | trial_curve.setData(x=x, y=y, pen=self.rpen_dot, connect='finite')
31 | trial_curve.setClickable(True)
32 | self.fig_img.addItem(trial_curve)
33 | self.fig_img.scene().sigMouseClicked.connect(self.on_mouse_clicked)
34 | trial_curve.sigClicked.connect(self.trial_line_clicked)
35 |
36 |
37 | def vertical_lines(self, x, ymin, ymax):
38 |
39 | x = np.tile(x, (3, 1))
40 | x[2, :] = np.nan
41 | y = np.zeros_like(x)
42 | y[0, :] = ymin
43 | y[1, :] = ymax
44 | y[2, :] = np.nan
45 |
46 | return x.T.flatten(), y.T.flatten()
47 |
48 | def trial_line_clicked(self, ev):
49 | self.clicked = ev
50 |
51 | def on_mouse_clicked(self, event):
52 | if not event.double() and type(self.clicked) == pg.PlotCurveItem:
53 | self.pos = self.data_plot.mapFromScene(event.scenePos())
54 | x = self.pos.x() * self.x_scale
55 | trial_id = np.argmin(np.abs(self.selected_trials - x))
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | # highlight the trial in the trial gui
65 |
66 |
67 |
68 | self.clicked = None
69 |
70 |
71 |
72 |
73 | class TrialWindow(trial_window.MainWindow):
74 | def __init__(self):
75 | super(TrialWindow, self).__init__()
76 | self.alignment_gui = None
77 | self.scat = None
78 |
79 | def on_scatter_plot_clicked(self, scatter, point):
80 | super().on_scatter_plot_clicked(scatter, point)
81 | self.add_clust_scatter()
82 |
83 | def on_cluster_list_clicked(self):
84 | super().on_cluster_list_clicked()
85 | self.add_clust_scatter()
86 |
87 | def on_next_cluster_clicked(self):
88 | super().on_next_cluster_clicked()
89 | self.add_clust_scatter()
90 |
91 | def on_previous_cluster_clicked(self):
92 | super().on_previous_cluster_clicked()
93 | self.add_clust_scatter()
94 |
95 | def add_clust_scatter(self):
96 | if not self.scat:
97 | self.scat = pg.ScatterPlotItem()
98 | self.alignment_gui.fig_img.addItem(self.scat)
99 |
100 | self.scat.setData(self.data.spikes.times[self.data.clus_idx],
101 | self.data.spikes.depths[self.data.clus_idx], brush='r')
102 |
103 | # need to get out spikes.times and spikes.depths
104 |
105 |
106 |
107 |
108 |
109 |
110 | def load_data(eid, probe, one=None):
111 | one = one or ONE()
112 | session_path = one.path_from_eid(eid).joinpath('alf')
113 | probe_path = session_path.joinpath(probe)
114 | data = Bunch()
115 | data['trials'] = alfio.load_object(session_path, 'trials', namespace='ibl')
116 | data['spikes'] = alfio.load_object(probe_path, 'spikes')
117 | data['clusters'] = alfio.load_object(probe_path, 'clusters')
118 |
119 | return data
120 |
121 | def viewer(probe_id=None, one=None):
122 | """
123 | """
124 | probe = one.alyx.rest('insertions', 'list', id=probe_id)[0]
125 | data = load_data(probe['session'], probe['name'], one=one)
126 | qt.create_app()
127 | av = AlignmentWindow(probe_id=probe_id, one=one)
128 | bv = TrialWindow()
129 | bv.on_data_given(data)
130 | av.trial_gui = bv
131 | bv.alignment_gui = av
132 |
133 | av.show()
134 | bv.show()
135 | return av, bv
136 |
137 |
--------------------------------------------------------------------------------
/viewspikes/plots.py:
--------------------------------------------------------------------------------
1 | # import modules
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import scipy.signal
5 | import pyqtgraph as pg
6 |
7 | import iblatlas.atlas as atlas
8 | from neuropixel import SITES_COORDINATES
9 | from ibllib.pipes.ephys_alignment import EphysAlignment
10 | from ibllib.plots import wiggle, color_cycle
11 |
12 | brain_atlas = atlas.AllenAtlas()
13 | # Instantiate brain atlas and one
14 |
15 |
16 | def show_psd(data, fs, ax=None):
17 | psd = np.zeros((data.shape[0], 129))
18 | for tr in np.arange(data.shape[0]):
19 | f, psd[tr, :] = scipy.signal.welch(data[tr, :], fs=fs)
20 |
21 | if ax is None:
22 | fig, ax = plt.subplots()
23 | ax.plot(f, 10 * np.log10(psd.T), color='gray', alpha=0.1)
24 | ax.plot(f, 10 * np.log10(np.mean(psd, axis=0).T), color='red')
25 | ax.set_xlabel('Frequency (Hz)')
26 | ax.set_ylabel('PSD (dB rel V/Hz)')
27 | ax.set_ylim(-150, -110)
28 | ax.set_xlim(0, fs / 2)
29 | plt.show()
30 |
31 |
32 | def plot_insertion(pid, one=None):
33 | # Find eid of interest
34 | assert one
35 | insertion = one.alyx.rest('insertions', 'list', id=pid)[0]
36 | probe_label = insertion['name']
37 | eid = insertion['session']
38 |
39 | # Load in channels.localCoordinates dataset type
40 | chn_coords = one.load(eid, dataset_types=['channels.localCoordinates'])[0]
41 | depths = chn_coords[:, 1]
42 |
43 | # Find the ephys aligned trajectory for eid probe combination
44 | trajs = one.alyx.rest('trajectories', 'list', session=eid, probe=probe_label)
45 | #provenance=,
46 | traj_aligned = next(filter(lambda x: x['provenance'] == 'Ephys aligned histology track', trajs), None)
47 | if traj_aligned is None:
48 | raise NotImplementedError(f"Plots only aligned insertions so far - TODO")
49 | else:
50 | plot_alignment(insertion, traj_aligned, -1)
51 | # Extract all alignments from the json field of object
52 | # Load in the initial user xyz_picks obtained from track traccing
53 |
54 |
55 | def plot_alignment(insertion, traj, ind=None):
56 | depths = SITES_COORDINATES[:, 1]
57 | xyz_picks = np.array(insertion['json']['xyz_picks']) / 1e6
58 |
59 | alignments = traj['json'].copy()
60 | k = list(alignments.keys())[-1] # if only I had a Walrus available !
61 | alignments = {k: alignments[k]}
62 | # Create a figure and arrange using gridspec
63 | widths = [1, 2.5]
64 | heights = [1] * len(alignments)
65 | gs_kw = dict(width_ratios=widths, height_ratios=heights)
66 | fig, axis = plt.subplots(len(alignments), 2, constrained_layout=True,
67 | gridspec_kw=gs_kw, figsize=(8, 9))
68 |
69 | # Iterate over all alignments for trajectory
70 | # 1. Plot brain regions that channel pass through
71 | # 2. Plot coronal slice along trajectory with location of channels shown as red points
72 | # 3. Save results for each alignment into a dict - channels
73 | channels = {}
74 | for iK, key in enumerate(alignments):
75 |
76 | # Location of reference lines used for alignmnet
77 | feature = np.array(alignments[key][0])
78 | track = np.array(alignments[key][1])
79 | chn_coords = SITES_COORDINATES
80 | # Instantiate EphysAlignment object
81 | ephysalign = EphysAlignment(xyz_picks, depths, track_prev=track, feature_prev=feature)
82 |
83 | # Find xyz location of all channels
84 | xyz_channels = ephysalign.get_channel_locations(feature, track)
85 | # Find brain region that each channel is located in
86 | brain_regions = ephysalign.get_brain_locations(xyz_channels)
87 | # Add extra keys to store all useful information as one bunch object
88 | brain_regions['xyz'] = xyz_channels
89 | brain_regions['lateral'] = chn_coords[:, 0]
90 | brain_regions['axial'] = chn_coords[:, 1]
91 |
92 | # Store brain regions result in channels dict with same key as in alignment
93 | channel_info = {key: brain_regions}
94 | channels.update(channel_info)
95 |
96 | # For plotting -> extract the boundaries of the brain regions, as well as CCF label and colour
97 | region, region_label, region_colour, _ = ephysalign.get_histology_regions(xyz_channels, depths)
98 |
99 | # Make plot that shows the brain regions that channels pass through
100 | ax_regions = fig.axes[iK * 2]
101 | for reg, col in zip(region, region_colour):
102 | height = np.abs(reg[1] - reg[0])
103 | bottom = reg[0]
104 | color = col / 255
105 | ax_regions.bar(x=0.5, height=height, width=1, color=color, bottom=reg[0], edgecolor='w')
106 | ax_regions.set_yticks(region_label[:, 0].astype(int))
107 | ax_regions.yaxis.set_tick_params(labelsize=8)
108 | ax_regions.get_xaxis().set_visible(False)
109 | ax_regions.set_yticklabels(region_label[:, 1])
110 | ax_regions.spines['right'].set_visible(False)
111 | ax_regions.spines['top'].set_visible(False)
112 | ax_regions.spines['bottom'].set_visible(False)
113 | ax_regions.hlines([0, 3840], *ax_regions.get_xlim(), linestyles='dashed', linewidth=3,
114 | colors='k')
115 | # ax_regions.plot(np.ones(channel_depths_track.shape), channel_depths_track, '*r')
116 |
117 | # Make plot that shows coronal slice that trajectory passes through with location of channels
118 | # shown in red
119 | ax_slice = fig.axes[iK * 2 + 1]
120 | brain_atlas.plot_tilted_slice(xyz_channels, axis=1, ax=ax_slice)
121 | ax_slice.plot(xyz_channels[:, 0] * 1e6, xyz_channels[:, 2] * 1e6, 'r*')
122 | ax_slice.title.set_text(insertion['id'] + '\n' + str(key))
123 |
124 | # Make sure the plot displays
125 | plt.show()
126 |
127 |
128 | def overlay_spikes(self, spikes, clusters, channels, rgb=None, label='default', symbol='x', size=8):
129 | first_sample = self.ctrl.model.t0 / self.ctrl.model.si
130 | last_sample = first_sample + self.ctrl.model.ns
131 |
132 | ifirst, ilast = np.searchsorted(spikes['samples'], [first_sample, last_sample])
133 | tspi = spikes['samples'][ifirst:ilast].astype(np.float64) * self.ctrl.model.si
134 | xspi = channels['rawInd'][clusters['channels'][spikes['clusters'][ifirst:ilast]]]
135 |
136 | n_side_by_side = int(self.ctrl.model.ntr / 384)
137 | print(n_side_by_side)
138 | if n_side_by_side > 1:
139 | addx = (np.zeros([1, xspi.size]) + np.array([self.ctrl.model.ntr * r for r in range(n_side_by_side)])[:, np.newaxis]).flatten()
140 | xspi = np.tile(xspi, n_side_by_side) + addx
141 | yspi = np.tile(tspi, n_side_by_side)
142 | if self.ctrl.model.taxis == 1:
143 | self.ctrl.add_scatter(xspi, tspi, label=label)
144 | else:
145 | self.ctrl.add_scatter(tspi, xspi, label=label)
146 | sc = self.layers[label]['layer']
147 | sc.setSize(size)
148 | sc.setSymbol(symbol)
149 | # sc.setPen(pg.mkPen((0, 255, 0, 155), width=1))
150 | if rgb is None:
151 | rgbs = [list((rgb * 255).astype(np.uint8)) for rgb in color_cycle(spikes['clusters'][ifirst:ilast])]
152 | sc.setBrush([pg.mkBrush(rgb) for rgb in rgbs])
153 | sc.setPen([pg.mkPen(rgb) for rgb in rgbs])
154 | else:
155 | sc.setBrush(pg.mkBrush(rgb))
156 | sc.setPen(pg.mkPen(rgb))
157 | return sc, tspi, xspi
158 |
159 | # sc.setData(x=xspi, y=tspi, brush=pg.mkBrush((255, 0, 0)))
160 | def callback(sc, points, evt):
161 | NTR = 12
162 | NS = 128
163 | qxy = self.imageItem_seismic.mapFromScene(evt.scenePos())
164 | # tr, _, _, s, _ = self.ctrl.cursor2timetraceamp(qxy)
165 | itr, itime = self.ctrl.cursor2ind(qxy)
166 | print(itr, itime)
167 | h = self.ctrl.model.header
168 | x = self.ctrl.model.data
169 | trsel = np.arange(np.max([0, itr - NTR]), np.min([self.ctrl.model.ntr, itr + NTR]))
170 | ordre = np.lexsort((h['x'][trsel], h['y'][trsel]))
171 | trsel = trsel[ordre]
172 | w = x[trsel, int(itime) - NS:int(itime) + NS]
173 | wiggle(-w.T * 10000, fs=1 / self.ctrl.model.si, t0=(itime - NS) * self.ctrl.model.si)
174 | # hw = {k:h[k][trsel] for k in h}
175 |
176 | # s.sigMouseClicked.disconnect()
177 | sc.sigClicked.connect(callback)
178 | # self.ctrl.remove_all_layers()
179 |
--------------------------------------------------------------------------------
/viewspikes/raster.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | MainWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 999
10 | 656
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 |
18 | -
19 |
20 |
-
21 |
22 |
23 |
24 |
25 | -
26 |
27 |
28 | Qt::Horizontal
29 |
30 |
31 |
32 |
33 |
34 |
51 |
52 |
53 |
54 | open...
55 |
56 |
57 |
58 |
59 |
60 | PlotWidget
61 | QGraphicsView
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------