├── track2p
├── gui
│ ├── __init__.py
│ ├── window_manager.py
│ ├── toolbar.py
│ ├── main_wd.py
│ ├── import_wd.py
│ ├── custom_wd.py
│ ├── statusbar.py
│ ├── roi_plot.py
│ ├── fluo_plot.py
│ ├── central_widget.py
│ ├── data_management.py
│ ├── cell_plot.py
│ ├── t2p_wd.py
│ └── raster_wd.py
├── io
│ ├── __init__.py
│ ├── utils.py
│ ├── loaders.py
│ ├── savers.py
│ └── s2p_loaders.py
├── ops
│ ├── __init__.py
│ └── default.py
├── match
│ ├── __init__.py
│ ├── loop.py
│ └── utils.py
├── plot
│ ├── __init__.py
│ ├── progress.py
│ ├── utils.py
│ └── output.py
├── register
│ ├── __init__.py
│ ├── elastix.py
│ ├── utils.py
│ └── loop.py
├── __init__.py
├── resources
│ └── logo.png
├── __main__.py
├── eval
│ ├── io.py
│ └── plot.py
└── t2p.py
├── docs
├── media
│ └── plots
│ │ ├── ex_all_vizualizations.png
│ │ └── ex_all_vizualizations_backup.png
├── README.md
└── troubleshooting.md
├── setup.py
├── .gitignore
├── notebooks
├── run_t2p.ipynb
├── utils
│ ├── npy_to_s2p.ipynb
│ └── s2p_to_npy.ipynb
├── eval
│ └── main_eval_fig.ipynb
└── demo_t2p_ouputs.ipynb
└── README.md
/track2p/gui/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/io/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/ops/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/match/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/plot/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/register/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/track2p/__init__.py:
--------------------------------------------------------------------------------
1 | # make all submodules available
2 | from . import *
3 |
--------------------------------------------------------------------------------
/track2p/resources/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juremaj/track2p/HEAD/track2p/resources/logo.png
--------------------------------------------------------------------------------
/docs/media/plots/ex_all_vizualizations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juremaj/track2p/HEAD/docs/media/plots/ex_all_vizualizations.png
--------------------------------------------------------------------------------
/docs/media/plots/ex_all_vizualizations_backup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juremaj/track2p/HEAD/docs/media/plots/ex_all_vizualizations_backup.png
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | As of August 2024 all docs have been moved to the GitHub pages [repo](https://github.com/track2p/track2p.github.io/tree/main/Track2p) and can be accessed on the associated [website](https://track2p.github.io/).
--------------------------------------------------------------------------------
/track2p/io/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # make a directory based on path if it doesn't exist yet
4 |
5 | def make_dir(path):
6 | if not os.path.exists(path):
7 | os.makedirs(path)
8 | print('Created directory: ' + path)
9 | else:
10 | print('Directory already exists: ' + path)
--------------------------------------------------------------------------------
/track2p/__main__.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QApplication
2 | import time
3 | from track2p.gui.main_wd import MainWindow
4 | from PyQt5.QtGui import QIcon
5 | import os
6 |
7 | # the same script as track2p/gui/run_gui.py
8 |
9 | if __name__ == '__main__':
10 | start_time = time.time()
11 |
12 | app = QApplication([])
13 |
14 |
15 | # Utiliser un chemin relatif pour l'icône
16 | icon_path = os.path.join(os.path.dirname(__file__), 'resources', 'logo.png')
17 | print(icon_path)
18 |
19 | if not os.path.exists(icon_path):
20 | print(f"Icon file not found: {icon_path}")
21 | else:
22 | app.setWindowIcon(QIcon(icon_path))
23 |
24 | mainWindow = MainWindow()
25 | mainWindow.setWindowTitle("track2p")
26 |
27 | end_time = time.time()
28 | print(f"Application took {end_time - start_time} seconds to open.")
29 | app.exec_()
--------------------------------------------------------------------------------
/track2p/gui/window_manager.py:
--------------------------------------------------------------------------------
1 | from track2p.gui.t2p_wd import Track2pWindow
2 | from track2p.gui.import_wd import ImportWindow
3 | from track2p.gui.raster_wd import RasterWindow
4 | #from track2p.gui.s2p_wd import Suite2pWindow
5 |
6 | class WindowManager:
7 | def __init__(self, main_window):
8 | self.main_window = main_window
9 | self.t2p_window = Track2pWindow(self.main_window)
10 | self.import_window = ImportWindow(self.main_window)
11 | self.raster_window=RasterWindow(self.main_window)
12 | #self.suite2p_window=Suite2pWindow(self.main_window)
13 |
14 | def open_track2p_wd(self):
15 | self.t2p_window.show()
16 |
17 | def open_import_wd(self):
18 | self.import_window.show()
19 |
20 | def open_raster_wd(self):
21 | self.raster_window.show()
22 |
23 |
24 |
25 | #def open_suite2p_wd(self):
26 | # self.suite2p_window.show()
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | with open('README.md', 'r') as f:
4 | long_description = f.read()
5 |
6 | setup(
7 | name='track2p',
8 | version='0.6.2',
9 | packages=find_packages(),
10 | install_requires=[
11 | 'numpy==2.0.2',
12 | 'matplotlib==3.9.4',
13 | 'scikit-image==0.24.0',
14 | 'itk==5.4.3',
15 | 'PyQt5==5.15.11',
16 | 'qtpy==2.4.3',
17 | 'tqdm==4.67.1',
18 | 'scikit-learn==1.6.1',
19 | 'openTSNE==1.0.2',
20 | 'pandas==2.2.3',
21 | 'itk-elastix==0.23.0'
22 | # 'numpy==1.23.5',
23 | # 'matplotlib==3.5.3',
24 | # 'scikit-image==0.20.0',
25 | # 'itk==5.4rc2',
26 | # 'PyQt5==5.15.10',
27 | # 'qtpy==2.4.1',
28 | # 'tqdm==4.66.2',
29 | # 'scikit-learn==1.4.0',
30 | # 'openTSNE==1.0.1',
31 | # 'pandas==1.5.3',
32 | ],
33 | long_description=long_description,
34 | long_description_content_type='text/markdown',
35 | include_package_data=True,
36 | package_data={
37 | '': ['resources/logo.png'],
38 | },
39 | )
--------------------------------------------------------------------------------
/track2p/plot/progress.py:
--------------------------------------------------------------------------------
1 | # contains code for generating plots while the pipeline is running
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 | from track2p.plot.utils import match_hist_all
7 |
8 | def plot_all_planes(all_ds_avg_ch, track_ops, sat_perc=99, ch='funcional'):
9 | nplanes = track_ops.nplanes
10 | fig, axs = plt.subplots(nplanes, len(track_ops.all_ds_path), figsize=(3 * len(track_ops.all_ds_path), 3 * nplanes), dpi=300)
11 | # add dummy dimension to axs if only one plane
12 | if nplanes==1:
13 | axs = np.expand_dims(axs, axis=0)
14 | all_ds_avg_ch_matched = match_hist_all(all_ds_avg_ch)
15 |
16 |
17 | for i in range(nplanes):
18 | for j in range(len(track_ops.all_ds_path)):
19 | img = all_ds_avg_ch_matched[j][i]
20 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, sat_perc))
21 | axs[i, j].set_title('Plane ' + str(i) + ' in dataset ' + str(j))
22 | axs[i, j].axis('off')
23 |
24 | fig.savefig(track_ops.save_path_fig + f'mean_fov_{ch}.png', bbox_inches='tight', dpi=200)
25 |
26 | plt.close(fig)
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/track2p/register/elastix.py:
--------------------------------------------------------------------------------
1 | import itk
2 | import numpy as np
3 |
4 | def reg_img_elastix(ref_img, mov_img, track_ops):
5 | # convert to itk images
6 | ref_img_itk = itk.GetImageFromArray(ref_img)
7 | mov_img_itk = itk.GetImageFromArray(mov_img)
8 |
9 | # import default parameter map
10 | parameter_object = itk.ParameterObject.New()
11 | parameter_map = parameter_object.GetDefaultParameterMap(track_ops.transform_type)
12 | parameter_object.AddParameterMap(parameter_map)
13 |
14 | # call registration function
15 | mov_img_reg_itk, reg_params = itk.elastix_registration_method(
16 | ref_img_itk,
17 | mov_img_itk,
18 | parameter_object = parameter_object
19 | )
20 |
21 | # convert back to numpy array
22 | mov_img_reg = itk.GetArrayFromImage(mov_img_reg_itk)
23 | reg_params.SetParameter("FinalBSplineInterpolationOrder", "0")
24 |
25 | return mov_img_reg, reg_params
26 |
27 | def itk_reg_roi(roi, reg_params):
28 | roi_itk = itk.GetImageFromArray(roi.astype(np.uint8))
29 | roi_itk_trans = itk.transformix_filter(roi_itk, reg_params)
30 | roi_trans = itk.GetArrayFromImage(roi_itk_trans)
31 | return roi_trans
32 |
33 | def itk_reg_all_roi(all_roi, reg_params):
34 | all_roi_array_reg = np.zeros_like(all_roi)
35 | for i in range(all_roi.shape[2]):
36 | roi_array = all_roi[:,:,i]
37 | roi_array_reg = itk_reg_roi(roi_array, reg_params)
38 | all_roi_array_reg[:,:,i] = roi_array_reg
39 | return all_roi_array_reg
--------------------------------------------------------------------------------
/track2p/gui/toolbar.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QToolBar,QMenu,QAction,QToolButton
2 |
3 | class Toolbar(QToolBar):
4 | def __init__(self, main_window):
5 | super().__init__()
6 |
7 | self.main_window = main_window
8 | self.init_tool_bar()
9 |
10 | def init_tool_bar(self):
11 | data_menu = QMenu(self)
12 | run_menu = QMenu(self)
13 | visualization_menu = QMenu(self)
14 |
15 | import_action = QAction("Load processed data (⌘L or Ctrl+L)", self)
16 | import_action.setShortcut("Ctrl+L")
17 | import_action.triggered.connect(self.main_window.window_manager.open_import_wd)
18 | data_menu.addAction(import_action)
19 |
20 | track2p_action = QAction("Run track2p algorithm (⌘R or Ctrl+R)", self)
21 | track2p_action.setShortcut("Ctrl+R")
22 | track2p_action.triggered.connect(self.main_window.window_manager.open_track2p_wd)
23 | run_menu.addAction(track2p_action)
24 |
25 | raster_action = QAction("Generate raster plot (⌘G or Ctrl+G)", self)
26 | raster_action.setShortcut("Ctrl+G")
27 | raster_action.triggered.connect(self.main_window.window_manager.open_raster_wd)
28 | visualization_menu.addAction(raster_action)
29 |
30 | self.add_tool_menu("File", data_menu)
31 | self.add_tool_menu("Run", run_menu)
32 | self.add_tool_menu("Visualization", visualization_menu)
33 |
34 | def add_tool_menu(self, name, menu):
35 | button = QToolButton(self)
36 | button.setText(name)
37 | button.setMenu(menu)
38 | button.setPopupMode(QToolButton.InstantPopup)
39 | button.setStyleSheet("font-size: 13px")
40 | self.addWidget(button)
41 |
42 |
--------------------------------------------------------------------------------
/track2p/eval/io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def load_alldays_f1_values(base_path, animals, conditions):
5 |
6 | f1_values = {animal: [] for animal in animals}
7 |
8 | for animal in animals:
9 | for condition in conditions:
10 | metrics = np.load(os.path.join(base_path, animal, condition,'metrics_t2p_all_days.npy'), allow_pickle=True)
11 | f1_value = metrics[np.where(metrics[:, 0] == 'F1')[0][0], 1]
12 | f1_values[animal].append(f1_value)
13 |
14 | return f1_values
15 |
16 |
17 | def load_alldays_ct_values(base_path, animals, conditions, ct_type='CT'):
18 |
19 | ct_values = {animal: [] for animal in animals}
20 | acc_values = {animal: [] for animal in animals}
21 |
22 | for animal in animals:
23 | for condition in conditions:
24 | metrics = np.load(os.path.join(base_path, animal, condition,f'result_{ct_type}.npy'), allow_pickle=True)
25 |
26 | print(metrics)
27 | print('-------------------')
28 |
29 | ct = metrics[0][1:]
30 | acc = metrics[1][1:]
31 |
32 | ct_values[animal].append(np.array(ct))
33 | acc_values[animal].append(np.array(acc))
34 |
35 | return ct_values, acc_values
36 |
37 | # def load_alldays_ct_gt_values(base_path, animals, conditions):
38 |
39 |
40 | def load_pairwise_f1_values(base_path, animals, condition):
41 |
42 | f1_values = {animal: [] for animal in animals}
43 |
44 | for animal in animals:
45 | if condition == 'pw_reg':
46 | file_path = os.path.join(base_path, animal, 'metrics_table_pw_registration.npy')
47 | else:
48 | file_path = os.path.join(base_path, animal, condition, 'metrics_table_pairs.npy')
49 | metrics = np.load(file_path, allow_pickle=True)
50 |
51 | f1_scores = metrics[7, 1:].astype(float)
52 | f1_values[animal] = f1_scores
53 |
54 | return f1_values
--------------------------------------------------------------------------------
/track2p/io/loaders.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | def load_track_ops(track_ops_path):
5 | track_ops = np.load(track_ops_path + 'track_ops_postreg.npy', allow_pickle=True).item()
6 | return track_ops
7 |
8 | def load_stat_ds_plane(track_ops_path, track_ops, plane_idx=0):
9 | stat = np.load(track_ops_path + f'/suite2p/plane{plane_idx}/stat.npy', allow_pickle=True)
10 | iscell = np.load(track_ops_path + f'/suite2p/plane{plane_idx}/iscell.npy', allow_pickle=True)
11 | # filter based on track_ops.iscell_thr
12 | len_stat_allcell = len(stat)
13 |
14 | if track_ops.iscell_thr==None:
15 | stat = stat[iscell[:,0] ==1]
16 | else:
17 | stat= stat[iscell[:,1]>track_ops.iscell_thr]
18 | len_stat_iscell = len(stat)
19 |
20 | print(f'Loading ROIs for plane{plane_idx} in dataset {track_ops_path.split(os.path.sep)[-1]}')
21 |
22 | print(f'Chose {len_stat_iscell}/{len_stat_allcell} ROIs, based on s2p iscell threshold {track_ops.iscell_thr} (see track_ops.iscell_thr)')
23 | # make a stat_summary dictionary
24 | stat_summary = {
25 | 'len_stat_allcell': len_stat_allcell,
26 | 'len_stat_iscell': len_stat_iscell,
27 | 'iscell_thr': track_ops.iscell_thr
28 | }
29 | return stat, stat_summary
30 |
31 | def get_all_roi_array_from_stat(stat, track_ops):
32 | n_xpix = track_ops.all_ds_avg_ch1[0][0].shape[0]
33 | n_ypix = track_ops.all_ds_avg_ch1[0][0].shape[1]
34 |
35 | all_roi_array = np.zeros((n_xpix, n_ypix, len(stat)), bool)
36 |
37 | for i in range(len(stat)):
38 |
39 | roi_xpix = stat[i]['xpix']
40 | roi_ypix = stat[i]['ypix']
41 |
42 | # Convert the ROI coordinates into a grid of values
43 | roi_grid = np.zeros((n_xpix, n_ypix), dtype=np.float32)
44 | roi_grid[np.array(roi_ypix), np.array(roi_xpix)] = 1
45 |
46 | all_roi_array[:,:,i] = roi_grid
47 |
48 | return all_roi_array
49 |
50 |
--------------------------------------------------------------------------------
/track2p/plot/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.exposure import match_histograms
3 |
4 |
5 | def match_hist_all(all_ds_avg_ch):
6 | ref = all_ds_avg_ch[0][0]
7 | all_ds_avg_ch_matched = []
8 | for ds_avg_ch in all_ds_avg_ch:
9 | ds_avg_ch_matched = []
10 | for i in range(len(ds_avg_ch)):
11 | ds_avg_ch_matched.append(match_histograms(ds_avg_ch[i], ref))
12 | all_ds_avg_ch_matched.append(ds_avg_ch_matched)
13 |
14 | return all_ds_avg_ch_matched
15 |
16 | def make_rgb_img(img1, img2):
17 | img_rgb = np.zeros((img1.shape[0], img1.shape[1], 3))
18 | # normalise images to 0-1
19 | img1_norm = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1))
20 | img2_norm = (img2 - np.min(img2)) / (np.max(img2) - np.min(img2))
21 | img_rgb[:, :, 0] = img1_norm
22 | img_rgb[:, :, 1] = img2_norm
23 |
24 | return img_rgb
25 |
26 | def saturate_perc(img_rgb, sat_perc=99):
27 | img_rgb = np.clip(img_rgb, 0, np.percentile(img_rgb, sat_perc))
28 | img_rgb = (img_rgb / np.max(img_rgb) * 255).astype(np.uint8)
29 | return img_rgb
30 |
31 | def get_all_wind_mean_img(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, nrn_id, plane_idx=0, win_size=64):
32 |
33 | all_wind_mean_img = []
34 | for i in range(len(all_ds_mean_img)):
35 | centroid_ids = all_pl_match_mat[plane_idx][nrn_id,:]
36 | cent = all_ds_centroids[i][plane_idx][centroid_ids[i]]
37 | mean_img = all_ds_mean_img[i][plane_idx]
38 |
39 | # pad mean image with mean and shift the centroid appropriately
40 | # mean_img = np.pad(mean_img, ((win_size,win_size),(win_size,win_size)))
41 | mean_img = np.pad(mean_img, ((win_size,win_size),(win_size,win_size)), mode='constant', constant_values=0)
42 | cent = cent + win_size
43 |
44 | wind_mean_img = mean_img[int(cent[0]-win_size/2):int(cent[0]+win_size/2), int(cent[1]-win_size/2):int(cent[1]+win_size/2)]
45 |
46 |
47 | all_wind_mean_img.append(wind_mean_img)
48 |
49 | return all_wind_mean_img
--------------------------------------------------------------------------------
/track2p/gui/main_wd.py:
--------------------------------------------------------------------------------
1 |
2 | from track2p.gui.window_manager import WindowManager
3 | from track2p.gui.toolbar import Toolbar
4 | from track2p.gui.statusbar import StatusBar
5 | from track2p.gui.data_management import DataManagement
6 | from track2p.gui.central_widget import CentralWidget
7 | from PyQt5.QtWidgets import QApplication,QMainWindow
8 | from PyQt5.QtGui import QIcon
9 |
10 | class MainWindow(QMainWindow):
11 |
12 | def __init__(self):
13 | super(MainWindow, self).__init__()
14 | self.main_window = self
15 | self.window_manager = WindowManager(self)
16 | self.data_management = DataManagement(self)
17 | self.central_widget = CentralWidget(self)
18 | self.toolbar = Toolbar(self)
19 | self.status_bar = StatusBar(self)
20 |
21 |
22 |
23 | self.initUI()
24 |
25 | def initUI(self):
26 |
27 |
28 | self.setStyleSheet(
29 | "QTabWidget::pane { border: 1px solid #666; }"
30 | "QTabWidget::tab-bar { alignment: center; }"
31 | "QTabBar::tab { background-color: #666; color: white; }"
32 | "QTabBar::tab:selected { background-color: #222; color: white; }"
33 | "QSplitter::handle { background: #888; }"
34 | "QFrame { background-color: black; color: black; border: 1px solid black;}"
35 | "QLabel { color: black; background-color: none; border: none; font-size: 13px}"
36 | "QPushButton { background-color: #666; color: white; border: 1px solid #888; }"
37 | "QPushButton:hover { background-color: #888; color: white; }"
38 | "QPushButton:pressed { background-color: #333; color: white; }"
39 | "QToolButton:pressed { background-color: #888; }"
40 | "QComboBox { background-color: black; color: white; }"
41 | "QComboBox QAbstractItemView { background-color: #666; color: white; }"
42 | )
43 |
44 | self.setWindowTitle("track2p GUI")
45 |
46 | self.setCentralWidget(self.central_widget)
47 | self.addToolBar(self.toolbar)
48 | self.setStatusBar(self.status_bar)
49 | QApplication.setStyle('Cleanlooks')
50 | self.showMaximized()
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/track2p/register/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def get_all_ds_img_for_reg(all_ds_avg_ch1, all_ds_avg_ch2, track_ops): # chooses which channel to use for registration and returns all_ref_img and all_mov_img (shifted by one day to always register to previous day)
4 | if track_ops.reg_chan==0:
5 | all_ds_avg = all_ds_avg_ch1
6 | elif track_ops.reg_chan==1:
7 | all_ds_avg = all_ds_avg_ch2
8 | print('WARNING: using anatomical channel for registration (this is not always available)')
9 |
10 | all_ds_ref_img = []
11 | all_ds_mov_img = []
12 |
13 | for i in range(len(track_ops.all_ds_path)-1):
14 | ds_ref_img = []
15 | ds_mov_img = []
16 |
17 | for j in range(track_ops.nplanes):
18 | ds_ref_img.append(all_ds_avg[i][j])
19 | ds_mov_img.append(all_ds_avg[i+1][j])
20 |
21 | all_ds_ref_img.append(ds_ref_img)
22 | all_ds_mov_img.append(ds_mov_img)
23 |
24 | track_ops.all_ds_ref_img = all_ds_ref_img
25 | track_ops.all_ds_mov_img = all_ds_mov_img
26 |
27 | return all_ds_ref_img, all_ds_mov_img
28 |
29 |
30 | def get_ref_reg_inters(all_roi_array_ref, all_roi_array_nonref):
31 | # get the projection of all rois
32 | all_roi_array_ref_proj = np.sum(all_roi_array_ref, axis=2) > 0
33 | all_roi_array_nonref_proj = np.sum(all_roi_array_nonref, axis=2) > 0
34 |
35 | # now get the intersection of the reg and ref
36 | all_roi_array_inters = np.logical_and(all_roi_array_ref_proj, all_roi_array_nonref_proj)
37 |
38 | # make rgb image of intersection
39 | ref_reg_inters = np.ones((all_roi_array_inters.shape[0], all_roi_array_inters.shape[1], 3))
40 | ref_reg_inters[:,:,0]
41 | ref_reg_inters[:,:,1] -= all_roi_array_inters/6 # orange tint
42 | ref_reg_inters[:,:,2] -= all_roi_array_inters
43 |
44 | return ref_reg_inters
45 |
46 | def get_all_ref_nonref_inters(all_ds_all_roi_array_ref, all_ds_all_roi_array_nonref, track_ops):
47 | all_ds_all_ref_nonref_inters = []
48 | for i in range(len(track_ops.all_ds_path)-1):
49 | ds_all_ref_nonref_inters = []
50 | for j in range(track_ops.nplanes):
51 | all_roi_array_ref = all_ds_all_roi_array_ref[i][j]
52 | all_roi_array_nonref = all_ds_all_roi_array_nonref[i][j]
53 | ref_nonref_inters = get_ref_reg_inters(all_roi_array_ref, all_roi_array_nonref)
54 | ds_all_ref_nonref_inters.append(ref_nonref_inters)
55 | all_ds_all_ref_nonref_inters.append(ds_all_ref_nonref_inters)
56 | return all_ds_all_ref_nonref_inters
--------------------------------------------------------------------------------
/track2p/eval/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | def plot_alldays_f1(animals, conditions, f1_values, symbols, colors, xshift=0, animals_names=None):
5 |
6 | plt.figure(figsize=(4, 2), dpi=300)
7 | for (i, animal) in enumerate(animals):
8 | print(animal)
9 | print(f1_values[animal])
10 | y_data = f1_values[animal]
11 | x_data = np.arange(len(conditions)) + i * xshift
12 |
13 | label = animal if animals_names is None else animals_names[i]
14 | plt.plot(x_data, y_data, symbols[animal], label=label, color=colors[animal])
15 | # add vertical lines to each point
16 | for (i, x) in enumerate(x_data):
17 | plt.vlines(x, 0, y_data[i], color=colors[animal], linewidth=2)
18 | # label x ticks
19 | plt.xticks(x_data-xshift, conditions, rotation=45)
20 |
21 | # add a dashed grey line at y=0
22 | plt.axhline(0, color='grey', linestyle='--', alpha=0.5)
23 |
24 |
25 | plt.ylabel('F1 Score')
26 | plt.yticks([0, 0.5, 1])
27 | ax = plt.gca()
28 | ax.spines['top'].set_visible(False)
29 | ax.spines['right'].set_visible(False)
30 | plt.legend(loc='upper right', fontsize=8)
31 | plt.show()
32 |
33 |
34 | def plot_pairwise_f1(animals, condition, pairwise_f1_values, symbols, colors, show_d0=True, show_legend=False):
35 | plt.figure(figsize=(2, 2), dpi=300)
36 | plt.title(f'{condition}')
37 | plt.xlabel('Days')
38 | plt.ylabel('Prop. correct')
39 |
40 | for (i, animal) in enumerate(animals):
41 |
42 | if show_d0:
43 | y_data = np.concatenate([[1], pairwise_f1_values[animal]])
44 | else:
45 | y_data = pairwise_f1_values[animal]
46 |
47 | x_data = np.arange(1, len(y_data) + 1)
48 |
49 | zorder = 10 if animal == 'jm039' else i + 1
50 |
51 | plt.plot(x_data, y_data, label=animal, marker=symbols[animal], color=colors[animal], markersize=4, zorder=zorder)
52 |
53 | # add a dashed grey line at y=0
54 | plt.axhline(0, color='grey', linestyle='--', alpha=0.5)
55 |
56 | plt.ylim(-0.05, 1.05)
57 | plt.yticks([0, 0.5, 1])
58 | ax = plt.gca()
59 | ax.spines['top'].set_visible(False)
60 | ax.spines['right'].set_visible(False)
61 |
62 | if show_d0:
63 | ax.set_xticks(x_data)
64 | # set 'P8' as first and 'P14' as last label
65 | ax.set_xticklabels(['P8', '', '', '', '', '', 'P14'])
66 | else:
67 | ax.set_xticks(x_data)
68 | ax.set_xticklabels(['P9', '', '', '', '', 'P14'])
69 |
70 |
71 |
72 | # plt.legend()
73 | # make legend smaller
74 | if show_legend:
75 | plt.legend(loc='upper right', fontsize=8)
76 | plt.show()
--------------------------------------------------------------------------------
/track2p/gui/import_wd.py:
--------------------------------------------------------------------------------
1 |
2 | from PyQt5.QtWidgets import QWidget, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout,QComboBox
3 |
4 | class ImportWindow(QWidget):
5 |
6 | def __init__(self, main_wd):
7 | super(ImportWindow,self).__init__()
8 |
9 | self.main_window = main_wd
10 | self.path_to_t2p = None
11 | self.plane=None
12 |
13 | layout = QFormLayout()
14 |
15 | import_label=QLabel("Import the directory containing the track2p folder:")
16 | self.import_button = QPushButton("Import", self)
17 | self.import_button.clicked.connect(self.save_t2p_path)
18 | layout.addRow(import_label,self.import_button)
19 | path_label=QLabel("Here is the path of the imported directory:")
20 | self.path = QLabel()
21 | layout.addRow(path_label,self.path)
22 | plane_label= QLabel("Choose the plane to analyze:")
23 | self.textbox = QLineEdit(self)
24 | self.textbox.setFixedWidth(50)
25 | self.textbox.setText('0')
26 | trace_label=QLabel("Choose trace type:")
27 | self.trace_choice=QComboBox()
28 | self.trace_choice.addItem("F")
29 | self.trace_choice.addItem("dF/F0")
30 | self.trace_choice.addItem("spks")
31 | channel_label=QLabel("Choose the channel to show in the mean images:")
32 | self.channel_choice=QComboBox()
33 | self.channel_choice.addItem("0")
34 | self.channel_choice.addItem("1")
35 | self.channel_choice.addItem("Vcorr")
36 | self.channel_choice.addItem("max_proj")
37 | layout.addRow(plane_label,self.textbox)
38 | layout.addRow(channel_label,self.channel_choice)
39 | layout.addRow(trace_label,self.trace_choice)
40 | label_run= QLabel("Run the analysis:")
41 | self.run_button = QPushButton("Run", self)
42 | self.run_button.clicked.connect(self.run)
43 | layout.addRow(label_run,self.run_button)
44 |
45 | self.setLayout(layout)
46 |
47 |
48 | def save_t2p_path(self):
49 | path_to_t2p= QFileDialog.getExistingDirectory(self, "Select Directory")
50 | if path_to_t2p:
51 | self.path_to_t2p=path_to_t2p
52 | self.path.setText(f'{self.path_to_t2p}')
53 |
54 | def run(self):
55 | self.plane = int(self.textbox.text())
56 | self.trace_type=self.trace_choice.currentText()
57 | self.channel= self.channel_choice.currentText()
58 | self.main_window.central_widget.data_management.import_files(self.path_to_t2p, plane=self.plane, trace_type=self.trace_type, channel=self.channel_choice.currentText())
59 |
--------------------------------------------------------------------------------
/track2p/io/savers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from track2p.io.utils import make_dir
5 |
6 | def save_track_ops(track_ops):
7 | # remove attributes taking a lot of memory (e.g. rois etc.)
8 | del track_ops.all_ds_all_roi_array_mov
9 | del track_ops.all_ds_all_roi_array_ref
10 | del track_ops.all_ds_all_roi_array_reg
11 |
12 | track_ops_dict = track_ops.to_dict() # convert to dictionary for compatibility
13 |
14 | np.save(os.path.join(track_ops.save_path, 'track_ops.npy'), track_ops_dict, allow_pickle=True)
15 | print('Saved track_ops.npy in ' + track_ops.save_path)
16 |
17 | def save_all_pl_match_mat(all_pl_match_mat, track_ops):
18 | for (i, all_pl_match_mat) in enumerate(all_pl_match_mat):
19 | np.save(os.path.join(track_ops.save_path, f'plane{i}_match_mat.npy'), all_pl_match_mat, allow_pickle=True)
20 |
21 |
22 | def npy_to_s2p(track_ops):
23 |
24 | for plane in range(track_ops.nplanes):
25 | print(f'Processing plane {plane + 1}/{track_ops.nplanes}...')
26 |
27 | for ds_path in track_ops.all_ds_path:
28 | # 1) define numpy and suite2p data paths (+ make sure they exist)
29 | npy_path = os.path.join(ds_path, 'data_npy', f'plane{plane}')
30 | s2p_path = npy_path.replace('data_npy', 'suite2p')
31 |
32 | if not os.path.exists(s2p_path):
33 | os.makedirs(s2p_path)
34 | else:
35 | print(f"Directory {s2p_path} already exists, skipping... (Delete or rename it if you want to overwrite)")
36 | continue
37 |
38 | # 2) Load numpy data
39 | F = np.load(os.path.join(npy_path, 'F.npy'))
40 | fov = np.load(os.path.join(npy_path, 'fov.npy'))
41 | rois = np.load(os.path.join(npy_path, 'rois.npy'))
42 |
43 | # 3) Convert and save data in suite2p format
44 |
45 | np.save(os.path.join(s2p_path, 'F.npy'), F)
46 |
47 | ops = {'meanImg': fov}
48 | ops['nchannels'] = 1 # Assuming single channel
49 | ops['fs'] = 30 # Assuming a sampling frequency of 30 Hz
50 | ops['nframes'] = F.shape[1]
51 | np.save(os.path.join(s2p_path, 'ops.npy'), ops)
52 |
53 | # stat is a list of dictionaries, each with keys 'xpix', 'ypix'
54 | stat = []
55 | for i in range(rois.shape[0]):
56 | ypix, xpix = np.where(rois[i] > 0)
57 | med = [int(np.median(ypix).item()), int(np.median(xpix).item())]
58 | stat.append({'xpix': xpix, 'ypix': ypix, 'med': med})
59 |
60 | np.save(os.path.join(s2p_path, 'stat.npy'), stat)
61 |
62 | # iscell is two columns of 1 the columns are length n_cells
63 | n_cells = len(stat)
64 | iscell = np.ones((n_cells, 2), dtype=int)
65 | np.save(os.path.join(s2p_path, 'iscell.npy'), iscell)
--------------------------------------------------------------------------------
/track2p/gui/custom_wd.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QHBoxLayout, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QListWidget, QMessageBox,QListWidgetItem, QInputDialog,QCheckBox,QSizePolicy,QComboBox,QDialog
3 | from PyQt5.QtCore import Qt
4 | from track2p.t2p import run_t2p
5 | from track2p.ops.default import DefaultTrackOps
6 |
7 | class CustomDialog(QDialog):
8 | def __init__(self, main_window, save_directory, channel):
9 | super(CustomDialog,self).__init__()
10 | self.main_window = main_window
11 | self.save_directory = save_directory
12 | self.channel = channel
13 | self.cancel_clicked = False
14 |
15 | # Layout
16 | layout = QVBoxLayout(self)
17 |
18 | # Plane input
19 | self.plane_label = QLabel("Enter your plane:", self)
20 | layout.addWidget(self.plane_label)
21 | self.plane_input = QLineEdit(self)
22 | layout.addWidget(self.plane_input)
23 |
24 | # Trace type input
25 | self.trace_label = QLabel("Choose trace type:", self)
26 | layout.addWidget(self.trace_label)
27 | self.trace_combo = QComboBox(self)
28 | self.trace_combo.addItems(["F", "spks", "dF/F0"])
29 | layout.addWidget(self.trace_combo)
30 |
31 | self.channel_label = QLabel("Choose the channel to show in the mean images:", self)
32 | layout.addWidget(self.channel_label)
33 | self.channel_combo = QComboBox(self)
34 | self.channel_combo.addItems(["0", "1", "Vcorr","max_proj"])
35 | layout.addWidget(self.channel_combo)
36 |
37 | # OK and Cancel buttons
38 | self.ok_button = QPushButton("OK", self)
39 | self.ok_button.clicked.connect(self.import_data)
40 | layout.addWidget(self.ok_button)
41 |
42 | self.cancel_button = QPushButton("Cancel", self)
43 | self.cancel_button.clicked.connect(self.on_cancel_clicked)
44 | layout.addWidget(self.cancel_button)
45 |
46 | def get_inputs(self):
47 | return self.plane_input.text(), self.trace_combo.currentText(), self.channel_combo.currentText()
48 |
49 | def on_cancel_clicked(self):
50 | self.cancel_clicked = True
51 | self.import_data()
52 |
53 | def import_data(self):
54 | if self.cancel_clicked:
55 | self.close()
56 | pass
57 | else:
58 | print("Opening GUI...")
59 | plane_text, trace_type , channel = self.get_inputs()
60 | self.plane = int(plane_text)
61 | self.trace_type = trace_type
62 | self.channel= channel
63 | print(f"Converted plane: {self.plane}, Trace type: {self.trace_type}") # Debugging print
64 | self.main_window.central_widget.data_management.import_files(t2p_folder_path = self.save_directory, plane=self.plane, trace_type=self.trace_type, channel= self.channel)
65 | self.close()
66 |
67 |
--------------------------------------------------------------------------------
/track2p/ops/default.py:
--------------------------------------------------------------------------------
1 | # make dummy track_ops object (would be input from command line or gui)
2 | import os
3 | from track2p.io.utils import make_dir
4 |
5 | class DefaultTrackOps:
6 | def __init__(self):
7 | # input list of dataset paths (each contains a 'suite2p' folder)
8 | self.all_ds_path = [
9 | 'data/ac/ac444118/2022-09-14_a',
10 | 'data/ac/ac444118/2022-09-15_a',
11 | 'data/ac/ac444118/2022-09-16_a'
12 | ]
13 |
14 | self.input_format = 'suite2p' # 'suite2p' or 'npy' (suite2p is the default suite2p format, npy is raw numpy arrays (F.npy, fov.npy, rois.npy))
15 |
16 | self.save_path = 'data/ac/ac444118/track2p/'
17 |
18 | self.reg_chan = 0 # channel to use for registration (0=functional, 1=anatomical) (1 is not always available)
19 | self.transform_type = 'affine' # 'affine' or 'rigid'
20 | self.iscell_thr = 0.50 # threshold for iscell.npy (only keep ROIs with iscell > iscell_thr) (here lowering this can be good and non-detrimental -> artefacts are unlikely to be consistently present in all datasets)
21 |
22 | self.matching_method='iou' # 'iou', 'cent' or 'cent_int-filt' (iou takes longer but is more accurate, cent is faster but less accurate)
23 | self.iou_dist_thr = 16 # distance between centroids (in pixels) above which to skip iou computation (to save time) (this is only relevant if self.matching_method=='iou')
24 |
25 | self.thr_remove_zeros = False # remove zeros from thr_met before computing automatic threshold (this is useful when there are many zeros in thr_met, which can skew the thresholding)
26 | self.thr_method = 'otsu' # 'otsu' or 'min' (min is just local minimum of pdf of thr_met)
27 |
28 | # do not change these
29 | self.show_roi_reg_output = False # this is slow because plt.contour is slow and also very memory intensive(it can easily crash) but the visualisation is nice for presentations (for example by increasing self.iscell_thr)
30 |
31 | # plotting parameters
32 | self.win_size = 48 # window size for visualising matched ROIs across days (crop of mean image)
33 | self.sat_perc = 99.9 # percentile to saturate image at (only affects visualisation not the registration/matching)
34 |
35 |
36 | self.colors = None
37 |
38 | #self.vector_curation_plane_0 = None
39 | #self.vector_curation_plane_1 = None
40 |
41 | self.save_in_s2p_format = False # save the output in suite2p format (this is useful for downstream analysis with suite2p)
42 |
43 | # make the output directories when initialising the object
44 |
45 | def init_save_paths(self):
46 | self.save_path = os.path.join(self.save_path, 'track2p/')
47 | self.save_path=self.save_path.replace("\\", "/")
48 | self.save_path_fig = os.path.join(self.save_path, 'fig/')
49 | self.save_path_fig=self.save_path_fig.replace("\\", "/")
50 | make_dir(self.save_path)
51 | make_dir(self.save_path_fig)
52 |
53 |
54 | def to_dict(self):
55 | # this is useful for saving the object to avoid needing class definition in downstream analysis
56 | track_ops_dict = {}
57 | for attr in dir(self):
58 | if not attr.startswith('__') and not callable(getattr(self, attr)):
59 | track_ops_dict[attr] = getattr(self, attr)
60 | return track_ops_dict
61 |
62 | def from_dict(self, track_ops_dict):
63 | # loop through all the keys and set the attributes
64 | for key in track_ops_dict:
65 | setattr(self, key, track_ops_dict[key])
--------------------------------------------------------------------------------
/track2p/gui/statusbar.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QStatusBar,QWidget, QHBoxLayout, QSpinBox,QPushButton,QLabel
2 | import numpy as np
3 | import os
4 |
5 | class StatusBar(QStatusBar):
6 | def __init__(self, main_window):
7 | super().__init__()
8 |
9 | self.main_window = main_window
10 | self.central_widget = self.main_window.central_widget
11 | self.vector_curation_t2p=None
12 | self.init_status_bar()
13 |
14 | def init_status_bar(self):
15 | status_widget = QWidget()
16 | layout = QHBoxLayout()
17 |
18 | self.spin_box = QSpinBox()
19 | self.spin_box.setFixedWidth(100)
20 | self.spin_box.valueChanged.connect(self.iterate_all_rois)
21 |
22 | self.roi_state = QLabel("state of ROI: ")
23 | self.roi_state.setFixedWidth(100)
24 | self.roi_state_value = QLabel()
25 | self.roi_state_value.setFixedWidth(30)
26 |
27 | not_cell_button = QPushButton('✖️')
28 | not_cell_button.setFixedSize(20, 20)
29 | not_cell_button.setStyleSheet("background-color: red; color: white; border: none;")
30 | not_cell_button.clicked.connect(self.set_roi_as_not_cell)
31 |
32 | cell_button = QPushButton('✓')
33 | cell_button.setFixedSize(20, 20)
34 | cell_button.setStyleSheet("background-color: green; color: white; border: none;")
35 | cell_button.clicked.connect(self.set_roi_as_cell)
36 |
37 | reset_button = QPushButton('Apply curation')
38 | reset_button.setFixedSize(100, 20)
39 | reset_button.setStyleSheet("background-color: grey; color: white; border: none;")
40 | reset_button.clicked.connect(self.main_window.central_widget.create_mean_img_from_curation)
41 |
42 |
43 | layout.addWidget(self.spin_box)
44 | layout.addWidget(self.roi_state)
45 | layout.addWidget(self.roi_state_value)
46 | layout.addWidget(not_cell_button)
47 | layout.addWidget(cell_button)
48 | layout.addWidget(reset_button)
49 |
50 | status_widget.setLayout(layout)
51 | self.addWidget(status_widget)
52 |
53 | def iterate_all_rois(self):
54 | current_ROI = self.spin_box.value()
55 | value=self.vector_curation_t2p[current_ROI]
56 | self.roi_state_value.setText(f"{value}")
57 | self.central_widget.update_selection(current_ROI)
58 |
59 | def set_roi_as_not_cell(self):
60 | plane=self.central_widget.data_management.plane
61 | key='vector_curation_plane_' + str(plane)
62 | self.vector_curation_t2p = self.main_window.central_widget.data_management.vector_curation_t2p
63 | if self.vector_curation_t2p[self.spin_box.value()] ==1:
64 | self.vector_curation_t2p[self.spin_box.value()]= 0
65 | current_ROI = self.spin_box.value()
66 | value=self.vector_curation_t2p [current_ROI]
67 | self.roi_state_value.setText(f"{value}")
68 | self.central_widget.track_ops_dict[key] = self.vector_curation_t2p
69 | np.save(os.path.join(self.central_widget.data_management.track_ops.save_path, "track_ops.npy"), self.central_widget.track_ops_dict)
70 |
71 | def set_roi_as_cell(self):
72 | plane=self.central_widget.data_management.plane
73 | key='vector_curation_plane_' + str(plane)
74 | self.vector_curation_t2p = self.main_window.central_widget.data_management.vector_curation_t2p
75 | if self.vector_curation_t2p[self.spin_box.value()] ==0:
76 | self.vector_curation_t2p[self.spin_box.value()]= 1
77 | current_ROI = self.spin_box.value()
78 | value=self.vector_curation_t2p[current_ROI]
79 | self.roi_state_value.setText(f"{value}")
80 | self.central_widget.track_ops_dict[key] = self.vector_curation_t2p
81 | np.save(os.path.join(self.central_widget.data_management.track_ops.save_path, "track_ops.npy"), self.central_widget.track_ops_dict)
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # data
2 | data/
3 | data_proc
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 | .DS_Store
166 | data
167 |
--------------------------------------------------------------------------------
/notebooks/run_t2p.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Run track2p programatically \n",
8 | "\n",
9 | "Track2p can also be easily launched through a notebook or a script as shown below"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from track2p.t2p import run_t2p # main function that launches track2p\n",
19 | "from track2p.ops.default import DefaultTrackOps # default track2p options\n",
20 | "\n",
21 | "# load default settings / parameters\n",
22 | "track_ops = DefaultTrackOps()"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "After importing the algorithm function and initialising default parameters, we can set the paths and modify parameters as shown below (Note: the parameters follow the same naming convention as in the 'Run algorithm' window of the GUI):"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# overwrite some defaults\n",
39 | "track_ops.all_ds_path = [ # list of paths to datasets containing a `suite2p` folder\n",
40 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-04-30_a',\n",
41 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-01_a',\n",
42 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-02_a',\n",
43 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-03_a',\n",
44 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-04_a',\n",
45 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-05_a',\n",
46 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-06_a'\n",
47 | " ]\n",
48 | "\n",
49 | "track_ops.save_path = '/Users/jure/Documents/cossart_lab/data/jm/jm038/' # path where to save the outputs of algorithm \n",
50 | " # (a 'track2p' folder will be created where figures for\n",
51 | " # visualisation and matrices of matches would be saved)\n",
52 | "\n",
53 | "track_ops.reg_chan = 1 # channel to use for registration (0=functional, 1=anatomical) (use 0 if only recording gcamp!)\n",
54 | "track_ops.iscell_thr = 0.5 # threshold for iscell (0.5 is a good value)\n"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "Before running algorithm we can also check the settings track2p will use, to be sure everything is correct:"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "# print all the settings / parameters used for running the algorithm\n",
71 | "for attr, value in track_ops.__dict__.items():\n",
72 | " print(attr, '=', value)"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "We can then simply launch the algorithm by passing the `track_ops` object to the `run_t2p` function."
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# run the algorithm\n",
89 | "run_t2p(track_ops)\n"
90 | ]
91 | }
92 | ],
93 | "metadata": {
94 | "kernelspec": {
95 | "display_name": "track2p",
96 | "language": "python",
97 | "name": "python3"
98 | },
99 | "language_info": {
100 | "codemirror_mode": {
101 | "name": "ipython",
102 | "version": 3
103 | },
104 | "file_extension": ".py",
105 | "mimetype": "text/x-python",
106 | "name": "python",
107 | "nbconvert_exporter": "python",
108 | "pygments_lexer": "ipython3",
109 | "version": "3.9.18"
110 | }
111 | },
112 | "nbformat": 4,
113 | "nbformat_minor": 2
114 | }
115 |
--------------------------------------------------------------------------------
/track2p/register/loop.py:
--------------------------------------------------------------------------------
1 | # for now the only algorithm is elastix, TODO: add other algorithms and run them in the same way within this loop
2 |
3 | from track2p.register.elastix import reg_img_elastix
4 | from track2p.io.loaders import load_stat_ds_plane, get_all_roi_array_from_stat
5 | from track2p.register.elastix import itk_reg_all_roi
6 |
7 | def run_reg_loop(all_ds_ref_img, all_ds_mov_img, track_ops):
8 | all_ds_mov_img_reg = []
9 | all_ds_reg_params = []
10 |
11 | for (i, ds_ref_img) in enumerate(all_ds_ref_img):
12 | ds_mov_img = all_ds_mov_img[i]
13 | ds_mov_img_reg = []
14 | ds_reg_params = []
15 |
16 | for j in range(track_ops.nplanes):
17 | ref_img = ds_ref_img[j]
18 | mov_img = ds_mov_img[j]
19 | mov_img_reg,reg_params = reg_img_elastix(ref_img, mov_img, track_ops)
20 | ds_mov_img_reg.append(mov_img_reg)
21 | ds_reg_params.append(reg_params)
22 |
23 | all_ds_mov_img_reg.append(ds_mov_img_reg)
24 | all_ds_reg_params.append(ds_reg_params)
25 |
26 | track_ops.all_ds_mov_img_reg = all_ds_mov_img_reg
27 | # track_ops.all_ds_reg_params = all_ds_reg_params # for now not saving elastix transform object (TODO: transforme it to serializable to be able to pickle (for example by saving params to a dictionary))
28 |
29 | return all_ds_mov_img_reg, all_ds_reg_params
30 |
31 |
32 | def reg_all_ds_all_roi(all_ds_reg_params, track_ops):
33 | all_ds_all_roi_array_ref = []
34 | all_ds_all_roi_array_mov = []
35 | all_ds_all_roi_array_reg = []
36 | all_ds_roi_counter = [] # this will keep track of how many cells make it thorugh the registration
37 |
38 | for i in range(len(track_ops.all_ds_path)-1):
39 | print(f'...\nTransforming ROIs for registration {i}/{len(track_ops.all_ds_path)-1}')
40 |
41 | ds_all_roi_array_ref = [] # for one dataset (all planes)
42 | ds_all_roi_array_mov = []
43 | ds_all_roi_array_reg = []
44 | ds_roi_counter_ref = []
45 | ds_roi_counter_mov = []
46 |
47 | for j in range(track_ops.nplanes):
48 |
49 | # 1) Set paths and transformation
50 | ref_ds_path = track_ops.all_ds_path[i]
51 | reg_ds_path = track_ops.all_ds_path[i+1]
52 |
53 | reg_params = all_ds_reg_params[i][j]
54 |
55 | # 2) Load ROIs # TODO: add (non-s2p dependent) Cellpose ROIs compatibility
56 | stat_ref, roi_counter_ref = load_stat_ds_plane(ref_ds_path, track_ops, plane_idx=j) # loading rois for one dataset one plane
57 | stat_mov, roi_counter_mov = load_stat_ds_plane(reg_ds_path, track_ops, plane_idx=j) # loading rois for one dataset one plane
58 |
59 | all_roi_array_ref = get_all_roi_array_from_stat(stat_ref, track_ops) # for dataset i, plane j
60 | all_roi_array_mov = get_all_roi_array_from_stat(stat_mov, track_ops) # for dataset i, plane j
61 |
62 | # 3) Apply transformation
63 | all_roi_array_reg = itk_reg_all_roi(all_roi_array_mov, reg_params)
64 |
65 | # 4) append
66 | ds_all_roi_array_ref.append(all_roi_array_ref)
67 | ds_all_roi_array_mov.append(all_roi_array_mov)
68 | ds_all_roi_array_reg.append(all_roi_array_reg)
69 |
70 | ds_roi_counter_ref.append(roi_counter_ref)
71 | ds_roi_counter_mov.append(roi_counter_mov) # only keep this one if it is the last pair
72 |
73 | print('Done with dataset...')
74 |
75 | all_ds_all_roi_array_ref.append(ds_all_roi_array_ref)
76 | all_ds_all_roi_array_mov.append(ds_all_roi_array_mov)
77 | all_ds_all_roi_array_reg.append(ds_all_roi_array_reg)
78 |
79 | all_ds_roi_counter.append(ds_roi_counter_ref) # keep both if its last pair (last recording does't appear as a ref)
80 | if i == len(track_ops.all_ds_path)-2:
81 | all_ds_roi_counter.append(ds_roi_counter_mov)
82 |
83 | # save all to track ops
84 | track_ops.all_ds_all_roi_array_ref = all_ds_all_roi_array_ref
85 | track_ops.all_ds_all_roi_array_mov = all_ds_all_roi_array_mov
86 | track_ops.all_ds_all_roi_array_reg = all_ds_all_roi_array_reg
87 | track_ops.all_ds_roi_counter = all_ds_roi_counter
88 |
89 | return all_ds_all_roi_array_ref, all_ds_all_roi_array_mov, all_ds_all_roi_array_reg, all_ds_roi_counter
--------------------------------------------------------------------------------
/track2p/match/loop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.filters import threshold_otsu, threshold_minimum
3 | from scipy.optimize import linear_sum_assignment
4 |
5 | from track2p.match.utils import get_cost_mat, get_iou, init_all_pl_match_mat
6 |
7 | # assigment of ROIs in each ref-reg pair
8 |
9 | def get_all_ds_assign(track_ops, all_ds_all_roi_ref, all_ds_all_roi_reg):
10 |
11 | all_ds_assign = []
12 | all_ds_assign_thr = []
13 | all_ds_thr_met = []
14 | all_ds_thr = []
15 |
16 | for i in range(len(track_ops.all_ds_path)-1):
17 | print(f'Finding matches in ref-reg pair: {i+1}/{len(track_ops.all_ds_path)-1}')
18 | ds_assign = []
19 | ds_assign_thr = []
20 | ds_thr_met = []
21 | ds_thr = []
22 | for j in range(track_ops.nplanes):
23 | all_roi_ref = all_ds_all_roi_ref[i][j]
24 | all_roi_reg = all_ds_all_roi_reg[i][j]
25 |
26 | # 1) compute cost matrix (currently two methods available, see DefaultTrackOps)
27 | cost_mat, all_inds_ref_filt, all_inds_reg_filt = get_cost_mat(all_roi_ref, all_roi_reg, track_ops)
28 |
29 | # 2) optimally assign pairs
30 | ref_ind_filt, reg_ind_filt = linear_sum_assignment(cost_mat)
31 |
32 | # 3) convert them to pre-filtered indices (these are the indices of the ROIs after iscell)
33 | ref_ind = all_inds_ref_filt[ref_ind_filt]
34 | reg_ind = all_inds_reg_filt[reg_ind_filt]
35 |
36 | # 4) for each matched pair (len(all_roi_ref)) compute thresholding metric (in this case IOU, the filtering will be done afterwards in the all-day assignment)
37 | thr_met = get_iou(all_roi_ref[:,:,ref_ind], all_roi_reg[:,:,reg_ind])
38 | thr_met_compute = thr_met[thr_met>0] if track_ops.thr_remove_zeros else thr_met # remove zeros for computing the threshold (otsu thresholding is squed
39 |
40 | # 5) compute otsu threshold on thr_met
41 | if track_ops.thr_method == 'otsu':
42 | thr = threshold_otsu(thr_met_compute)
43 | elif track_ops.thr_method == 'min':
44 | thr = threshold_minimum(thr_met_compute)
45 |
46 |
47 | ds_assign.append([ref_ind, reg_ind])
48 | ds_assign_thr.append([ref_ind[thr_met>thr], reg_ind[thr_met>thr]])
49 | ds_thr_met.append(thr_met)
50 | ds_thr.append(thr)
51 |
52 | all_ds_assign.append(ds_assign)
53 | all_ds_assign_thr.append(ds_assign_thr)
54 | all_ds_thr_met.append(ds_thr_met)
55 | all_ds_thr.append(ds_thr)
56 | print(f'Done ref-reg pair: {i+1}/{len(track_ops.all_ds_path)-1}')
57 |
58 | return all_ds_assign, all_ds_assign_thr, all_ds_thr_met, all_ds_thr
59 |
60 |
61 | # propagating matches across all days
62 |
63 | def get_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops):
64 |
65 | all_pl_match_mat = init_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops)
66 |
67 | for i in range(track_ops.nplanes):
68 | pl_match_mat = all_pl_match_mat[i]
69 | # now for each row in the match matrix (each ROI in the ref recording) we need to find the match across all days, if there is none then we leave it as None
70 |
71 | for roi_idx in range(pl_match_mat.shape[0]): # roi_idx is the index on first session
72 | # if first column is none then we skip this row
73 | if pl_match_mat[roi_idx, 0] is None:
74 | continue
75 | # otherwise we find the match in the all_ds_assign_thr
76 | else:
77 | ref_roi_ds0 = pl_match_mat[roi_idx, 0]
78 | track_roi = np.array(ref_roi_ds0)
79 | for ds_ind in range(pl_match_mat.shape[1]-1):
80 | matches = all_ds_assign_thr[ds_ind][i]
81 |
82 | ref_ind = matches[0]
83 | reg_ind = matches[1]
84 |
85 | reg_ind_ind = np.where(ref_ind==track_roi.item())[0]
86 |
87 | # if there is a match then we update the track_roi
88 | if reg_ind_ind.size>0:
89 | track_roi = reg_ind[reg_ind_ind]
90 | pl_match_mat[roi_idx, ds_ind+1] = track_roi.item()
91 |
92 | # if there is no match then we stop tracking this ROI
93 | else:
94 | break
95 |
96 | # compute how many ROIs are tracked across all days
97 | n_tracked = np.sum(np.all(pl_match_mat!=None, axis=1))
98 | print(f'Number of ROIs tracked in plane{i} across all days: {n_tracked}')
99 | track_ops.all_pl_match_mat = all_pl_match_mat
100 | track_ops.n_tracked = n_tracked
101 |
102 | return all_pl_match_mat
--------------------------------------------------------------------------------
/track2p/io/s2p_loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def check_nplanes(track_ops):
5 | all_nplanes = []
6 |
7 | for ds_path in track_ops.all_ds_path:
8 | # check how many subfolders starting with plane* in suite2p folder
9 | if track_ops.input_format == 'suite2p':
10 | n_planes = len([name for name in os.listdir(ds_path + '/suite2p') if name.startswith('plane')])
11 | elif track_ops.input_format == 'npy':
12 | n_planes = len([name for name in os.listdir(ds_path + '/data_npy') if name.startswith('plane')])
13 | print(f'Found {n_planes} planes in {ds_path}')
14 | all_nplanes.append(n_planes)
15 | track_ops.all_nplanes = all_nplanes
16 | # if all elements in all_n_planes are the same, then save it in track_ops.n_planes
17 | if all_nplanes.count(all_nplanes[0]) == len(all_nplanes):
18 | track_ops.nplanes = all_nplanes[0]
19 | print(f'Found {track_ops.nplanes} planes in all datasets')
20 | else:
21 | print('Found different number of planes in different datasets')
22 | print('Please check your dataset paths')
23 | print('Exiting...')
24 | exit()
25 |
26 |
27 | # loads mean images
28 | def load_all_imgs(track_ops):
29 | all_ds_avg_ch1 = []
30 | all_ds_avg_ch2 = []
31 | all_ds_nchannels = []
32 |
33 | for ds_path in track_ops.all_ds_path:
34 | ds_nchannels = []
35 | ds_avg_ch1 = []
36 | ds_avg_ch2 = []
37 |
38 | for i in range(track_ops.nplanes):
39 | ops = np.load(ds_path + '/suite2p/plane' + str(i) + '/ops.npy', allow_pickle=True).item()
40 | nchannels = ops['nchannels']
41 | print('nchannels: ' + str(nchannels) + ' for plane ' + str(i) + ' in dataset ' + ds_path)
42 | ds_avg_ch1.append(ops['meanImg'])
43 | ds_avg_ch2.append(ops['meanImg_chan2']) if nchannels==2 else ds_avg_ch2.append(None)
44 | ds_nchannels.append(nchannels)
45 |
46 | all_ds_avg_ch1.append(ds_avg_ch1)
47 | all_ds_avg_ch2.append(ds_avg_ch2)
48 | all_ds_nchannels.append(ds_nchannels)
49 |
50 | track_ops.all_ds_avg_ch1 = all_ds_avg_ch1
51 | track_ops.all_ds_avg_ch2 = all_ds_avg_ch2
52 | track_ops.all_ds_nchannels = all_ds_nchannels
53 |
54 | # if all elements in all_ds_nchannels are the same, then print its fine otherwise exit
55 | if all_ds_nchannels.count(all_ds_nchannels[0]) == len(all_ds_nchannels):
56 | track_ops.nchannels = all_ds_nchannels[0][0]
57 | print(f'Found {track_ops.nchannels} channels in all datasets')
58 | else:
59 | print('Found different number of channels in different datasets')
60 | print('Please check your dataset paths')
61 | print('Exiting...')
62 | exit()
63 |
64 | return all_ds_avg_ch1, all_ds_avg_ch2
65 |
66 | def load_all_ds_stat_iscell(track_ops):
67 | all_ds_stat_iscell = []
68 | for (i, ds_path) in enumerate(track_ops.all_ds_path):
69 | ds_stat_iscell = []
70 | for j in range(track_ops.nplanes):
71 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'stat.npy'), allow_pickle=True)
72 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'iscell.npy'), allow_pickle=True)
73 | if track_ops.iscell_thr==None:
74 | stat_iscell = stat[iscell[:,0]==1]
75 | else:
76 | stat_iscell = stat[iscell[:,1]>track_ops.iscell_thr]
77 | ds_stat_iscell.append(stat_iscell)
78 | all_ds_stat_iscell.append(ds_stat_iscell)
79 |
80 | return all_ds_stat_iscell
81 |
82 | def load_all_ds_ops(track_ops):
83 | all_ds_ops = []
84 | for ds_path in track_ops.all_ds_path:
85 | ds_ops = []
86 | for j in range(track_ops.nplanes):
87 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'ops.npy'), allow_pickle=True).item()
88 | ds_ops.append(ops)
89 | all_ds_ops.append(ds_ops)
90 |
91 | return all_ds_ops
92 |
93 | def load_all_ds_mean_img(track_ops, ch=1):
94 | all_ds_ops = load_all_ds_ops(track_ops)
95 | all_ds_mean_img = []
96 | for ds_ops in all_ds_ops:
97 | ds_mean_img = []
98 | for ops in ds_ops:
99 | mean_img = ops['meanImg'] if ch==1 else ops['meanImg_chan2']
100 | ds_mean_img.append(mean_img)
101 | all_ds_mean_img.append(ds_mean_img)
102 |
103 | return all_ds_mean_img
104 |
105 | def load_all_ds_centroids(all_ds_stat_iscell, track_ops):
106 | all_ds_centroids = []
107 | for i in range(len(track_ops.all_ds_path)):
108 | ds_centroids = []
109 | for stat_iscell in all_ds_stat_iscell[i]:
110 | centroids = []
111 | for roi_stat in stat_iscell:
112 | centroids.append(roi_stat['med'])
113 | ds_centroids.append(np.array(centroids))
114 | all_ds_centroids.append(ds_centroids)
115 |
116 | return all_ds_centroids
117 |
--------------------------------------------------------------------------------
/track2p/gui/roi_plot.py:
--------------------------------------------------------------------------------
1 |
2 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import skimage
6 |
7 |
8 | class ZoomPlotWidget(FigureCanvas):
9 | """It is used to display the roi of the selected cell across days with each zoom being a different day.
10 | The roi is centered on the median coordinates of the selected cell. The mean image of the recording is used to create the zooms.
11 | The probability of the cell being a cell and the index of the cell in the suite2p files associated with each recording (day) are also displayed under the zooms."""
12 |
13 | def __init__(self, all_ops=None, all_stat_t2p=None, colors=None, all_iscell_t2p=None,t2p_match_mat_allday =None,track_ops=None, imgs=None):
14 | nb_plot=len(all_iscell_t2p)
15 | self.fig, self.ax_zoom = plt.subplots(1, nb_plot, figsize = (10*nb_plot, nb_plot), gridspec_kw={'width_ratios': [1] * nb_plot},facecolor='black')
16 | super().__init__(self.fig)
17 | self.track_ops=track_ops
18 | self.all_ops = all_ops
19 | self.all_stat_t2p = all_stat_t2p
20 | self.colors = colors
21 | self.all_is_cell=all_iscell_t2p
22 | self.t2p_match_mat_allday =t2p_match_mat_allday
23 | self.imgs= imgs
24 | self.coord_dict={}
25 |
26 | def display_zooms(self, selected_cell_index):
27 | """It is called when the application is opened and a cell is selected."""
28 | #it is used to store the roi and the median coordinates of the selected cell for each recording (day)
29 | self.coord_dict={}
30 | if self.all_ops is not None and self.all_stat_t2p is not None and self.all_is_cell is not None:
31 | for i in range(len(self.imgs)):
32 | wind = 20
33 | #mean_img = self.all_ops[i]['meanImg']
34 | mean_img=self.imgs[i]
35 | stat_t2p = self.all_stat_t2p[i]
36 | median_coord = stat_t2p[selected_cell_index]['med']
37 | print(f'median_coord : ', median_coord)
38 |
39 | print(mean_img.shape)
40 | # Définir la taille de la marge
41 | margin = 20
42 |
43 | # Obtenir les dimensions de l'image originale
44 | Ly, Lx = mean_img.shape
45 |
46 | # Créer une nouvelle image avec des dimensions augmentées
47 | new_Ly = Ly + 2 * margin
48 | new_Lx = Lx + 2 * margin
49 | range_img = np.zeros((new_Ly, new_Lx))
50 |
51 |
52 | # Copier mean_img au centre de la nouvelle image
53 | range_img[margin:margin + Ly, margin:margin + Lx] = mean_img
54 |
55 | # Ajouter la marge aux coordonnées médianes
56 | median_x = int(median_coord[1]) + margin
57 | median_y = int(median_coord[0]) + margin
58 |
59 | # Calculer les nouvelles coordonnées de la ROI
60 | x_start = median_x - wind
61 | x_end = median_x + wind
62 | y_start = median_y - wind
63 | y_end = median_y + wind
64 |
65 | # Extraire la ROI de l'image avec la marge
66 | roi = range_img[y_start:y_end, x_start:x_end]
67 |
68 | print(range_img.shape)
69 | print(roi.shape)
70 |
71 |
72 | #prob and index
73 | iscell=self.all_is_cell[i]
74 | if self.track_ops.iscell_thr==None:
75 | indices_lignes_1 = np.where(iscell[:,0]==1)[0]
76 | match_index=self.t2p_match_mat_allday[selected_cell_index,i]
77 | true_index=indices_lignes_1[match_index]
78 | else:
79 | indices_lignes_1= np.where(iscell[:,1]>self.track_ops.iscell_thr)[0]
80 | match_index=self.t2p_match_mat_allday[selected_cell_index,i]
81 | true_index=indices_lignes_1[match_index]
82 | prob=round(iscell[true_index,1],2)
83 |
84 | ypix=stat_t2p[selected_cell_index]['ypix']
85 | xpix=stat_t2p[selected_cell_index]['xpix']
86 |
87 |
88 | ax = self.ax_zoom[i]
89 | color = self.colors[selected_cell_index]
90 |
91 | mask=np.zeros((2*wind,2*wind))
92 | mask[ypix-median_coord[0]+wind,xpix-median_coord[1]+wind]=1
93 | ax.clear()
94 | ax.contour(mask,levels=[0.5], colors=[color],linewidths=2)
95 |
96 | #last_img=list(self.coord_dict.values())[-1][0]
97 | #match_roi=skimage.exposure.match_histograms(list[0], last_img, channel_axis=None)
98 |
99 | ax.imshow(roi, cmap='gray')
100 | ax.set_title(f'Day {i + 1}', color='white', fontsize=10)
101 | ax.text(0.5, -0.2, f'i: {true_index}', color='white', fontsize=10, ha='center', va='center', transform=ax.transAxes)
102 | ax.text(0.5, -0.4, f'p: {prob}', color='white', fontsize=10, ha='center', va='center', transform=ax.transAxes)
103 |
104 | ax.axis('off')
105 |
106 | self.draw()
107 |
--------------------------------------------------------------------------------
/notebooks/utils/npy_to_s2p.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "f0538be5",
6 | "metadata": {},
7 | "source": [
8 | "# Conversion of npy to s2p \n",
9 | "\n",
10 | "This is used to implement and test the feature suggested by reviewer 2 - allowing easier use of track2p by people using pipelines other than suite2p\n",
11 | "\n",
12 | "Inputs:\n",
13 | "- rois.npy (bool ndarray of shape: (n_roi, n_px_y, n_px_x))\n",
14 | "- fov.npy (float32 ndarray of shape: (n_px_y, n_px_x)) - IMPORTANT: ROIs and FOV must be aligned (this is done by default in suite2p outputs, but might need to be registered manually in other cases)\n",
15 | "- traces.npy (float32 ndarray of shape: (n_roi, n_tstamps))\n",
16 | "\n",
17 | "Outputs:\n",
18 | "- suite2p folder"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "4e3cd55a",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import os\n",
29 | "import numpy as np\n",
30 | "import matplotlib.pyplot as plt"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "b008c170",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "nplanes = 1\n",
41 | "plane = 0"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "bed3a691",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "# these are the paths to the datasets that will be used to \n",
52 | "subject_npy_path = '/Users/jure/Documents/cossart_lab/data/jm/jm039_npy/'\n",
53 | "# npy_paths are all the subfolders of subject_path ending in _a\n",
54 | "npy_paths = [os.path.join(f.path, 'data_npy', f'plane{plane}') for f in os.scandir(subject_npy_path) if f.is_dir() and f.name.endswith('_a')]\n",
55 | "npy_paths.sort()\n",
56 | "print(npy_paths)\n",
57 | "\n",
58 | "\n",
59 | "s2p_paths = []\n",
60 | "for npy_path in npy_paths:\n",
61 | " print(npy_path)\n",
62 | " # remove _npy\n",
63 | " s2p_path = npy_path.replace('jm039_npy', 'jm039')\n",
64 | " s2p_path = s2p_path.replace('data_npy', 'suite2p')\n",
65 | " s2p_paths.append(s2p_path)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "6ab62d00",
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "path = npy_paths[0] # Example path, you can loop through npy_paths if needed\n",
76 | "\n"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "c9c13434",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "for (npy_path, s2p_path) in zip(npy_paths, s2p_paths):\n",
87 | " print(s2p_path)\n",
88 | " \n",
89 | " F = np.load(os.path.join(npy_path, 'F.npy'))\n",
90 | " fov = np.load(os.path.join(npy_path, 'fov.npy'))\n",
91 | " rois = np.load(os.path.join(npy_path, 'rois.npy'))\n",
92 | "\n",
93 | "\n",
94 | "\n",
95 | " if not os.path.exists(s2p_path):\n",
96 | " os.makedirs(s2p_path)\n",
97 | " else:\n",
98 | " print(f\"Directory {s2p_path} already exists, skipping... (Delete or rename it if you want to overwrite)\")\n",
99 | " continue\n",
100 | " \n",
101 | " np.save(os.path.join(s2p_path, 'F.npy'), F)\n",
102 | "\n",
103 | " ops = {'meanImg': fov}\n",
104 | " ops['nchannels'] = 1 # Assuming single channel\n",
105 | " ops['fs'] = 30 # Assuming a sampling frequency of 30 Hz\n",
106 | " ops['nframes'] = F.shape[1]\n",
107 | " np.save(os.path.join(s2p_path, 'ops.npy'), ops)\n",
108 | "\n",
109 | " # stat is a list of dictionaries, each with keys 'xpix', 'ypix'\n",
110 | " stat = []\n",
111 | " for i in range(rois.shape[0]):\n",
112 | " # TODO: make sure using the correct axis convention!!!\n",
113 | " ypix, xpix = np.where(rois[i] > 0)\n",
114 | " med = [int(np.median(ypix).item()), int(np.median(xpix).item())]\n",
115 | " stat.append({'xpix': xpix, 'ypix': ypix, 'med': med})\n",
116 | "\n",
117 | " np.save(os.path.join(s2p_path, 'stat.npy'), stat)\n",
118 | " \n",
119 | " # iscell is two columns of 1 the columns are length n_cells\n",
120 | " n_cells = len(stat)\n",
121 | " iscell = np.ones((n_cells, 2), dtype=int)\n",
122 | " np.save(os.path.join(s2p_path, 'iscell.npy'), iscell)\n",
123 | " print(f\"Saved to {s2p_path}\")\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "9563eb37",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "med = [int(np.median(ypix).item()), int(np.median(xpix).item())]\n",
134 | "\n",
135 | "med"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "id": "aebc8434",
142 | "metadata": {},
143 | "outputs": [],
144 | "source": []
145 | }
146 | ],
147 | "metadata": {
148 | "kernelspec": {
149 | "display_name": "track2p",
150 | "language": "python",
151 | "name": "python3"
152 | },
153 | "language_info": {
154 | "codemirror_mode": {
155 | "name": "ipython",
156 | "version": 3
157 | },
158 | "file_extension": ".py",
159 | "mimetype": "text/x-python",
160 | "name": "python",
161 | "nbconvert_exporter": "python",
162 | "pygments_lexer": "ipython3",
163 | "version": "3.9.21"
164 | }
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 5
168 | }
169 |
--------------------------------------------------------------------------------
/notebooks/utils/s2p_to_npy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "192b95be",
6 | "metadata": {},
7 | "source": [
8 | "# Conversion of s2p to npy\n",
9 | "\n",
10 | "This is used to implement and test the feature suggested by reviewer 2 - allowing easier use of track2p by people using pipelines other than suite2p\n",
11 | "\n",
12 | "Inputs:\n",
13 | "- suite2p folder\n",
14 | "\n",
15 | "Outputs:\n",
16 | "- rois.npy (bool ndarray of shape: (n_roi, n_px_y, n_px_x))\n",
17 | "- fov.npy (float32 ndarray of shape: (n_px_y, n_px_x)) - IMPORTANT: ROIs and FOV must be aligned (this is done by default in suite2p outputs, but might need to be registered manually in other cases)\n",
18 | "- traces.npy (float32 ndarray of shape: (n_roi, n_tstamps))\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "1e4cb8a3",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import os\n",
29 | "import numpy as np\n",
30 | "import matplotlib.pyplot as plt"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "5c6c7926",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "# these are the paths to the datasets that will be used to \n",
41 | "subject_s2p_path = '/Volumes/data_jm_share/data_proc/el/el017/'\n",
42 | "# s2p_paths are all the subfolders of subject_path ending in _a\n",
43 | "s2p_paths = [f.path for f in os.scandir(subject_s2p_path) if f.is_dir() and f.name.endswith('_a')]\n",
44 | "s2p_paths.sort()\n",
45 | "print(s2p_paths)\n",
46 | "\n",
47 | "# \n",
48 | "npy_paths = []\n",
49 | "for s2p_path in s2p_paths:\n",
50 | " \n",
51 | " # add _npy after jm039 to the path\n",
52 | " npy_path = s2p_path.replace('el017', 'el017_npy')\n",
53 | " npy_path = os.path.join(npy_path, 'data_npy')\n",
54 | " # make this directory if it does not exist\n",
55 | "\n",
56 | " if not os.path.exists(npy_path):\n",
57 | " os.makedirs(npy_path)\n",
58 | "\n",
59 | " npy_paths.append(npy_path)"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "f8811710",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "# now write code to load the s2p data and saves it to npy\n",
70 | "plane = 0\n",
71 | "for plane in range(2):\n",
72 | " print(f'Processing plane {plane}...')\n",
73 | " for s2p_path, npy_path in zip(s2p_paths, npy_paths):\n",
74 | " print(f'Loading {s2p_path} and saving to {npy_path}')\n",
75 | "\n",
76 | " s2p_path_full = os.path.join(s2p_path, 'suite2p', f'plane{plane}')\n",
77 | "\n",
78 | " ops = np.load(os.path.join(s2p_path_full, 'ops.npy'), allow_pickle=True).item()\n",
79 | " stat = np.load(os.path.join(s2p_path_full, 'stat.npy'), allow_pickle=True)\n",
80 | " iscell = np.load(os.path.join(s2p_path_full, 'iscell.npy'), allow_pickle=True)\n",
81 | " F = np.load(os.path.join(s2p_path_full, 'F.npy'), allow_pickle=True)\n",
82 | "\n",
83 | " iscell_bool = iscell[:, 0] == 1\n",
84 | " print(sum(iscell_bool), 'cells in this dataset')\n",
85 | "\n",
86 | " # get mean image\n",
87 | " mean_img = ops['meanImg']\n",
88 | " print('Mean image shape:', mean_img.shape)\n",
89 | " print(type(mean_img[0,0]))\n",
90 | " plt.imshow(mean_img, cmap='gray')\n",
91 | " plt.show()\n",
92 | "\n",
93 | " # now filter stat and F by iscell_bool\n",
94 | " stat = stat[iscell_bool]\n",
95 | " F = F[iscell_bool, :]\n",
96 | " print('Stat shape:', stat.shape)\n",
97 | " print('F shape:', F.shape)\n",
98 | " print(type(F[0, 0]))\n",
99 | "\n",
100 | " rois = np.zeros((len(stat), mean_img.shape[0], mean_img.shape[1]), dtype=np.bool)\n",
101 | " for i, s in enumerate(stat):\n",
102 | " # create a mask for the roi\n",
103 | " mask = np.zeros((mean_img.shape[0], mean_img.shape[1]), dtype=np.bool)\n",
104 | " mask[s['ypix'], s['xpix']] = True\n",
105 | " rois[i] = mask\n",
106 | " print('ROIs shape:', rois.shape)\n",
107 | " print(type(rois[0, 0, 0]))\n",
108 | "\n",
109 | " plt.imshow(rois[0], cmap='gray')\n",
110 | " plt.show()\n",
111 | " plt.imshow(rois[1], cmap='gray')\n",
112 | " plt.show()\n",
113 | "\n",
114 | " # save to npy\n",
115 | " npy_path = os.path.join(npy_path, f'plane{plane}')\n",
116 | " \n",
117 | " if not os.path.exists(npy_path):\n",
118 | " os.makedirs(npy_path)\n",
119 | "\n",
120 | " np.save(os.path.join(npy_path, 'rois.npy'), rois)\n",
121 | " np.save(os.path.join(npy_path, 'F.npy'), F)\n",
122 | " np.save(os.path.join(npy_path, 'fov.npy'), mean_img)\n",
123 | "\n",
124 | " # TODO: save these to npy files ...\n"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "id": "211c1aac",
130 | "metadata": {},
131 | "source": []
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "id": "f5ff8c6d",
136 | "metadata": {},
137 | "source": []
138 | }
139 | ],
140 | "metadata": {
141 | "kernelspec": {
142 | "display_name": "track2p",
143 | "language": "python",
144 | "name": "python3"
145 | },
146 | "language_info": {
147 | "codemirror_mode": {
148 | "name": "ipython",
149 | "version": 3
150 | },
151 | "file_extension": ".py",
152 | "mimetype": "text/x-python",
153 | "name": "python",
154 | "nbconvert_exporter": "python",
155 | "pygments_lexer": "ipython3",
156 | "version": "3.9.21"
157 | }
158 | },
159 | "nbformat": 4,
160 | "nbformat_minor": 5
161 | }
162 |
--------------------------------------------------------------------------------
/notebooks/eval/main_eval_fig.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import os \n",
11 | "import matplotlib.pyplot as plt\n",
12 | "\n",
13 | "# import SimpleNamespace\n",
14 | "from types import SimpleNamespace\n",
15 | "\n",
16 | "from track2p.eval.io import load_pairwise_f1_values, load_alldays_ct_values\n",
17 | "from track2p.eval.plot import plot_alldays_f1, plot_pairwise_f1\n",
18 | "\n",
19 | "# avoid restarting notebook to import changes\n",
20 | "%load_ext autoreload\n",
21 | "%autoreload 2"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "## Define loading and plotting functions"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "## Set paths and paramters"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "# Chemins des fichiers NumPy\n",
45 | "base_path = '/Volumes/data_jm_share/data_proc/jm' # Remplacez par le chemin correct\n",
46 | "conditions = ['chan0', 'chan1', 'rigid', 'cellreg'] # ['chan0', 'chan1', 'rigid', 'cellreg']\n",
47 | "conditions_names = ['Anatomical', 'Functional', 'Rigid', 'CellReg']\n",
48 | "animals = ['jm038', 'jm039', 'jm046'] \n",
49 | "symbols = {'jm038': 'o', 'jm039': 'o', 'jm046': 'o'}\n",
50 | "colors = {'jm038': (0.8, 0.8, 0.8), 'jm039': 'C0', 'jm046': (0.7, 0.7, 0.7)}\n"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "## All days evaluation"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "ct, acc = load_alldays_ct_values(base_path, animals, conditions, ct_type='CT')\n",
67 | "\n",
68 | "ct_gt, acc_gt = load_alldays_ct_values(base_path, animals, conditions, ct_type='CT_GT')\n",
69 | "\n",
70 | "# in paper CT should be used in 'all day evaluation' - it takes into account the false positives as well\n",
71 | "# and CT_GT should be used in 'pairwise evaluation' - it can't take these into account (it only reports proportion of correctly reconstructed traces)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "# iterate over keys in ct\n",
81 | "f1_values = {animal: [] for animal in animals}\n",
82 | "\n",
83 | "for animal in ct.keys():\n",
84 | " for (i, condition) in enumerate(conditions):\n",
85 | " # get the last value of associated array\n",
86 | " f1_val = ct[animal][i][-1]\n",
87 | " f1_values[animal].append(f1_val)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "# TODO: remove - just manually computing average F1 scores for publication\n",
97 | "mn_f1_anat = np.mean(np.array([0.89, 0.98, 0.93]))\n",
98 | "mn_f1_func = np.mean(np.array([0.85, 0.98, 0.9]))\n",
99 | "mn_f1_rigi = np.mean(np.array([0.16, 0.29, 0.21]))\n",
100 | "mn_f1_creg = 0\n",
101 | "\n",
102 | "# print them to two digits each in its own row\n",
103 | "print(f'Mean values \\nAnatomical: {mn_f1_anat:.2f} \\nFunctional: {mn_f1_func:.2f} \\nRigid: {mn_f1_rigi:.2f} \\nCellReg: {mn_f1_creg:.2f}')"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "plot_alldays_f1(animals, conditions_names, f1_values, symbols, colors, xshift=0.2)\n",
113 | "\n",
114 | "animals_names = ['mouse C', 'mouse D', 'mouse F']\n",
115 | "plot_alldays_f1(animals, conditions_names, f1_values, symbols, colors, xshift=0.2, animals_names=animals_names)\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 6,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# print mean F1 values for conditions\n"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "## Pairwise evaluation"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "\n",
141 | "\n",
142 | "for (i, condition_name) in enumerate(conditions_names):\n",
143 | " \n",
144 | " # initialise dictionary\n",
145 | " pairwise_ct_values = {}\n",
146 | "\n",
147 | " for animal in animals:\n",
148 | " pairwise_ct_values[animal] = ct_gt[animal][i]\n",
149 | "\n",
150 | " plot_pairwise_f1(animals, condition_name, pairwise_ct_values, symbols, colors)\n"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "condition = 'pw_reg'\n",
160 | "symbols_pw = {'jm038': 's', 'jm039': 's', 'jm046': 's'} # to differentiate from all day registration\n",
161 | "\n",
162 | "pairwise_f1_values = load_pairwise_f1_values(base_path, animals, condition)\n",
163 | "plot_pairwise_f1(animals, condition, pairwise_f1_values, symbols_pw, colors)\n"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": []
172 | }
173 | ],
174 | "metadata": {
175 | "kernelspec": {
176 | "display_name": "track2p",
177 | "language": "python",
178 | "name": "python3"
179 | },
180 | "language_info": {
181 | "codemirror_mode": {
182 | "name": "ipython",
183 | "version": 3
184 | },
185 | "file_extension": ".py",
186 | "mimetype": "text/x-python",
187 | "name": "python",
188 | "nbconvert_exporter": "python",
189 | "pygments_lexer": "ipython3",
190 | "version": "3.9.19"
191 | }
192 | },
193 | "nbformat": 4,
194 | "nbformat_minor": 2
195 | }
196 |
--------------------------------------------------------------------------------
/track2p/match/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 | from skimage import measure
4 | from skimage.filters import threshold_otsu
5 |
6 | # compute centroids
7 | def get_centroids(all_roi):
8 | centroids = []
9 | for i in range(all_roi.shape[2]):
10 | roi = all_roi[:,:,i]
11 | labels = measure.label(roi)
12 | features = measure.regionprops(labels)
13 | try:
14 | centroids.append(features[0].centroid)
15 | except IndexError:
16 | centroids.append([0, 0])
17 |
18 | return np.array(centroids)
19 |
20 | def get_cent_dist_mat(all_roi_ref, all_roi_reg):
21 | # compute distances
22 | centroids_ref = get_centroids(all_roi_ref)
23 | centroids_reg = get_centroids(all_roi_reg)
24 | distances = cdist(centroids_ref, centroids_reg)
25 | return distances
26 |
27 | def filt_non_overlap(all_roi1, all_roi2, cent_dist_mat):
28 | all_inds = np.arange(all_roi1.shape[2])
29 | filt_inds_ref = []
30 | for i in range(all_roi1.shape[2]):
31 | # get the index of closest roi from the distances matrix
32 | roi = all_roi1[:,:,i]
33 | closest_roi_idx = np.argmin(cent_dist_mat[i,:])
34 | closest_roi = all_roi2[:,:,closest_roi_idx]
35 | intersection = roi*closest_roi
36 | if np.sum(intersection)==0: # TODO:maybe it overlaps a bit with non-first closest roi?
37 | filt_inds_ref.append(i)
38 |
39 | # get all indices that are not in filt_inds_ref
40 | all_inds_filt = all_inds[~np.isin(all_inds, filt_inds_ref)]
41 | return all_inds_filt
42 |
43 | def get_cent_dist_mat_non_overlap(all_roi_ref, all_roi_reg):
44 |
45 | cent_dist_mat = get_cent_dist_mat(all_roi_ref, all_roi_reg)
46 |
47 | # get indices of neuorns that overlap with their closest neighbor (at least will have some chance to match)
48 | all_inds_ref_filt = filt_non_overlap(all_roi_ref, all_roi_reg, cent_dist_mat)
49 | all_inds_reg_filt = filt_non_overlap(all_roi_reg, all_roi_ref, cent_dist_mat.T)
50 |
51 | cost_mat = cent_dist_mat[all_inds_ref_filt, :]
52 | cost_mat = cost_mat[:, all_inds_reg_filt]
53 |
54 | return cost_mat, all_inds_ref_filt, all_inds_reg_filt
55 |
56 | def get_cost_mat(all_roi_ref, all_roi_reg, track_ops):
57 | # compute distances
58 | if track_ops.matching_method=='cent': # simple assignment based on centroids (probably works with sparse data but not in development)
59 | cost_mat = get_cent_dist_mat(all_roi_ref, all_roi_reg)
60 | all_inds_ref_filt = np.arange(all_roi_ref.shape[2]) # here we don't filter the indices
61 | all_inds_reg_filt = np.arange(all_roi_reg.shape[2]) # here we don't filter the indices
62 | elif track_ops.matching_method=='cent_int-filt': # here to simplify the matching we first filter out the ROIs that have no intersection with their closest neighbor
63 | # this also outputs the indices of neurons after filtering
64 | # costa mat here is smaller than above since the matches are filtered
65 | # additionally when outputting we need to be careful to index with the filtered indices as well
66 | cost_mat, all_inds_ref_filt, all_inds_reg_filt = get_cent_dist_mat_non_overlap(all_roi_ref, all_roi_reg)
67 | elif track_ops.matching_method=='iou':
68 | cost_mat = 1-get_cross_iou_mat(all_roi_ref, all_roi_reg, dist_thr=track_ops.iou_dist_thr)
69 | all_inds_ref_filt = np.arange(all_roi_ref.shape[2])
70 | all_inds_reg_filt = np.arange(all_roi_reg.shape[2])
71 | else:
72 | raise Exception('Matching method not implemented')
73 |
74 | print(f'cost_mat computed with method: {track_ops.matching_method}')
75 | print(f'cost_mat shape: {cost_mat.shape}')
76 | print(f'cost_mat min: {np.min(cost_mat)}')
77 | print(f'cost_mat max: {np.max(cost_mat)}')
78 |
79 | return cost_mat, all_inds_ref_filt, all_inds_reg_filt
80 |
81 | def get_iou(all_roi_ref, all_roi_reg):
82 |
83 | ious = []
84 | for i in range(all_roi_ref.shape[2]):
85 | roi_ref = all_roi_ref[:,:,i]
86 | roi_reg = all_roi_reg[:,:,i]
87 | intersection = np.sum(np.logical_and(roi_ref, roi_reg))
88 | union = np.sum(np.logical_or(roi_ref, roi_reg))
89 | ious.append(intersection/union)
90 |
91 | return np.array(ious)
92 |
93 | def get_cross_iou_mat(all_roi_ref, all_roi_reg, dist_thr=16):
94 | # if the distance between two rois is larger than dist_thr, we assume they are not the same cell and just skip the computation
95 | distances = get_cent_dist_mat(all_roi_ref, all_roi_reg)
96 |
97 | cross_iou_mat = np.zeros((all_roi_ref.shape[2], all_roi_reg.shape[2]))
98 | for i in range(all_roi_ref.shape[2]):
99 | for j in range(all_roi_reg.shape[2]):
100 | if distances[i,j] > dist_thr: # skipping if far apart
101 | continue
102 | # compute IOU
103 | intersection = np.logical_and(all_roi_ref[:,:,i], all_roi_reg[:,:,j])
104 | union = np.logical_or(all_roi_ref[:,:,i], all_roi_reg[:,:,j])
105 | iou_score = np.sum(intersection) / np.sum(union)
106 | cross_iou_mat[i,j] = iou_score
107 |
108 | return cross_iou_mat
109 |
110 | def init_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops):
111 | all_pl_match_mat = []
112 | for i in range(track_ops.nplanes):
113 | # set up the match matrix for each plane it will be size of all iscell ROIs in the ref recording x number of datasets
114 | pl_match_mat = np.full((all_ds_all_roi_ref[0][i].shape[2], len(track_ops.all_ds_path)), None)
115 | all_pl_match_mat.append(pl_match_mat)
116 | # populate first row of the match matrices with the matches from the first ref-reg pair
117 | for i in range(track_ops.nplanes):
118 | pl_match_mat = all_pl_match_mat[i]
119 | assign_thr = all_ds_assign_thr[0][i] # zeroth dataset, ith plane
120 | ref_ind = assign_thr[0]
121 | pl_match_mat[ref_ind, 0] = ref_ind
122 |
123 | return all_pl_match_mat
124 |
125 | def filt_by_otsu(vect_filt, vect_comp):
126 | # vect_filt is the vector that we want to filter
127 | # vect_comp is the vector that we want to compute the otsu threshold on
128 |
129 | thresh = threshold_otsu(vect_comp)
130 | return vect_filt[vect_comp>thresh]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # track2p
2 |
3 |
4 |
5 |
6 | Cell tracking for longitudinal calcium imaging recordings.
7 |
8 | [](https://pypi.org/project/track2p/)
9 | [](https://pepy.tech/project/track2p)
10 | [](https://www.gnu.org/licenses/gpl-3.0)
11 |
12 | For more detailed information on installation and use visit:
13 |
14 | https://track2p.github.io/
15 |
16 |
17 | Or read the [paper](https://elifesciences.org/articles/107540).
18 |
19 |
20 |
21 | # Installation
22 |
23 | ## Installing via pip
24 |
25 | First we need to set up a conda environment with python 3.9:
26 |
27 | ```
28 | conda create --name track2p python=3.9
29 | conda activate track2p
30 | ```
31 |
32 | Then simply install the track2p package using pip:
33 |
34 | ```
35 | pip install track2p
36 | ```
37 |
38 | Thats it, track2p should be succesfully set up :)
39 | You can simply run it by:
40 |
41 | ```
42 | python -m track2p
43 | ```
44 |
45 | This opens a GUI allowing the user to launch the algorithm and visualise the results interactively.
46 |
47 | (For instructions on running track2p without the GUI see the 'Run via script' under the 'Usage' section)
48 |
49 | Note: For common installation issues see ['Installation > Common issues'](https://track2p.github.io/install_common_issues.html) in documentation.
50 |
51 | ## Reinstall
52 |
53 | To reinstall, open anaconda and remove the environment with:
54 |
55 | ```
56 | conda env remove -n track2p
57 | ```
58 |
59 | Then follow the 'Installing via pip' instructions above :)
60 |
61 |
62 | # Usage
63 |
64 | ## Run track2p through the GUI
65 |
66 | After activating the GUI through `python -m track2p` the user should navigate to the 'Run' tab on the top left of the window and select 'Run track2p algorithm' from the dropdown menu.
67 |
68 | This will open a pop-up window that will allow the user to set the paths to suite2p datasets and to set the algorithm parameters. After configuring these settings, the user can click 'Run' to run the track2p algorithm, and the progress will be displayed in the terminal.
69 |
70 | Once the algorithm finishes a subsequent pop-up window will prompt the user to decide whether they wish to visualize the results within the interface.
71 |
72 | For more details on how to run the algorithm through the GUI see [run track2p](https://track2p.github.io/run_track2p_gui.html) and for more description of parameters see documentation [parameters](https://track2p.github.io/run_inputs_and_parameters.html#parameters).
73 |
74 | ## Run track2p via script
75 |
76 | To run via script you can use the `run_track2p.py` script in the root of this repo as a template. It is exactly the same as running thrugh the gui, only that the paths and the parameters are defined within the script (for more on parameters etc. see documentation). When running make sure you are running it within the track2p environment, for example:
77 |
78 | ```
79 | conda activate track2p
80 | python -m run_track2p
81 | ```
82 |
83 |
84 | ## Visualising track2p outputs within the GUI
85 |
86 | After activating the GUI through python -m track2p the user can import the results of any previous analysis by clicking on 'File' tab on the top left of the window and select 'Load processed data' from the dropdown menu. This will open a pop-up window that will allow the user to set the path to the track2p folder (containing the results of the algorithm) and the plane they want to open.
87 |
88 | Once completed, the interface showcases multiple visualizations:
89 |
90 | 
91 |
92 | In this example we are using the track2p GUI to visualise the outputs for an experiment containing 7 consecutive daily recordings in mouse barrel cortex (between P8 and P14).
93 |
94 | In the upper left, the GUI visualises the mean image of the motion-corrected functional channel (usually green / GCaMP). The image is overlayed with ROIs of the cells detected by track2p across all days, with the color of a particular cell matching across days. These images are interactive, allowing the user to click on a cell, which displays the fluorescence traces on each day at the bottom of the window (sorted from the first day to the last).
95 |
96 | In addition, a zoomed-in image of the cell for each day is shown in the top right. Underneath each zoomed-in image the GUI displays this cell's index in the corresponding 'suite2p’ dataset and the 'iscell' probability suite2p has assigned to it on that day.
97 |
98 | Finally, the user can browse all the putative matches detected by the algorithm using the bar at the bottom to toggle through matches, or alternatively they can enter the index of a specific number to display it within the GUI. This bar is also used for manual curation, where we allow the user to evaluate the quality of the tracking for each individual match.
99 |
100 | For more details on how to use the GUI see [GUI usage](https://track2p.github.io/gui_overview.html).
101 |
102 |
103 | # Outputs
104 |
105 | All the outputs of the script will be saved in a `track2p` folder created within the `track_ops.save_path` directory specified by the user when running the algorithm. For an introduction on how to use the outputs for further downstream analysis we provide a useful demo notebook `demo_t2p_output.ipynb` in the root of this repository. Note: You will need to additionally install jupyter for this to work. For example:
106 |
107 | ```
108 | conda install conda-forge::jupyterlab
109 | ```
110 |
111 | For more information see documentation relating to track2p [visualisations](https://track2p.github.io/outputs_visualisations.html) and [outputs](https://track2p.github.io/outputs_matches.html).
112 |
113 | # Troubleshooting
114 |
115 | A brief troubleshooting guide including some common issues is included [here](https://github.com/juremaj/track2p/blob/main/docs/troubleshooting.md).
116 |
117 | # Reference
118 |
119 | If you use the algorithm please reference the [eLife paper](https://elifesciences.org/articles/107540):
120 |
121 | **Majnik, J., Mantez, M., Zangila, S., Bugeon, S., Guignard, L., Platel, J.-C., & Cossart, R. (2025). Longitudinal tracking of neuronal activity from the same cells in the developing brain using Track2p. eLife.**
122 |
123 | You can also see a Youtube recording of a talk related to the paper: [Link to video (starting at 47:20)](https://youtu.be/Tr97HwgQ9ik?t=2839)
124 |
125 | The data associated with the paper is also available on [Zenodo](https://zenodo.org/records/17091226)
126 |
127 | ___
128 |
129 |
130 | By Cossart Lab (Jure Majnik & Manon Mantez)
131 |
132 | Logo by Eleonora Ambrad Giovannetti
133 |
134 | © Copyright 2025.
135 |
--------------------------------------------------------------------------------
/track2p/gui/fluo_plot.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from scipy.stats import zscore
6 | from PyQt5.QtCore import Qt
7 | import matplotlib.patches as patches
8 | from PyQt5 import QtCore
9 | from scipy.ndimage import maximum_filter1d, minimum_filter1d, gaussian_filter
10 |
11 |
12 | class FluorescencePlotWidget(FigureCanvas):
13 | """this class is used to display the fluorescence of the selected cell across days. It also allows to select a region of interest (ROI) on the fluorescence plot and zoom in on the selected ROI"""
14 | def __init__(self, all_f_t2p=None, all_ops=None, colors=None, all_stat_t2p=None):
15 | self.fig, self.ax_fluorescence = plt.subplots(1, 1)
16 | super().__init__(self.fig)
17 | self.all_f_t2p = all_f_t2p
18 | self.all_ops=all_ops
19 | self.fig.set_facecolor('black')
20 | self.colors = colors
21 | self.all_stat_t2p= all_stat_t2p
22 |
23 |
24 | self.rect = patches.Rectangle((0,0), 1, 1, color='white', linewidth=2)
25 |
26 | self.fig.canvas.mpl_connect('button_press_event', self.on_press)
27 | self.fig.canvas.mpl_connect('button_release_event', self.on_release)
28 | self.fig.canvas.mpl_connect('key_press_event', self.on_key_press)
29 | self.fig.canvas.mpl_connect('key_press_event', self.on_enter_pressed)
30 | self.setFocusPolicy(Qt.StrongFocus)
31 | self.setFocus()
32 |
33 | self.initial_xlim = None
34 | self.initial_ylim = None
35 | self.cmd_pressed = False
36 |
37 | def draw_rectangle(self):
38 | """it draws a rectangle on the fluorescence plot"""
39 | #if the rectangle is already on the plot, it is removed
40 | if self.rect in self.ax_fluorescence.patches:
41 | self.ax_fluorescence.patches.remove(self.rect)
42 | self.rect = patches.Rectangle((self.x0, self.y0), self.x1 - self.x0, self.y1 - self.y0, fill=None, linewidth=2, edgecolor='white') #create a rectangle, fill is set to None to make the rectangle transparent
43 | self.rect.set_zorder(10)
44 | self.ax_fluorescence.add_patch(self.rect)
45 | self.draw()
46 |
47 | def on_key_press(self, event):
48 | """it allows to reset the zoom of the fluorescence plot. It is called when the Command key and - key are pressed."""
49 | if event.key == 'r':
50 | self.ax_fluorescence.set_xlim(self.initial_xlim)
51 | self.ax_fluorescence.set_ylim(self.initial_ylim)
52 | self.fig.canvas.draw()
53 |
54 | def draw_point(self):
55 | if hasattr(self, 'point') and self.point in self.ax_fluorescence.collections:
56 | print(self.ax_fluorescence.collections)
57 | self.ax_fluorescence.collections.remove(self.point)
58 | self.point = self.ax_fluorescence.scatter([self.x0], [self.y0], s=5, color='w') # Create a new point
59 | self.draw()
60 |
61 | def on_press(self, event):
62 | """it allows to draw a point on the fluorescence plot. It is called when the mouse button is pressed."""
63 | self.x0 = event.xdata #x coordinate of the mouse cursor
64 | self.y0 = event.ydata #y coordinate of the mouse cursor
65 | self.draw_point()
66 |
67 | def on_release(self, event):
68 | """It allows to draw a rectangle on the fluorescence plot. It is called when the mouse button is released."""
69 | self.x1 = event.xdata
70 | self.y1 = event.ydata
71 | self.rect.set_width(self.x1 - self.x0)
72 | self.rect.set_height(self.y1 - self.y0)
73 | self.rect.set_xy((self.x0, self.y0)) #set the position of the rectangle
74 | self.draw_rectangle()
75 |
76 |
77 | def on_enter_pressed(self, event):
78 | """it allows to zoom in of the fluorescence plot. It is called when the Command key and + key are pressed."""
79 | if event.key == 'enter':
80 | self.ax_fluorescence.set_xlim(self.x0, self.x1)
81 | self.ax_fluorescence.set_ylim(self.y0, self.y1)
82 | self.rect.set_visible(False)
83 | self.point.set_visible(False)
84 | self.fig.canvas.draw()
85 |
86 |
87 | def display_all_f_t2p(self, selected_cell_index):
88 | """it plots the fluroescence of the selected cell across days where each curve being a different day (the curve at the top of the plot is the first day)"""
89 |
90 | if self.all_f_t2p is not None and selected_cell_index is not None:
91 | self.ax_fluorescence.clear()
92 | self.ax_fluorescence.set_facecolor('black')
93 | self.ax_fluorescence.tick_params(axis='x', colors='white')
94 | self.ax_fluorescence.tick_params(axis='y', colors='white')
95 | self.ax_fluorescence.xaxis.label.set_color('white')
96 | self.ax_fluorescence.yaxis.label.set_color('white')
97 | self.ax_fluorescence.spines['bottom'].set_color('#666')
98 |
99 |
100 | for i, fluorescence_data in list(enumerate(reversed(self.all_f_t2p))):
101 | #print(i)
102 | #print(fluorescence_data[selected_cell_index, :])
103 | fluorescence_zscore = zscore(fluorescence_data, axis=1, ddof=1) #zscore is used to normalize the fluorescence data
104 | offset = i * 12 #
105 | y_values = fluorescence_zscore[selected_cell_index, :] + offset
106 | #print(y_values)
107 | color = self.colors[selected_cell_index]
108 | #create a gradient of colors for the curves (the darkest shade is for the last day and the lightest shade is for the first day)
109 | if i == 0:
110 | color = color
111 | else:
112 | h, l, s = colorsys.rgb_to_hls(*color)
113 | l_range = 1 - (l + 0.05)
114 | l_add = l_range/len(self.all_f_t2p)
115 |
116 | adjusted_luminosity = l + (l_add *i)
117 | color = colorsys.hls_to_rgb(h, adjusted_luminosity, s)
118 |
119 | ops=self.all_ops[i]
120 | fs = ops['fs']
121 | tstamps = np.arange(len(y_values))
122 |
123 | if len(tstamps) != len(y_values):
124 | raise ValueError(f"tstamps and y_values must have the same length, but have lengths {len(tstamps)} and {len(y_values)}")
125 | self.ax_fluorescence.plot(tstamps,y_values, label=f'Curve {i + 1}', color= color)
126 |
127 | self.ax_fluorescence.set_xticks([0, int(len(tstamps)/2), len(tstamps)])
128 | self.ax_fluorescence.set_yticklabels([])
129 | self.ax_fluorescence.get_yaxis().set_visible(False)
130 | self.ax_fluorescence.set_xlabel('Frames')
131 | self.initial_xlim=self.ax_fluorescence.get_xlim()
132 | self.initial_ylim=self.ax_fluorescence.get_ylim()
133 | self.fig.tight_layout()
134 | self.draw()
135 |
136 |
137 | pass
138 |
139 |
--------------------------------------------------------------------------------
/track2p/gui/central_widget.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtCore import Qt
2 | from PyQt5.QtWidgets import QTabWidget, QVBoxLayout, QWidget, QSplitter, QHBoxLayout, QFrame, QFrame
3 | from track2p.gui.fluo_plot import FluorescencePlotWidget
4 | from track2p.gui.roi_plot import ZoomPlotWidget
5 | from track2p.gui.cell_plot import CellPlotWidget
6 | from track2p.gui.data_management import DataManagement
7 | from track2p.gui.raster_wd import RasterWindow
8 |
9 | class CentralWidget(QWidget):
10 | def __init__(self, main_window):
11 | super().__init__()
12 |
13 | self.main_window = main_window
14 | self.fluorescences_plotting = None
15 | self.rois_plotting = None
16 | self.selected_roi = None
17 | self.cell_plot = None
18 | self.track_ops_dict=None
19 | self.data_management = DataManagement(self)
20 | self.vector_curation_t2p = self.data_management.vector_curation_t2p
21 | self.init_central_widget()
22 |
23 | def init_central_widget(self):
24 | self.top = QFrame()
25 | self.top.setFrameShape(QFrame.StyledPanel)
26 | self.top_layout = QHBoxLayout(self.top)
27 |
28 | self.top_right = QFrame()
29 | self.top_right.setFrameShape(QFrame.StyledPanel)
30 | self.top_layout_right = QVBoxLayout(self.top_right)
31 |
32 | self.tabs = QTabWidget(self)
33 |
34 | self.splitter1 = QSplitter(Qt.Horizontal)
35 | self.splitter1.addWidget(self.tabs)
36 | self.splitter1.addWidget(self.top)
37 | self.splitter1.setSizes([100, 100])
38 |
39 | self.splitter2 = QSplitter(Qt.Horizontal)
40 | self.splitter2.addWidget(self.top_right)
41 |
42 | self.splitter3 = QSplitter(Qt.Vertical)
43 | self.splitter3.addWidget(self.splitter1)
44 | self.splitter3.addWidget(self.splitter2)
45 | self.splitter3.setSizes([100, 100])
46 |
47 | central_layout = QVBoxLayout()
48 | central_layout.addWidget(self.splitter3)
49 | self.setLayout(central_layout)
50 |
51 |
52 | def create_mean_img(self,channel):
53 | for i, (ops, stat_t2p) in enumerate(zip(self.data_management.all_ops, self.data_management.all_stat_t2p)):
54 | tab = QWidget()
55 | self.cell_plot = CellPlotWidget(tab, ops=ops, stat_t2p=stat_t2p, f_t2p=self.data_management.all_f_t2p[i],
56 | colors=self.data_management.colors, update_selection_callback=self.update_selection,
57 | all_f_t2p=self.data_management.all_f_t2p, all_ops=self.data_management.all_ops, channel=channel)
58 | layout = QVBoxLayout(tab)
59 | layout.addWidget(self.cell_plot)
60 | tab.setLayout(layout)
61 | self.tabs.addTab(tab, f"Day {i + 1}")
62 | self.cell_plot.cell_selected.connect(self.update_selection)
63 |
64 |
65 | def create_mean_img_from_curation(self):
66 | import_window = self.main_window.window_manager.import_window
67 | t2p_window = self.main_window.window_manager.t2p_window
68 |
69 | if import_window is not None and import_window.plane is not None:
70 | self.data_management.import_files(import_window.path_to_t2p, import_window.plane, import_window.trace_type, import_window.channel)
71 | elif t2p_window is not None and t2p_window.saved_directory is not None:
72 | self.data_management.import_files(t2p_window.saved_directory, t2p_window.dialog.plane, t2p_window.dialog.trace_type, t2p_window.dialog.channel)
73 | else:
74 | print("Both import_window and t2p_window are None or not properly initialized.")
75 |
76 | def clear(self):
77 | self.data_management.reset_attributes()
78 | if self.fluorescences_plotting:
79 | self.top_layout_right.removeWidget(self.fluorescences_plotting)
80 | self.fluorescences_plotting.deleteLater()
81 | self.fluorescences_plotting = None
82 | if self.rois_plotting:
83 | self.top_layout.removeWidget(self.rois_plotting)
84 | self.rois_plotting.deleteLater()
85 | self.rois_plotting = None
86 | for i in range(self.tabs.count()):
87 | self.tabs.removeTab(0)
88 |
89 |
90 | def update_selection(self, selected_cell_index):
91 | self.selected_roi = selected_cell_index
92 | self.main_window.status_bar.spin_box.setValue(selected_cell_index)
93 | self.main_window.status_bar.roi_state_value.setText(f"{self.vector_curation_t2p[selected_cell_index]}")
94 | #it removes the underline of the previsouly selected cell even if the tab is not visible (not the current tab)
95 | for i in range(self.tabs.count()):
96 | tab_widget = self.tabs.widget(i)
97 | cell_object = tab_widget.findChild(CellPlotWidget)
98 | cell_object.remove_previous_underline()
99 | current_tab_index = self.tabs.currentIndex()
100 | current_tab_widget = self.tabs.widget(current_tab_index)
101 | cell_plot = current_tab_widget.findChild(CellPlotWidget)
102 | if cell_plot:
103 | cell_plot.underline_cell(selected_cell_index)
104 | if self.fluorescences_plotting is None:
105 | self.fluorescences_plotting = FluorescencePlotWidget(all_f_t2p=self.data_management.all_f_t2p,
106 | all_ops=self.data_management.all_ops,
107 | colors=self.data_management.colors)
108 | self.top_layout_right.addWidget(self.fluorescences_plotting)
109 | if self.rois_plotting is None:
110 | self.rois_plotting = ZoomPlotWidget(all_ops=self.data_management.all_ops,
111 | all_stat_t2p=self.data_management.all_stat_t2p,
112 | colors=self.data_management.colors,
113 | all_iscell_t2p=self.data_management.all_iscell,
114 | t2p_match_mat_allday=self.data_management.t2p_match_mat_allday,track_ops=self.track_ops)
115 | self.top_layout.addWidget(self.rois_plotting)
116 |
117 | self.fluorescences_plotting.display_all_f_t2p(selected_cell_index)
118 | self.rois_plotting.display_zooms(selected_cell_index)
119 |
120 |
121 | def display_first_ROI(self,index):
122 | """it displays the first cell of the t2p_match_mat_allday and its fluorescence and zooms across days. It is called when the application is opened.
123 | An instance of FluorescencePlotWidget and an instance of ZoomPlotWidget are created and added to attributes of the MainWindow class. """
124 | tab_widget = self.tabs.widget(0)
125 | cell_object = tab_widget.findChild(CellPlotWidget) #It finds the instance of the CellPlotWidget class in the first tab of the QTabWidget
126 | cell_object.underline_cell(index)
127 | cell_object.draw()
128 | if self.fluorescences_plotting is None:
129 | self.fluorescences_plotting = FluorescencePlotWidget(all_f_t2p=self.data_management.all_f_t2p,
130 | all_ops=self.data_management.all_ops,
131 | colors=self.data_management.colors, all_stat_t2p=self.data_management.all_stat_t2p)
132 | self.top_layout_right.addWidget(self.fluorescences_plotting)
133 | if self.rois_plotting is None:
134 | self.rois_plotting = ZoomPlotWidget(all_ops=self.data_management.all_ops,
135 | all_stat_t2p=self.data_management.all_stat_t2p,
136 | colors=self.data_management.colors,
137 | all_iscell_t2p=self.data_management.all_iscell,
138 | t2p_match_mat_allday=self.data_management.t2p_match_mat_allday,track_ops=self.data_management.track_ops, imgs= self.cell_plot.all_img)
139 | self.top_layout.addWidget(self.rois_plotting)
140 | self.fluorescences_plotting.display_all_f_t2p(index)
141 | self.rois_plotting.display_zooms(index)
142 | self.main_window.status_bar.roi_state_value.setText(f"{self.vector_curation_t2p[self.main_window.status_bar.spin_box.value()]}") #
143 |
--------------------------------------------------------------------------------
/docs/troubleshooting.md:
--------------------------------------------------------------------------------
1 | # Troubleshooting Guide: Low Cell Counts and Tracking Failures in Track2p
2 |
3 | This guide provides practical steps for diagnosing and resolving common issues encountered when running Track2p, especially when tracking neurons across multiple days. It covers:
4 |
5 | - Low numbers of tracked cells
6 | - Potential tracking and registration failures
7 | - Recommended parameter adjustments
8 | - Handling potentially problematic sessions
9 |
10 | Whenever you encounter unexpected results, it is strongly recommended to inspect the **Track2p output figures** (see documentation: *Outputs & Visualisations*) and, if possible, view your raw or Suite2p-processed data side-by-side across days.
11 |
12 | In the ideal world the solution is to track a subset of cells manually (see the paper for one way of doing this), and use that as a 'validation set' to help you choose the right parameters for tracking. Generating the ground truth dataset will also give you an insight into how many cells you would be expecting to be successfully tracked across all days in your experimental setting.
13 |
14 | ---
15 |
16 | ## 1. Low Numbers of Tracked Cells Across Days
17 |
18 | Low cell counts across days can arise from issues in **cell detection** (relating to **cell activity**, **FOV consistency** etc.), or **tracking quality**. Track2p only reports cells that can be matched across *all* selected days, so variability in detection or registration can significantly reduce the final cell set.
19 |
20 | ### 1.1 Issues with cell detection ('segmentation')
21 |
22 | #### Important points
23 | - **Suite2p** only detects ROIs that are "active" on that session and considered as “good” (based on the outputs of a classifier)
24 | - **Track2p** only includes cells that appear on *every* selected day.
25 | If a cell is missing on any day (due to low activity or detection thresholds), it will not be included in the tracked set.
26 |
27 | #### Common causes
28 | 1. **Different active cells across days**
29 | Some neurons may be active only on some days, making them undetectable in others.
30 |
31 | 2. **Conservative Suite2p parameters**
32 | High thresholds can reduce the number of detected ROIs, especially those with low activity or considered below the threshold of the classifier.
33 |
34 | 3. **Changes or drift in the recording**
35 | These can cause ROI detection inconsistencies:
36 | - z-shift
37 | - moving blood vessels
38 | - debris or accumulated tissue along FOV edges
39 | - subtle optical changes across sessions
40 |
41 | ### 1.2 Recommended Solutions
42 |
43 | #### 1. Inspect Track2p output figures
44 | See the *Outputs & Visualisations* documentation.
45 | These plots help identify whether failures are due to cell detection or during registration, matching, or thresholding.
46 |
47 | #### 2. Inspect Suite2p outputs across days
48 | Open two or more Suite2p GUIs side-by-side and check:
49 | - Are the same ROIs detected across days?
50 | - Do obvious cells appear on one day but not another?
51 |
52 | (Optional but ideal): manually track a small set of neurons to generate “ground truth” and verify Track2p performance.
53 |
54 | #### 3. Adjust Suite2p parameters to detect more cells
55 | If detection seems too strict, reduce conservativeness by tuning Suite2p parameters (see Suite2p documentation).
56 | This should increases cell count but also probably at the expense of increasing the number of false positives.
57 | False ROIs can be removed manually before or after Track2p processing (using either the Suite2p or Track2p curation capabilities).
58 |
59 | #### 4. Include more borderline ROIs in Track2p
60 | Lowering the Track2p parameter `iscell_thr` allows inclusion of ROIs rejected by the Suite2p classifier.
61 |
62 | - Suite2p default: **0.5**
63 | - Recommended more permissive option: **0.25**
64 |
65 | This would have a similar effect as above, so also make sure to remove any additional false positives that might occur as a result of this.
66 |
67 | ---
68 |
69 | ## 2. Issues with cell tracking ('linking')
70 |
71 | Track2p’s tracking pipeline involves three stages:
72 |
73 | 1. **Registration** of consecutive FOVs
74 | 2. **Matching** ROIs based on spatial overlap
75 | 3. **Automatic thresholding** of IoU values to filter reliable matches
76 |
77 | A problem in any stage can reduce the number of tracked cells.
78 |
79 | ### 2.1 Common Causes of Tracking Failure
80 |
81 | #### 1. Registration failure
82 | Check `reg_img_output.png`.
83 | The bottom row overlays red/green images across sessions. They should align well.
84 | If they do not, all subsequent steps will fail!
85 |
86 | #### 2. Poor matching quality
87 | If registration is poor or ROIs do not overlap sufficiently, spatial matching breaks down.
88 | Matching can also have issues even if the registration is successful, but this would only expected to be the case for samples with extremely dense somata and thick optical sectioning (leading to 'overlapping' ROIs).
89 |
90 | #### 3. Thresholding failure
91 | If registration/matching fail, the IoU histogram (`thr_met_hist.png`) may:
92 | - be unimodal, for example looking like it is decaying 'exponentially' from 0
93 | - lack a visible separation between matched/unmatched populations (the two usual modes)
94 |
95 | This causes automatic thresholds to be placed arbitrarily, leading to most matches being discarded or accepted, without any statistical justification.
96 |
97 | ### 2.2 Recommended Solutions
98 |
99 | #### 1. Examine diagnostic figures
100 | - **Registration** → `reg_img_output.png`
101 | - **Matching & thresholding** → `thr_met_hist.png`
102 |
103 | These files usually reveal the underlying issue.
104 |
105 | Registration: check the last row of `reg_img_output.png` should show good overlap between red and green images, if not it means the registration did not work properly. It could also be the case that it only works in a part of the field of view and not homogenously! If you see the algorithm has failed make sure to check the FOV images and see if they resemble each other (e. g. could they be aligned manually).
106 |
107 | Matching & thresholding: check that `thr_met_hist.png` shows a bimodal histogram for all pairs of consecutive sessions. Also check that the statistical threshold that is marked as a vertical line 'makes sense' statistically (e. g. it should separate the two assumed distributions corresponding to the two modes).
108 |
109 | #### 2. Try alternate registration methods
110 | Track2p supports:
111 | - `'affine'` (default)
112 | - `'rigid'`
113 |
114 | Rigid has fewer degrees of freedom so it is easier to find a good alignment in theory. However it can not account for expansion or shearing across sessions, so in the case of early development `'affine'` would be prefered (based on results from the eLife paper, might depend on the particular preparation). We have not tested this in adults, but it might be that `'rigid'` would work better - we encourage users to try both :)
115 |
116 | #### 3. Switch thresholding method
117 | If your IoU histograms are bimodal, but you think that the threshold does not appropriately separate the two distributions, you can try a different threshold, currently the options are:
118 | - `'otsu'`
119 | - `'min'`
120 |
121 | Depending on histogram shape, one or the other may be more appropriate.
122 |
123 | #### 4. Remove problematic recording days
124 | If a single session is very different (e.g., z-shift, optical debris, abnormal brightness), it may disrupt registration between the surrounding sessions.
125 | Removing the problematic day often restores normal tracking across the remaining days, but it of course leads to missing values, which might be an issue in downsream analysis.
126 |
127 | #### 5. Consider shorter intervals
128 | For some scientific questions it might be sufficient to only track across two (or a few) neighbouring days. If this is the case, the user can also perform shorter tracks that are expected to yield more accurately tracked neurons.
129 |
130 | ---
131 |
132 | ## 4. Summary of Recommended Workflow for Troubleshooting
133 |
134 | 1. **Inspect Track2p output figures** to locate the stage causing failure.
135 | 2. **Check Suite2p results side-by-side** across days for ROI consistency.
136 | 3. **Adjust Suite2p thresholds** or Track2p’s `iscell_thr` to increase detected ROI overlap across days.
137 | 4. **Try different registration methods** and thresholding strategies.
138 | 5. **Identify and remove problematic sessions** (e.g., z-shift or debris).
139 | 6. **Prefer shorter, overlapping blocks** if long-span tracking is inconsistent and if it does not interfere with your scientific goals :)
140 |
141 | This approach usually resolves the most common problems involving low tracked cell counts and tracking failures.
142 |
--------------------------------------------------------------------------------
/track2p/gui/data_management.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | import matplotlib.colors as mcolors
5 | import random
6 | from types import SimpleNamespace
7 | from scipy.ndimage import maximum_filter1d, minimum_filter1d, gaussian_filter
8 |
9 | class DataManagement:
10 | def __init__(self, central_widget):
11 | self.central_widget = central_widget
12 | self.main_window = central_widget.main_window
13 | self.all_f_t2p = []
14 | self.all_ops = []
15 | self.all_stat_t2p = []
16 | self.all_iscell = []
17 | self.all_fneu= []
18 | self.colors = None
19 | self.t2p_match_mat_allday = None
20 | self.track_ops = None
21 | self.vector_curation_t2p = None
22 | self.curation_npy = None
23 | self.colors_copy=None
24 | self.plane=None
25 | self.trace_type=None
26 |
27 | def reset_attributes(self):
28 | self.all_f_t2p = []
29 | self.all_ops = []
30 | self.all_stat_t2p = []
31 | self.all_iscell = []
32 | self.all_fneu= []
33 | self.colors = None
34 | self.t2p_match_mat_allday = None
35 | self.track_ops = None
36 | self.vector_curation_t2p = None
37 |
38 | def import_files(self, t2p_folder_path, plane, trace_type, channel):
39 |
40 | self.plane=plane
41 | self.trace_type=trace_type
42 |
43 | if self.central_widget.fluorescences_plotting is not None:
44 | self.central_widget.clear()
45 | # load track2p outputs
46 | t2p_match_mat = np.load(os.path.join(t2p_folder_path,"track2p" ,f"plane{plane}_match_mat.npy"), allow_pickle=True)
47 | self.t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :] # remove rows with None values
48 | track_ops_dict = np.load(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), allow_pickle=True).item()
49 | track_ops = SimpleNamespace(**track_ops_dict)
50 |
51 |
52 | # process suite2p files
53 | for (i, ds_path) in enumerate(track_ops.all_ds_path):
54 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'ops.npy'), allow_pickle=True).item()
55 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'stat.npy'), allow_pickle=True)
56 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy'), allow_pickle=True)
57 | if trace_type == 'F' :
58 | print('F trace')
59 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True)
60 | if trace_type == 'spks':
61 | print('spks trace')
62 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'spks.npy'), allow_pickle=True)
63 | if trace_type == 'dF/F0':
64 | print('dF/F0 trace')
65 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True)
66 | fneu = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'Fneu.npy'), allow_pickle=True)
67 | if track_ops.iscell_thr is None:
68 | fneu_iscell = fneu[iscell[:, 0] == 1, :]
69 | else:
70 | fneu_iscell = fneu[iscell[:, 1] > track_ops.iscell_thr, :]
71 | fneu_t2p= fneu_iscell[self.t2p_match_mat_allday[:, i].astype(int), :]
72 | self.all_fneu.append(fneu_t2p)
73 | if track_ops.iscell_thr is None:
74 | stat_iscell = stat[iscell[:, 0] == 1]
75 | f_iscell = f[iscell[:, 0] == 1, :]
76 | else:
77 | stat_iscell = stat[iscell[:, 1] > track_ops.iscell_thr]
78 | f_iscell = f[iscell[:, 1] > track_ops.iscell_thr, :]
79 |
80 |
81 | stat_t2p = stat_iscell[self.t2p_match_mat_allday[:, i].astype(int)]
82 | f_t2p = f_iscell[self.t2p_match_mat_allday[:, i].astype(int), :]
83 | self.all_stat_t2p.append(stat_t2p)
84 | self.all_f_t2p.append(f_t2p)
85 | self.all_ops.append(ops)
86 | self.all_iscell.append(iscell)
87 |
88 | if trace_type == 'dF/F0':
89 | for i in range(len(self.all_f_t2p)):
90 | f_prc=self.F_processing(F=self.all_f_t2p[i], Fneu=self.all_fneu[i], fs=self.all_ops[i]['fs'])
91 | self.all_f_t2p[i]=f_prc
92 |
93 |
94 | attr_name = 'vector_curation_plane_' + str(plane)
95 | if hasattr(track_ops, attr_name):
96 | key = 'vector_curation_plane_' + str(plane)
97 | self.vector_curation_t2p=track_ops_dict[key]
98 | else:
99 | vector_curation_keys=np.arange(self.t2p_match_mat_allday.shape[0])
100 | vector_curation_values = np.ones_like(vector_curation_keys)
101 | self.vector_curation_t2p_dict = dict(zip(vector_curation_keys, vector_curation_values))
102 | values = list(self.vector_curation_t2p_dict.values())
103 | self.vector_curation_t2p = np.array(values)
104 | key = 'vector_curation_plane_' + str(plane)
105 | track_ops_dict[key] = self.vector_curation_t2p
106 | np.save(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), track_ops_dict)
107 |
108 |
109 | attr_name_color = 'colors_plane_' + str(plane)
110 | if hasattr(track_ops, attr_name_color):
111 | self.colors= track_ops_dict[attr_name_color]
112 | else:
113 | self.colors=self.generate_vibrant_colors(len(self.all_stat_t2p[0]))
114 | track_ops_dict[attr_name_color] = self.colors
115 | np.save(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), track_ops_dict)
116 |
117 |
118 | self.track_ops = track_ops
119 | track_ops_dict = np.load(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), allow_pickle=True).item()
120 | self.central_widget.track_ops_dict=track_ops_dict
121 | self.main_window.status_bar.vector_curation_t2p = self.vector_curation_t2p
122 |
123 | self.main_window.status_bar.vector_curation_t2p = self.vector_curation_t2p
124 |
125 | self.main_window.status_bar.spin_box.setSuffix(f'/{len(self.t2p_match_mat_allday)-1}')
126 | self.main_window.status_bar.spin_box.setMinimum(0)
127 | self.main_window.status_bar.spin_box.setMaximum(len(self.t2p_match_mat_allday)-1)
128 |
129 | num_ones = {}
130 |
131 | for cell, line in enumerate(t2p_match_mat):
132 | all_iscell_value=[]
133 | for day,index_match in enumerate(line):
134 | if track_ops.iscell_thr is None:
135 | if index_match is None:
136 | all_iscell_value.append(0)
137 | else:
138 | iscell=self.all_iscell[day]
139 | indices_lignes_1 = np.where(iscell[:,0]==1)[0] # take the indices where the ROIs were considered as cells in suite2p
140 | true_index=indices_lignes_1[index_match] # take the "true index"
141 | iscell_value=iscell[true_index,0]
142 | all_iscell_value.append(iscell_value)
143 | else:
144 | if index_match is None:
145 | all_iscell_value.append(0)
146 | else:
147 | iscell=self.all_iscell[day]
148 | indices_lignes_1= np.where(iscell[:,1]>track_ops.iscell_thr)[0] # take the indices where the ROIs have a probability greater than trackops.is_cell_thr
149 | true_index=indices_lignes_1[index_match] # take the "true index"
150 | iscell_value=iscell[true_index,0]
151 | all_iscell_value.append(iscell_value)
152 | num_ones[cell] = all_iscell_value.count(1)
153 |
154 |
155 | for day in range(len(self.t2p_match_mat_allday[1])):
156 | count_cells_day=0
157 | for value in num_ones.values():
158 | if value == day:
159 | count_cells_day+=1
160 | print(f'Number of cells present over {day} day(s): {count_cells_day}')
161 | print(t2p_match_mat.shape)
162 |
163 | for i, line in enumerate(self.t2p_match_mat_allday):
164 | if self.vector_curation_t2p[i]==0:
165 | self.colors[i] =(0.78, 0.78, 0.78)
166 | print(f'ROI {i} has been considered as "not cell"')
167 |
168 |
169 | self.central_widget.create_mean_img(channel)
170 | self.central_widget.vector_curation_t2p = self.vector_curation_t2p
171 | self.central_widget.display_first_ROI(0)
172 |
173 |
174 |
175 | def generate_vibrant_colors(self, num_colors):
176 | vibrant_colors = []
177 | for _ in range(num_colors):
178 | l = np.random.uniform(0.55, 0.80)
179 | color = mcolors.hsv_to_rgb((random.random(), 1, l))
180 | vibrant_colors.append(color)
181 |
182 | return vibrant_colors
183 |
184 |
185 | def F_processing(self,F, Fneu, fs, neucoeff=0.0, baseline='maximin', sig_baseline=10.0, win_baseline=60.0, prctile_baseline: float = 8):
186 |
187 | print(F.shape)
188 | print(Fneu.shape)
189 | #neuropil substraction
190 | Fc = F - neucoeff * Fneu
191 |
192 | # baseline operation
193 | win = int(win_baseline * fs)
194 | if baseline == "maximin":
195 | Flow = gaussian_filter(Fc, [0., sig_baseline])
196 | Flow = minimum_filter1d(Flow, win)
197 | Flow = maximum_filter1d(Flow, win)
198 | elif baseline == "constant":
199 | Flow = gaussian_filter(Fc, [0., sig_baseline])
200 | Flow = np.amin(Flow)
201 | elif baseline == "constant_prctile":
202 | Flow = np.percentile(Fc, prctile_baseline, axis=1)
203 | Flow = np.expand_dims(Flow, axis=1)
204 | else:
205 | Flow = 0.
206 |
207 | F = Fc - Flow
208 |
209 | print('DONE')
210 |
211 | return F
212 |
--------------------------------------------------------------------------------
/track2p/gui/cell_plot.py:
--------------------------------------------------------------------------------
1 | import time
2 | from qtpy.QtCore import Signal
3 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import skimage
7 |
8 | class CellPlotWidget(FigureCanvas):
9 | '''This class is used to view and interact with the mean image of each recording (day)'''
10 | cell_selected = Signal(int)
11 |
12 | def __init__(self, tab=None, ops=None, stat_t2p=None, f_t2p=None, colors=None, update_selection_callback=None,
13 | all_f_t2p=None, all_stat_t2p=None, all_ops=None, initial_colors=None, channel=None):
14 | """It initializes the class attributes and connects certain events to their respective handlers. It also creates the figure and the axes to display the mean image of the recording."""
15 | self.fig, self.ax_image = plt.subplots(1, 1)
16 | self.fig.set_facecolor('black')
17 | super().__init__(self.fig)
18 | self.ops = ops
19 | self.stat_t2p = stat_t2p
20 | self.f_t2p = f_t2p
21 | self.all_fluorescence = all_f_t2p
22 | self.all_stat_t2p=all_stat_t2p
23 | self.all_ops=all_ops
24 | self.colors = colors
25 | self.initial_colors=initial_colors
26 | self.channel = channel
27 | self.all_img, self.img= self.load_all_imgs()
28 | self.selected_cell_index = None
29 | self.mpl_connect('button_press_event', self.on_mouse_press)
30 | self.update_selection_callback = update_selection_callback
31 | self.nb_cells= len(self.stat_t2p)
32 | self.plot_cells()
33 | self.initialize_interactions()
34 |
35 | def load_all_imgs(self):
36 | print('loading all images')
37 | print('channel img:', self.channel)
38 | all_img = []
39 | img= None
40 | if self.channel == 'max_proj':
41 | print('max_proj')
42 | img= self.ops['max_proj']
43 | Ly = self.ops['Ly'] # number of pixels in y (512)
44 | Lx = self.ops['Lx'] # number of pixels in x (512)
45 | yr = self.ops['yrange'] # first and last pixel of crop in y
46 | xr = self.ops['xrange'] # first and last pixel of crop in x
47 | range_img = np.zeros((Ly, Lx))
48 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img
49 | img= range_img
50 | for ops in self.all_ops:
51 | img_ops= ops['max_proj']
52 | Ly = ops['Ly']
53 | Lx = ops['Lx']
54 | yr = ops['yrange']
55 | xr = ops['xrange']
56 | range_img = np.zeros((Ly, Lx))
57 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img_ops
58 | img_ops= range_img
59 | all_img.append(img_ops)
60 | if self.channel == 'Vcorr':
61 | print('Vcorr')
62 | img= self.ops['max_proj']
63 | Ly = self.ops['Ly'] # number of pixels in y (512)
64 | Lx = self.ops['Lx'] # number of pixels in x (512)
65 | yr = self.ops['yrange'] # first and last pixel of crop in y
66 | xr = self.ops['xrange'] # first and last pixel of crop in x
67 | range_img = np.zeros((Ly, Lx))
68 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img
69 | img= range_img
70 | for ops in self.all_ops:
71 | img_ops= ops['max_proj']
72 | Ly = ops['Ly']
73 | Lx = ops['Lx']
74 | yr = ops['yrange']
75 | xr = ops['xrange']
76 | range_img = np.zeros((Ly, Lx))
77 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img_ops
78 | img_ops= range_img
79 | all_img.append(img_ops)
80 | if self.channel == '0':
81 | #print('0')
82 | all_img = [ops['meanImg'] for ops in self.all_ops]
83 | img= self.ops['meanImg']
84 | if self.channel == '1':
85 | #print('1')
86 | all_img = [ops['meanImg_chan2'] for ops in self.all_ops]
87 | img= self.ops['meanImg_chan2']
88 |
89 | return all_img, img
90 |
91 | def plot_cells(self):
92 | """It plots the mean image of the recording and the contours of the cells. It also sets the axis to be invisible and the title of the plot. It uses the colors attribute to color the contours of the cells.
93 | the match_histograms function of the skimage library is used to match the histograms of the mean image of the recording and the last mean image of the recordings. . This is done to make the mean images of the recordings more comparable."""
94 | self.ax_image.clear()
95 | start = time.time()
96 | match_mean_img=skimage.exposure.match_histograms(self.img,self.all_img[-1], channel_axis=None)
97 | self.ax_image.imshow(match_mean_img, cmap='gray')
98 | cell_count = 0
99 | for cell in range(self.nb_cells):
100 | # print(cell)
101 | bin_mask = np.zeros_like(self.img) #create a binary mask with the same shape as the mean image of the recording
102 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1
103 | color_cell=self.colors[cell]
104 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1)
105 | cell_count += 1
106 | self.ax_image.axis('off')
107 | print(f'time for plotting cells on mean image for recording : {time.time()-start}')
108 | print(f'Total cells plotted: {cell_count}')
109 | self.draw()
110 |
111 | def plot_cells_remix(self,keys):
112 | self.ax_image.clear()
113 | start = time.time()
114 | match_mean_img=skimage.exposure.match_histograms(self.img,self.all_img[-1], channel_axis=None)
115 | self.ax_image.imshow(match_mean_img, cmap='gray')
116 | cell_count = 0
117 | for cell in range(self.nb_cells):
118 | if cell in keys:
119 | continue
120 | bin_mask = np.zeros_like(self.img) #create a binary mask with the same shape as the mean image of the recording
121 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1
122 | color_cell=self.colors[cell]
123 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1)
124 | cell_count += 1
125 | self.ax_image.axis('off')
126 | print(f'time for plotting cells on mean image for recording : {time.time()-start}')
127 | print(f'Total cells plotted: {cell_count}')
128 | self.draw()
129 |
130 |
131 |
132 | def underline_cell_remix(self,colors):
133 |
134 | for cell in range(self.nb_cells):
135 | bin_mask = np.zeros_like(self.img)
136 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1
137 | color_cell=colors[cell]
138 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1)
139 | self.draw()
140 |
141 |
142 |
143 | def underline_cell(self,selected_cell_index):
144 | """It underlines the selected cell by increasing the linewidth of the contour of the cell"""
145 | for cell in range(self.nb_cells):
146 | if cell == selected_cell_index:
147 | bin_mask = np.zeros_like(self.img)
148 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1
149 | color_cell=self.colors[cell]
150 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=3)
151 | self.draw()
152 |
153 | def remove_previous_underline(self):
154 | """It removes the underline of the previously selected cell by decreasing the linewidth of the contour of the cell."""
155 | for collection in self.ax_image.collections:
156 | collection.set_linewidth(1)
157 | # important, don't forget it ! (update the plot)
158 |
159 | def initialize_interactions(self):
160 | """This method is used to initialize user interactions with the mean image. It connects the scroll event to the on_scroll method and records the initial xlim and ylim of the mean image."""
161 | self.cid_scroll = self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
162 | self.initial_xlim = self.ax_image.get_xlim()
163 | self.initial_ylim = self.ax_image.get_ylim()
164 |
165 | def on_scroll(self, event):
166 | """it allows to zoom in and out of the mean image of the recording. It is called when the mouse wheel is scrolled. It uses the base_scale to zoom in and out of the mean image of the recording. """
167 | if event.inaxes == self.ax_image:
168 | current_xlim = self.ax_image.get_xlim()
169 | current_ylim = self.ax_image.get_ylim()
170 | base_scale = 0.9
171 | scale_factor = base_scale if event.button == 'up' else 1/base_scale
172 | # Get the coordinates of the mouse cursor in data coordinates
173 | x_data, y_data = event.xdata, event.ydata
174 | # Compute the new limits centered on the mouse cursor
175 | new_xlim = [x_data - (x_data - x) * scale_factor for x in current_xlim]
176 | new_ylim = [y_data - (y_data - y) * scale_factor for y in current_ylim]
177 | # Ensure the zoom doesn't exceed the initial image boundaries
178 | new_xlim = [max(self.initial_xlim[0], min(self.initial_xlim[1], x)) for x in new_xlim]
179 | new_ylim = [max(self.initial_ylim[1], min(self.initial_ylim[0], y)) for y in new_ylim]
180 |
181 | self.ax_image.set_xlim(new_xlim)
182 | self.ax_image.set_ylim(new_ylim)
183 | self.fig.canvas.draw_idle()
184 |
185 | def on_mouse_press(self, event):
186 | """It allows to select a cell by clicking on it. It is called when the mouse is clicked. It uses the x and y coordinates of the mouse cursor to determine if a cell is clicked. If a cell is clicked, the selected_cell_index attribute is updated and the cell_selected signal is emitted.
187 | The update_selection method of the MainWindow class is called and used to update the fluorescence and zoom plots with the selected cell. """
188 | start = time.time()
189 | if event.inaxes == self.ax_image:
190 | x, y = event.xdata, event.ydata
191 | for j, cell_info in enumerate(self.stat_t2p):
192 | ypix = cell_info['ypix'] #ypix are the y coordinates of the pixels of the cell
193 | xpix = cell_info['xpix'] #xpix are the x coordinates of the pixels of the cell
194 | if np.any((xpix == int(x)) & (ypix == int(y))):
195 | self.selected_cell_index = j
196 | self.update_selection_callback(j)
197 | print(f"Cell selected: {j}", flush=True)
198 | break
199 | print(f'time taken for updating: {time.time()-start}')
200 |
201 |
202 |
--------------------------------------------------------------------------------
/track2p/gui/t2p_wd.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QHBoxLayout, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QListWidget, QMessageBox,QListWidgetItem, QInputDialog,QCheckBox,QSizePolicy,QComboBox,QDialog
4 | from PyQt5.QtCore import Qt
5 | from track2p.t2p import run_t2p
6 | from track2p.ops.default import DefaultTrackOps
7 | from track2p.gui.custom_wd import CustomDialog
8 | class Track2pWindow(QWidget):
9 | """it is used to set the parameters of the track2p algorithm"""
10 | def __init__(self, main_wd):
11 | super(Track2pWindow,self).__init__()
12 | self.main_window = main_wd
13 | layout = QFormLayout()
14 | self.setLayout(layout)
15 | self.track_ops = DefaultTrackOps()
16 | self.saved_directory=None
17 | self.plane=None
18 |
19 |
20 | instruction1=QLabel("Import the directory containing subfolders for each session of a given subject:")
21 | self.import_recording_button = QPushButton("Import", self)
22 | self.import_recording_button.clicked.connect(self.import_path_to_recordings)
23 | layout.addRow(instruction1,self.import_recording_button)
24 |
25 | instruction2= QLabel("Imported path:")
26 | instruction2.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
27 | self.path_recording=QLabel()
28 | self.path_recording.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
29 | layout.addRow(instruction2,self.path_recording)
30 |
31 |
32 | instruction3=QLabel("Once loaded press '->' to add to the list of paths to use for track2p (in the correct order):")
33 |
34 | # TODO: Here add da driopdown menu to select the input format (e.g. suite2p, raw npy, etc.)
35 | instruction_format=QLabel("Input format:")
36 | self.format=QComboBox()
37 | self.format.addItem("suite2p")
38 | self.format.addItem("npy")
39 | self.format.setCurrentIndex(0)
40 | self.format.currentIndexChanged.connect(self.display_suite2p_options)
41 |
42 |
43 | instruction4= QLabel("Method for selecting suite2p ROIs:")
44 | field_checkbox= QVBoxLayout()
45 | self.checkbox1 = QCheckBox('manually curated', self)
46 | self.checkbox2 = QCheckBox('iscell threshold', self)
47 | self.checkbox2.stateChanged.connect(self.display_iscell)
48 | field_checkbox.addWidget(self.checkbox1)
49 | field_checkbox.addWidget(self.checkbox2)
50 |
51 |
52 | self.is_cell_thr= QLineEdit()
53 | self.is_cell_thr.setVisible(False)
54 | self.is_cell_thr.setText('0.5')
55 | self.is_cell_thr.setFixedWidth(50)
56 |
57 |
58 | file_layout = QHBoxLayout()
59 |
60 | self.computer_file_list = QListWidget(self)
61 | self.computer_file_list.setFixedHeight(200)
62 | self.move_to_computer_list = QPushButton("<-", self)
63 | self.move_to_computer_list.clicked.connect(self.move_file_to_computer_list)
64 | self.move_to_paths_list = QPushButton("->", self)
65 | self.move_to_paths_list.clicked.connect(self.move_file_to_paths_list)
66 | self.paths_list=QListWidget(self)
67 | self.paths_list.setFixedHeight(200)
68 |
69 | file_layout.addWidget(self.computer_file_list)
70 | file_layout.addWidget(self.move_to_computer_list)
71 | file_layout.addWidget(self.move_to_paths_list)
72 | file_layout.addWidget(self.paths_list)
73 | layout.addRow(instruction3, file_layout)
74 | layout.addRow(instruction_format,self.format)
75 | layout.addRow(instruction4,field_checkbox)
76 | layout.addRow("suite2p iscell threshold:",self.is_cell_thr)
77 |
78 |
79 | instruction5=QLabel("Channel to use for registration (0 : functional, 1 : anatomical (if available))")
80 | self.reg_chan= QLineEdit()
81 | self.reg_chan.setFixedWidth(50)
82 | self.reg_chan.setText('0')
83 | layout.addRow(instruction5,self.reg_chan)
84 |
85 | trsfrm_type=QLabel("Choose the type of transformation to use for registration:")
86 | self.trsfrm_type=QComboBox()
87 | self.trsfrm_type.addItem("affine")
88 | self.trsfrm_type.addItem("rigid")
89 | self.trsfrm_type.setCurrentIndex(0)
90 | layout.addRow(trsfrm_type,self.trsfrm_type)
91 |
92 | # compute_iou=QLabel("iou_dist_thr:")
93 | # self.compute_iou= QLineEdit()
94 | # self.compute_iou.setFixedWidth(50)
95 | # self.compute_iou.setText('16')
96 | # layout.addRow(compute_iou,self.compute_iou)
97 |
98 | thr_method=QLabel("Thresholding method for filtering IoU histogram:")
99 | self.thr_method=QComboBox()
100 | self.thr_method.addItem("min")
101 | self.thr_method.addItem("otsu")
102 | self.thr_method.setCurrentIndex(1)
103 | layout.addRow(thr_method,self.thr_method)
104 |
105 | instruction6=QLabel("Import the directory where outputs will be saved (a 'track2p' sub-folder will be created):")
106 | self.t2p_path_button = QPushButton("Import", self)
107 | self.t2p_path_button.clicked.connect(self.save_directoy)
108 | layout.addRow(instruction6,self.t2p_path_button)
109 |
110 | instruction7= QLabel("Path to access the output 'track2p' folder:")
111 | instruction7.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
112 | self.save_path=QLabel()
113 | self.save_path.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
114 | layout.addRow(instruction7,self.save_path)
115 |
116 | instruction8= QLabel("Save the outputs in suite2p format (containing cells tracked on all days):")
117 | self.checkbox3 = QCheckBox(self)
118 | layout.addRow(instruction8,self.checkbox3)
119 |
120 |
121 | self.run_button = QPushButton("Run", self)
122 | self.run_button.clicked.connect(self.run)
123 | layout.addRow("Run the algorithm:", self.run_button)
124 |
125 | terminal_intruction=QLabel("To monitor progress see outputs in the terminal where the GUI was launched from.")
126 | layout.addRow(terminal_intruction)
127 |
128 |
129 | def display_iscell(self,state):
130 | if state == Qt.Checked:
131 | self.is_cell_thr.setVisible(True)
132 | else:
133 | self.is_cell_thr.setVisible(False)
134 |
135 | def display_suite2p_options(self):
136 | if self.format.currentText() == "suite2p":
137 | self.checkbox1.setVisible(True)
138 | self.checkbox2.setVisible(True)
139 | self.reg_chan.setVisible(True)
140 | self.is_cell_thr.setVisible(self.checkbox2.isChecked())
141 | else:
142 | self.checkbox1.setVisible(False)
143 | self.checkbox2.setVisible(False)
144 | self.reg_chan.setVisible(False)
145 | self.is_cell_thr.setVisible(False)
146 |
147 | def run(self):
148 |
149 | stored_all_ds_path = []
150 | for i in range(self.paths_list.count()):
151 | item=self.paths_list.item(i).data(Qt.UserRole)
152 | item_universel=item.replace("\\", "/")
153 | stored_all_ds_path.append(item_universel)
154 | self.track_ops.all_ds_path= stored_all_ds_path
155 | save_path=self.saved_directory
156 | save_path=save_path.replace("\\", "/")
157 | self.track_ops.save_path = save_path
158 | self.track_ops.input_format = self.format.currentText()
159 | self.track_ops.reg_chan=int(self.reg_chan.text())
160 | self.track_ops.transform_type=self.trsfrm_type.currentText()
161 | # self.track_ops.iou_dist_thr=int(self.compute_iou.text())
162 | self.track_ops.thr_method=self.thr_method.currentText()
163 | print("transformation type:", self.track_ops.transform_type)
164 | print("iou_dist_thr:", self.track_ops.iou_dist_thr)
165 | print("thr_method:", self.track_ops.thr_method)
166 | if self.checkbox1.isChecked():
167 | self.track_ops.iscell_thr=None
168 | if self.checkbox2.isChecked():
169 | self.track_ops.iscell_thr=float(self.is_cell_thr.text())
170 | if self.checkbox3.isChecked():
171 | self.track_ops.save_in_s2p_format=True
172 | print("All parameters have been recorded ! The track2p algorithm is running...")
173 | run_t2p(self.track_ops)
174 | self.open_track2p_in_gui()
175 |
176 |
177 | def open_track2p_in_gui(self):
178 | reply = QMessageBox.question(self, "", "Run completed successfully!\nDo you want to launch the gui?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
179 |
180 | if reply == QMessageBox.Yes:
181 | #print("Opening GUI...")
182 | self.dialog = CustomDialog(self.main_window, self.saved_directory, self.reg_chan.text())
183 | self.dialog.exec_()
184 | if reply == QMessageBox.No:
185 | pass
186 |
187 | def save_directoy(self):
188 | saved_directory= QFileDialog.getExistingDirectory(self, "Select Directory")
189 | if saved_directory:
190 | self.saved_directory=saved_directory
191 | self.save_path.setText(f'{self.saved_directory}')
192 |
193 | def import_path_to_recordings(self):
194 | directory = QFileDialog.getExistingDirectory(self, "Select Directory")
195 | self.saved_directory=directory
196 | self.save_path.setText(f'{self.saved_directory}')
197 |
198 | if directory:
199 | self.path_recording.setText(f'{directory}')
200 | self.computer_file_list.clear()
201 | files= os.listdir(directory)
202 | for file in sorted(files):
203 | full_path = os.path.join(directory, file)
204 | item=QListWidgetItem(file)
205 | item.setData(Qt.UserRole, full_path)
206 | self.computer_file_list.addItem(item)
207 |
208 |
209 | def move_file_to_paths_list(self):
210 | selected_items = self.computer_file_list.selectedItems()
211 | for item in selected_items:
212 | self.computer_file_list.takeItem(self.computer_file_list.row(item))
213 | self.paths_list.addItem(item)
214 |
215 |
216 | def move_file_to_computer_list(self):
217 | selected_items = self.paths_list.selectedItems()
218 | for item in selected_items:
219 | self.paths_list.takeItem(self.paths_list.row(item))
220 | self.computer_file_list.addItem(item)
221 |
222 |
223 |
224 |
--------------------------------------------------------------------------------
/track2p/t2p.py:
--------------------------------------------------------------------------------
1 | from track2p.ops.default import DefaultTrackOps
2 | from types import SimpleNamespace
3 |
4 | from track2p.io.s2p_loaders import load_all_imgs, check_nplanes, load_all_ds_stat_iscell, load_all_ds_mean_img, load_all_ds_centroids
5 | from track2p.io.savers import npy_to_s2p, save_track_ops, save_all_pl_match_mat
6 |
7 | from track2p.register.loop import run_reg_loop, reg_all_ds_all_roi
8 | from track2p.register.utils import get_all_ds_img_for_reg, get_all_ref_nonref_inters
9 |
10 | from track2p.plot.progress import plot_all_planes
11 | from track2p.plot.output import plot_reg_img_output, plot_thr_met_hist, plot_n_matched_roi, plot_roi_reg_output, plot_roi_match_multiplane, plot_allroi_match_multiplane
12 |
13 | from track2p.match.loop import get_all_ds_assign, get_all_pl_match_mat
14 | import numpy as np
15 | import os
16 | import scipy as spicy
17 | import pandas as pd
18 |
19 |
20 | def run_t2p(track_ops):
21 |
22 | # 1) initialise save paths for figures and matched neurons output
23 | track_ops.init_save_paths()
24 |
25 | # 2) Load data
26 | check_nplanes(track_ops)
27 |
28 | if track_ops.input_format == 'npy':
29 | print('Converting npy data to track2p-compatible format...')
30 | npy_to_s2p(track_ops)
31 |
32 | all_ds_avg_ch1, all_ds_avg_ch2 = load_all_imgs(track_ops)
33 |
34 | # 3) Plot available planes for registration
35 | plot_all_planes(all_ds_avg_ch1, track_ops)
36 | if track_ops.nchannels==2:
37 | plot_all_planes(all_ds_avg_ch2, track_ops, ch='anatomical')
38 |
39 | # 4) do the actual registration based on chosen channel
40 | all_ds_ref_img, all_ds_mov_img = get_all_ds_img_for_reg(all_ds_avg_ch1, all_ds_avg_ch2, track_ops)
41 |
42 | all_ds_mov_img_reg, all_ds_reg_params = run_reg_loop(all_ds_ref_img, all_ds_mov_img, track_ops)
43 |
44 | plot_reg_img_output(track_ops)
45 |
46 |
47 | # 5) apply computed transorm to all ROIs
48 | all_ds_all_roi_ref, all_ds_all_roi_mov, all_ds_all_roi_reg, all_ds_roi_counter = reg_all_ds_all_roi(all_ds_reg_params, track_ops)
49 |
50 |
51 | # 6) optional: generate 'yellow intersection' plot (this is only needed for plotting below)
52 | all_ds_ref_reg_inters = get_all_ref_nonref_inters(all_ds_all_roi_ref, all_ds_all_roi_reg, track_ops)
53 |
54 | all_ds_ref_mov_inters = get_all_ref_nonref_inters(all_ds_all_roi_ref, all_ds_all_roi_mov, track_ops)
55 |
56 |
57 | track_ops.all_ds_ref_mov_inters = all_ds_ref_mov_inters
58 | track_ops.all_ds_ref_reg_inters = all_ds_ref_reg_inters
59 |
60 | if track_ops.show_roi_reg_output:
61 | plot_roi_reg_output(track_ops)
62 |
63 |
64 | # 7) get optimal assignments for all pairs of recordings (first to last)
65 | all_ds_assign, all_ds_assign_thr, all_ds_thr_met, all_ds_thr = get_all_ds_assign(track_ops, all_ds_all_roi_ref, all_ds_all_roi_reg)
66 | plot_thr_met_hist(all_ds_thr_met, all_ds_thr, track_ops)
67 | plot_n_matched_roi(all_ds_thr_met, all_ds_thr, track_ops)
68 |
69 |
70 | # 8) get match matrices for all pairs of recordings (first to last)
71 | all_pl_match_mat = get_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops)
72 |
73 |
74 | # 9) save results
75 | save_track_ops(track_ops)
76 |
77 |
78 | save_all_pl_match_mat(all_pl_match_mat, track_ops)
79 |
80 | print('Generating suite2p indices')
81 | generate_suite2p_indices(track_ops)
82 |
83 |
84 |
85 | # 10) save in suite2p format
86 | if track_ops.save_in_s2p_format:
87 | print('Saving in suite2p format...')
88 | save_in_s2p_format(track_ops)
89 |
90 | # 11) plot results
91 | print('Finished with algorithm!\n\nGenerating plots (this can take some time)...\n\n')
92 | all_ds_stat_iscell = load_all_ds_stat_iscell(track_ops)
93 | all_ds_centroids = load_all_ds_centroids(all_ds_stat_iscell, track_ops)
94 | all_ds_mean_img = load_all_ds_mean_img(track_ops)
95 | if track_ops.nchannels==2:
96 | all_ds_mean_img_ch2 = load_all_ds_mean_img(track_ops, ch=2)
97 |
98 |
99 | plot_roi_match_multiplane(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, track_ops, win_size=track_ops.win_size)
100 | plot_allroi_match_multiplane(all_ds_mean_img, all_pl_match_mat, track_ops)
101 | if track_ops.nchannels==2:
102 | plot_roi_match_multiplane(all_ds_mean_img_ch2, all_ds_centroids, all_pl_match_mat, track_ops, win_size=track_ops.win_size, ch=2)
103 | plot_allroi_match_multiplane(all_ds_mean_img_ch2, all_pl_match_mat, track_ops, ch=2)
104 |
105 |
106 |
107 | print('\n\n\nDone!\n\n\n')
108 |
109 |
110 |
111 | def generate_suite2p_indices (track_ops):
112 |
113 |
114 | for plane in range(track_ops.nplanes):
115 |
116 | t2p_match_mat = np.load(os.path.join(track_ops.save_path, f"plane{plane}_match_mat.npy"), allow_pickle=True)
117 |
118 | all_iscell = []
119 |
120 | for ds_path in track_ops.all_ds_path:
121 | iscell = np.load(os.path.normpath(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy')), allow_pickle=True)
122 | all_iscell.append(iscell)
123 |
124 | true_indices = []
125 |
126 | for line in t2p_match_mat:
127 | indexes = []
128 | for day, index_match in enumerate(line):
129 | if index_match is None:
130 | indexes.append(None)
131 | else:
132 | iscell = all_iscell[day]
133 | if track_ops.iscell_thr is None:
134 | indices_lignes_1 = np.where(iscell[:, 0] == 1)[0] # ROIs considérés comme cellules
135 | else:
136 | indices_lignes_1 = np.where(iscell[:, 1] > track_ops.iscell_thr)[0] # ROIs avec proba > threshold
137 |
138 | true_index = indices_lignes_1[index_match]
139 | indexes.append(true_index)
140 |
141 | true_indices.append(indexes)
142 |
143 | true_indices = np.array([[int(x) if x is not None else None for x in row] for row in true_indices])
144 | true_indices_nan= np.array([[float(x) if x is not None else np.nan for x in row] for row in true_indices])
145 |
146 | np.save(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.npy"), true_indices)
147 | np.save(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices_nan.npy"), true_indices_nan)
148 | spicy.io.savemat(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.mat"), {'data': true_indices_nan})
149 |
150 | column_names = [os.path.basename(ds_path) for ds_path in track_ops.all_ds_path]
151 | csv_path = os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.csv")
152 | df=pd.DataFrame(true_indices, columns=column_names)
153 | df.to_csv(csv_path, index=False, sep=';', na_rep='NaN')
154 |
155 | print(true_indices.dtype)
156 | print(true_indices_nan.dtype)
157 |
158 |
159 |
160 |
161 |
162 |
163 | def save_in_s2p_format(track_ops):
164 |
165 | for ds_path in track_ops.all_ds_path:
166 | # check how many subfolders starting with plane* in suite2p folder
167 | n_planes = len([name for name in os.listdir(ds_path + '/suite2p') if name.startswith('plane')])
168 | print(f'Found {n_planes} planes in {ds_path}')
169 |
170 | folderpath=track_ops.save_path
171 | track_ops_dict = np.load(os.path.join(folderpath, "track_ops.npy"), allow_pickle=True).item()
172 | track_ops = SimpleNamespace(**track_ops_dict)
173 | iscell_thr = track_ops.iscell_thr
174 |
175 |
176 | for j in range(track_ops.nplanes):
177 |
178 | all_f_t2p= []
179 | all_ops = []
180 | all_stat_t2p = []
181 | all_iscell_t2p = []
182 | fneu_iscell_t2p= []
183 | spks_iscell_t2p= []
184 |
185 | fneu_chan2_iscell_t2p = []
186 | f_chan2_iscell_t2p = []
187 | redcell_iscell_t2p = []
188 |
189 |
190 | print(f'nplanes{j}')
191 | t2p_match_mat = np.load(os.path.join(folderpath,f'plane{str(j)}_match_mat.npy'), allow_pickle=True)
192 | t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :]
193 | for (i, ds_path) in enumerate(track_ops.all_ds_path):
194 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'ops.npy'), allow_pickle=True).item()
195 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'stat.npy'), allow_pickle=True)
196 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'F.npy'), allow_pickle=True)
197 | fneu= np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'Fneu.npy'), allow_pickle=True)
198 | spks= np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'spks.npy'), allow_pickle=True)
199 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'iscell.npy'), allow_pickle=True)
200 | if track_ops.nchannels==2:
201 | f_chan2=np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'F_chan2.npy'), allow_pickle=True)
202 | fneu_chan2 = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'Fneu_chan2.npy'), allow_pickle=True)
203 | redcell=np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'redcell.npy'), allow_pickle=True)
204 |
205 | if track_ops.iscell_thr==None:
206 | stat_iscell = stat[iscell[:, 0] == 1]
207 | f_iscell = f[iscell[:, 0] == 1, :]
208 | fneu_iscell = fneu[iscell[:, 0] == 1, :]
209 | spks_iscell = spks[iscell[:, 0] == 1, :]
210 | is_cell = iscell[iscell[:, 0] == 1, :]
211 | if track_ops.nchannels==2:
212 | f_chan2_iscell = f_chan2[iscell[:, 0] == 1, :]
213 | fneu_chan2_iscell = fneu_chan2[iscell[:, 0] == 1, :]
214 | redcell_iscell = redcell[iscell[:, 0] == 1]
215 | else:
216 | stat_iscell = stat[iscell[:, 1] > iscell_thr]
217 | f_iscell = f[iscell[:, 1] > iscell_thr, :]
218 | fneu_iscell = fneu[iscell[:, 1] > iscell_thr, :]
219 | spks_iscell = spks[iscell[:, 1] > iscell_thr, :]
220 | is_cell = iscell[iscell[:, 1] > iscell_thr, :]
221 | if track_ops.nchannels==2:
222 | f_chan2_iscell = f_chan2[iscell[:, 1] > track_ops.iscell_thr, :]
223 | fneu_chan2_iscell = fneu_chan2[iscell[:, 1] > track_ops.iscell_thr, :]
224 | redcell_iscell = redcell[iscell[:, 1] > track_ops.iscell_thr]
225 |
226 | stat_t2p = stat_iscell[t2p_match_mat_allday[:, i].astype(int)]
227 | f_t2p = f_iscell[t2p_match_mat_allday[:, i].astype(int), :]
228 | fneu_t2p = fneu_iscell[t2p_match_mat_allday[:, i].astype(int), :]
229 | spks_t2p = spks_iscell[t2p_match_mat_allday[:, i].astype(int), :]
230 | iscell_t2p = is_cell[t2p_match_mat_allday[:, i].astype(int), :]
231 | if track_ops.nchannels==2:
232 | fneu_chan2_t2p = fneu_chan2_iscell[t2p_match_mat_allday[:, i].astype(int), :]
233 | f_chan2_t2p = f_chan2_iscell[t2p_match_mat_allday[:, i].astype(int), :]
234 | redcell_t2p = redcell_iscell[t2p_match_mat_allday[:, i].astype(int)]
235 |
236 | all_stat_t2p.append(stat_t2p)
237 | all_f_t2p.append(f_t2p)
238 | all_ops.append(ops)
239 | all_iscell_t2p.append(iscell_t2p)
240 | fneu_iscell_t2p.append(fneu_t2p)
241 | spks_iscell_t2p.append(spks_t2p)
242 | if track_ops.nchannels==2:
243 | fneu_chan2_iscell_t2p.append(fneu_chan2_t2p)
244 | f_chan2_iscell_t2p.append(f_chan2_t2p)
245 | redcell_iscell_t2p.append(redcell_t2p)
246 |
247 |
248 | output_folderpath=os.path.join(folderpath, 'matched_suite2p')
249 | last_elements = [os.path.basename(path) for path in track_ops.all_ds_path]
250 | print(last_elements)
251 |
252 | # Save each element of each list to a .npy file
253 | for i in range(len(track_ops.all_ds_path)):
254 | stat_t2p, f_t2p, ops, iscell_t2p, fneu_t2p, spks_t2p = all_stat_t2p[i], all_f_t2p[i], all_ops[i], all_iscell_t2p[i], fneu_iscell_t2p[i], spks_iscell_t2p[i]
255 | subfolder_path = os.path.join(output_folderpath, last_elements[i])
256 | if not os.path.exists(subfolder_path):
257 | os.makedirs(subfolder_path)
258 | inner_folderpath = os.path.join(subfolder_path, 'suite2p')
259 | if not os.path.exists(inner_folderpath):
260 | os.makedirs(inner_folderpath)
261 | plane_folderpath = os.path.join(inner_folderpath, f'plane{j}')
262 | if not os.path.exists(plane_folderpath):
263 | os.makedirs(plane_folderpath)
264 |
265 |
266 | np.save(os.path.join(plane_folderpath,f"stat.npy"), stat_t2p)
267 | np.save(os.path.join(plane_folderpath, f"F.npy"), f_t2p)
268 | np.save(os.path.join(plane_folderpath, f"ops.npy"), ops)
269 | np.save(os.path.join(plane_folderpath, f"iscell.npy"), iscell_t2p)
270 | np.save(os.path.join(plane_folderpath, f"Fneu.npy"), fneu_t2p)
271 | np.save(os.path.join(plane_folderpath,f"spks.npy"), spks_t2p)
272 | if track_ops.nchannels==2:
273 | for i, (redcell_t2p, f_chan2_t2p, fneu_chan2_t2p) in enumerate(zip(redcell_iscell_t2p, f_chan2_iscell_t2p, fneu_chan2_iscell_t2p)):
274 | np.save(os.path.join(plane_folderpath, f"F_chan2.npy"), f_chan2_t2p)
275 | np.save(os.path.join(plane_folderpath, f"Fneu_chan2.npy"), fneu_chan2_t2p)
276 | np.save(os.path.join(plane_folderpath, f"redcell.npy"), redcell_t2p)
--------------------------------------------------------------------------------
/notebooks/demo_t2p_ouputs.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Analysing output of t2p\n",
8 | "This is a short demo explaining how to explore the output of track2p and use the matched neurons/traces for custom downstream analysis."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "The example here is for a 1 plane recording with simultaneous videography (given dataset is jm032)."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "# imports\n",
25 | "import os\n",
26 | "from types import SimpleNamespace\n",
27 | "\n",
28 | "import numpy as np\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "\n",
31 | "from scipy.stats import zscore\n"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "### Step by step guide (more detailed explanations below):\n",
39 | "\n",
40 | "Each point from this list matches one section of this notebook\n",
41 | "\n",
42 | "1) Load the output of track2p\n",
43 | "2) Find cells that are present in all recordings ('matched cells')\n",
44 | "3) Load the data from one example dataset and visualise it\n",
45 | "4) Load the activity of the matched cells\n",
46 | "5) Visualise the activity of matched cells"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "### 1) Load the output of track2p\n",
54 | "We will load the `.npy` files: `t2p_output_path/track2p/plane#_match_mat.npy` and `t2p_output_path/track2p/track_ops.npy`. These are the matrix of cell matches for all days and the settings respectively. For more info see the repo readme and documentation.\n",
55 | "\n",
56 | "Note: In this demo a single-plane recording is used, but it can be modified easily for multiplane compatility (just repeat the same procedure while looping through planes)"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# this is the directory that contains a /track2p folder that is output by running the track2p algorithm\n",
66 | "t2p_save_path = '/Users/manonmantez/Desktop/el' # (change this based on your data)\n",
67 | "plane = 'plane0' # which plane to process (the example dataset is single-plane)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "# np.load() the match matrix (plane0_match_mat.npy)\n",
77 | "t2p_match_mat = np.load(os.path.join(t2p_save_path, 'track2p', f'{plane}_match_mat.npy'), allow_pickle=True)\n",
78 | "\n",
79 | "# np.load() settings (this contains suite2p paths etc.) (track_ops.npy)\n",
80 | "track_ops_dict = np.load(os.path.join(t2p_save_path, 'track2p', 'track_ops.npy'), allow_pickle=True).item()\n",
81 | "track_ops = SimpleNamespace(**track_ops_dict) # create dummy object from the track_ops dictionary"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "### 2) Find cells that are present in all recordings ('matched cells')\n",
89 | "\n"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "Now from this matrix get the matches that are present on all days:\n",
97 | "\n",
98 | "- A matrix (`plane#_match_mat.npy`) containing the indices of matched neurons across the session for a given plane (`#` is the index of the plane). Since matching is done from first day to last, some neurons will not be sucessfully tracked after one or a few days. In this case the matrix contains `None` values. To get neurons tracked across all days only take the rows of the matrices containing no `None` values. \n",
99 | "\n",
100 | "Note: of course we can use cells that are not present on all days, but for now this is the intended use case for downstream analysis."
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "# get the rows that do not contain any Nones (if track2p doesnt find a match for a cell across two consecutive days it will append a None) -> cells with no Nones are cells matched across all days\n",
110 | "t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat==None, axis=1), :]\n",
111 | "\n",
112 | "print(f'Shape of match matrix for cells present on all days: {t2p_match_mat_allday.shape} (cells, days)')"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "### 3) Load the data from one example dataset and visualise it\n",
120 | "\n",
121 | "Note: The track_ops.npy ('settings file') contains all the paths to suite2p folders used when running track2p (see cell below)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "print('Datasets used for t2p:\\n')\n",
131 | "for ds_path in track_ops.all_ds_path:\n",
132 | " print(ds_path)"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "Now just to test if the paths work we can try to look at data of one of the recordings (in the case below we use the last one). For this part it is important to know a bit about how the suite2p structures the outputs: https://suite2p.readthedocs.io/en/latest/outputs.html (the important things will be the `ops.npy`, `stat.npy`, `iscell.npy` and the `F.npy`). There are also separate tutorials and demos for this so we won't go into so much detail."
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "# lets take the last dataset\n",
149 | "last_ds_path = track_ops.all_ds_path[-1]\n",
150 | "print(f'We will look at the dataset saved at: {last_ds_path}')"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "# load the three files\n",
160 | "last_ops = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'ops.npy'), allow_pickle=True).item()\n",
161 | "last_f = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'F.npy'), allow_pickle=True)\n",
162 | "iscell = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'iscell.npy'), allow_pickle=True)"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "# we filter the traces based on suite2p's iscell probability (note: it is crucial to use the same probability as in the track2p settings to keep the correct indexing of matches)\n",
172 | "iscell_thr = track_ops.iscell_thr\n",
173 | "\n",
174 | "print(f'The iscell threshold used when running track2p was: {iscell_thr}')\n",
175 | "\n",
176 | "if track_ops.iscell_thr==None:\n",
177 | " last_f_iscell = last_f[iscell[:, 0] == 1, :]\n",
178 | "\n",
179 | "else:\n",
180 | " last_f_iscell = last_f[iscell[:, 1] > iscell_thr, :]"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "# now first plot the mean image of the movie (it is saved in ops.npy, for more info see the suite2p outputs documentation)\n",
190 | "plt.imshow(last_ops['meanImg'], cmap='gray')\n",
191 | "plt.axis('off')\n",
192 | "plt.title('Mean image')\n",
193 | "plt.show()\n",
194 | "\n",
195 | "plt.figure(figsize=(10, 1))\n",
196 | "nonmatch_nrn_idx = 0\n",
197 | "plt.plot(last_f[nonmatch_nrn_idx, :])\n",
198 | "plt.xlabel('Frame')\n",
199 | "plt.ylabel('F')\n",
200 | "plt.title(f'Example trace (nrn_idx: {nonmatch_nrn_idx})')\n",
201 | "plt.show()\n",
202 | "\n",
203 | "plt.figure(figsize=(10, 3))\n",
204 | "plt.imshow(zscore(last_f_iscell, axis=1), aspect='auto', cmap='Greys', vmin=0, vmax=1.96)\n",
205 | "plt.xlabel('Frame')\n",
206 | "plt.ylabel('ROI')\n",
207 | "plt.title('Raster plot')\n",
208 | "plt.show()\n"
209 | ]
210 | },
211 | {
212 | "cell_type": "markdown",
213 | "metadata": {},
214 | "source": [
215 | "## 4) Load the activity of the matched cells\n",
216 | "\n",
217 | "Now that we know how to look at data in one recording we will use the output from track2p to look at activity of the same cells across all datasets."
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {},
223 | "source": [
224 | "To do this we need to loop through all datasets and:\n",
225 | "- load the files described above\n",
226 | "- filter `stat.npy` and `fluo.npy` by the track2p iscell threshold (classical suite2p)\n",
227 | "- filter `stat.npy` and `fluo.npy` by the appropriate indices from the matrix of neurons matched on all days (additional filtering step after track2p)\n",
228 | "\n",
229 | "This will produce a nice data structure where the indices of cells are matched within the stat and fluo objects. Sorting the object in this way allows for very straightforward extraction of matched data (see cells below)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "iscell_thr = track_ops.iscell_thr # use the same threshold as when running the algo (to be consistent with indexing)\n",
239 | "\n",
240 | "all_stat_t2p = []\n",
241 | "all_f_t2p = []\n",
242 | "all_ops = [] # ops dont change\n",
243 | "\n",
244 | "for (i, ds_path) in enumerate(track_ops.all_ds_path):\n",
245 | " ops = np.load(os.path.join(ds_path, 'suite2p', plane, 'ops.npy'), allow_pickle=True).item()\n",
246 | " stat = np.load(os.path.join(ds_path, 'suite2p', plane, 'stat.npy'), allow_pickle=True)\n",
247 | " f = np.load(os.path.join(ds_path, 'suite2p', plane, 'F.npy'), allow_pickle=True)\n",
248 | " iscell = np.load(os.path.join(ds_path, 'suite2p', plane, 'iscell.npy'), allow_pickle=True)\n",
249 | " \n",
250 | " \n",
251 | " if track_ops.iscell_thr==None:\n",
252 | " stat_iscell = stat[iscell[:, 0] == 1]\n",
253 | " f_iscell = f[iscell[:, 0] == 1, :]\n",
254 | "\n",
255 | " else:\n",
256 | " stat_iscell = stat[iscell[:, 1] > iscell_thr]\n",
257 | " f_iscell = f[iscell[:, 1] > iscell_thr, :]\n",
258 | " \n",
259 | " \n",
260 | " stat_t2p = stat_iscell[t2p_match_mat_allday[:,i].astype(int)]\n",
261 | " f_t2p = f_iscell[t2p_match_mat_allday[:,i].astype(int), :]\n",
262 | "\n",
263 | " all_stat_t2p.append(stat_t2p)\n",
264 | " all_f_t2p.append(f_t2p)\n",
265 | " all_ops.append(ops)\n",
266 | "\n"
267 | ]
268 | },
269 | {
270 | "cell_type": "markdown",
271 | "metadata": {},
272 | "source": [
273 | "### 5) Visualise the ROIs and the activity of (a) matched cell(s)\n",
274 | "\n",
275 | "\n",
276 | "This example shows how to extract the information of a ROI from all_stat. We first index by the day to get stat_t2p from all_stat2p (this is the sorted stat object for that day). We can then get the roi information by indexing stat_t2p by the index of the cell match (because of resorting we use the same index across days)."
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "wind = 24\n",
286 | "nrn_idx = 0\n",
287 | "\n",
288 | "for i in range(len(track_ops.all_ds_path)):\n",
289 | " mean_img = all_ops[i]['meanImg']\n",
290 | " stat_t2p = all_stat_t2p[i]\n",
291 | " median_coord = stat_t2p[nrn_idx]['med']\n",
292 | "\n",
293 | " plt.figure(figsize=(1.5,1.5))\n",
294 | " plt.imshow(mean_img[int(median_coord[0])-wind:int(median_coord[0])+wind, int(median_coord[1])-wind:int(median_coord[1])+wind], cmap='gray') # plot a short window around the ROI centroid\n",
295 | " plt.scatter(wind, wind)\n",
296 | " plt.axis('off')\n",
297 | " plt.show()"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "# first plot the trace of cell c for all days\n",
307 | "nrn_idx = 0 # the activity of the ROI visualised above on all days\n",
308 | "\n",
309 | "for i in range(len(track_ops.all_ds_path)):\n",
310 | " plt.figure(figsize=(10, 1)) # make a wide figure\n",
311 | " plt.plot(all_f_t2p[i][nrn_idx, :])\n",
312 | " plt.xlabel('Frame')\n",
313 | " plt.ylabel('F')\n",
314 | " plt.show()\n"
315 | ]
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {},
320 | "source": [
321 | "Now to visualise the rasters its a simple exercise, since they are already sorted in a way that the rows represent the same cell across days we don't need to do anything other than simply looping through all_f_t2p and plotting each element as we did before."
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "for i in range(len(track_ops.all_ds_path)):\n",
331 | " plt.figure(figsize=(10, 3)) # make a wide figure\n",
332 | " f_plot = zscore(all_f_t2p[i], axis=1)\n",
333 | " plt.imshow(f_plot, aspect='auto', cmap='Greys', vmin=0, vmax=1.96)\n",
334 | " plt.xlabel('Frame') "
335 | ]
336 | },
337 | {
338 | "cell_type": "markdown",
339 | "metadata": {},
340 | "source": [
341 | "### The End!\n",
342 | "\n",
343 | "Congrats! Hopefully this notebook was a clear and useful way of showing how to interact with the track2p outputs.\n",
344 | "\n",
345 | "From here on custom analysis pipelines can very easily be applied (for example looking at stability of assemblies, representational drift etc etc). \n",
346 | "\n",
347 | "The most straightforward way of doing this is to just run an already implemented pipeline on the data loaded as shown here. Alternatively the loaded match indices can be used to look at already-processed data as a way of post-hoc matching.\n",
348 | "\n",
349 | "Thanks and have fun with analysis :)"
350 | ]
351 | }
352 | ],
353 | "metadata": {
354 | "kernelspec": {
355 | "display_name": "track2p",
356 | "language": "python",
357 | "name": "python3"
358 | },
359 | "language_info": {
360 | "codemirror_mode": {
361 | "name": "ipython",
362 | "version": 3
363 | },
364 | "file_extension": ".py",
365 | "mimetype": "text/x-python",
366 | "name": "python",
367 | "nbconvert_exporter": "python",
368 | "pygments_lexer": "ipython3",
369 | "version": "3.9.19"
370 | }
371 | },
372 | "nbformat": 4,
373 | "nbformat_minor": 2
374 | }
375 |
--------------------------------------------------------------------------------
/track2p/plot/output.py:
--------------------------------------------------------------------------------
1 | # contains code for generating plots based on track_ops object after the pipeline is run (this will be saved as an npy file)
2 | # here all functions just take the track_ops object as input
3 | import os
4 |
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | from skimage.exposure import match_histograms
9 | from track2p.plot.utils import make_rgb_img, saturate_perc, get_all_wind_mean_img
10 | from track2p.io.loaders import load_stat_ds_plane, get_all_roi_array_from_stat
11 |
12 | def plot_reg_img_output(track_ops):
13 | # make a plot where on the top its all the images and the bottom is the overlays before and after registration
14 | nplanes = track_ops.nplanes
15 | n_row = nplanes + 2*nplanes # number of plays + 2 overlays per plane
16 | n_col = len(track_ops.all_ds_path) # number of datasets
17 | figsize = (10/3 * n_col, 10/3 * n_row)
18 | fig, axs = plt.subplots(n_row, n_col, figsize=figsize, dpi=300)
19 |
20 | # first populate first (n_planes) rows with images
21 | for i in range(nplanes):
22 | axs[i, 0].set_ylabel(f'plane{i}\nchan{track_ops.reg_chan}', rotation=0, size='large', ha='right', va='center', labelpad=20)
23 | for j in range(len(track_ops.all_ds_path)):
24 | img = track_ops.all_ds_avg_ch1[j][i] if track_ops.reg_chan==0 else track_ops.all_ds_avg_ch2[j][i]
25 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, 99))
26 |
27 |
28 | for i in range(len(track_ops.all_ds_path)-1): # last one won't have overlay
29 | for j in range(nplanes):
30 | # get subplot indices
31 | row_nonreg = nplanes + 2*j # first shift for number of initial rows, then shift by 2 for each plane (before and after registration)
32 | row_reg = nplanes + 2*j + 1 # the row after the nonreg
33 |
34 | # get images
35 | ref_img = track_ops.all_ds_ref_img[i][j] # get ref image for this pair
36 | mov_img = track_ops.all_ds_mov_img[i][j] # get mov image for this pair
37 | mov_img_ref = track_ops.all_ds_mov_img_reg[i][j] # get mov image after registration for this pair
38 |
39 | # match histograms to reference
40 | mov_img = match_histograms(mov_img, ref_img)
41 | mov_img_ref = match_histograms(mov_img_ref, ref_img)
42 |
43 | # assemble and saturate the overlays
44 | img_rgb = make_rgb_img(ref_img, mov_img)
45 | img_rgb_reg = make_rgb_img(ref_img, mov_img_ref)
46 | img_rgb = saturate_perc(img_rgb, sat_perc=track_ops.sat_perc)
47 | img_rgb_reg = saturate_perc(img_rgb_reg, sat_perc=track_ops.sat_perc)
48 |
49 | # plot the overlays
50 | axs[row_nonreg, i].imshow(img_rgb)
51 | axs[row_reg, i].imshow(img_rgb_reg)
52 |
53 | if i == 0:
54 | for k in range(nplanes):
55 | axs[nplanes + 2*k,i].set_ylabel(f'plane{k}\nchan{track_ops.reg_chan}\nnon-reg', rotation=0, size='large', ha='right', va='center', labelpad=20)
56 | axs[nplanes + 2*k+1,i].set_ylabel(f'plane{k}\nchan{track_ops.reg_chan}\nreg', rotation=0, size='large', ha='right', va='center', labelpad=20)
57 |
58 | # loop through all subplots and remove ticks and labels and spines
59 |
60 | for ax in axs.flat:
61 | ax.set(xticks=[], yticks=[])
62 | ax.spines['top'].set_visible(False)
63 | ax.spines['right'].set_visible(False)
64 | ax.spines['left'].set_visible(False)
65 | ax.spines['bottom'].set_visible(False)
66 |
67 | # add arrows and dashed lines etc.
68 |
69 | axs[2*nplanes, 0].annotate('', xy=(0, 1.1), xytext=(n_col + 0.5, 1.1),
70 | xycoords='axes fraction', textcoords='axes fraction',
71 | arrowprops=dict(arrowstyle='-', linestyle='dashed', color='grey'),
72 | annotation_clip=False)
73 |
74 | for i in range(len(track_ops.all_ds_path)-1):
75 | axs[nplanes-1,i].annotate('', xy=(0.5, -0.17), xytext=(0.5, -0.02),
76 | xycoords='axes fraction', textcoords='axes fraction',
77 | arrowprops=dict(facecolor=(1,0,0), edgecolor=(1,0,0), shrink=0.05),
78 | annotation_clip=False)
79 | axs[nplanes-1,i+1].annotate('', xy=(-0.6, -0.17), xytext=(0.5, -0.02),
80 | xycoords='axes fraction', textcoords='axes fraction',
81 | arrowprops=dict(facecolor=(0,1,0), edgecolor=(0,1,0), shrink=0.05),
82 | annotation_clip=False)
83 |
84 |
85 | # save figure into the output path
86 | fig.savefig(os.path.join(track_ops.save_path_fig, 'reg_img_output.png'), bbox_inches='tight', dpi=200)
87 | plt.close(fig)
88 |
89 | def plot_roi_reg_output(track_ops):
90 | # make a plot where on the top its all the images and the bottom is the overlays before and after registration
91 | nplanes = track_ops.nplanes
92 | n_row = nplanes + 2*nplanes # number of plays + 2 overlays per plane
93 | n_col = len(track_ops.all_ds_path) # number of datasets
94 | figsize = (10/3 * n_col, 10/3 * n_row)
95 | # figsize = (10 * n_col, 10 * n_row)
96 |
97 | fig, axs = plt.subplots(n_row, n_col, figsize=figsize, dpi=300)
98 |
99 | # first populate first (n_planes) rows with images
100 | for i in range(nplanes):
101 | axs[i, 0].set_ylabel(f'plane{i}\nchan{track_ops.reg_chan}', rotation=0, size='large', ha='right', va='center', labelpad=20)
102 | for j in range(len(track_ops.all_ds_path)):
103 | img = track_ops.all_ds_avg_ch1[j][i]
104 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, 99), alpha=0.5)
105 | # take the rois of reference unless it is last on, then take the rois of the mov
106 | all_roi = track_ops.all_ds_all_roi_array_ref[j][i] if j < len(track_ops.all_ds_path)-1 else track_ops.all_ds_all_roi_array_mov[j-1][i]
107 | for k in range(all_roi.shape[2]):
108 | axs[i, j].contour(all_roi[:,:,k], colors='C0', linewidths=0.3)
109 |
110 | # now populate the next (n_planes) rows with the overlays before registration
111 | for i in range(len(track_ops.all_ds_path)-1):
112 | print(f'Plotting contours for dataset {i}/{len(track_ops.all_ds_path)-1}')
113 | for j in range(track_ops.nplanes):
114 |
115 | row_nonreg = nplanes + 2*j # first shift for number of initial rows, then shift by 2 for each plane (before and after registration)
116 | row_reg = nplanes + 2*j + 1 # the row after the nonreg
117 |
118 | all_roi_array_ref = track_ops.all_ds_all_roi_array_ref[i][j]
119 |
120 | all_roi_array_mov = track_ops.all_ds_all_roi_array_mov[i][j]
121 | ref_mov_inters = track_ops.all_ds_ref_mov_inters[i][j]
122 |
123 | axs[row_nonreg,i].imshow(ref_mov_inters)
124 | for k in range(all_roi_array_ref.shape[2]):
125 | axs[row_nonreg,i].contour(all_roi_array_ref[:,:,k], colors='r', linewidths=0.3)
126 | for k in range(all_roi_array_mov.shape[2]):
127 | axs[row_nonreg,i].contour(all_roi_array_mov[:,:,k], colors='g', linewidths=0.3)
128 |
129 | # TODO: for reg row
130 | all_roi_array_reg = track_ops.all_ds_all_roi_array_reg[i][j]
131 | ref_reg_inters = track_ops.all_ds_ref_reg_inters[i][j]
132 |
133 | axs[row_reg,i].imshow(ref_reg_inters)
134 | for k in range(all_roi_array_ref.shape[2]):
135 | axs[row_reg,i].contour(all_roi_array_ref[:,:,k], colors='r', linewidths=0.3)
136 | for k in range(all_roi_array_reg.shape[2]):
137 | axs[row_reg,i].contour(all_roi_array_reg[:,:,k], colors='g', linewidths=0.3)
138 |
139 | if i == 0:
140 | for k in range(nplanes):
141 | axs[nplanes + 2*k,i].set_ylabel(f'plane{k}\nnon-reg', rotation=0, size='large', ha='right', va='center', labelpad=20)
142 | axs[nplanes + 2*k+1,i].set_ylabel(f'plane{k}\nreg', rotation=0, size='large', ha='right', va='center', labelpad=20)
143 |
144 |
145 | for ax in axs.flat:
146 | ax.set(xticks=[], yticks=[])
147 | ax.spines['top'].set_visible(False)
148 | ax.spines['right'].set_visible(False)
149 | ax.spines['left'].set_visible(False)
150 | ax.spines['bottom'].set_visible(False)
151 |
152 | # add arrows and dashed lines etc.
153 |
154 | axs[2*nplanes, 0].annotate('', xy=(0, 1.1), xytext=(n_col + 0.5, 1.1),
155 | xycoords='axes fraction', textcoords='axes fraction',
156 | arrowprops=dict(arrowstyle='-', linestyle='dashed', color='grey'),
157 | annotation_clip=False)
158 |
159 | for i in range(len(track_ops.all_ds_path)-1):
160 | axs[nplanes-1,i].annotate('', xy=(0.5, -0.17), xytext=(0.5, -0.02),
161 | xycoords='axes fraction', textcoords='axes fraction',
162 | arrowprops=dict(facecolor='r', edgecolor='r', shrink=0.05),
163 | annotation_clip=False)
164 | axs[nplanes-1,i+1].annotate('', xy=(-0.6, -0.17), xytext=(0.5, -0.02),
165 | xycoords='axes fraction', textcoords='axes fraction',
166 | arrowprops=dict(facecolor='g', edgecolor='g', shrink=0.05),
167 | annotation_clip=False)
168 |
169 |
170 | # save figure into the output path
171 | fig.savefig(track_ops.save_path_fig + 'reg_roi_output.png', bbox_inches='tight', dpi=200)
172 | plt.close(fig)
173 |
174 |
175 | def plot_thr_met_hist(all_ds_thr_met, all_ds_thr, track_ops):
176 | fig, axs = plt.subplots(track_ops.nplanes, len(all_ds_thr_met), figsize=(6*len(all_ds_thr_met), 6*track_ops.nplanes), sharey=True, sharex=True)
177 | for i in range(len(all_ds_thr_met)):
178 | for j in range(track_ops.nplanes):
179 | axs = np.array([axs]) if type(axs) is not np.ndarray else axs
180 |
181 | this_ax = axs[i] if track_ops.nplanes==1 or len(track_ops.all_ds_path)==2 else axs[j][i]
182 |
183 | n_reg_roi = len(all_ds_thr_met[i][j])
184 | n_abovethr_roi = np.sum(all_ds_thr_met[i][j]>all_ds_thr[i][j])
185 | this_ax.hist(all_ds_thr_met[i][j], bins=20)
186 | this_ax.axvline(all_ds_thr[i][j], color='grey', linestyle='--')
187 | # label the line with 'otsu threshold'
188 | this_ax.text(all_ds_thr[i][j]+0.02, this_ax.get_ylim()[1]*0.9, f'{track_ops.thr_method} thr.: {all_ds_thr[i][j]:.2f}')
189 | this_ax.set_title(f'ds{i} (ref) to ds{i+1} (reg); matched {n_abovethr_roi}/{n_reg_roi} ({n_abovethr_roi/n_reg_roi*100:.1f}%)')
190 |
191 | if i==0:
192 | this_ax.set_ylabel(f'ROI count (plane{j})')
193 |
194 | # remove the top and right spines
195 | for a in axs.flatten():
196 | a.spines['top'].set_visible(False)
197 | a.spines['right'].set_visible(False)
198 |
199 |
200 | plt.tight_layout()
201 | plt.savefig(os.path.join(track_ops.save_path_fig, 'thr_met_hist.png'), dpi=200)
202 | plt.close(fig)
203 |
204 |
205 |
206 | def plot_n_matched_roi(all_ds_thr_met, all_ds_thr, track_ops):
207 |
208 | fig, axs = plt.subplots(track_ops.nplanes, 2, figsize=(8, 4*track_ops.nplanes), sharey='col', sharex=True)
209 |
210 | for i in range(track_ops.nplanes):
211 | all_n_reg_roi = [len(all_ds_thr_met[j][i]) for j in range(len(all_ds_thr_met))]
212 | all_n_abovethr_roi = [np.sum(all_ds_thr_met[j][i]>all_ds_thr[j][i]) for j in range(len(all_ds_thr_met))]
213 | all_prop_abovethr = [n_abovethr_roi/n_reg_roi for n_abovethr_roi, n_reg_roi in zip(all_n_abovethr_roi, all_n_reg_roi)]
214 |
215 | ax0 = axs[i, 0] if track_ops.nplanes>1 else axs[0]
216 |
217 | ax0.bar(np.arange(len(all_n_reg_roi)), all_n_reg_roi, color='grey', label='total')
218 | ax0.bar(np.arange(len(all_n_abovethr_roi)), all_n_abovethr_roi, color='C0', label='above threshold')
219 |
220 | ax0.set_ylabel(f'ROI count (plane {i})')
221 | ax0.set_xticks(np.arange(len(all_n_reg_roi)))
222 | ax0.set_xticklabels([f'r{i}-{i+1}' for i in range(len(all_n_reg_roi))])
223 | ax0.legend(loc='lower right')
224 | ax0.spines['top'].set_visible(False)
225 | ax0.spines['right'].set_visible(False)
226 |
227 | ax1 = axs[i, 1] if track_ops.nplanes>1 else axs[1]
228 |
229 | ax1.plot(np.arange(len(all_n_reg_roi)), all_prop_abovethr, color='C0', marker='o')
230 | ax1.set_ylabel(f'ROI prop. (plane {i})')
231 | ax1.set_xticks(np.arange(len(all_n_reg_roi)))
232 | ax1.set_xticklabels([f'r{i}-{i+1}' for i in range(len(all_n_reg_roi))])
233 | ax1.set_ylim(0, 1)
234 | ax1.spines['top'].set_visible(False)
235 | ax1.spines['right'].set_visible(False)
236 |
237 | fig.savefig(os.path.join(track_ops.save_path_fig, 'n_prop_matched_roi.png'), dpi=200, bbox_inches='tight')
238 | plt.close(fig)
239 |
240 | def plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids, track_ops, plane_idx=0, win_size=64, k=0, n=None, ch=1):
241 |
242 | if n is None:
243 | n = len(neuron_ids)
244 |
245 | nrows = len(neuron_ids)
246 | ncols = len(all_ds_mean_img)
247 | fig, axs = plt.subplots(nrows, ncols, figsize=(2*ncols, 2*nrows), dpi=50)
248 |
249 | for (i, nrn_id) in enumerate(neuron_ids):
250 | # check if neuron is not matched (if ith row of all_pl_match_mat is all None)
251 | if any(all_pl_match_mat[plane_idx][nrn_id,:]==None):
252 | continue
253 |
254 | # get all wind_mean_img (small window around centroid)
255 | all_wind_mean_img = get_all_wind_mean_img(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, nrn_id, plane_idx=plane_idx, win_size=win_size)
256 |
257 | if i == 0:
258 | fig_ref = all_wind_mean_img[0] # reference image for matching histograms along whole image
259 | else:
260 | all_wind_mean_img[0] = match_histograms(all_wind_mean_img[0], fig_ref) # if not first ROI of whole image then match to first ROI of whole image
261 |
262 | for j in range(len(all_ds_mean_img)):
263 | wind_mean_img = all_wind_mean_img[j]
264 | ref_img = all_wind_mean_img[0]
265 | matched = match_histograms(wind_mean_img, ref_img)
266 | axs[i, j].imshow(matched, cmap='gray')
267 | # scatter middle pixel of the window
268 | axs[i, j].scatter(win_size/2, win_size/2, color='C0')
269 | if i==0:
270 | axs[i, j].set_title(f'ds {j} (pl {plane_idx})')
271 | if j==0:
272 | axs[i, j].set_ylabel(f'ROI {nrn_id} ({i+k}/{n})')
273 |
274 | # remove axes labels
275 | for ax in axs.flat:
276 | ax.set_xticks([])
277 | ax.set_yticks([])
278 |
279 | plt.tight_layout()
280 | if ch == 1:
281 | plt.savefig(os.path.join(track_ops.save_path_fig, f'roi_match_plane{plane_idx}_idx{k}-{k+len(neuron_ids)}.png'), dpi=50)
282 | elif ch == 2:
283 | plt.savefig(os.path.join(track_ops.save_path_fig, f'roi_match_chan2_plane{plane_idx}_idx{k}-{k+len(neuron_ids)}.png'), dpi=50)
284 | plt.close(fig)
285 |
286 | def plot_roi_match_multiplane(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, track_ops, win_size=48, ch=1):
287 |
288 | for i in range(track_ops.nplanes):
289 | pl_neuron_ids = np.arange(all_pl_match_mat[i].shape[0])
290 | for j in range(len(track_ops.save_path)):
291 | neuron_ids = pl_neuron_ids[~np.any(all_pl_match_mat[i]==None, axis=1)]
292 |
293 | if len(neuron_ids) < 100:
294 | plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids, track_ops, plane_idx=i, win_size=win_size, ch=ch)
295 | else:
296 | # plot in batches of 100
297 | for k in range(0, len(neuron_ids), 100):
298 | plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids[k:k+100], track_ops, plane_idx=i, win_size=win_size, k=k, n=len(neuron_ids), ch=ch)
299 |
300 |
301 | def plot_allroi_match_multiplane(all_ds_mean_img, all_pl_match_mat, track_ops, ch=1):
302 |
303 |
304 | fig, axs = plt.subplots(track_ops.nplanes, len(track_ops.all_ds_path), figsize=(4*len(track_ops.all_ds_path), 4*track_ops.nplanes), dpi=300)
305 |
306 | for ds_idx, ds_path in enumerate(track_ops.all_ds_path):
307 | for plane_idx in range(track_ops.nplanes):
308 | # saturating the reference image
309 | ref_mean_img = all_ds_mean_img[0][plane_idx]
310 | perc = np.percentile(ref_mean_img, track_ops.sat_perc)
311 | ref_mean_img[ref_mean_img>perc] = perc
312 |
313 | ax = axs[ds_idx] if track_ops.nplanes==1 else axs[plane_idx, ds_idx]
314 | mean_img = all_ds_mean_img[ds_idx][plane_idx]
315 | mean_img_matched = match_histograms(mean_img, ref_mean_img)
316 | ax.imshow(mean_img_matched, cmap='gray')
317 | ax.set_title(f'ds{ds_idx} (plane{plane_idx})')
318 |
319 | # add contours
320 | for plane_idx in range(track_ops.nplanes):
321 | pl_neuron_ids = np.arange(all_pl_match_mat[plane_idx].shape[0])
322 | neuron_ids = pl_neuron_ids[~np.any(all_pl_match_mat[plane_idx]==None, axis=1)]
323 | neuron_colors = [np.random.rand(3) for i in range(len(neuron_ids))]
324 |
325 | for ds_idx, _ in enumerate(track_ops.all_ds_path):
326 |
327 | stat_ds_plane, _ = load_stat_ds_plane(track_ops.all_ds_path[ds_idx], track_ops, plane_idx=plane_idx)
328 | all_roi_array = get_all_roi_array_from_stat(stat_ds_plane, track_ops)
329 |
330 | ax = axs[ds_idx] if track_ops.nplanes==1 else axs[plane_idx, ds_idx]
331 |
332 | for (i, neuron_idx) in enumerate(neuron_ids):
333 | idx_orig = all_pl_match_mat[plane_idx][neuron_idx, ds_idx] # neurons index in the original recording
334 | cont = stat_ds_plane
335 |
336 | cont_plot = ax.contour(all_roi_array[:,:,idx_orig], linewidths=0.5)
337 |
338 | for collection in cont_plot.collections:
339 | collection.set_edgecolor(neuron_colors[i]) # RGB value for red
340 |
341 | left_axs = axs[0] if track_ops.nplanes==1 else axs[plane_idx,0]
342 | left_axs.set_ylabel(f'plane{plane_idx} (n={len(neuron_ids)})')
343 |
344 | # remove axes elements
345 | for ax in axs.flat:
346 | ax.set_xticks([])
347 | ax.set_yticks([])
348 | ax.set_xticklabels([])
349 | ax.set_yticklabels([])
350 |
351 | plt.tight_layout()
352 | if ch == 1:
353 | plt.savefig(os.path.join(track_ops.save_path_fig, f'all_roi_match.png'), dpi=300)
354 | else:
355 | plt.savefig(os.path.join(track_ops.save_path_fig, f'all_roi_match_chan2.png'), dpi=300)
356 | plt.close(fig)
357 |
--------------------------------------------------------------------------------
/track2p/gui/raster_wd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QCheckBox, QComboBox,QGraphicsView,QGraphicsScene,QSplitter,QGroupBox
4 | from PyQt5.QtCore import Qt
5 | from PyQt5.QtGui import QPixmap
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from tqdm import tqdm
9 | from scipy.stats import zscore
10 | from sklearn.decomposition import PCA
11 | from openTSNE import TSNE
12 | import matplotlib.pyplot as plt
13 | import copy
14 |
15 |
16 |
17 | class RasterWindow(QWidget):
18 | #QWidget is the parent class
19 | def __init__(self, mainWindow ):
20 | super(RasterWindow,self).__init__()
21 | self.main_window = mainWindow
22 | self.raster_type=None
23 | self.bin_size = None
24 | self.filename=None
25 | self.all_f_t2p_preproc=None
26 | self.vmin_value=None
27 | self.vmax_value=None
28 |
29 |
30 | #Create the right-hand side of the window
31 | layout = QFormLayout()
32 |
33 | self.checkboxes=[]
34 | label_checkbox= QLabel("Choose the sorting method:")
35 | field_checkbox= QVBoxLayout()
36 | self.checkbox1 = QCheckBox('without sorting', self)
37 | self.checkbox2 = QCheckBox('sorting by PCA', self)
38 | self.checkbox3 = QCheckBox('sorting by PCA on given day', self)
39 | self.checkbox4 = QCheckBox('sorting by tSNE', self)
40 | self.checkbox5 = QCheckBox('sorting by tSNE on given day', self)
41 | self.checkboxes.append(self.checkbox1)
42 | self.checkboxes.append(self.checkbox2)
43 | self.checkboxes.append(self.checkbox3)
44 | self.checkboxes.append(self.checkbox4)
45 | self.checkboxes.append(self.checkbox5)
46 |
47 | self.checkbox3.stateChanged.connect(self.update_day_choice)
48 | self.checkbox5.stateChanged.connect(self.update_day_choice)
49 |
50 | self.day_choice= QComboBox(self)
51 | self.day_choice.addItem('Choose recording index (for sorting on given day)')
52 | field_checkbox.addWidget(self.checkbox1)
53 | field_checkbox.addWidget(self.checkbox2)
54 | field_checkbox.addWidget(self.checkbox3)
55 | field_checkbox.addWidget(self.checkbox4)
56 | field_checkbox.addWidget(self.checkbox5)
57 | field_checkbox.addWidget(self.day_choice)
58 | layout.addRow(label_checkbox,field_checkbox)
59 |
60 | for checkbox in self.checkboxes:
61 | checkbox.stateChanged.connect(self.handle_checkbox_state)
62 |
63 | self.combined_checkbox=QCheckBox('Combined (put neurons of all planes together) ', self)
64 | self.combined_checkbox.stateChanged.connect(self.concatenate_rasters)
65 | layout.addRow(" ",self.combined_checkbox)
66 |
67 | label_check=QLabel("Advanced options:")
68 | self.check=QCheckBox(self)
69 | self.check.stateChanged.connect(self.display_advanced_options)
70 |
71 | layout.addRow(label_check,self.check)
72 | self.advanced_options_group = QGroupBox()
73 | self.advanced_options_group.setVisible(False)
74 | group_layout = QFormLayout()
75 | self.bin= QLineEdit()
76 | self.bin.setText('1')
77 | self.bin.setFixedWidth(50)
78 | group_layout.addRow("Averaging bin size:", self.bin)
79 | self.vmin= QLineEdit()
80 | self.vmin.setText('0')
81 | self.vmin.setFixedWidth(50)
82 | group_layout.addRow("vmin:", self.vmin)
83 | self.vmax= QLineEdit()
84 | self.vmax.setText('1.96')
85 | self.vmax.setFixedWidth(50)
86 | group_layout.addRow("vmax:", self.vmax)
87 | self.advanced_options_group.setLayout(group_layout)
88 | layout.addRow('',self.advanced_options_group)
89 |
90 |
91 | # Initially hide the widgets
92 | self.bin.setVisible(False)
93 | self.vmin.setVisible(False)
94 | self.vmax.setVisible(False)
95 |
96 | label_run= QLabel("Run the analysis:")
97 | field_run=QPushButton("Run")
98 | field_run.clicked.connect(self.run)
99 | layout.addRow(label_run,field_run)
100 |
101 | label_save= QLabel("Save the figure:")
102 | field_save = QPushButton("Save")
103 | field_save.clicked.connect(self.get_output_save_path)
104 | layout.addRow(label_save,field_save)
105 |
106 | label_output_path= QLabel("The figure has been saved here:")
107 | self.field_output_path=QLabel()
108 | layout.addRow(label_output_path,self.field_output_path)
109 |
110 | self.view=QGraphicsView(self) #QGraphicsView is a subclass of QWidget
111 |
112 | splitter = QSplitter(Qt.Horizontal)
113 | right_widget = QWidget()
114 | right_widget.setLayout(layout)
115 | splitter.addWidget(right_widget)
116 | splitter.addWidget(self.view)
117 |
118 | main_layout = QVBoxLayout(self)
119 | main_layout.addWidget(splitter)
120 |
121 | self.setLayout(main_layout)
122 |
123 | ############################################################################################################################################################################
124 |
125 | def update_day_choice(self):
126 | if self.day_choice.count() ==1:
127 | self.all_stat_t2p=self.main_window.central_widget.data_management.all_stat_t2p
128 | if self.checkbox3.isChecked() or self.checkbox5.isChecked():
129 | print(self.day_choice.count())
130 | self.day_choice.clear()
131 | for i in range(len(self.all_stat_t2p)):
132 | self.day_choice.addItem(str(i + 1))
133 | print('ComboBox updated')
134 |
135 | def handle_checkbox_state(self):
136 | sender = self.sender()
137 | if sender.isChecked():
138 | for checkbox in self.checkboxes:
139 | if checkbox != sender:
140 | checkbox.setChecked(False)
141 |
142 | def run(self):
143 | if self.combined_checkbox.isChecked():
144 | self.all_f_t2p= self.concatenate_rasters()
145 | else:
146 | self.all_f_t2p=self.main_window.central_widget.data_management.all_f_t2p
147 | self.plane=self.main_window.central_widget.data_management.plane
148 | self.preprocessing()
149 | self.get_checkbox_choice()
150 |
151 |
152 | def display_advanced_options(self,state):
153 | if state == Qt.Checked:
154 | self.advanced_options_group.setVisible(True)
155 | self.bin.setVisible(True)
156 | self.vmin.setVisible(True)
157 | self.vmax.setVisible(True)
158 | else:
159 | self.advanced_options_group.setVisible(False)
160 | self.bin.setVisible(False)
161 | self.vmin.setVisible(False)
162 | self.vmax.setVisible(False)
163 |
164 | def fit_pca_1d(self,data):
165 | print('fitting 1d-PCA...')
166 | pca=PCA(n_components=1)
167 | pca.fit(data)
168 | embedding = pca.components_.T
169 | return embedding
170 |
171 | def fit_tsne_1d(self,data):
172 | print('fitting 1d-tSNE...')
173 | tsne = TSNE(
174 | n_components=1,
175 | perplexity=30,
176 | initialization="pca",
177 | metric="euclidean",
178 | n_jobs=8,
179 | random_state=3
180 | )
181 |
182 | tsne_emb = tsne.fit(data.T)
183 | return tsne_emb
184 |
185 |
186 | def get_output_save_path(self):
187 | save_path, _ = QFileDialog.getSaveFileName(self, "Select Save Path")
188 | if save_path:
189 | self.save_path = save_path
190 | print(f'Selected save path: {self.save_path}')
191 | plt.savefig(self.save_path)
192 |
193 | def get_checkbox_choice(self):
194 | vmin=float(self.vmin.text())
195 | vmax=float(self.vmax.text())
196 | if self.checkbox1.isChecked():
197 | self.raster_type='without_sorting'
198 | self.plot_track2p_rasters(self.all_f_t2p_preproc, bin_size=self.bin_size, vmin=vmin, vmax=vmax)
199 | if self.checkbox2.isChecked():
200 | self.raster_type='sorting_by_PCA'
201 | all_pca_emb_1d = []
202 | for f in tqdm(self.all_f_t2p_preproc):
203 | print(f.shape)
204 | pca_emb_1d = self.fit_pca_1d(f.T)
205 | print(pca_emb_1d.shape)
206 | all_pca_emb_1d.append(pca_emb_1d)
207 | all_f_t2p_sorted = []
208 | for (i, f) in enumerate(self.all_f_t2p_preproc):
209 | sort_inds = np.argsort(np.array(all_pca_emb_1d[i]).squeeze())
210 | f_sorted = f[sort_inds, :].squeeze()
211 | all_f_t2p_sorted.append(f_sorted)
212 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters
213 | if self.checkbox3.isChecked():
214 | self.raster_type='sorting_by_PCA_and_by_day_' + str(self.day_choice.currentText())
215 | all_pca_emb_1d = []
216 | for f in tqdm(self.all_f_t2p_preproc):
217 | print(f.shape)
218 | pca_emb_1d = self.fit_pca_1d(f.T)
219 | print(pca_emb_1d.shape)
220 | all_pca_emb_1d.append(pca_emb_1d)
221 | all_f_t2p_sorted = []
222 | for (i, f) in enumerate(self.all_f_t2p_preproc):
223 | user_i= int(self.day_choice.currentText()) - 1
224 | sort_inds = np.argsort(np.array(all_pca_emb_1d[user_i]).squeeze())
225 | f_sorted = f[sort_inds, :].squeeze()
226 | all_f_t2p_sorted.append(f_sorted)
227 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax)
228 | if self.checkbox4.isChecked():
229 | self.raster_type='sorting_by_tSNE'
230 | all_tsne_emb_1d = []
231 | for f in tqdm(self.all_f_t2p_preproc):
232 | print(f.shape)
233 | tsne_emb_1d = self.fit_tsne_1d(f.T)
234 | all_tsne_emb_1d.append(tsne_emb_1d)
235 | all_f_t2p_sorted = []
236 | for (i, f) in enumerate(self.all_f_t2p_preproc):
237 | sort_inds = np.argsort(np.array(all_tsne_emb_1d[i]).squeeze())
238 | f_sorted = f[sort_inds, :].squeeze()
239 | all_f_t2p_sorted.append(f_sorted)
240 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters
241 | if self.checkbox5.isChecked():
242 | self.raster_type='sorting_tSNE_and_by_day_' + str(self.day_choice.currentText())
243 | all_tsne_emb_1d = []
244 | for f in tqdm(self.all_f_t2p_preproc):
245 | print(f.shape)
246 | tsne_emb_1d = self.fit_tsne_1d(f.T)
247 | all_tsne_emb_1d.append(tsne_emb_1d)
248 | all_f_t2p_sorted = []
249 | for (i, f) in enumerate(self.all_f_t2p_preproc):
250 | user_i= int(self.day_choice.currentText()) - 1
251 | sort_inds = np.argsort(np.array(all_tsne_emb_1d[user_i]).squeeze())
252 | f_sorted = f[sort_inds, :].squeeze()
253 | all_f_t2p_sorted.append(f_sorted)
254 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters
255 |
256 |
257 | def preprocessing(self):
258 | bin_data = True
259 | self.bin_size = int(self.bin.text()) #number of frames to average (1 = no averging)
260 | print(f'bin_size: {self.bin_size}')
261 | rem_zero_rows = True
262 | if bin_data:
263 | all_f_t2p_original = copy.deepcopy(self.all_f_t2p)
264 | self.all_f_t2p_preproc = [np.mean(f.reshape(f.shape[0], -1, self.bin_size), axis=2) for f in all_f_t2p_original]
265 | # renormalize
266 | self.all_f_t2p_preproc =self.zscore_all_f_t2p(self.all_f_t2p_preproc)
267 | if rem_zero_rows:
268 | # get zero rows in any of the datasets
269 | zero_rows = np.any([np.sum(np.isnan(f), axis=1) for f in self.all_f_t2p_preproc], axis=0)
270 | print(f'Number of zero rows: {np.sum(zero_rows)}')
271 | self.all_f_t2p_preproc= [f[~zero_rows, :] for f in self.all_f_t2p_preproc]
272 |
273 |
274 | def zscore(self, f, axis=1):
275 | '''this method calculates the z-score for each element in the input array f along the specified axis, ignoring NaN values.'''
276 | return (f - np.nanmean(f, axis=axis, keepdims=True)) / np.nanstd(f, axis=axis, keepdims=True)
277 |
278 | def zscore_all_f_t2p(self, all_f_t2p):
279 | return [zscore(f, axis=1) for f in all_f_t2p] #axis=1 means that we are normalizing each neuron's activity (each line of the matrix f)
280 |
281 |
282 | def plot_track2p_rasters(self, all_f_t2p, bin_size=1, vmin=None, vmax=None):
283 |
284 | fig, ax = plt.subplots(len(all_f_t2p), 1, figsize=(6, len(all_f_t2p)*1), dpi=150) #rows, columns, size, resolution
285 |
286 | for i, f in enumerate(all_f_t2p):
287 | ax[i].imshow(zscore(f, axis=1), aspect='auto', cmap='Greys', vmin=vmin, vmax=vmax)
288 | # only first and last yticks
289 | ax[i].set_yticks([0, f.shape[0]-1])
290 | ax[i].set_yticklabels([0, f.shape[0]])
291 | # move the text of the yticklabels a bit: the first one down and the last one up
292 | ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0, va='center', ha='right')
293 |
294 | # only first and last xticks on y axis in minutes
295 | if i == len(all_f_t2p)-1:
296 | ax[i].set_xlabel('Frames')
297 | ax[i].set_xticks([0, f.shape[1]/2, f.shape[1]-1])
298 | ax[i].set_xticklabels([0, int((f.shape[1]*bin_size)/2), int(f.shape[1]*bin_size)])
299 |
300 | else:
301 | ax[i].set_xticks([])
302 | ax[i].set_xticklabels([])
303 |
304 | plt.subplots_adjust(bottom=0.2)
305 | buf = io.BytesIO()
306 | plt.savefig(buf, format='png')
307 | buf.seek(0)
308 | pixmap = QPixmap()
309 | pixmap.loadFromData(buf.getvalue())
310 | scene=QGraphicsScene()
311 | scene.addPixmap(pixmap)
312 | self.view.setScene(scene)
313 | print('Done')
314 |
315 |
316 |
317 |
318 | def concatenate_rasters(self):
319 |
320 | track_ops=self.main_window.central_widget.data_management.track_ops
321 | t2p_folder_path= os.path.dirname(track_ops.all_ds_path[0])
322 | print(t2p_folder_path)
323 | if track_ops.nplanes > 1:
324 | self.results_by_plane = {}
325 | print(track_ops.nplanes)
326 | for plane in range (track_ops.nplanes):
327 | print(plane)
328 | if plane == self.main_window.central_widget.data_management.plane:
329 | print('plane is equal to the current plane, skipping to the next plane')
330 | self.results_by_plane[plane]={
331 | 'all_ft2p': self.main_window.central_widget.data_management.all_f_t2p,
332 | 'all_fneu2p': self.main_window.central_widget.data_management.all_fneu
333 | }
334 | continue
335 | print('plane is not equal to the current plane')
336 | t2p_match_mat = np.load(os.path.join(t2p_folder_path,"track2p" ,f"plane{plane}_match_mat.npy"), allow_pickle=True)
337 | t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :]
338 | trace_type=self.main_window.central_widget.data_management.trace_type #common to all planes
339 | print(f"Processing plane {plane}")
340 | self.process_plane(plane,track_ops,t2p_match_mat_allday,trace_type)
341 | for plane, data in self.results_by_plane .items():
342 | print(f"Plane {plane}:")
343 | print(f" all_ft2p: {len(data['all_ft2p'])}")
344 | print(f" {len(data['all_ft2p'][0])}")
345 |
346 | # Initialiser une liste pour stocker les éléments concaténés
347 | concatenated_elements = []
348 | num_elements = len(self.results_by_plane[0]['all_ft2p'])
349 | print(f"Number of elements: {num_elements}")
350 | # Itérer sur les indices des éléments
351 | for i in range(num_elements):
352 | elements_to_concatenate = []
353 | for plane in range(track_ops.nplanes):
354 | if 'all_ft2p' in self.results_by_plane[plane] and isinstance(self.results_by_plane[plane]['all_ft2p'], list):
355 | # Récupérer l'élément i de 'all_ft2p' pour le plan actuel
356 | element = self.results_by_plane[plane]['all_ft2p'][i]
357 | elements_to_concatenate.append(element)
358 | else:
359 | print("La clé 'all_ft2p' n'existe pas ou n'est pas une liste.")
360 | if elements_to_concatenate:
361 | concatenated_element = np.vstack(elements_to_concatenate)
362 | concatenated_elements.append(concatenated_element)
363 |
364 | # Afficher les éléments concaténés
365 | for idx, concatenated_element in enumerate(concatenated_elements):
366 | print(f"Concatenated element {idx}:")
367 | #print(concatenated_element)
368 | print(concatenated_element.shape)
369 |
370 | return concatenated_elements
371 |
372 |
373 |
374 | def process_plane(self, plane, track_ops, t2p_match_mat_allday, trace_type):
375 | all_ft2p=[]
376 | all_fneu2p=None
377 | for (i, ds_path) in enumerate(track_ops.all_ds_path):
378 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy'), allow_pickle=True)
379 | if trace_type == 'F' :
380 | print('F trace')
381 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True)
382 | if trace_type == 'spks':
383 | print('spks trace')
384 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'spks.npy'), allow_pickle=True)
385 | if trace_type == 'dF/F0':
386 | print('dF/F0 trace')
387 | if all_fneu2p is None:
388 | all_fneu2p= []
389 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True)
390 | fneu = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'Fneu.npy'), allow_pickle=True)
391 | if track_ops.iscell_thr is None:
392 | fneu_iscell = fneu[iscell[:, 0] == 1, :]
393 | else:
394 | fneu_iscell = fneu[iscell[:, 1] > track_ops.iscell_thr, :]
395 | fneu_t2p= fneu_iscell[t2p_match_mat_allday[:, i].astype(int), :]
396 | all_fneu2p.append(fneu_t2p)
397 | if track_ops.iscell_thr is None:
398 | f_iscell = f[iscell[:, 0] == 1, :]
399 | else:
400 | f_iscell = f[iscell[:, 1] > track_ops.iscell_thr, :]
401 |
402 | f_t2p = f_iscell[t2p_match_mat_allday[:, i].astype(int), :]
403 | all_ft2p.append(f_t2p)
404 |
405 | self.results_by_plane[plane]={
406 | 'all_ft2p': all_ft2p,
407 | 'all_fneu2p': all_fneu2p
408 | }
409 |
--------------------------------------------------------------------------------