├── track2p ├── gui │ ├── __init__.py │ ├── window_manager.py │ ├── toolbar.py │ ├── main_wd.py │ ├── import_wd.py │ ├── custom_wd.py │ ├── statusbar.py │ ├── roi_plot.py │ ├── fluo_plot.py │ ├── central_widget.py │ ├── data_management.py │ ├── cell_plot.py │ ├── t2p_wd.py │ └── raster_wd.py ├── io │ ├── __init__.py │ ├── utils.py │ ├── loaders.py │ ├── savers.py │ └── s2p_loaders.py ├── ops │ ├── __init__.py │ └── default.py ├── match │ ├── __init__.py │ ├── loop.py │ └── utils.py ├── plot │ ├── __init__.py │ ├── progress.py │ ├── utils.py │ └── output.py ├── register │ ├── __init__.py │ ├── elastix.py │ ├── utils.py │ └── loop.py ├── __init__.py ├── resources │ └── logo.png ├── __main__.py ├── eval │ ├── io.py │ └── plot.py └── t2p.py ├── docs ├── media │ └── plots │ │ ├── ex_all_vizualizations.png │ │ └── ex_all_vizualizations_backup.png ├── README.md └── troubleshooting.md ├── setup.py ├── .gitignore ├── notebooks ├── run_t2p.ipynb ├── utils │ ├── npy_to_s2p.ipynb │ └── s2p_to_npy.ipynb ├── eval │ └── main_eval_fig.ipynb └── demo_t2p_ouputs.ipynb └── README.md /track2p/gui/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/io/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/match/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/plot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/register/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /track2p/__init__.py: -------------------------------------------------------------------------------- 1 | # make all submodules available 2 | from . import * 3 | -------------------------------------------------------------------------------- /track2p/resources/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juremaj/track2p/HEAD/track2p/resources/logo.png -------------------------------------------------------------------------------- /docs/media/plots/ex_all_vizualizations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juremaj/track2p/HEAD/docs/media/plots/ex_all_vizualizations.png -------------------------------------------------------------------------------- /docs/media/plots/ex_all_vizualizations_backup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juremaj/track2p/HEAD/docs/media/plots/ex_all_vizualizations_backup.png -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | As of August 2024 all docs have been moved to the GitHub pages [repo](https://github.com/track2p/track2p.github.io/tree/main/Track2p) and can be accessed on the associated [website](https://track2p.github.io/). -------------------------------------------------------------------------------- /track2p/io/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # make a directory based on path if it doesn't exist yet 4 | 5 | def make_dir(path): 6 | if not os.path.exists(path): 7 | os.makedirs(path) 8 | print('Created directory: ' + path) 9 | else: 10 | print('Directory already exists: ' + path) -------------------------------------------------------------------------------- /track2p/__main__.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QApplication 2 | import time 3 | from track2p.gui.main_wd import MainWindow 4 | from PyQt5.QtGui import QIcon 5 | import os 6 | 7 | # the same script as track2p/gui/run_gui.py 8 | 9 | if __name__ == '__main__': 10 | start_time = time.time() 11 | 12 | app = QApplication([]) 13 | 14 | 15 | # Utiliser un chemin relatif pour l'icône 16 | icon_path = os.path.join(os.path.dirname(__file__), 'resources', 'logo.png') 17 | print(icon_path) 18 | 19 | if not os.path.exists(icon_path): 20 | print(f"Icon file not found: {icon_path}") 21 | else: 22 | app.setWindowIcon(QIcon(icon_path)) 23 | 24 | mainWindow = MainWindow() 25 | mainWindow.setWindowTitle("track2p") 26 | 27 | end_time = time.time() 28 | print(f"Application took {end_time - start_time} seconds to open.") 29 | app.exec_() -------------------------------------------------------------------------------- /track2p/gui/window_manager.py: -------------------------------------------------------------------------------- 1 | from track2p.gui.t2p_wd import Track2pWindow 2 | from track2p.gui.import_wd import ImportWindow 3 | from track2p.gui.raster_wd import RasterWindow 4 | #from track2p.gui.s2p_wd import Suite2pWindow 5 | 6 | class WindowManager: 7 | def __init__(self, main_window): 8 | self.main_window = main_window 9 | self.t2p_window = Track2pWindow(self.main_window) 10 | self.import_window = ImportWindow(self.main_window) 11 | self.raster_window=RasterWindow(self.main_window) 12 | #self.suite2p_window=Suite2pWindow(self.main_window) 13 | 14 | def open_track2p_wd(self): 15 | self.t2p_window.show() 16 | 17 | def open_import_wd(self): 18 | self.import_window.show() 19 | 20 | def open_raster_wd(self): 21 | self.raster_window.show() 22 | 23 | 24 | 25 | #def open_suite2p_wd(self): 26 | # self.suite2p_window.show() 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('README.md', 'r') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='track2p', 8 | version='0.6.2', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'numpy==2.0.2', 12 | 'matplotlib==3.9.4', 13 | 'scikit-image==0.24.0', 14 | 'itk==5.4.3', 15 | 'PyQt5==5.15.11', 16 | 'qtpy==2.4.3', 17 | 'tqdm==4.67.1', 18 | 'scikit-learn==1.6.1', 19 | 'openTSNE==1.0.2', 20 | 'pandas==2.2.3', 21 | 'itk-elastix==0.23.0' 22 | # 'numpy==1.23.5', 23 | # 'matplotlib==3.5.3', 24 | # 'scikit-image==0.20.0', 25 | # 'itk==5.4rc2', 26 | # 'PyQt5==5.15.10', 27 | # 'qtpy==2.4.1', 28 | # 'tqdm==4.66.2', 29 | # 'scikit-learn==1.4.0', 30 | # 'openTSNE==1.0.1', 31 | # 'pandas==1.5.3', 32 | ], 33 | long_description=long_description, 34 | long_description_content_type='text/markdown', 35 | include_package_data=True, 36 | package_data={ 37 | '': ['resources/logo.png'], 38 | }, 39 | ) -------------------------------------------------------------------------------- /track2p/plot/progress.py: -------------------------------------------------------------------------------- 1 | # contains code for generating plots while the pipeline is running 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from track2p.plot.utils import match_hist_all 7 | 8 | def plot_all_planes(all_ds_avg_ch, track_ops, sat_perc=99, ch='funcional'): 9 | nplanes = track_ops.nplanes 10 | fig, axs = plt.subplots(nplanes, len(track_ops.all_ds_path), figsize=(3 * len(track_ops.all_ds_path), 3 * nplanes), dpi=300) 11 | # add dummy dimension to axs if only one plane 12 | if nplanes==1: 13 | axs = np.expand_dims(axs, axis=0) 14 | all_ds_avg_ch_matched = match_hist_all(all_ds_avg_ch) 15 | 16 | 17 | for i in range(nplanes): 18 | for j in range(len(track_ops.all_ds_path)): 19 | img = all_ds_avg_ch_matched[j][i] 20 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, sat_perc)) 21 | axs[i, j].set_title('Plane ' + str(i) + ' in dataset ' + str(j)) 22 | axs[i, j].axis('off') 23 | 24 | fig.savefig(track_ops.save_path_fig + f'mean_fov_{ch}.png', bbox_inches='tight', dpi=200) 25 | 26 | plt.close(fig) 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /track2p/register/elastix.py: -------------------------------------------------------------------------------- 1 | import itk 2 | import numpy as np 3 | 4 | def reg_img_elastix(ref_img, mov_img, track_ops): 5 | # convert to itk images 6 | ref_img_itk = itk.GetImageFromArray(ref_img) 7 | mov_img_itk = itk.GetImageFromArray(mov_img) 8 | 9 | # import default parameter map 10 | parameter_object = itk.ParameterObject.New() 11 | parameter_map = parameter_object.GetDefaultParameterMap(track_ops.transform_type) 12 | parameter_object.AddParameterMap(parameter_map) 13 | 14 | # call registration function 15 | mov_img_reg_itk, reg_params = itk.elastix_registration_method( 16 | ref_img_itk, 17 | mov_img_itk, 18 | parameter_object = parameter_object 19 | ) 20 | 21 | # convert back to numpy array 22 | mov_img_reg = itk.GetArrayFromImage(mov_img_reg_itk) 23 | reg_params.SetParameter("FinalBSplineInterpolationOrder", "0") 24 | 25 | return mov_img_reg, reg_params 26 | 27 | def itk_reg_roi(roi, reg_params): 28 | roi_itk = itk.GetImageFromArray(roi.astype(np.uint8)) 29 | roi_itk_trans = itk.transformix_filter(roi_itk, reg_params) 30 | roi_trans = itk.GetArrayFromImage(roi_itk_trans) 31 | return roi_trans 32 | 33 | def itk_reg_all_roi(all_roi, reg_params): 34 | all_roi_array_reg = np.zeros_like(all_roi) 35 | for i in range(all_roi.shape[2]): 36 | roi_array = all_roi[:,:,i] 37 | roi_array_reg = itk_reg_roi(roi_array, reg_params) 38 | all_roi_array_reg[:,:,i] = roi_array_reg 39 | return all_roi_array_reg -------------------------------------------------------------------------------- /track2p/gui/toolbar.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QToolBar,QMenu,QAction,QToolButton 2 | 3 | class Toolbar(QToolBar): 4 | def __init__(self, main_window): 5 | super().__init__() 6 | 7 | self.main_window = main_window 8 | self.init_tool_bar() 9 | 10 | def init_tool_bar(self): 11 | data_menu = QMenu(self) 12 | run_menu = QMenu(self) 13 | visualization_menu = QMenu(self) 14 | 15 | import_action = QAction("Load processed data (⌘L or Ctrl+L)", self) 16 | import_action.setShortcut("Ctrl+L") 17 | import_action.triggered.connect(self.main_window.window_manager.open_import_wd) 18 | data_menu.addAction(import_action) 19 | 20 | track2p_action = QAction("Run track2p algorithm (⌘R or Ctrl+R)", self) 21 | track2p_action.setShortcut("Ctrl+R") 22 | track2p_action.triggered.connect(self.main_window.window_manager.open_track2p_wd) 23 | run_menu.addAction(track2p_action) 24 | 25 | raster_action = QAction("Generate raster plot (⌘G or Ctrl+G)", self) 26 | raster_action.setShortcut("Ctrl+G") 27 | raster_action.triggered.connect(self.main_window.window_manager.open_raster_wd) 28 | visualization_menu.addAction(raster_action) 29 | 30 | self.add_tool_menu("File", data_menu) 31 | self.add_tool_menu("Run", run_menu) 32 | self.add_tool_menu("Visualization", visualization_menu) 33 | 34 | def add_tool_menu(self, name, menu): 35 | button = QToolButton(self) 36 | button.setText(name) 37 | button.setMenu(menu) 38 | button.setPopupMode(QToolButton.InstantPopup) 39 | button.setStyleSheet("font-size: 13px") 40 | self.addWidget(button) 41 | 42 | -------------------------------------------------------------------------------- /track2p/eval/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def load_alldays_f1_values(base_path, animals, conditions): 5 | 6 | f1_values = {animal: [] for animal in animals} 7 | 8 | for animal in animals: 9 | for condition in conditions: 10 | metrics = np.load(os.path.join(base_path, animal, condition,'metrics_t2p_all_days.npy'), allow_pickle=True) 11 | f1_value = metrics[np.where(metrics[:, 0] == 'F1')[0][0], 1] 12 | f1_values[animal].append(f1_value) 13 | 14 | return f1_values 15 | 16 | 17 | def load_alldays_ct_values(base_path, animals, conditions, ct_type='CT'): 18 | 19 | ct_values = {animal: [] for animal in animals} 20 | acc_values = {animal: [] for animal in animals} 21 | 22 | for animal in animals: 23 | for condition in conditions: 24 | metrics = np.load(os.path.join(base_path, animal, condition,f'result_{ct_type}.npy'), allow_pickle=True) 25 | 26 | print(metrics) 27 | print('-------------------') 28 | 29 | ct = metrics[0][1:] 30 | acc = metrics[1][1:] 31 | 32 | ct_values[animal].append(np.array(ct)) 33 | acc_values[animal].append(np.array(acc)) 34 | 35 | return ct_values, acc_values 36 | 37 | # def load_alldays_ct_gt_values(base_path, animals, conditions): 38 | 39 | 40 | def load_pairwise_f1_values(base_path, animals, condition): 41 | 42 | f1_values = {animal: [] for animal in animals} 43 | 44 | for animal in animals: 45 | if condition == 'pw_reg': 46 | file_path = os.path.join(base_path, animal, 'metrics_table_pw_registration.npy') 47 | else: 48 | file_path = os.path.join(base_path, animal, condition, 'metrics_table_pairs.npy') 49 | metrics = np.load(file_path, allow_pickle=True) 50 | 51 | f1_scores = metrics[7, 1:].astype(float) 52 | f1_values[animal] = f1_scores 53 | 54 | return f1_values -------------------------------------------------------------------------------- /track2p/io/loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | def load_track_ops(track_ops_path): 5 | track_ops = np.load(track_ops_path + 'track_ops_postreg.npy', allow_pickle=True).item() 6 | return track_ops 7 | 8 | def load_stat_ds_plane(track_ops_path, track_ops, plane_idx=0): 9 | stat = np.load(track_ops_path + f'/suite2p/plane{plane_idx}/stat.npy', allow_pickle=True) 10 | iscell = np.load(track_ops_path + f'/suite2p/plane{plane_idx}/iscell.npy', allow_pickle=True) 11 | # filter based on track_ops.iscell_thr 12 | len_stat_allcell = len(stat) 13 | 14 | if track_ops.iscell_thr==None: 15 | stat = stat[iscell[:,0] ==1] 16 | else: 17 | stat= stat[iscell[:,1]>track_ops.iscell_thr] 18 | len_stat_iscell = len(stat) 19 | 20 | print(f'Loading ROIs for plane{plane_idx} in dataset {track_ops_path.split(os.path.sep)[-1]}') 21 | 22 | print(f'Chose {len_stat_iscell}/{len_stat_allcell} ROIs, based on s2p iscell threshold {track_ops.iscell_thr} (see track_ops.iscell_thr)') 23 | # make a stat_summary dictionary 24 | stat_summary = { 25 | 'len_stat_allcell': len_stat_allcell, 26 | 'len_stat_iscell': len_stat_iscell, 27 | 'iscell_thr': track_ops.iscell_thr 28 | } 29 | return stat, stat_summary 30 | 31 | def get_all_roi_array_from_stat(stat, track_ops): 32 | n_xpix = track_ops.all_ds_avg_ch1[0][0].shape[0] 33 | n_ypix = track_ops.all_ds_avg_ch1[0][0].shape[1] 34 | 35 | all_roi_array = np.zeros((n_xpix, n_ypix, len(stat)), bool) 36 | 37 | for i in range(len(stat)): 38 | 39 | roi_xpix = stat[i]['xpix'] 40 | roi_ypix = stat[i]['ypix'] 41 | 42 | # Convert the ROI coordinates into a grid of values 43 | roi_grid = np.zeros((n_xpix, n_ypix), dtype=np.float32) 44 | roi_grid[np.array(roi_ypix), np.array(roi_xpix)] = 1 45 | 46 | all_roi_array[:,:,i] = roi_grid 47 | 48 | return all_roi_array 49 | 50 | -------------------------------------------------------------------------------- /track2p/plot/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.exposure import match_histograms 3 | 4 | 5 | def match_hist_all(all_ds_avg_ch): 6 | ref = all_ds_avg_ch[0][0] 7 | all_ds_avg_ch_matched = [] 8 | for ds_avg_ch in all_ds_avg_ch: 9 | ds_avg_ch_matched = [] 10 | for i in range(len(ds_avg_ch)): 11 | ds_avg_ch_matched.append(match_histograms(ds_avg_ch[i], ref)) 12 | all_ds_avg_ch_matched.append(ds_avg_ch_matched) 13 | 14 | return all_ds_avg_ch_matched 15 | 16 | def make_rgb_img(img1, img2): 17 | img_rgb = np.zeros((img1.shape[0], img1.shape[1], 3)) 18 | # normalise images to 0-1 19 | img1_norm = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1)) 20 | img2_norm = (img2 - np.min(img2)) / (np.max(img2) - np.min(img2)) 21 | img_rgb[:, :, 0] = img1_norm 22 | img_rgb[:, :, 1] = img2_norm 23 | 24 | return img_rgb 25 | 26 | def saturate_perc(img_rgb, sat_perc=99): 27 | img_rgb = np.clip(img_rgb, 0, np.percentile(img_rgb, sat_perc)) 28 | img_rgb = (img_rgb / np.max(img_rgb) * 255).astype(np.uint8) 29 | return img_rgb 30 | 31 | def get_all_wind_mean_img(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, nrn_id, plane_idx=0, win_size=64): 32 | 33 | all_wind_mean_img = [] 34 | for i in range(len(all_ds_mean_img)): 35 | centroid_ids = all_pl_match_mat[plane_idx][nrn_id,:] 36 | cent = all_ds_centroids[i][plane_idx][centroid_ids[i]] 37 | mean_img = all_ds_mean_img[i][plane_idx] 38 | 39 | # pad mean image with mean and shift the centroid appropriately 40 | # mean_img = np.pad(mean_img, ((win_size,win_size),(win_size,win_size))) 41 | mean_img = np.pad(mean_img, ((win_size,win_size),(win_size,win_size)), mode='constant', constant_values=0) 42 | cent = cent + win_size 43 | 44 | wind_mean_img = mean_img[int(cent[0]-win_size/2):int(cent[0]+win_size/2), int(cent[1]-win_size/2):int(cent[1]+win_size/2)] 45 | 46 | 47 | all_wind_mean_img.append(wind_mean_img) 48 | 49 | return all_wind_mean_img -------------------------------------------------------------------------------- /track2p/gui/main_wd.py: -------------------------------------------------------------------------------- 1 | 2 | from track2p.gui.window_manager import WindowManager 3 | from track2p.gui.toolbar import Toolbar 4 | from track2p.gui.statusbar import StatusBar 5 | from track2p.gui.data_management import DataManagement 6 | from track2p.gui.central_widget import CentralWidget 7 | from PyQt5.QtWidgets import QApplication,QMainWindow 8 | from PyQt5.QtGui import QIcon 9 | 10 | class MainWindow(QMainWindow): 11 | 12 | def __init__(self): 13 | super(MainWindow, self).__init__() 14 | self.main_window = self 15 | self.window_manager = WindowManager(self) 16 | self.data_management = DataManagement(self) 17 | self.central_widget = CentralWidget(self) 18 | self.toolbar = Toolbar(self) 19 | self.status_bar = StatusBar(self) 20 | 21 | 22 | 23 | self.initUI() 24 | 25 | def initUI(self): 26 | 27 | 28 | self.setStyleSheet( 29 | "QTabWidget::pane { border: 1px solid #666; }" 30 | "QTabWidget::tab-bar { alignment: center; }" 31 | "QTabBar::tab { background-color: #666; color: white; }" 32 | "QTabBar::tab:selected { background-color: #222; color: white; }" 33 | "QSplitter::handle { background: #888; }" 34 | "QFrame { background-color: black; color: black; border: 1px solid black;}" 35 | "QLabel { color: black; background-color: none; border: none; font-size: 13px}" 36 | "QPushButton { background-color: #666; color: white; border: 1px solid #888; }" 37 | "QPushButton:hover { background-color: #888; color: white; }" 38 | "QPushButton:pressed { background-color: #333; color: white; }" 39 | "QToolButton:pressed { background-color: #888; }" 40 | "QComboBox { background-color: black; color: white; }" 41 | "QComboBox QAbstractItemView { background-color: #666; color: white; }" 42 | ) 43 | 44 | self.setWindowTitle("track2p GUI") 45 | 46 | self.setCentralWidget(self.central_widget) 47 | self.addToolBar(self.toolbar) 48 | self.setStatusBar(self.status_bar) 49 | QApplication.setStyle('Cleanlooks') 50 | self.showMaximized() 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /track2p/register/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_all_ds_img_for_reg(all_ds_avg_ch1, all_ds_avg_ch2, track_ops): # chooses which channel to use for registration and returns all_ref_img and all_mov_img (shifted by one day to always register to previous day) 4 | if track_ops.reg_chan==0: 5 | all_ds_avg = all_ds_avg_ch1 6 | elif track_ops.reg_chan==1: 7 | all_ds_avg = all_ds_avg_ch2 8 | print('WARNING: using anatomical channel for registration (this is not always available)') 9 | 10 | all_ds_ref_img = [] 11 | all_ds_mov_img = [] 12 | 13 | for i in range(len(track_ops.all_ds_path)-1): 14 | ds_ref_img = [] 15 | ds_mov_img = [] 16 | 17 | for j in range(track_ops.nplanes): 18 | ds_ref_img.append(all_ds_avg[i][j]) 19 | ds_mov_img.append(all_ds_avg[i+1][j]) 20 | 21 | all_ds_ref_img.append(ds_ref_img) 22 | all_ds_mov_img.append(ds_mov_img) 23 | 24 | track_ops.all_ds_ref_img = all_ds_ref_img 25 | track_ops.all_ds_mov_img = all_ds_mov_img 26 | 27 | return all_ds_ref_img, all_ds_mov_img 28 | 29 | 30 | def get_ref_reg_inters(all_roi_array_ref, all_roi_array_nonref): 31 | # get the projection of all rois 32 | all_roi_array_ref_proj = np.sum(all_roi_array_ref, axis=2) > 0 33 | all_roi_array_nonref_proj = np.sum(all_roi_array_nonref, axis=2) > 0 34 | 35 | # now get the intersection of the reg and ref 36 | all_roi_array_inters = np.logical_and(all_roi_array_ref_proj, all_roi_array_nonref_proj) 37 | 38 | # make rgb image of intersection 39 | ref_reg_inters = np.ones((all_roi_array_inters.shape[0], all_roi_array_inters.shape[1], 3)) 40 | ref_reg_inters[:,:,0] 41 | ref_reg_inters[:,:,1] -= all_roi_array_inters/6 # orange tint 42 | ref_reg_inters[:,:,2] -= all_roi_array_inters 43 | 44 | return ref_reg_inters 45 | 46 | def get_all_ref_nonref_inters(all_ds_all_roi_array_ref, all_ds_all_roi_array_nonref, track_ops): 47 | all_ds_all_ref_nonref_inters = [] 48 | for i in range(len(track_ops.all_ds_path)-1): 49 | ds_all_ref_nonref_inters = [] 50 | for j in range(track_ops.nplanes): 51 | all_roi_array_ref = all_ds_all_roi_array_ref[i][j] 52 | all_roi_array_nonref = all_ds_all_roi_array_nonref[i][j] 53 | ref_nonref_inters = get_ref_reg_inters(all_roi_array_ref, all_roi_array_nonref) 54 | ds_all_ref_nonref_inters.append(ref_nonref_inters) 55 | all_ds_all_ref_nonref_inters.append(ds_all_ref_nonref_inters) 56 | return all_ds_all_ref_nonref_inters -------------------------------------------------------------------------------- /track2p/eval/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_alldays_f1(animals, conditions, f1_values, symbols, colors, xshift=0, animals_names=None): 5 | 6 | plt.figure(figsize=(4, 2), dpi=300) 7 | for (i, animal) in enumerate(animals): 8 | print(animal) 9 | print(f1_values[animal]) 10 | y_data = f1_values[animal] 11 | x_data = np.arange(len(conditions)) + i * xshift 12 | 13 | label = animal if animals_names is None else animals_names[i] 14 | plt.plot(x_data, y_data, symbols[animal], label=label, color=colors[animal]) 15 | # add vertical lines to each point 16 | for (i, x) in enumerate(x_data): 17 | plt.vlines(x, 0, y_data[i], color=colors[animal], linewidth=2) 18 | # label x ticks 19 | plt.xticks(x_data-xshift, conditions, rotation=45) 20 | 21 | # add a dashed grey line at y=0 22 | plt.axhline(0, color='grey', linestyle='--', alpha=0.5) 23 | 24 | 25 | plt.ylabel('F1 Score') 26 | plt.yticks([0, 0.5, 1]) 27 | ax = plt.gca() 28 | ax.spines['top'].set_visible(False) 29 | ax.spines['right'].set_visible(False) 30 | plt.legend(loc='upper right', fontsize=8) 31 | plt.show() 32 | 33 | 34 | def plot_pairwise_f1(animals, condition, pairwise_f1_values, symbols, colors, show_d0=True, show_legend=False): 35 | plt.figure(figsize=(2, 2), dpi=300) 36 | plt.title(f'{condition}') 37 | plt.xlabel('Days') 38 | plt.ylabel('Prop. correct') 39 | 40 | for (i, animal) in enumerate(animals): 41 | 42 | if show_d0: 43 | y_data = np.concatenate([[1], pairwise_f1_values[animal]]) 44 | else: 45 | y_data = pairwise_f1_values[animal] 46 | 47 | x_data = np.arange(1, len(y_data) + 1) 48 | 49 | zorder = 10 if animal == 'jm039' else i + 1 50 | 51 | plt.plot(x_data, y_data, label=animal, marker=symbols[animal], color=colors[animal], markersize=4, zorder=zorder) 52 | 53 | # add a dashed grey line at y=0 54 | plt.axhline(0, color='grey', linestyle='--', alpha=0.5) 55 | 56 | plt.ylim(-0.05, 1.05) 57 | plt.yticks([0, 0.5, 1]) 58 | ax = plt.gca() 59 | ax.spines['top'].set_visible(False) 60 | ax.spines['right'].set_visible(False) 61 | 62 | if show_d0: 63 | ax.set_xticks(x_data) 64 | # set 'P8' as first and 'P14' as last label 65 | ax.set_xticklabels(['P8', '', '', '', '', '', 'P14']) 66 | else: 67 | ax.set_xticks(x_data) 68 | ax.set_xticklabels(['P9', '', '', '', '', 'P14']) 69 | 70 | 71 | 72 | # plt.legend() 73 | # make legend smaller 74 | if show_legend: 75 | plt.legend(loc='upper right', fontsize=8) 76 | plt.show() -------------------------------------------------------------------------------- /track2p/gui/import_wd.py: -------------------------------------------------------------------------------- 1 | 2 | from PyQt5.QtWidgets import QWidget, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout,QComboBox 3 | 4 | class ImportWindow(QWidget): 5 | 6 | def __init__(self, main_wd): 7 | super(ImportWindow,self).__init__() 8 | 9 | self.main_window = main_wd 10 | self.path_to_t2p = None 11 | self.plane=None 12 | 13 | layout = QFormLayout() 14 | 15 | import_label=QLabel("Import the directory containing the track2p folder:") 16 | self.import_button = QPushButton("Import", self) 17 | self.import_button.clicked.connect(self.save_t2p_path) 18 | layout.addRow(import_label,self.import_button) 19 | path_label=QLabel("Here is the path of the imported directory:") 20 | self.path = QLabel() 21 | layout.addRow(path_label,self.path) 22 | plane_label= QLabel("Choose the plane to analyze:") 23 | self.textbox = QLineEdit(self) 24 | self.textbox.setFixedWidth(50) 25 | self.textbox.setText('0') 26 | trace_label=QLabel("Choose trace type:") 27 | self.trace_choice=QComboBox() 28 | self.trace_choice.addItem("F") 29 | self.trace_choice.addItem("dF/F0") 30 | self.trace_choice.addItem("spks") 31 | channel_label=QLabel("Choose the channel to show in the mean images:") 32 | self.channel_choice=QComboBox() 33 | self.channel_choice.addItem("0") 34 | self.channel_choice.addItem("1") 35 | self.channel_choice.addItem("Vcorr") 36 | self.channel_choice.addItem("max_proj") 37 | layout.addRow(plane_label,self.textbox) 38 | layout.addRow(channel_label,self.channel_choice) 39 | layout.addRow(trace_label,self.trace_choice) 40 | label_run= QLabel("Run the analysis:") 41 | self.run_button = QPushButton("Run", self) 42 | self.run_button.clicked.connect(self.run) 43 | layout.addRow(label_run,self.run_button) 44 | 45 | self.setLayout(layout) 46 | 47 | 48 | def save_t2p_path(self): 49 | path_to_t2p= QFileDialog.getExistingDirectory(self, "Select Directory") 50 | if path_to_t2p: 51 | self.path_to_t2p=path_to_t2p 52 | self.path.setText(f'{self.path_to_t2p}') 53 | 54 | def run(self): 55 | self.plane = int(self.textbox.text()) 56 | self.trace_type=self.trace_choice.currentText() 57 | self.channel= self.channel_choice.currentText() 58 | self.main_window.central_widget.data_management.import_files(self.path_to_t2p, plane=self.plane, trace_type=self.trace_type, channel=self.channel_choice.currentText()) 59 | -------------------------------------------------------------------------------- /track2p/io/savers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from track2p.io.utils import make_dir 5 | 6 | def save_track_ops(track_ops): 7 | # remove attributes taking a lot of memory (e.g. rois etc.) 8 | del track_ops.all_ds_all_roi_array_mov 9 | del track_ops.all_ds_all_roi_array_ref 10 | del track_ops.all_ds_all_roi_array_reg 11 | 12 | track_ops_dict = track_ops.to_dict() # convert to dictionary for compatibility 13 | 14 | np.save(os.path.join(track_ops.save_path, 'track_ops.npy'), track_ops_dict, allow_pickle=True) 15 | print('Saved track_ops.npy in ' + track_ops.save_path) 16 | 17 | def save_all_pl_match_mat(all_pl_match_mat, track_ops): 18 | for (i, all_pl_match_mat) in enumerate(all_pl_match_mat): 19 | np.save(os.path.join(track_ops.save_path, f'plane{i}_match_mat.npy'), all_pl_match_mat, allow_pickle=True) 20 | 21 | 22 | def npy_to_s2p(track_ops): 23 | 24 | for plane in range(track_ops.nplanes): 25 | print(f'Processing plane {plane + 1}/{track_ops.nplanes}...') 26 | 27 | for ds_path in track_ops.all_ds_path: 28 | # 1) define numpy and suite2p data paths (+ make sure they exist) 29 | npy_path = os.path.join(ds_path, 'data_npy', f'plane{plane}') 30 | s2p_path = npy_path.replace('data_npy', 'suite2p') 31 | 32 | if not os.path.exists(s2p_path): 33 | os.makedirs(s2p_path) 34 | else: 35 | print(f"Directory {s2p_path} already exists, skipping... (Delete or rename it if you want to overwrite)") 36 | continue 37 | 38 | # 2) Load numpy data 39 | F = np.load(os.path.join(npy_path, 'F.npy')) 40 | fov = np.load(os.path.join(npy_path, 'fov.npy')) 41 | rois = np.load(os.path.join(npy_path, 'rois.npy')) 42 | 43 | # 3) Convert and save data in suite2p format 44 | 45 | np.save(os.path.join(s2p_path, 'F.npy'), F) 46 | 47 | ops = {'meanImg': fov} 48 | ops['nchannels'] = 1 # Assuming single channel 49 | ops['fs'] = 30 # Assuming a sampling frequency of 30 Hz 50 | ops['nframes'] = F.shape[1] 51 | np.save(os.path.join(s2p_path, 'ops.npy'), ops) 52 | 53 | # stat is a list of dictionaries, each with keys 'xpix', 'ypix' 54 | stat = [] 55 | for i in range(rois.shape[0]): 56 | ypix, xpix = np.where(rois[i] > 0) 57 | med = [int(np.median(ypix).item()), int(np.median(xpix).item())] 58 | stat.append({'xpix': xpix, 'ypix': ypix, 'med': med}) 59 | 60 | np.save(os.path.join(s2p_path, 'stat.npy'), stat) 61 | 62 | # iscell is two columns of 1 the columns are length n_cells 63 | n_cells = len(stat) 64 | iscell = np.ones((n_cells, 2), dtype=int) 65 | np.save(os.path.join(s2p_path, 'iscell.npy'), iscell) -------------------------------------------------------------------------------- /track2p/gui/custom_wd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QHBoxLayout, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QListWidget, QMessageBox,QListWidgetItem, QInputDialog,QCheckBox,QSizePolicy,QComboBox,QDialog 3 | from PyQt5.QtCore import Qt 4 | from track2p.t2p import run_t2p 5 | from track2p.ops.default import DefaultTrackOps 6 | 7 | class CustomDialog(QDialog): 8 | def __init__(self, main_window, save_directory, channel): 9 | super(CustomDialog,self).__init__() 10 | self.main_window = main_window 11 | self.save_directory = save_directory 12 | self.channel = channel 13 | self.cancel_clicked = False 14 | 15 | # Layout 16 | layout = QVBoxLayout(self) 17 | 18 | # Plane input 19 | self.plane_label = QLabel("Enter your plane:", self) 20 | layout.addWidget(self.plane_label) 21 | self.plane_input = QLineEdit(self) 22 | layout.addWidget(self.plane_input) 23 | 24 | # Trace type input 25 | self.trace_label = QLabel("Choose trace type:", self) 26 | layout.addWidget(self.trace_label) 27 | self.trace_combo = QComboBox(self) 28 | self.trace_combo.addItems(["F", "spks", "dF/F0"]) 29 | layout.addWidget(self.trace_combo) 30 | 31 | self.channel_label = QLabel("Choose the channel to show in the mean images:", self) 32 | layout.addWidget(self.channel_label) 33 | self.channel_combo = QComboBox(self) 34 | self.channel_combo.addItems(["0", "1", "Vcorr","max_proj"]) 35 | layout.addWidget(self.channel_combo) 36 | 37 | # OK and Cancel buttons 38 | self.ok_button = QPushButton("OK", self) 39 | self.ok_button.clicked.connect(self.import_data) 40 | layout.addWidget(self.ok_button) 41 | 42 | self.cancel_button = QPushButton("Cancel", self) 43 | self.cancel_button.clicked.connect(self.on_cancel_clicked) 44 | layout.addWidget(self.cancel_button) 45 | 46 | def get_inputs(self): 47 | return self.plane_input.text(), self.trace_combo.currentText(), self.channel_combo.currentText() 48 | 49 | def on_cancel_clicked(self): 50 | self.cancel_clicked = True 51 | self.import_data() 52 | 53 | def import_data(self): 54 | if self.cancel_clicked: 55 | self.close() 56 | pass 57 | else: 58 | print("Opening GUI...") 59 | plane_text, trace_type , channel = self.get_inputs() 60 | self.plane = int(plane_text) 61 | self.trace_type = trace_type 62 | self.channel= channel 63 | print(f"Converted plane: {self.plane}, Trace type: {self.trace_type}") # Debugging print 64 | self.main_window.central_widget.data_management.import_files(t2p_folder_path = self.save_directory, plane=self.plane, trace_type=self.trace_type, channel= self.channel) 65 | self.close() 66 | 67 | -------------------------------------------------------------------------------- /track2p/ops/default.py: -------------------------------------------------------------------------------- 1 | # make dummy track_ops object (would be input from command line or gui) 2 | import os 3 | from track2p.io.utils import make_dir 4 | 5 | class DefaultTrackOps: 6 | def __init__(self): 7 | # input list of dataset paths (each contains a 'suite2p' folder) 8 | self.all_ds_path = [ 9 | 'data/ac/ac444118/2022-09-14_a', 10 | 'data/ac/ac444118/2022-09-15_a', 11 | 'data/ac/ac444118/2022-09-16_a' 12 | ] 13 | 14 | self.input_format = 'suite2p' # 'suite2p' or 'npy' (suite2p is the default suite2p format, npy is raw numpy arrays (F.npy, fov.npy, rois.npy)) 15 | 16 | self.save_path = 'data/ac/ac444118/track2p/' 17 | 18 | self.reg_chan = 0 # channel to use for registration (0=functional, 1=anatomical) (1 is not always available) 19 | self.transform_type = 'affine' # 'affine' or 'rigid' 20 | self.iscell_thr = 0.50 # threshold for iscell.npy (only keep ROIs with iscell > iscell_thr) (here lowering this can be good and non-detrimental -> artefacts are unlikely to be consistently present in all datasets) 21 | 22 | self.matching_method='iou' # 'iou', 'cent' or 'cent_int-filt' (iou takes longer but is more accurate, cent is faster but less accurate) 23 | self.iou_dist_thr = 16 # distance between centroids (in pixels) above which to skip iou computation (to save time) (this is only relevant if self.matching_method=='iou') 24 | 25 | self.thr_remove_zeros = False # remove zeros from thr_met before computing automatic threshold (this is useful when there are many zeros in thr_met, which can skew the thresholding) 26 | self.thr_method = 'otsu' # 'otsu' or 'min' (min is just local minimum of pdf of thr_met) 27 | 28 | # do not change these 29 | self.show_roi_reg_output = False # this is slow because plt.contour is slow and also very memory intensive(it can easily crash) but the visualisation is nice for presentations (for example by increasing self.iscell_thr) 30 | 31 | # plotting parameters 32 | self.win_size = 48 # window size for visualising matched ROIs across days (crop of mean image) 33 | self.sat_perc = 99.9 # percentile to saturate image at (only affects visualisation not the registration/matching) 34 | 35 | 36 | self.colors = None 37 | 38 | #self.vector_curation_plane_0 = None 39 | #self.vector_curation_plane_1 = None 40 | 41 | self.save_in_s2p_format = False # save the output in suite2p format (this is useful for downstream analysis with suite2p) 42 | 43 | # make the output directories when initialising the object 44 | 45 | def init_save_paths(self): 46 | self.save_path = os.path.join(self.save_path, 'track2p/') 47 | self.save_path=self.save_path.replace("\\", "/") 48 | self.save_path_fig = os.path.join(self.save_path, 'fig/') 49 | self.save_path_fig=self.save_path_fig.replace("\\", "/") 50 | make_dir(self.save_path) 51 | make_dir(self.save_path_fig) 52 | 53 | 54 | def to_dict(self): 55 | # this is useful for saving the object to avoid needing class definition in downstream analysis 56 | track_ops_dict = {} 57 | for attr in dir(self): 58 | if not attr.startswith('__') and not callable(getattr(self, attr)): 59 | track_ops_dict[attr] = getattr(self, attr) 60 | return track_ops_dict 61 | 62 | def from_dict(self, track_ops_dict): 63 | # loop through all the keys and set the attributes 64 | for key in track_ops_dict: 65 | setattr(self, key, track_ops_dict[key]) -------------------------------------------------------------------------------- /track2p/gui/statusbar.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QStatusBar,QWidget, QHBoxLayout, QSpinBox,QPushButton,QLabel 2 | import numpy as np 3 | import os 4 | 5 | class StatusBar(QStatusBar): 6 | def __init__(self, main_window): 7 | super().__init__() 8 | 9 | self.main_window = main_window 10 | self.central_widget = self.main_window.central_widget 11 | self.vector_curation_t2p=None 12 | self.init_status_bar() 13 | 14 | def init_status_bar(self): 15 | status_widget = QWidget() 16 | layout = QHBoxLayout() 17 | 18 | self.spin_box = QSpinBox() 19 | self.spin_box.setFixedWidth(100) 20 | self.spin_box.valueChanged.connect(self.iterate_all_rois) 21 | 22 | self.roi_state = QLabel("state of ROI: ") 23 | self.roi_state.setFixedWidth(100) 24 | self.roi_state_value = QLabel() 25 | self.roi_state_value.setFixedWidth(30) 26 | 27 | not_cell_button = QPushButton('✖️') 28 | not_cell_button.setFixedSize(20, 20) 29 | not_cell_button.setStyleSheet("background-color: red; color: white; border: none;") 30 | not_cell_button.clicked.connect(self.set_roi_as_not_cell) 31 | 32 | cell_button = QPushButton('✓') 33 | cell_button.setFixedSize(20, 20) 34 | cell_button.setStyleSheet("background-color: green; color: white; border: none;") 35 | cell_button.clicked.connect(self.set_roi_as_cell) 36 | 37 | reset_button = QPushButton('Apply curation') 38 | reset_button.setFixedSize(100, 20) 39 | reset_button.setStyleSheet("background-color: grey; color: white; border: none;") 40 | reset_button.clicked.connect(self.main_window.central_widget.create_mean_img_from_curation) 41 | 42 | 43 | layout.addWidget(self.spin_box) 44 | layout.addWidget(self.roi_state) 45 | layout.addWidget(self.roi_state_value) 46 | layout.addWidget(not_cell_button) 47 | layout.addWidget(cell_button) 48 | layout.addWidget(reset_button) 49 | 50 | status_widget.setLayout(layout) 51 | self.addWidget(status_widget) 52 | 53 | def iterate_all_rois(self): 54 | current_ROI = self.spin_box.value() 55 | value=self.vector_curation_t2p[current_ROI] 56 | self.roi_state_value.setText(f"{value}") 57 | self.central_widget.update_selection(current_ROI) 58 | 59 | def set_roi_as_not_cell(self): 60 | plane=self.central_widget.data_management.plane 61 | key='vector_curation_plane_' + str(plane) 62 | self.vector_curation_t2p = self.main_window.central_widget.data_management.vector_curation_t2p 63 | if self.vector_curation_t2p[self.spin_box.value()] ==1: 64 | self.vector_curation_t2p[self.spin_box.value()]= 0 65 | current_ROI = self.spin_box.value() 66 | value=self.vector_curation_t2p [current_ROI] 67 | self.roi_state_value.setText(f"{value}") 68 | self.central_widget.track_ops_dict[key] = self.vector_curation_t2p 69 | np.save(os.path.join(self.central_widget.data_management.track_ops.save_path, "track_ops.npy"), self.central_widget.track_ops_dict) 70 | 71 | def set_roi_as_cell(self): 72 | plane=self.central_widget.data_management.plane 73 | key='vector_curation_plane_' + str(plane) 74 | self.vector_curation_t2p = self.main_window.central_widget.data_management.vector_curation_t2p 75 | if self.vector_curation_t2p[self.spin_box.value()] ==0: 76 | self.vector_curation_t2p[self.spin_box.value()]= 1 77 | current_ROI = self.spin_box.value() 78 | value=self.vector_curation_t2p[current_ROI] 79 | self.roi_state_value.setText(f"{value}") 80 | self.central_widget.track_ops_dict[key] = self.vector_curation_t2p 81 | np.save(os.path.join(self.central_widget.data_management.track_ops.save_path, "track_ops.npy"), self.central_widget.track_ops_dict) 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # data 2 | data/ 3 | data_proc 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | .DS_Store 166 | data 167 | -------------------------------------------------------------------------------- /notebooks/run_t2p.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Run track2p programatically \n", 8 | "\n", 9 | "Track2p can also be easily launched through a notebook or a script as shown below" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from track2p.t2p import run_t2p # main function that launches track2p\n", 19 | "from track2p.ops.default import DefaultTrackOps # default track2p options\n", 20 | "\n", 21 | "# load default settings / parameters\n", 22 | "track_ops = DefaultTrackOps()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "After importing the algorithm function and initialising default parameters, we can set the paths and modify parameters as shown below (Note: the parameters follow the same naming convention as in the 'Run algorithm' window of the GUI):" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# overwrite some defaults\n", 39 | "track_ops.all_ds_path = [ # list of paths to datasets containing a `suite2p` folder\n", 40 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-04-30_a',\n", 41 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-01_a',\n", 42 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-02_a',\n", 43 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-03_a',\n", 44 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-04_a',\n", 45 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-05_a',\n", 46 | " '/Users/jure/Documents/cossart_lab/data/jm/jm038/2023-05-06_a'\n", 47 | " ]\n", 48 | "\n", 49 | "track_ops.save_path = '/Users/jure/Documents/cossart_lab/data/jm/jm038/' # path where to save the outputs of algorithm \n", 50 | " # (a 'track2p' folder will be created where figures for\n", 51 | " # visualisation and matrices of matches would be saved)\n", 52 | "\n", 53 | "track_ops.reg_chan = 1 # channel to use for registration (0=functional, 1=anatomical) (use 0 if only recording gcamp!)\n", 54 | "track_ops.iscell_thr = 0.5 # threshold for iscell (0.5 is a good value)\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "Before running algorithm we can also check the settings track2p will use, to be sure everything is correct:" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "# print all the settings / parameters used for running the algorithm\n", 71 | "for attr, value in track_ops.__dict__.items():\n", 72 | " print(attr, '=', value)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "We can then simply launch the algorithm by passing the `track_ops` object to the `run_t2p` function." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# run the algorithm\n", 89 | "run_t2p(track_ops)\n" 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "track2p", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.9.18" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 2 114 | } 115 | -------------------------------------------------------------------------------- /track2p/register/loop.py: -------------------------------------------------------------------------------- 1 | # for now the only algorithm is elastix, TODO: add other algorithms and run them in the same way within this loop 2 | 3 | from track2p.register.elastix import reg_img_elastix 4 | from track2p.io.loaders import load_stat_ds_plane, get_all_roi_array_from_stat 5 | from track2p.register.elastix import itk_reg_all_roi 6 | 7 | def run_reg_loop(all_ds_ref_img, all_ds_mov_img, track_ops): 8 | all_ds_mov_img_reg = [] 9 | all_ds_reg_params = [] 10 | 11 | for (i, ds_ref_img) in enumerate(all_ds_ref_img): 12 | ds_mov_img = all_ds_mov_img[i] 13 | ds_mov_img_reg = [] 14 | ds_reg_params = [] 15 | 16 | for j in range(track_ops.nplanes): 17 | ref_img = ds_ref_img[j] 18 | mov_img = ds_mov_img[j] 19 | mov_img_reg,reg_params = reg_img_elastix(ref_img, mov_img, track_ops) 20 | ds_mov_img_reg.append(mov_img_reg) 21 | ds_reg_params.append(reg_params) 22 | 23 | all_ds_mov_img_reg.append(ds_mov_img_reg) 24 | all_ds_reg_params.append(ds_reg_params) 25 | 26 | track_ops.all_ds_mov_img_reg = all_ds_mov_img_reg 27 | # track_ops.all_ds_reg_params = all_ds_reg_params # for now not saving elastix transform object (TODO: transforme it to serializable to be able to pickle (for example by saving params to a dictionary)) 28 | 29 | return all_ds_mov_img_reg, all_ds_reg_params 30 | 31 | 32 | def reg_all_ds_all_roi(all_ds_reg_params, track_ops): 33 | all_ds_all_roi_array_ref = [] 34 | all_ds_all_roi_array_mov = [] 35 | all_ds_all_roi_array_reg = [] 36 | all_ds_roi_counter = [] # this will keep track of how many cells make it thorugh the registration 37 | 38 | for i in range(len(track_ops.all_ds_path)-1): 39 | print(f'...\nTransforming ROIs for registration {i}/{len(track_ops.all_ds_path)-1}') 40 | 41 | ds_all_roi_array_ref = [] # for one dataset (all planes) 42 | ds_all_roi_array_mov = [] 43 | ds_all_roi_array_reg = [] 44 | ds_roi_counter_ref = [] 45 | ds_roi_counter_mov = [] 46 | 47 | for j in range(track_ops.nplanes): 48 | 49 | # 1) Set paths and transformation 50 | ref_ds_path = track_ops.all_ds_path[i] 51 | reg_ds_path = track_ops.all_ds_path[i+1] 52 | 53 | reg_params = all_ds_reg_params[i][j] 54 | 55 | # 2) Load ROIs # TODO: add (non-s2p dependent) Cellpose ROIs compatibility 56 | stat_ref, roi_counter_ref = load_stat_ds_plane(ref_ds_path, track_ops, plane_idx=j) # loading rois for one dataset one plane 57 | stat_mov, roi_counter_mov = load_stat_ds_plane(reg_ds_path, track_ops, plane_idx=j) # loading rois for one dataset one plane 58 | 59 | all_roi_array_ref = get_all_roi_array_from_stat(stat_ref, track_ops) # for dataset i, plane j 60 | all_roi_array_mov = get_all_roi_array_from_stat(stat_mov, track_ops) # for dataset i, plane j 61 | 62 | # 3) Apply transformation 63 | all_roi_array_reg = itk_reg_all_roi(all_roi_array_mov, reg_params) 64 | 65 | # 4) append 66 | ds_all_roi_array_ref.append(all_roi_array_ref) 67 | ds_all_roi_array_mov.append(all_roi_array_mov) 68 | ds_all_roi_array_reg.append(all_roi_array_reg) 69 | 70 | ds_roi_counter_ref.append(roi_counter_ref) 71 | ds_roi_counter_mov.append(roi_counter_mov) # only keep this one if it is the last pair 72 | 73 | print('Done with dataset...') 74 | 75 | all_ds_all_roi_array_ref.append(ds_all_roi_array_ref) 76 | all_ds_all_roi_array_mov.append(ds_all_roi_array_mov) 77 | all_ds_all_roi_array_reg.append(ds_all_roi_array_reg) 78 | 79 | all_ds_roi_counter.append(ds_roi_counter_ref) # keep both if its last pair (last recording does't appear as a ref) 80 | if i == len(track_ops.all_ds_path)-2: 81 | all_ds_roi_counter.append(ds_roi_counter_mov) 82 | 83 | # save all to track ops 84 | track_ops.all_ds_all_roi_array_ref = all_ds_all_roi_array_ref 85 | track_ops.all_ds_all_roi_array_mov = all_ds_all_roi_array_mov 86 | track_ops.all_ds_all_roi_array_reg = all_ds_all_roi_array_reg 87 | track_ops.all_ds_roi_counter = all_ds_roi_counter 88 | 89 | return all_ds_all_roi_array_ref, all_ds_all_roi_array_mov, all_ds_all_roi_array_reg, all_ds_roi_counter -------------------------------------------------------------------------------- /track2p/match/loop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.filters import threshold_otsu, threshold_minimum 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | from track2p.match.utils import get_cost_mat, get_iou, init_all_pl_match_mat 6 | 7 | # assigment of ROIs in each ref-reg pair 8 | 9 | def get_all_ds_assign(track_ops, all_ds_all_roi_ref, all_ds_all_roi_reg): 10 | 11 | all_ds_assign = [] 12 | all_ds_assign_thr = [] 13 | all_ds_thr_met = [] 14 | all_ds_thr = [] 15 | 16 | for i in range(len(track_ops.all_ds_path)-1): 17 | print(f'Finding matches in ref-reg pair: {i+1}/{len(track_ops.all_ds_path)-1}') 18 | ds_assign = [] 19 | ds_assign_thr = [] 20 | ds_thr_met = [] 21 | ds_thr = [] 22 | for j in range(track_ops.nplanes): 23 | all_roi_ref = all_ds_all_roi_ref[i][j] 24 | all_roi_reg = all_ds_all_roi_reg[i][j] 25 | 26 | # 1) compute cost matrix (currently two methods available, see DefaultTrackOps) 27 | cost_mat, all_inds_ref_filt, all_inds_reg_filt = get_cost_mat(all_roi_ref, all_roi_reg, track_ops) 28 | 29 | # 2) optimally assign pairs 30 | ref_ind_filt, reg_ind_filt = linear_sum_assignment(cost_mat) 31 | 32 | # 3) convert them to pre-filtered indices (these are the indices of the ROIs after iscell) 33 | ref_ind = all_inds_ref_filt[ref_ind_filt] 34 | reg_ind = all_inds_reg_filt[reg_ind_filt] 35 | 36 | # 4) for each matched pair (len(all_roi_ref)) compute thresholding metric (in this case IOU, the filtering will be done afterwards in the all-day assignment) 37 | thr_met = get_iou(all_roi_ref[:,:,ref_ind], all_roi_reg[:,:,reg_ind]) 38 | thr_met_compute = thr_met[thr_met>0] if track_ops.thr_remove_zeros else thr_met # remove zeros for computing the threshold (otsu thresholding is squed 39 | 40 | # 5) compute otsu threshold on thr_met 41 | if track_ops.thr_method == 'otsu': 42 | thr = threshold_otsu(thr_met_compute) 43 | elif track_ops.thr_method == 'min': 44 | thr = threshold_minimum(thr_met_compute) 45 | 46 | 47 | ds_assign.append([ref_ind, reg_ind]) 48 | ds_assign_thr.append([ref_ind[thr_met>thr], reg_ind[thr_met>thr]]) 49 | ds_thr_met.append(thr_met) 50 | ds_thr.append(thr) 51 | 52 | all_ds_assign.append(ds_assign) 53 | all_ds_assign_thr.append(ds_assign_thr) 54 | all_ds_thr_met.append(ds_thr_met) 55 | all_ds_thr.append(ds_thr) 56 | print(f'Done ref-reg pair: {i+1}/{len(track_ops.all_ds_path)-1}') 57 | 58 | return all_ds_assign, all_ds_assign_thr, all_ds_thr_met, all_ds_thr 59 | 60 | 61 | # propagating matches across all days 62 | 63 | def get_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops): 64 | 65 | all_pl_match_mat = init_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops) 66 | 67 | for i in range(track_ops.nplanes): 68 | pl_match_mat = all_pl_match_mat[i] 69 | # now for each row in the match matrix (each ROI in the ref recording) we need to find the match across all days, if there is none then we leave it as None 70 | 71 | for roi_idx in range(pl_match_mat.shape[0]): # roi_idx is the index on first session 72 | # if first column is none then we skip this row 73 | if pl_match_mat[roi_idx, 0] is None: 74 | continue 75 | # otherwise we find the match in the all_ds_assign_thr 76 | else: 77 | ref_roi_ds0 = pl_match_mat[roi_idx, 0] 78 | track_roi = np.array(ref_roi_ds0) 79 | for ds_ind in range(pl_match_mat.shape[1]-1): 80 | matches = all_ds_assign_thr[ds_ind][i] 81 | 82 | ref_ind = matches[0] 83 | reg_ind = matches[1] 84 | 85 | reg_ind_ind = np.where(ref_ind==track_roi.item())[0] 86 | 87 | # if there is a match then we update the track_roi 88 | if reg_ind_ind.size>0: 89 | track_roi = reg_ind[reg_ind_ind] 90 | pl_match_mat[roi_idx, ds_ind+1] = track_roi.item() 91 | 92 | # if there is no match then we stop tracking this ROI 93 | else: 94 | break 95 | 96 | # compute how many ROIs are tracked across all days 97 | n_tracked = np.sum(np.all(pl_match_mat!=None, axis=1)) 98 | print(f'Number of ROIs tracked in plane{i} across all days: {n_tracked}') 99 | track_ops.all_pl_match_mat = all_pl_match_mat 100 | track_ops.n_tracked = n_tracked 101 | 102 | return all_pl_match_mat -------------------------------------------------------------------------------- /track2p/io/s2p_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def check_nplanes(track_ops): 5 | all_nplanes = [] 6 | 7 | for ds_path in track_ops.all_ds_path: 8 | # check how many subfolders starting with plane* in suite2p folder 9 | if track_ops.input_format == 'suite2p': 10 | n_planes = len([name for name in os.listdir(ds_path + '/suite2p') if name.startswith('plane')]) 11 | elif track_ops.input_format == 'npy': 12 | n_planes = len([name for name in os.listdir(ds_path + '/data_npy') if name.startswith('plane')]) 13 | print(f'Found {n_planes} planes in {ds_path}') 14 | all_nplanes.append(n_planes) 15 | track_ops.all_nplanes = all_nplanes 16 | # if all elements in all_n_planes are the same, then save it in track_ops.n_planes 17 | if all_nplanes.count(all_nplanes[0]) == len(all_nplanes): 18 | track_ops.nplanes = all_nplanes[0] 19 | print(f'Found {track_ops.nplanes} planes in all datasets') 20 | else: 21 | print('Found different number of planes in different datasets') 22 | print('Please check your dataset paths') 23 | print('Exiting...') 24 | exit() 25 | 26 | 27 | # loads mean images 28 | def load_all_imgs(track_ops): 29 | all_ds_avg_ch1 = [] 30 | all_ds_avg_ch2 = [] 31 | all_ds_nchannels = [] 32 | 33 | for ds_path in track_ops.all_ds_path: 34 | ds_nchannels = [] 35 | ds_avg_ch1 = [] 36 | ds_avg_ch2 = [] 37 | 38 | for i in range(track_ops.nplanes): 39 | ops = np.load(ds_path + '/suite2p/plane' + str(i) + '/ops.npy', allow_pickle=True).item() 40 | nchannels = ops['nchannels'] 41 | print('nchannels: ' + str(nchannels) + ' for plane ' + str(i) + ' in dataset ' + ds_path) 42 | ds_avg_ch1.append(ops['meanImg']) 43 | ds_avg_ch2.append(ops['meanImg_chan2']) if nchannels==2 else ds_avg_ch2.append(None) 44 | ds_nchannels.append(nchannels) 45 | 46 | all_ds_avg_ch1.append(ds_avg_ch1) 47 | all_ds_avg_ch2.append(ds_avg_ch2) 48 | all_ds_nchannels.append(ds_nchannels) 49 | 50 | track_ops.all_ds_avg_ch1 = all_ds_avg_ch1 51 | track_ops.all_ds_avg_ch2 = all_ds_avg_ch2 52 | track_ops.all_ds_nchannels = all_ds_nchannels 53 | 54 | # if all elements in all_ds_nchannels are the same, then print its fine otherwise exit 55 | if all_ds_nchannels.count(all_ds_nchannels[0]) == len(all_ds_nchannels): 56 | track_ops.nchannels = all_ds_nchannels[0][0] 57 | print(f'Found {track_ops.nchannels} channels in all datasets') 58 | else: 59 | print('Found different number of channels in different datasets') 60 | print('Please check your dataset paths') 61 | print('Exiting...') 62 | exit() 63 | 64 | return all_ds_avg_ch1, all_ds_avg_ch2 65 | 66 | def load_all_ds_stat_iscell(track_ops): 67 | all_ds_stat_iscell = [] 68 | for (i, ds_path) in enumerate(track_ops.all_ds_path): 69 | ds_stat_iscell = [] 70 | for j in range(track_ops.nplanes): 71 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'stat.npy'), allow_pickle=True) 72 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'iscell.npy'), allow_pickle=True) 73 | if track_ops.iscell_thr==None: 74 | stat_iscell = stat[iscell[:,0]==1] 75 | else: 76 | stat_iscell = stat[iscell[:,1]>track_ops.iscell_thr] 77 | ds_stat_iscell.append(stat_iscell) 78 | all_ds_stat_iscell.append(ds_stat_iscell) 79 | 80 | return all_ds_stat_iscell 81 | 82 | def load_all_ds_ops(track_ops): 83 | all_ds_ops = [] 84 | for ds_path in track_ops.all_ds_path: 85 | ds_ops = [] 86 | for j in range(track_ops.nplanes): 87 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{j}', 'ops.npy'), allow_pickle=True).item() 88 | ds_ops.append(ops) 89 | all_ds_ops.append(ds_ops) 90 | 91 | return all_ds_ops 92 | 93 | def load_all_ds_mean_img(track_ops, ch=1): 94 | all_ds_ops = load_all_ds_ops(track_ops) 95 | all_ds_mean_img = [] 96 | for ds_ops in all_ds_ops: 97 | ds_mean_img = [] 98 | for ops in ds_ops: 99 | mean_img = ops['meanImg'] if ch==1 else ops['meanImg_chan2'] 100 | ds_mean_img.append(mean_img) 101 | all_ds_mean_img.append(ds_mean_img) 102 | 103 | return all_ds_mean_img 104 | 105 | def load_all_ds_centroids(all_ds_stat_iscell, track_ops): 106 | all_ds_centroids = [] 107 | for i in range(len(track_ops.all_ds_path)): 108 | ds_centroids = [] 109 | for stat_iscell in all_ds_stat_iscell[i]: 110 | centroids = [] 111 | for roi_stat in stat_iscell: 112 | centroids.append(roi_stat['med']) 113 | ds_centroids.append(np.array(centroids)) 114 | all_ds_centroids.append(ds_centroids) 115 | 116 | return all_ds_centroids 117 | -------------------------------------------------------------------------------- /track2p/gui/roi_plot.py: -------------------------------------------------------------------------------- 1 | 2 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import skimage 6 | 7 | 8 | class ZoomPlotWidget(FigureCanvas): 9 | """It is used to display the roi of the selected cell across days with each zoom being a different day. 10 | The roi is centered on the median coordinates of the selected cell. The mean image of the recording is used to create the zooms. 11 | The probability of the cell being a cell and the index of the cell in the suite2p files associated with each recording (day) are also displayed under the zooms.""" 12 | 13 | def __init__(self, all_ops=None, all_stat_t2p=None, colors=None, all_iscell_t2p=None,t2p_match_mat_allday =None,track_ops=None, imgs=None): 14 | nb_plot=len(all_iscell_t2p) 15 | self.fig, self.ax_zoom = plt.subplots(1, nb_plot, figsize = (10*nb_plot, nb_plot), gridspec_kw={'width_ratios': [1] * nb_plot},facecolor='black') 16 | super().__init__(self.fig) 17 | self.track_ops=track_ops 18 | self.all_ops = all_ops 19 | self.all_stat_t2p = all_stat_t2p 20 | self.colors = colors 21 | self.all_is_cell=all_iscell_t2p 22 | self.t2p_match_mat_allday =t2p_match_mat_allday 23 | self.imgs= imgs 24 | self.coord_dict={} 25 | 26 | def display_zooms(self, selected_cell_index): 27 | """It is called when the application is opened and a cell is selected.""" 28 | #it is used to store the roi and the median coordinates of the selected cell for each recording (day) 29 | self.coord_dict={} 30 | if self.all_ops is not None and self.all_stat_t2p is not None and self.all_is_cell is not None: 31 | for i in range(len(self.imgs)): 32 | wind = 20 33 | #mean_img = self.all_ops[i]['meanImg'] 34 | mean_img=self.imgs[i] 35 | stat_t2p = self.all_stat_t2p[i] 36 | median_coord = stat_t2p[selected_cell_index]['med'] 37 | print(f'median_coord : ', median_coord) 38 | 39 | print(mean_img.shape) 40 | # Définir la taille de la marge 41 | margin = 20 42 | 43 | # Obtenir les dimensions de l'image originale 44 | Ly, Lx = mean_img.shape 45 | 46 | # Créer une nouvelle image avec des dimensions augmentées 47 | new_Ly = Ly + 2 * margin 48 | new_Lx = Lx + 2 * margin 49 | range_img = np.zeros((new_Ly, new_Lx)) 50 | 51 | 52 | # Copier mean_img au centre de la nouvelle image 53 | range_img[margin:margin + Ly, margin:margin + Lx] = mean_img 54 | 55 | # Ajouter la marge aux coordonnées médianes 56 | median_x = int(median_coord[1]) + margin 57 | median_y = int(median_coord[0]) + margin 58 | 59 | # Calculer les nouvelles coordonnées de la ROI 60 | x_start = median_x - wind 61 | x_end = median_x + wind 62 | y_start = median_y - wind 63 | y_end = median_y + wind 64 | 65 | # Extraire la ROI de l'image avec la marge 66 | roi = range_img[y_start:y_end, x_start:x_end] 67 | 68 | print(range_img.shape) 69 | print(roi.shape) 70 | 71 | 72 | #prob and index 73 | iscell=self.all_is_cell[i] 74 | if self.track_ops.iscell_thr==None: 75 | indices_lignes_1 = np.where(iscell[:,0]==1)[0] 76 | match_index=self.t2p_match_mat_allday[selected_cell_index,i] 77 | true_index=indices_lignes_1[match_index] 78 | else: 79 | indices_lignes_1= np.where(iscell[:,1]>self.track_ops.iscell_thr)[0] 80 | match_index=self.t2p_match_mat_allday[selected_cell_index,i] 81 | true_index=indices_lignes_1[match_index] 82 | prob=round(iscell[true_index,1],2) 83 | 84 | ypix=stat_t2p[selected_cell_index]['ypix'] 85 | xpix=stat_t2p[selected_cell_index]['xpix'] 86 | 87 | 88 | ax = self.ax_zoom[i] 89 | color = self.colors[selected_cell_index] 90 | 91 | mask=np.zeros((2*wind,2*wind)) 92 | mask[ypix-median_coord[0]+wind,xpix-median_coord[1]+wind]=1 93 | ax.clear() 94 | ax.contour(mask,levels=[0.5], colors=[color],linewidths=2) 95 | 96 | #last_img=list(self.coord_dict.values())[-1][0] 97 | #match_roi=skimage.exposure.match_histograms(list[0], last_img, channel_axis=None) 98 | 99 | ax.imshow(roi, cmap='gray') 100 | ax.set_title(f'Day {i + 1}', color='white', fontsize=10) 101 | ax.text(0.5, -0.2, f'i: {true_index}', color='white', fontsize=10, ha='center', va='center', transform=ax.transAxes) 102 | ax.text(0.5, -0.4, f'p: {prob}', color='white', fontsize=10, ha='center', va='center', transform=ax.transAxes) 103 | 104 | ax.axis('off') 105 | 106 | self.draw() 107 | -------------------------------------------------------------------------------- /notebooks/utils/npy_to_s2p.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f0538be5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Conversion of npy to s2p \n", 9 | "\n", 10 | "This is used to implement and test the feature suggested by reviewer 2 - allowing easier use of track2p by people using pipelines other than suite2p\n", 11 | "\n", 12 | "Inputs:\n", 13 | "- rois.npy (bool ndarray of shape: (n_roi, n_px_y, n_px_x))\n", 14 | "- fov.npy (float32 ndarray of shape: (n_px_y, n_px_x)) - IMPORTANT: ROIs and FOV must be aligned (this is done by default in suite2p outputs, but might need to be registered manually in other cases)\n", 15 | "- traces.npy (float32 ndarray of shape: (n_roi, n_tstamps))\n", 16 | "\n", 17 | "Outputs:\n", 18 | "- suite2p folder" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "4e3cd55a", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import os\n", 29 | "import numpy as np\n", 30 | "import matplotlib.pyplot as plt" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "b008c170", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "nplanes = 1\n", 41 | "plane = 0" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "bed3a691", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# these are the paths to the datasets that will be used to \n", 52 | "subject_npy_path = '/Users/jure/Documents/cossart_lab/data/jm/jm039_npy/'\n", 53 | "# npy_paths are all the subfolders of subject_path ending in _a\n", 54 | "npy_paths = [os.path.join(f.path, 'data_npy', f'plane{plane}') for f in os.scandir(subject_npy_path) if f.is_dir() and f.name.endswith('_a')]\n", 55 | "npy_paths.sort()\n", 56 | "print(npy_paths)\n", 57 | "\n", 58 | "\n", 59 | "s2p_paths = []\n", 60 | "for npy_path in npy_paths:\n", 61 | " print(npy_path)\n", 62 | " # remove _npy\n", 63 | " s2p_path = npy_path.replace('jm039_npy', 'jm039')\n", 64 | " s2p_path = s2p_path.replace('data_npy', 'suite2p')\n", 65 | " s2p_paths.append(s2p_path)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "6ab62d00", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "path = npy_paths[0] # Example path, you can loop through npy_paths if needed\n", 76 | "\n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "c9c13434", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "for (npy_path, s2p_path) in zip(npy_paths, s2p_paths):\n", 87 | " print(s2p_path)\n", 88 | " \n", 89 | " F = np.load(os.path.join(npy_path, 'F.npy'))\n", 90 | " fov = np.load(os.path.join(npy_path, 'fov.npy'))\n", 91 | " rois = np.load(os.path.join(npy_path, 'rois.npy'))\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | " if not os.path.exists(s2p_path):\n", 96 | " os.makedirs(s2p_path)\n", 97 | " else:\n", 98 | " print(f\"Directory {s2p_path} already exists, skipping... (Delete or rename it if you want to overwrite)\")\n", 99 | " continue\n", 100 | " \n", 101 | " np.save(os.path.join(s2p_path, 'F.npy'), F)\n", 102 | "\n", 103 | " ops = {'meanImg': fov}\n", 104 | " ops['nchannels'] = 1 # Assuming single channel\n", 105 | " ops['fs'] = 30 # Assuming a sampling frequency of 30 Hz\n", 106 | " ops['nframes'] = F.shape[1]\n", 107 | " np.save(os.path.join(s2p_path, 'ops.npy'), ops)\n", 108 | "\n", 109 | " # stat is a list of dictionaries, each with keys 'xpix', 'ypix'\n", 110 | " stat = []\n", 111 | " for i in range(rois.shape[0]):\n", 112 | " # TODO: make sure using the correct axis convention!!!\n", 113 | " ypix, xpix = np.where(rois[i] > 0)\n", 114 | " med = [int(np.median(ypix).item()), int(np.median(xpix).item())]\n", 115 | " stat.append({'xpix': xpix, 'ypix': ypix, 'med': med})\n", 116 | "\n", 117 | " np.save(os.path.join(s2p_path, 'stat.npy'), stat)\n", 118 | " \n", 119 | " # iscell is two columns of 1 the columns are length n_cells\n", 120 | " n_cells = len(stat)\n", 121 | " iscell = np.ones((n_cells, 2), dtype=int)\n", 122 | " np.save(os.path.join(s2p_path, 'iscell.npy'), iscell)\n", 123 | " print(f\"Saved to {s2p_path}\")\n" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "9563eb37", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "med = [int(np.median(ypix).item()), int(np.median(xpix).item())]\n", 134 | "\n", 135 | "med" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "id": "aebc8434", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "track2p", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.9.21" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 5 168 | } 169 | -------------------------------------------------------------------------------- /notebooks/utils/s2p_to_npy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "192b95be", 6 | "metadata": {}, 7 | "source": [ 8 | "# Conversion of s2p to npy\n", 9 | "\n", 10 | "This is used to implement and test the feature suggested by reviewer 2 - allowing easier use of track2p by people using pipelines other than suite2p\n", 11 | "\n", 12 | "Inputs:\n", 13 | "- suite2p folder\n", 14 | "\n", 15 | "Outputs:\n", 16 | "- rois.npy (bool ndarray of shape: (n_roi, n_px_y, n_px_x))\n", 17 | "- fov.npy (float32 ndarray of shape: (n_px_y, n_px_x)) - IMPORTANT: ROIs and FOV must be aligned (this is done by default in suite2p outputs, but might need to be registered manually in other cases)\n", 18 | "- traces.npy (float32 ndarray of shape: (n_roi, n_tstamps))\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "1e4cb8a3", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import os\n", 29 | "import numpy as np\n", 30 | "import matplotlib.pyplot as plt" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "5c6c7926", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# these are the paths to the datasets that will be used to \n", 41 | "subject_s2p_path = '/Volumes/data_jm_share/data_proc/el/el017/'\n", 42 | "# s2p_paths are all the subfolders of subject_path ending in _a\n", 43 | "s2p_paths = [f.path for f in os.scandir(subject_s2p_path) if f.is_dir() and f.name.endswith('_a')]\n", 44 | "s2p_paths.sort()\n", 45 | "print(s2p_paths)\n", 46 | "\n", 47 | "# \n", 48 | "npy_paths = []\n", 49 | "for s2p_path in s2p_paths:\n", 50 | " \n", 51 | " # add _npy after jm039 to the path\n", 52 | " npy_path = s2p_path.replace('el017', 'el017_npy')\n", 53 | " npy_path = os.path.join(npy_path, 'data_npy')\n", 54 | " # make this directory if it does not exist\n", 55 | "\n", 56 | " if not os.path.exists(npy_path):\n", 57 | " os.makedirs(npy_path)\n", 58 | "\n", 59 | " npy_paths.append(npy_path)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "f8811710", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# now write code to load the s2p data and saves it to npy\n", 70 | "plane = 0\n", 71 | "for plane in range(2):\n", 72 | " print(f'Processing plane {plane}...')\n", 73 | " for s2p_path, npy_path in zip(s2p_paths, npy_paths):\n", 74 | " print(f'Loading {s2p_path} and saving to {npy_path}')\n", 75 | "\n", 76 | " s2p_path_full = os.path.join(s2p_path, 'suite2p', f'plane{plane}')\n", 77 | "\n", 78 | " ops = np.load(os.path.join(s2p_path_full, 'ops.npy'), allow_pickle=True).item()\n", 79 | " stat = np.load(os.path.join(s2p_path_full, 'stat.npy'), allow_pickle=True)\n", 80 | " iscell = np.load(os.path.join(s2p_path_full, 'iscell.npy'), allow_pickle=True)\n", 81 | " F = np.load(os.path.join(s2p_path_full, 'F.npy'), allow_pickle=True)\n", 82 | "\n", 83 | " iscell_bool = iscell[:, 0] == 1\n", 84 | " print(sum(iscell_bool), 'cells in this dataset')\n", 85 | "\n", 86 | " # get mean image\n", 87 | " mean_img = ops['meanImg']\n", 88 | " print('Mean image shape:', mean_img.shape)\n", 89 | " print(type(mean_img[0,0]))\n", 90 | " plt.imshow(mean_img, cmap='gray')\n", 91 | " plt.show()\n", 92 | "\n", 93 | " # now filter stat and F by iscell_bool\n", 94 | " stat = stat[iscell_bool]\n", 95 | " F = F[iscell_bool, :]\n", 96 | " print('Stat shape:', stat.shape)\n", 97 | " print('F shape:', F.shape)\n", 98 | " print(type(F[0, 0]))\n", 99 | "\n", 100 | " rois = np.zeros((len(stat), mean_img.shape[0], mean_img.shape[1]), dtype=np.bool)\n", 101 | " for i, s in enumerate(stat):\n", 102 | " # create a mask for the roi\n", 103 | " mask = np.zeros((mean_img.shape[0], mean_img.shape[1]), dtype=np.bool)\n", 104 | " mask[s['ypix'], s['xpix']] = True\n", 105 | " rois[i] = mask\n", 106 | " print('ROIs shape:', rois.shape)\n", 107 | " print(type(rois[0, 0, 0]))\n", 108 | "\n", 109 | " plt.imshow(rois[0], cmap='gray')\n", 110 | " plt.show()\n", 111 | " plt.imshow(rois[1], cmap='gray')\n", 112 | " plt.show()\n", 113 | "\n", 114 | " # save to npy\n", 115 | " npy_path = os.path.join(npy_path, f'plane{plane}')\n", 116 | " \n", 117 | " if not os.path.exists(npy_path):\n", 118 | " os.makedirs(npy_path)\n", 119 | "\n", 120 | " np.save(os.path.join(npy_path, 'rois.npy'), rois)\n", 121 | " np.save(os.path.join(npy_path, 'F.npy'), F)\n", 122 | " np.save(os.path.join(npy_path, 'fov.npy'), mean_img)\n", 123 | "\n", 124 | " # TODO: save these to npy files ...\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "211c1aac", 130 | "metadata": {}, 131 | "source": [] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "f5ff8c6d", 136 | "metadata": {}, 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "track2p", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.9.21" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 5 161 | } 162 | -------------------------------------------------------------------------------- /notebooks/eval/main_eval_fig.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os \n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "# import SimpleNamespace\n", 14 | "from types import SimpleNamespace\n", 15 | "\n", 16 | "from track2p.eval.io import load_pairwise_f1_values, load_alldays_ct_values\n", 17 | "from track2p.eval.plot import plot_alldays_f1, plot_pairwise_f1\n", 18 | "\n", 19 | "# avoid restarting notebook to import changes\n", 20 | "%load_ext autoreload\n", 21 | "%autoreload 2" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Define loading and plotting functions" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Set paths and paramters" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# Chemins des fichiers NumPy\n", 45 | "base_path = '/Volumes/data_jm_share/data_proc/jm' # Remplacez par le chemin correct\n", 46 | "conditions = ['chan0', 'chan1', 'rigid', 'cellreg'] # ['chan0', 'chan1', 'rigid', 'cellreg']\n", 47 | "conditions_names = ['Anatomical', 'Functional', 'Rigid', 'CellReg']\n", 48 | "animals = ['jm038', 'jm039', 'jm046'] \n", 49 | "symbols = {'jm038': 'o', 'jm039': 'o', 'jm046': 'o'}\n", 50 | "colors = {'jm038': (0.8, 0.8, 0.8), 'jm039': 'C0', 'jm046': (0.7, 0.7, 0.7)}\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## All days evaluation" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "ct, acc = load_alldays_ct_values(base_path, animals, conditions, ct_type='CT')\n", 67 | "\n", 68 | "ct_gt, acc_gt = load_alldays_ct_values(base_path, animals, conditions, ct_type='CT_GT')\n", 69 | "\n", 70 | "# in paper CT should be used in 'all day evaluation' - it takes into account the false positives as well\n", 71 | "# and CT_GT should be used in 'pairwise evaluation' - it can't take these into account (it only reports proportion of correctly reconstructed traces)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# iterate over keys in ct\n", 81 | "f1_values = {animal: [] for animal in animals}\n", 82 | "\n", 83 | "for animal in ct.keys():\n", 84 | " for (i, condition) in enumerate(conditions):\n", 85 | " # get the last value of associated array\n", 86 | " f1_val = ct[animal][i][-1]\n", 87 | " f1_values[animal].append(f1_val)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# TODO: remove - just manually computing average F1 scores for publication\n", 97 | "mn_f1_anat = np.mean(np.array([0.89, 0.98, 0.93]))\n", 98 | "mn_f1_func = np.mean(np.array([0.85, 0.98, 0.9]))\n", 99 | "mn_f1_rigi = np.mean(np.array([0.16, 0.29, 0.21]))\n", 100 | "mn_f1_creg = 0\n", 101 | "\n", 102 | "# print them to two digits each in its own row\n", 103 | "print(f'Mean values \\nAnatomical: {mn_f1_anat:.2f} \\nFunctional: {mn_f1_func:.2f} \\nRigid: {mn_f1_rigi:.2f} \\nCellReg: {mn_f1_creg:.2f}')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "plot_alldays_f1(animals, conditions_names, f1_values, symbols, colors, xshift=0.2)\n", 113 | "\n", 114 | "animals_names = ['mouse C', 'mouse D', 'mouse F']\n", 115 | "plot_alldays_f1(animals, conditions_names, f1_values, symbols, colors, xshift=0.2, animals_names=animals_names)\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# print mean F1 values for conditions\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## Pairwise evaluation" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "\n", 141 | "\n", 142 | "for (i, condition_name) in enumerate(conditions_names):\n", 143 | " \n", 144 | " # initialise dictionary\n", 145 | " pairwise_ct_values = {}\n", 146 | "\n", 147 | " for animal in animals:\n", 148 | " pairwise_ct_values[animal] = ct_gt[animal][i]\n", 149 | "\n", 150 | " plot_pairwise_f1(animals, condition_name, pairwise_ct_values, symbols, colors)\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "condition = 'pw_reg'\n", 160 | "symbols_pw = {'jm038': 's', 'jm039': 's', 'jm046': 's'} # to differentiate from all day registration\n", 161 | "\n", 162 | "pairwise_f1_values = load_pairwise_f1_values(base_path, animals, condition)\n", 163 | "plot_pairwise_f1(animals, condition, pairwise_f1_values, symbols_pw, colors)\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "track2p", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.9.19" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /track2p/match/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from skimage import measure 4 | from skimage.filters import threshold_otsu 5 | 6 | # compute centroids 7 | def get_centroids(all_roi): 8 | centroids = [] 9 | for i in range(all_roi.shape[2]): 10 | roi = all_roi[:,:,i] 11 | labels = measure.label(roi) 12 | features = measure.regionprops(labels) 13 | try: 14 | centroids.append(features[0].centroid) 15 | except IndexError: 16 | centroids.append([0, 0]) 17 | 18 | return np.array(centroids) 19 | 20 | def get_cent_dist_mat(all_roi_ref, all_roi_reg): 21 | # compute distances 22 | centroids_ref = get_centroids(all_roi_ref) 23 | centroids_reg = get_centroids(all_roi_reg) 24 | distances = cdist(centroids_ref, centroids_reg) 25 | return distances 26 | 27 | def filt_non_overlap(all_roi1, all_roi2, cent_dist_mat): 28 | all_inds = np.arange(all_roi1.shape[2]) 29 | filt_inds_ref = [] 30 | for i in range(all_roi1.shape[2]): 31 | # get the index of closest roi from the distances matrix 32 | roi = all_roi1[:,:,i] 33 | closest_roi_idx = np.argmin(cent_dist_mat[i,:]) 34 | closest_roi = all_roi2[:,:,closest_roi_idx] 35 | intersection = roi*closest_roi 36 | if np.sum(intersection)==0: # TODO:maybe it overlaps a bit with non-first closest roi? 37 | filt_inds_ref.append(i) 38 | 39 | # get all indices that are not in filt_inds_ref 40 | all_inds_filt = all_inds[~np.isin(all_inds, filt_inds_ref)] 41 | return all_inds_filt 42 | 43 | def get_cent_dist_mat_non_overlap(all_roi_ref, all_roi_reg): 44 | 45 | cent_dist_mat = get_cent_dist_mat(all_roi_ref, all_roi_reg) 46 | 47 | # get indices of neuorns that overlap with their closest neighbor (at least will have some chance to match) 48 | all_inds_ref_filt = filt_non_overlap(all_roi_ref, all_roi_reg, cent_dist_mat) 49 | all_inds_reg_filt = filt_non_overlap(all_roi_reg, all_roi_ref, cent_dist_mat.T) 50 | 51 | cost_mat = cent_dist_mat[all_inds_ref_filt, :] 52 | cost_mat = cost_mat[:, all_inds_reg_filt] 53 | 54 | return cost_mat, all_inds_ref_filt, all_inds_reg_filt 55 | 56 | def get_cost_mat(all_roi_ref, all_roi_reg, track_ops): 57 | # compute distances 58 | if track_ops.matching_method=='cent': # simple assignment based on centroids (probably works with sparse data but not in development) 59 | cost_mat = get_cent_dist_mat(all_roi_ref, all_roi_reg) 60 | all_inds_ref_filt = np.arange(all_roi_ref.shape[2]) # here we don't filter the indices 61 | all_inds_reg_filt = np.arange(all_roi_reg.shape[2]) # here we don't filter the indices 62 | elif track_ops.matching_method=='cent_int-filt': # here to simplify the matching we first filter out the ROIs that have no intersection with their closest neighbor 63 | # this also outputs the indices of neurons after filtering 64 | # costa mat here is smaller than above since the matches are filtered 65 | # additionally when outputting we need to be careful to index with the filtered indices as well 66 | cost_mat, all_inds_ref_filt, all_inds_reg_filt = get_cent_dist_mat_non_overlap(all_roi_ref, all_roi_reg) 67 | elif track_ops.matching_method=='iou': 68 | cost_mat = 1-get_cross_iou_mat(all_roi_ref, all_roi_reg, dist_thr=track_ops.iou_dist_thr) 69 | all_inds_ref_filt = np.arange(all_roi_ref.shape[2]) 70 | all_inds_reg_filt = np.arange(all_roi_reg.shape[2]) 71 | else: 72 | raise Exception('Matching method not implemented') 73 | 74 | print(f'cost_mat computed with method: {track_ops.matching_method}') 75 | print(f'cost_mat shape: {cost_mat.shape}') 76 | print(f'cost_mat min: {np.min(cost_mat)}') 77 | print(f'cost_mat max: {np.max(cost_mat)}') 78 | 79 | return cost_mat, all_inds_ref_filt, all_inds_reg_filt 80 | 81 | def get_iou(all_roi_ref, all_roi_reg): 82 | 83 | ious = [] 84 | for i in range(all_roi_ref.shape[2]): 85 | roi_ref = all_roi_ref[:,:,i] 86 | roi_reg = all_roi_reg[:,:,i] 87 | intersection = np.sum(np.logical_and(roi_ref, roi_reg)) 88 | union = np.sum(np.logical_or(roi_ref, roi_reg)) 89 | ious.append(intersection/union) 90 | 91 | return np.array(ious) 92 | 93 | def get_cross_iou_mat(all_roi_ref, all_roi_reg, dist_thr=16): 94 | # if the distance between two rois is larger than dist_thr, we assume they are not the same cell and just skip the computation 95 | distances = get_cent_dist_mat(all_roi_ref, all_roi_reg) 96 | 97 | cross_iou_mat = np.zeros((all_roi_ref.shape[2], all_roi_reg.shape[2])) 98 | for i in range(all_roi_ref.shape[2]): 99 | for j in range(all_roi_reg.shape[2]): 100 | if distances[i,j] > dist_thr: # skipping if far apart 101 | continue 102 | # compute IOU 103 | intersection = np.logical_and(all_roi_ref[:,:,i], all_roi_reg[:,:,j]) 104 | union = np.logical_or(all_roi_ref[:,:,i], all_roi_reg[:,:,j]) 105 | iou_score = np.sum(intersection) / np.sum(union) 106 | cross_iou_mat[i,j] = iou_score 107 | 108 | return cross_iou_mat 109 | 110 | def init_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops): 111 | all_pl_match_mat = [] 112 | for i in range(track_ops.nplanes): 113 | # set up the match matrix for each plane it will be size of all iscell ROIs in the ref recording x number of datasets 114 | pl_match_mat = np.full((all_ds_all_roi_ref[0][i].shape[2], len(track_ops.all_ds_path)), None) 115 | all_pl_match_mat.append(pl_match_mat) 116 | # populate first row of the match matrices with the matches from the first ref-reg pair 117 | for i in range(track_ops.nplanes): 118 | pl_match_mat = all_pl_match_mat[i] 119 | assign_thr = all_ds_assign_thr[0][i] # zeroth dataset, ith plane 120 | ref_ind = assign_thr[0] 121 | pl_match_mat[ref_ind, 0] = ref_ind 122 | 123 | return all_pl_match_mat 124 | 125 | def filt_by_otsu(vect_filt, vect_comp): 126 | # vect_filt is the vector that we want to filter 127 | # vect_comp is the vector that we want to compute the otsu threshold on 128 | 129 | thresh = threshold_otsu(vect_comp) 130 | return vect_filt[vect_comp>thresh] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # track2p 2 | 3 | 4 |
5 | 6 | Cell tracking for longitudinal calcium imaging recordings. 7 | 8 | [![PyPI version](https://img.shields.io/pypi/v/track2p)](https://pypi.org/project/track2p/) 9 | [![All time downloads](https://static.pepy.tech/badge/track2p)](https://pepy.tech/project/track2p) 10 | [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) 11 | 12 | For more detailed information on installation and use visit: 13 | 14 | https://track2p.github.io/ 15 |
16 |
17 | Or read the [paper](https://elifesciences.org/articles/107540). 18 |
19 |
20 | 21 | # Installation 22 | 23 | ## Installing via pip 24 | 25 | First we need to set up a conda environment with python 3.9: 26 | 27 | ``` 28 | conda create --name track2p python=3.9 29 | conda activate track2p 30 | ``` 31 | 32 | Then simply install the track2p package using pip: 33 | 34 | ``` 35 | pip install track2p 36 | ``` 37 | 38 | Thats it, track2p should be succesfully set up :) 39 | You can simply run it by: 40 | 41 | ``` 42 | python -m track2p 43 | ``` 44 | 45 | This opens a GUI allowing the user to launch the algorithm and visualise the results interactively. 46 | 47 | (For instructions on running track2p without the GUI see the 'Run via script' under the 'Usage' section) 48 | 49 | Note: For common installation issues see ['Installation > Common issues'](https://track2p.github.io/install_common_issues.html) in documentation. 50 | 51 | ## Reinstall 52 | 53 | To reinstall, open anaconda and remove the environment with: 54 | 55 | ``` 56 | conda env remove -n track2p 57 | ``` 58 | 59 | Then follow the 'Installing via pip' instructions above :) 60 | 61 | 62 | # Usage 63 | 64 | ## Run track2p through the GUI 65 | 66 | After activating the GUI through `python -m track2p` the user should navigate to the 'Run' tab on the top left of the window and select 'Run track2p algorithm' from the dropdown menu. 67 | 68 | This will open a pop-up window that will allow the user to set the paths to suite2p datasets and to set the algorithm parameters. After configuring these settings, the user can click 'Run' to run the track2p algorithm, and the progress will be displayed in the terminal. 69 | 70 | Once the algorithm finishes a subsequent pop-up window will prompt the user to decide whether they wish to visualize the results within the interface. 71 | 72 | For more details on how to run the algorithm through the GUI see [run track2p](https://track2p.github.io/run_track2p_gui.html) and for more description of parameters see documentation [parameters](https://track2p.github.io/run_inputs_and_parameters.html#parameters). 73 | 74 | ## Run track2p via script 75 | 76 | To run via script you can use the `run_track2p.py` script in the root of this repo as a template. It is exactly the same as running thrugh the gui, only that the paths and the parameters are defined within the script (for more on parameters etc. see documentation). When running make sure you are running it within the track2p environment, for example: 77 | 78 | ``` 79 | conda activate track2p 80 | python -m run_track2p 81 | ``` 82 | 83 | 84 | ## Visualising track2p outputs within the GUI 85 | 86 | After activating the GUI through python -m track2p the user can import the results of any previous analysis by clicking on 'File' tab on the top left of the window and select 'Load processed data' from the dropdown menu. This will open a pop-up window that will allow the user to set the path to the track2p folder (containing the results of the algorithm) and the plane they want to open. 87 | 88 | Once completed, the interface showcases multiple visualizations: 89 | 90 | ![ex_all_vizualizations.png](docs/media/plots/ex_all_vizualizations.png) 91 | 92 | In this example we are using the track2p GUI to visualise the outputs for an experiment containing 7 consecutive daily recordings in mouse barrel cortex (between P8 and P14). 93 | 94 | In the upper left, the GUI visualises the mean image of the motion-corrected functional channel (usually green / GCaMP). The image is overlayed with ROIs of the cells detected by track2p across all days, with the color of a particular cell matching across days. These images are interactive, allowing the user to click on a cell, which displays the fluorescence traces on each day at the bottom of the window (sorted from the first day to the last). 95 | 96 | In addition, a zoomed-in image of the cell for each day is shown in the top right. Underneath each zoomed-in image the GUI displays this cell's index in the corresponding 'suite2p’ dataset and the 'iscell' probability suite2p has assigned to it on that day. 97 | 98 | Finally, the user can browse all the putative matches detected by the algorithm using the bar at the bottom to toggle through matches, or alternatively they can enter the index of a specific number to display it within the GUI. This bar is also used for manual curation, where we allow the user to evaluate the quality of the tracking for each individual match. 99 | 100 | For more details on how to use the GUI see [GUI usage](https://track2p.github.io/gui_overview.html). 101 | 102 | 103 | # Outputs 104 | 105 | All the outputs of the script will be saved in a `track2p` folder created within the `track_ops.save_path` directory specified by the user when running the algorithm. For an introduction on how to use the outputs for further downstream analysis we provide a useful demo notebook `demo_t2p_output.ipynb` in the root of this repository. Note: You will need to additionally install jupyter for this to work. For example: 106 | 107 | ``` 108 | conda install conda-forge::jupyterlab 109 | ``` 110 | 111 | For more information see documentation relating to track2p [visualisations](https://track2p.github.io/outputs_visualisations.html) and [outputs](https://track2p.github.io/outputs_matches.html). 112 | 113 | # Troubleshooting 114 | 115 | A brief troubleshooting guide including some common issues is included [here](https://github.com/juremaj/track2p/blob/main/docs/troubleshooting.md). 116 | 117 | # Reference 118 | 119 | If you use the algorithm please reference the [eLife paper](https://elifesciences.org/articles/107540): 120 | 121 | **Majnik, J., Mantez, M., Zangila, S., Bugeon, S., Guignard, L., Platel, J.-C., & Cossart, R. (2025). Longitudinal tracking of neuronal activity from the same cells in the developing brain using Track2p. eLife.** 122 | 123 | You can also see a Youtube recording of a talk related to the paper: [Link to video (starting at 47:20)](https://youtu.be/Tr97HwgQ9ik?t=2839) 124 | 125 | The data associated with the paper is also available on [Zenodo](https://zenodo.org/records/17091226) 126 | 127 | ___ 128 | 129 | 130 | By Cossart Lab (Jure Majnik & Manon Mantez) 131 | 132 | Logo by Eleonora Ambrad Giovannetti 133 | 134 | © Copyright 2025. 135 | -------------------------------------------------------------------------------- /track2p/gui/fluo_plot.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.stats import zscore 6 | from PyQt5.QtCore import Qt 7 | import matplotlib.patches as patches 8 | from PyQt5 import QtCore 9 | from scipy.ndimage import maximum_filter1d, minimum_filter1d, gaussian_filter 10 | 11 | 12 | class FluorescencePlotWidget(FigureCanvas): 13 | """this class is used to display the fluorescence of the selected cell across days. It also allows to select a region of interest (ROI) on the fluorescence plot and zoom in on the selected ROI""" 14 | def __init__(self, all_f_t2p=None, all_ops=None, colors=None, all_stat_t2p=None): 15 | self.fig, self.ax_fluorescence = plt.subplots(1, 1) 16 | super().__init__(self.fig) 17 | self.all_f_t2p = all_f_t2p 18 | self.all_ops=all_ops 19 | self.fig.set_facecolor('black') 20 | self.colors = colors 21 | self.all_stat_t2p= all_stat_t2p 22 | 23 | 24 | self.rect = patches.Rectangle((0,0), 1, 1, color='white', linewidth=2) 25 | 26 | self.fig.canvas.mpl_connect('button_press_event', self.on_press) 27 | self.fig.canvas.mpl_connect('button_release_event', self.on_release) 28 | self.fig.canvas.mpl_connect('key_press_event', self.on_key_press) 29 | self.fig.canvas.mpl_connect('key_press_event', self.on_enter_pressed) 30 | self.setFocusPolicy(Qt.StrongFocus) 31 | self.setFocus() 32 | 33 | self.initial_xlim = None 34 | self.initial_ylim = None 35 | self.cmd_pressed = False 36 | 37 | def draw_rectangle(self): 38 | """it draws a rectangle on the fluorescence plot""" 39 | #if the rectangle is already on the plot, it is removed 40 | if self.rect in self.ax_fluorescence.patches: 41 | self.ax_fluorescence.patches.remove(self.rect) 42 | self.rect = patches.Rectangle((self.x0, self.y0), self.x1 - self.x0, self.y1 - self.y0, fill=None, linewidth=2, edgecolor='white') #create a rectangle, fill is set to None to make the rectangle transparent 43 | self.rect.set_zorder(10) 44 | self.ax_fluorescence.add_patch(self.rect) 45 | self.draw() 46 | 47 | def on_key_press(self, event): 48 | """it allows to reset the zoom of the fluorescence plot. It is called when the Command key and - key are pressed.""" 49 | if event.key == 'r': 50 | self.ax_fluorescence.set_xlim(self.initial_xlim) 51 | self.ax_fluorescence.set_ylim(self.initial_ylim) 52 | self.fig.canvas.draw() 53 | 54 | def draw_point(self): 55 | if hasattr(self, 'point') and self.point in self.ax_fluorescence.collections: 56 | print(self.ax_fluorescence.collections) 57 | self.ax_fluorescence.collections.remove(self.point) 58 | self.point = self.ax_fluorescence.scatter([self.x0], [self.y0], s=5, color='w') # Create a new point 59 | self.draw() 60 | 61 | def on_press(self, event): 62 | """it allows to draw a point on the fluorescence plot. It is called when the mouse button is pressed.""" 63 | self.x0 = event.xdata #x coordinate of the mouse cursor 64 | self.y0 = event.ydata #y coordinate of the mouse cursor 65 | self.draw_point() 66 | 67 | def on_release(self, event): 68 | """It allows to draw a rectangle on the fluorescence plot. It is called when the mouse button is released.""" 69 | self.x1 = event.xdata 70 | self.y1 = event.ydata 71 | self.rect.set_width(self.x1 - self.x0) 72 | self.rect.set_height(self.y1 - self.y0) 73 | self.rect.set_xy((self.x0, self.y0)) #set the position of the rectangle 74 | self.draw_rectangle() 75 | 76 | 77 | def on_enter_pressed(self, event): 78 | """it allows to zoom in of the fluorescence plot. It is called when the Command key and + key are pressed.""" 79 | if event.key == 'enter': 80 | self.ax_fluorescence.set_xlim(self.x0, self.x1) 81 | self.ax_fluorescence.set_ylim(self.y0, self.y1) 82 | self.rect.set_visible(False) 83 | self.point.set_visible(False) 84 | self.fig.canvas.draw() 85 | 86 | 87 | def display_all_f_t2p(self, selected_cell_index): 88 | """it plots the fluroescence of the selected cell across days where each curve being a different day (the curve at the top of the plot is the first day)""" 89 | 90 | if self.all_f_t2p is not None and selected_cell_index is not None: 91 | self.ax_fluorescence.clear() 92 | self.ax_fluorescence.set_facecolor('black') 93 | self.ax_fluorescence.tick_params(axis='x', colors='white') 94 | self.ax_fluorescence.tick_params(axis='y', colors='white') 95 | self.ax_fluorescence.xaxis.label.set_color('white') 96 | self.ax_fluorescence.yaxis.label.set_color('white') 97 | self.ax_fluorescence.spines['bottom'].set_color('#666') 98 | 99 | 100 | for i, fluorescence_data in list(enumerate(reversed(self.all_f_t2p))): 101 | #print(i) 102 | #print(fluorescence_data[selected_cell_index, :]) 103 | fluorescence_zscore = zscore(fluorescence_data, axis=1, ddof=1) #zscore is used to normalize the fluorescence data 104 | offset = i * 12 # 105 | y_values = fluorescence_zscore[selected_cell_index, :] + offset 106 | #print(y_values) 107 | color = self.colors[selected_cell_index] 108 | #create a gradient of colors for the curves (the darkest shade is for the last day and the lightest shade is for the first day) 109 | if i == 0: 110 | color = color 111 | else: 112 | h, l, s = colorsys.rgb_to_hls(*color) 113 | l_range = 1 - (l + 0.05) 114 | l_add = l_range/len(self.all_f_t2p) 115 | 116 | adjusted_luminosity = l + (l_add *i) 117 | color = colorsys.hls_to_rgb(h, adjusted_luminosity, s) 118 | 119 | ops=self.all_ops[i] 120 | fs = ops['fs'] 121 | tstamps = np.arange(len(y_values)) 122 | 123 | if len(tstamps) != len(y_values): 124 | raise ValueError(f"tstamps and y_values must have the same length, but have lengths {len(tstamps)} and {len(y_values)}") 125 | self.ax_fluorescence.plot(tstamps,y_values, label=f'Curve {i + 1}', color= color) 126 | 127 | self.ax_fluorescence.set_xticks([0, int(len(tstamps)/2), len(tstamps)]) 128 | self.ax_fluorescence.set_yticklabels([]) 129 | self.ax_fluorescence.get_yaxis().set_visible(False) 130 | self.ax_fluorescence.set_xlabel('Frames') 131 | self.initial_xlim=self.ax_fluorescence.get_xlim() 132 | self.initial_ylim=self.ax_fluorescence.get_ylim() 133 | self.fig.tight_layout() 134 | self.draw() 135 | 136 | 137 | pass 138 | 139 | -------------------------------------------------------------------------------- /track2p/gui/central_widget.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import Qt 2 | from PyQt5.QtWidgets import QTabWidget, QVBoxLayout, QWidget, QSplitter, QHBoxLayout, QFrame, QFrame 3 | from track2p.gui.fluo_plot import FluorescencePlotWidget 4 | from track2p.gui.roi_plot import ZoomPlotWidget 5 | from track2p.gui.cell_plot import CellPlotWidget 6 | from track2p.gui.data_management import DataManagement 7 | from track2p.gui.raster_wd import RasterWindow 8 | 9 | class CentralWidget(QWidget): 10 | def __init__(self, main_window): 11 | super().__init__() 12 | 13 | self.main_window = main_window 14 | self.fluorescences_plotting = None 15 | self.rois_plotting = None 16 | self.selected_roi = None 17 | self.cell_plot = None 18 | self.track_ops_dict=None 19 | self.data_management = DataManagement(self) 20 | self.vector_curation_t2p = self.data_management.vector_curation_t2p 21 | self.init_central_widget() 22 | 23 | def init_central_widget(self): 24 | self.top = QFrame() 25 | self.top.setFrameShape(QFrame.StyledPanel) 26 | self.top_layout = QHBoxLayout(self.top) 27 | 28 | self.top_right = QFrame() 29 | self.top_right.setFrameShape(QFrame.StyledPanel) 30 | self.top_layout_right = QVBoxLayout(self.top_right) 31 | 32 | self.tabs = QTabWidget(self) 33 | 34 | self.splitter1 = QSplitter(Qt.Horizontal) 35 | self.splitter1.addWidget(self.tabs) 36 | self.splitter1.addWidget(self.top) 37 | self.splitter1.setSizes([100, 100]) 38 | 39 | self.splitter2 = QSplitter(Qt.Horizontal) 40 | self.splitter2.addWidget(self.top_right) 41 | 42 | self.splitter3 = QSplitter(Qt.Vertical) 43 | self.splitter3.addWidget(self.splitter1) 44 | self.splitter3.addWidget(self.splitter2) 45 | self.splitter3.setSizes([100, 100]) 46 | 47 | central_layout = QVBoxLayout() 48 | central_layout.addWidget(self.splitter3) 49 | self.setLayout(central_layout) 50 | 51 | 52 | def create_mean_img(self,channel): 53 | for i, (ops, stat_t2p) in enumerate(zip(self.data_management.all_ops, self.data_management.all_stat_t2p)): 54 | tab = QWidget() 55 | self.cell_plot = CellPlotWidget(tab, ops=ops, stat_t2p=stat_t2p, f_t2p=self.data_management.all_f_t2p[i], 56 | colors=self.data_management.colors, update_selection_callback=self.update_selection, 57 | all_f_t2p=self.data_management.all_f_t2p, all_ops=self.data_management.all_ops, channel=channel) 58 | layout = QVBoxLayout(tab) 59 | layout.addWidget(self.cell_plot) 60 | tab.setLayout(layout) 61 | self.tabs.addTab(tab, f"Day {i + 1}") 62 | self.cell_plot.cell_selected.connect(self.update_selection) 63 | 64 | 65 | def create_mean_img_from_curation(self): 66 | import_window = self.main_window.window_manager.import_window 67 | t2p_window = self.main_window.window_manager.t2p_window 68 | 69 | if import_window is not None and import_window.plane is not None: 70 | self.data_management.import_files(import_window.path_to_t2p, import_window.plane, import_window.trace_type, import_window.channel) 71 | elif t2p_window is not None and t2p_window.saved_directory is not None: 72 | self.data_management.import_files(t2p_window.saved_directory, t2p_window.dialog.plane, t2p_window.dialog.trace_type, t2p_window.dialog.channel) 73 | else: 74 | print("Both import_window and t2p_window are None or not properly initialized.") 75 | 76 | def clear(self): 77 | self.data_management.reset_attributes() 78 | if self.fluorescences_plotting: 79 | self.top_layout_right.removeWidget(self.fluorescences_plotting) 80 | self.fluorescences_plotting.deleteLater() 81 | self.fluorescences_plotting = None 82 | if self.rois_plotting: 83 | self.top_layout.removeWidget(self.rois_plotting) 84 | self.rois_plotting.deleteLater() 85 | self.rois_plotting = None 86 | for i in range(self.tabs.count()): 87 | self.tabs.removeTab(0) 88 | 89 | 90 | def update_selection(self, selected_cell_index): 91 | self.selected_roi = selected_cell_index 92 | self.main_window.status_bar.spin_box.setValue(selected_cell_index) 93 | self.main_window.status_bar.roi_state_value.setText(f"{self.vector_curation_t2p[selected_cell_index]}") 94 | #it removes the underline of the previsouly selected cell even if the tab is not visible (not the current tab) 95 | for i in range(self.tabs.count()): 96 | tab_widget = self.tabs.widget(i) 97 | cell_object = tab_widget.findChild(CellPlotWidget) 98 | cell_object.remove_previous_underline() 99 | current_tab_index = self.tabs.currentIndex() 100 | current_tab_widget = self.tabs.widget(current_tab_index) 101 | cell_plot = current_tab_widget.findChild(CellPlotWidget) 102 | if cell_plot: 103 | cell_plot.underline_cell(selected_cell_index) 104 | if self.fluorescences_plotting is None: 105 | self.fluorescences_plotting = FluorescencePlotWidget(all_f_t2p=self.data_management.all_f_t2p, 106 | all_ops=self.data_management.all_ops, 107 | colors=self.data_management.colors) 108 | self.top_layout_right.addWidget(self.fluorescences_plotting) 109 | if self.rois_plotting is None: 110 | self.rois_plotting = ZoomPlotWidget(all_ops=self.data_management.all_ops, 111 | all_stat_t2p=self.data_management.all_stat_t2p, 112 | colors=self.data_management.colors, 113 | all_iscell_t2p=self.data_management.all_iscell, 114 | t2p_match_mat_allday=self.data_management.t2p_match_mat_allday,track_ops=self.track_ops) 115 | self.top_layout.addWidget(self.rois_plotting) 116 | 117 | self.fluorescences_plotting.display_all_f_t2p(selected_cell_index) 118 | self.rois_plotting.display_zooms(selected_cell_index) 119 | 120 | 121 | def display_first_ROI(self,index): 122 | """it displays the first cell of the t2p_match_mat_allday and its fluorescence and zooms across days. It is called when the application is opened. 123 | An instance of FluorescencePlotWidget and an instance of ZoomPlotWidget are created and added to attributes of the MainWindow class. """ 124 | tab_widget = self.tabs.widget(0) 125 | cell_object = tab_widget.findChild(CellPlotWidget) #It finds the instance of the CellPlotWidget class in the first tab of the QTabWidget 126 | cell_object.underline_cell(index) 127 | cell_object.draw() 128 | if self.fluorescences_plotting is None: 129 | self.fluorescences_plotting = FluorescencePlotWidget(all_f_t2p=self.data_management.all_f_t2p, 130 | all_ops=self.data_management.all_ops, 131 | colors=self.data_management.colors, all_stat_t2p=self.data_management.all_stat_t2p) 132 | self.top_layout_right.addWidget(self.fluorescences_plotting) 133 | if self.rois_plotting is None: 134 | self.rois_plotting = ZoomPlotWidget(all_ops=self.data_management.all_ops, 135 | all_stat_t2p=self.data_management.all_stat_t2p, 136 | colors=self.data_management.colors, 137 | all_iscell_t2p=self.data_management.all_iscell, 138 | t2p_match_mat_allday=self.data_management.t2p_match_mat_allday,track_ops=self.data_management.track_ops, imgs= self.cell_plot.all_img) 139 | self.top_layout.addWidget(self.rois_plotting) 140 | self.fluorescences_plotting.display_all_f_t2p(index) 141 | self.rois_plotting.display_zooms(index) 142 | self.main_window.status_bar.roi_state_value.setText(f"{self.vector_curation_t2p[self.main_window.status_bar.spin_box.value()]}") # 143 | -------------------------------------------------------------------------------- /docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting Guide: Low Cell Counts and Tracking Failures in Track2p 2 | 3 | This guide provides practical steps for diagnosing and resolving common issues encountered when running Track2p, especially when tracking neurons across multiple days. It covers: 4 | 5 | - Low numbers of tracked cells 6 | - Potential tracking and registration failures 7 | - Recommended parameter adjustments 8 | - Handling potentially problematic sessions 9 | 10 | Whenever you encounter unexpected results, it is strongly recommended to inspect the **Track2p output figures** (see documentation: *Outputs & Visualisations*) and, if possible, view your raw or Suite2p-processed data side-by-side across days. 11 | 12 | In the ideal world the solution is to track a subset of cells manually (see the paper for one way of doing this), and use that as a 'validation set' to help you choose the right parameters for tracking. Generating the ground truth dataset will also give you an insight into how many cells you would be expecting to be successfully tracked across all days in your experimental setting. 13 | 14 | --- 15 | 16 | ## 1. Low Numbers of Tracked Cells Across Days 17 | 18 | Low cell counts across days can arise from issues in **cell detection** (relating to **cell activity**, **FOV consistency** etc.), or **tracking quality**. Track2p only reports cells that can be matched across *all* selected days, so variability in detection or registration can significantly reduce the final cell set. 19 | 20 | ### 1.1 Issues with cell detection ('segmentation') 21 | 22 | #### Important points 23 | - **Suite2p** only detects ROIs that are "active" on that session and considered as “good” (based on the outputs of a classifier) 24 | - **Track2p** only includes cells that appear on *every* selected day. 25 | If a cell is missing on any day (due to low activity or detection thresholds), it will not be included in the tracked set. 26 | 27 | #### Common causes 28 | 1. **Different active cells across days** 29 | Some neurons may be active only on some days, making them undetectable in others. 30 | 31 | 2. **Conservative Suite2p parameters** 32 | High thresholds can reduce the number of detected ROIs, especially those with low activity or considered below the threshold of the classifier. 33 | 34 | 3. **Changes or drift in the recording** 35 | These can cause ROI detection inconsistencies: 36 | - z-shift 37 | - moving blood vessels 38 | - debris or accumulated tissue along FOV edges 39 | - subtle optical changes across sessions 40 | 41 | ### 1.2 Recommended Solutions 42 | 43 | #### 1. Inspect Track2p output figures 44 | See the *Outputs & Visualisations* documentation. 45 | These plots help identify whether failures are due to cell detection or during registration, matching, or thresholding. 46 | 47 | #### 2. Inspect Suite2p outputs across days 48 | Open two or more Suite2p GUIs side-by-side and check: 49 | - Are the same ROIs detected across days? 50 | - Do obvious cells appear on one day but not another? 51 | 52 | (Optional but ideal): manually track a small set of neurons to generate “ground truth” and verify Track2p performance. 53 | 54 | #### 3. Adjust Suite2p parameters to detect more cells 55 | If detection seems too strict, reduce conservativeness by tuning Suite2p parameters (see Suite2p documentation). 56 | This should increases cell count but also probably at the expense of increasing the number of false positives. 57 | False ROIs can be removed manually before or after Track2p processing (using either the Suite2p or Track2p curation capabilities). 58 | 59 | #### 4. Include more borderline ROIs in Track2p 60 | Lowering the Track2p parameter `iscell_thr` allows inclusion of ROIs rejected by the Suite2p classifier. 61 | 62 | - Suite2p default: **0.5** 63 | - Recommended more permissive option: **0.25** 64 | 65 | This would have a similar effect as above, so also make sure to remove any additional false positives that might occur as a result of this. 66 | 67 | --- 68 | 69 | ## 2. Issues with cell tracking ('linking') 70 | 71 | Track2p’s tracking pipeline involves three stages: 72 | 73 | 1. **Registration** of consecutive FOVs 74 | 2. **Matching** ROIs based on spatial overlap 75 | 3. **Automatic thresholding** of IoU values to filter reliable matches 76 | 77 | A problem in any stage can reduce the number of tracked cells. 78 | 79 | ### 2.1 Common Causes of Tracking Failure 80 | 81 | #### 1. Registration failure 82 | Check `reg_img_output.png`. 83 | The bottom row overlays red/green images across sessions. They should align well. 84 | If they do not, all subsequent steps will fail! 85 | 86 | #### 2. Poor matching quality 87 | If registration is poor or ROIs do not overlap sufficiently, spatial matching breaks down. 88 | Matching can also have issues even if the registration is successful, but this would only expected to be the case for samples with extremely dense somata and thick optical sectioning (leading to 'overlapping' ROIs). 89 | 90 | #### 3. Thresholding failure 91 | If registration/matching fail, the IoU histogram (`thr_met_hist.png`) may: 92 | - be unimodal, for example looking like it is decaying 'exponentially' from 0 93 | - lack a visible separation between matched/unmatched populations (the two usual modes) 94 | 95 | This causes automatic thresholds to be placed arbitrarily, leading to most matches being discarded or accepted, without any statistical justification. 96 | 97 | ### 2.2 Recommended Solutions 98 | 99 | #### 1. Examine diagnostic figures 100 | - **Registration** → `reg_img_output.png` 101 | - **Matching & thresholding** → `thr_met_hist.png` 102 | 103 | These files usually reveal the underlying issue. 104 | 105 | Registration: check the last row of `reg_img_output.png` should show good overlap between red and green images, if not it means the registration did not work properly. It could also be the case that it only works in a part of the field of view and not homogenously! If you see the algorithm has failed make sure to check the FOV images and see if they resemble each other (e. g. could they be aligned manually). 106 | 107 | Matching & thresholding: check that `thr_met_hist.png` shows a bimodal histogram for all pairs of consecutive sessions. Also check that the statistical threshold that is marked as a vertical line 'makes sense' statistically (e. g. it should separate the two assumed distributions corresponding to the two modes). 108 | 109 | #### 2. Try alternate registration methods 110 | Track2p supports: 111 | - `'affine'` (default) 112 | - `'rigid'` 113 | 114 | Rigid has fewer degrees of freedom so it is easier to find a good alignment in theory. However it can not account for expansion or shearing across sessions, so in the case of early development `'affine'` would be prefered (based on results from the eLife paper, might depend on the particular preparation). We have not tested this in adults, but it might be that `'rigid'` would work better - we encourage users to try both :) 115 | 116 | #### 3. Switch thresholding method 117 | If your IoU histograms are bimodal, but you think that the threshold does not appropriately separate the two distributions, you can try a different threshold, currently the options are: 118 | - `'otsu'` 119 | - `'min'` 120 | 121 | Depending on histogram shape, one or the other may be more appropriate. 122 | 123 | #### 4. Remove problematic recording days 124 | If a single session is very different (e.g., z-shift, optical debris, abnormal brightness), it may disrupt registration between the surrounding sessions. 125 | Removing the problematic day often restores normal tracking across the remaining days, but it of course leads to missing values, which might be an issue in downsream analysis. 126 | 127 | #### 5. Consider shorter intervals 128 | For some scientific questions it might be sufficient to only track across two (or a few) neighbouring days. If this is the case, the user can also perform shorter tracks that are expected to yield more accurately tracked neurons. 129 | 130 | --- 131 | 132 | ## 4. Summary of Recommended Workflow for Troubleshooting 133 | 134 | 1. **Inspect Track2p output figures** to locate the stage causing failure. 135 | 2. **Check Suite2p results side-by-side** across days for ROI consistency. 136 | 3. **Adjust Suite2p thresholds** or Track2p’s `iscell_thr` to increase detected ROI overlap across days. 137 | 4. **Try different registration methods** and thresholding strategies. 138 | 5. **Identify and remove problematic sessions** (e.g., z-shift or debris). 139 | 6. **Prefer shorter, overlapping blocks** if long-span tracking is inconsistent and if it does not interfere with your scientific goals :) 140 | 141 | This approach usually resolves the most common problems involving low tracked cell counts and tracking failures. 142 | -------------------------------------------------------------------------------- /track2p/gui/data_management.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import matplotlib.colors as mcolors 5 | import random 6 | from types import SimpleNamespace 7 | from scipy.ndimage import maximum_filter1d, minimum_filter1d, gaussian_filter 8 | 9 | class DataManagement: 10 | def __init__(self, central_widget): 11 | self.central_widget = central_widget 12 | self.main_window = central_widget.main_window 13 | self.all_f_t2p = [] 14 | self.all_ops = [] 15 | self.all_stat_t2p = [] 16 | self.all_iscell = [] 17 | self.all_fneu= [] 18 | self.colors = None 19 | self.t2p_match_mat_allday = None 20 | self.track_ops = None 21 | self.vector_curation_t2p = None 22 | self.curation_npy = None 23 | self.colors_copy=None 24 | self.plane=None 25 | self.trace_type=None 26 | 27 | def reset_attributes(self): 28 | self.all_f_t2p = [] 29 | self.all_ops = [] 30 | self.all_stat_t2p = [] 31 | self.all_iscell = [] 32 | self.all_fneu= [] 33 | self.colors = None 34 | self.t2p_match_mat_allday = None 35 | self.track_ops = None 36 | self.vector_curation_t2p = None 37 | 38 | def import_files(self, t2p_folder_path, plane, trace_type, channel): 39 | 40 | self.plane=plane 41 | self.trace_type=trace_type 42 | 43 | if self.central_widget.fluorescences_plotting is not None: 44 | self.central_widget.clear() 45 | # load track2p outputs 46 | t2p_match_mat = np.load(os.path.join(t2p_folder_path,"track2p" ,f"plane{plane}_match_mat.npy"), allow_pickle=True) 47 | self.t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :] # remove rows with None values 48 | track_ops_dict = np.load(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), allow_pickle=True).item() 49 | track_ops = SimpleNamespace(**track_ops_dict) 50 | 51 | 52 | # process suite2p files 53 | for (i, ds_path) in enumerate(track_ops.all_ds_path): 54 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'ops.npy'), allow_pickle=True).item() 55 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'stat.npy'), allow_pickle=True) 56 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy'), allow_pickle=True) 57 | if trace_type == 'F' : 58 | print('F trace') 59 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True) 60 | if trace_type == 'spks': 61 | print('spks trace') 62 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'spks.npy'), allow_pickle=True) 63 | if trace_type == 'dF/F0': 64 | print('dF/F0 trace') 65 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True) 66 | fneu = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'Fneu.npy'), allow_pickle=True) 67 | if track_ops.iscell_thr is None: 68 | fneu_iscell = fneu[iscell[:, 0] == 1, :] 69 | else: 70 | fneu_iscell = fneu[iscell[:, 1] > track_ops.iscell_thr, :] 71 | fneu_t2p= fneu_iscell[self.t2p_match_mat_allday[:, i].astype(int), :] 72 | self.all_fneu.append(fneu_t2p) 73 | if track_ops.iscell_thr is None: 74 | stat_iscell = stat[iscell[:, 0] == 1] 75 | f_iscell = f[iscell[:, 0] == 1, :] 76 | else: 77 | stat_iscell = stat[iscell[:, 1] > track_ops.iscell_thr] 78 | f_iscell = f[iscell[:, 1] > track_ops.iscell_thr, :] 79 | 80 | 81 | stat_t2p = stat_iscell[self.t2p_match_mat_allday[:, i].astype(int)] 82 | f_t2p = f_iscell[self.t2p_match_mat_allday[:, i].astype(int), :] 83 | self.all_stat_t2p.append(stat_t2p) 84 | self.all_f_t2p.append(f_t2p) 85 | self.all_ops.append(ops) 86 | self.all_iscell.append(iscell) 87 | 88 | if trace_type == 'dF/F0': 89 | for i in range(len(self.all_f_t2p)): 90 | f_prc=self.F_processing(F=self.all_f_t2p[i], Fneu=self.all_fneu[i], fs=self.all_ops[i]['fs']) 91 | self.all_f_t2p[i]=f_prc 92 | 93 | 94 | attr_name = 'vector_curation_plane_' + str(plane) 95 | if hasattr(track_ops, attr_name): 96 | key = 'vector_curation_plane_' + str(plane) 97 | self.vector_curation_t2p=track_ops_dict[key] 98 | else: 99 | vector_curation_keys=np.arange(self.t2p_match_mat_allday.shape[0]) 100 | vector_curation_values = np.ones_like(vector_curation_keys) 101 | self.vector_curation_t2p_dict = dict(zip(vector_curation_keys, vector_curation_values)) 102 | values = list(self.vector_curation_t2p_dict.values()) 103 | self.vector_curation_t2p = np.array(values) 104 | key = 'vector_curation_plane_' + str(plane) 105 | track_ops_dict[key] = self.vector_curation_t2p 106 | np.save(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), track_ops_dict) 107 | 108 | 109 | attr_name_color = 'colors_plane_' + str(plane) 110 | if hasattr(track_ops, attr_name_color): 111 | self.colors= track_ops_dict[attr_name_color] 112 | else: 113 | self.colors=self.generate_vibrant_colors(len(self.all_stat_t2p[0])) 114 | track_ops_dict[attr_name_color] = self.colors 115 | np.save(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), track_ops_dict) 116 | 117 | 118 | self.track_ops = track_ops 119 | track_ops_dict = np.load(os.path.join(t2p_folder_path, "track2p", "track_ops.npy"), allow_pickle=True).item() 120 | self.central_widget.track_ops_dict=track_ops_dict 121 | self.main_window.status_bar.vector_curation_t2p = self.vector_curation_t2p 122 | 123 | self.main_window.status_bar.vector_curation_t2p = self.vector_curation_t2p 124 | 125 | self.main_window.status_bar.spin_box.setSuffix(f'/{len(self.t2p_match_mat_allday)-1}') 126 | self.main_window.status_bar.spin_box.setMinimum(0) 127 | self.main_window.status_bar.spin_box.setMaximum(len(self.t2p_match_mat_allday)-1) 128 | 129 | num_ones = {} 130 | 131 | for cell, line in enumerate(t2p_match_mat): 132 | all_iscell_value=[] 133 | for day,index_match in enumerate(line): 134 | if track_ops.iscell_thr is None: 135 | if index_match is None: 136 | all_iscell_value.append(0) 137 | else: 138 | iscell=self.all_iscell[day] 139 | indices_lignes_1 = np.where(iscell[:,0]==1)[0] # take the indices where the ROIs were considered as cells in suite2p 140 | true_index=indices_lignes_1[index_match] # take the "true index" 141 | iscell_value=iscell[true_index,0] 142 | all_iscell_value.append(iscell_value) 143 | else: 144 | if index_match is None: 145 | all_iscell_value.append(0) 146 | else: 147 | iscell=self.all_iscell[day] 148 | indices_lignes_1= np.where(iscell[:,1]>track_ops.iscell_thr)[0] # take the indices where the ROIs have a probability greater than trackops.is_cell_thr 149 | true_index=indices_lignes_1[index_match] # take the "true index" 150 | iscell_value=iscell[true_index,0] 151 | all_iscell_value.append(iscell_value) 152 | num_ones[cell] = all_iscell_value.count(1) 153 | 154 | 155 | for day in range(len(self.t2p_match_mat_allday[1])): 156 | count_cells_day=0 157 | for value in num_ones.values(): 158 | if value == day: 159 | count_cells_day+=1 160 | print(f'Number of cells present over {day} day(s): {count_cells_day}') 161 | print(t2p_match_mat.shape) 162 | 163 | for i, line in enumerate(self.t2p_match_mat_allday): 164 | if self.vector_curation_t2p[i]==0: 165 | self.colors[i] =(0.78, 0.78, 0.78) 166 | print(f'ROI {i} has been considered as "not cell"') 167 | 168 | 169 | self.central_widget.create_mean_img(channel) 170 | self.central_widget.vector_curation_t2p = self.vector_curation_t2p 171 | self.central_widget.display_first_ROI(0) 172 | 173 | 174 | 175 | def generate_vibrant_colors(self, num_colors): 176 | vibrant_colors = [] 177 | for _ in range(num_colors): 178 | l = np.random.uniform(0.55, 0.80) 179 | color = mcolors.hsv_to_rgb((random.random(), 1, l)) 180 | vibrant_colors.append(color) 181 | 182 | return vibrant_colors 183 | 184 | 185 | def F_processing(self,F, Fneu, fs, neucoeff=0.0, baseline='maximin', sig_baseline=10.0, win_baseline=60.0, prctile_baseline: float = 8): 186 | 187 | print(F.shape) 188 | print(Fneu.shape) 189 | #neuropil substraction 190 | Fc = F - neucoeff * Fneu 191 | 192 | # baseline operation 193 | win = int(win_baseline * fs) 194 | if baseline == "maximin": 195 | Flow = gaussian_filter(Fc, [0., sig_baseline]) 196 | Flow = minimum_filter1d(Flow, win) 197 | Flow = maximum_filter1d(Flow, win) 198 | elif baseline == "constant": 199 | Flow = gaussian_filter(Fc, [0., sig_baseline]) 200 | Flow = np.amin(Flow) 201 | elif baseline == "constant_prctile": 202 | Flow = np.percentile(Fc, prctile_baseline, axis=1) 203 | Flow = np.expand_dims(Flow, axis=1) 204 | else: 205 | Flow = 0. 206 | 207 | F = Fc - Flow 208 | 209 | print('DONE') 210 | 211 | return F 212 | -------------------------------------------------------------------------------- /track2p/gui/cell_plot.py: -------------------------------------------------------------------------------- 1 | import time 2 | from qtpy.QtCore import Signal 3 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import skimage 7 | 8 | class CellPlotWidget(FigureCanvas): 9 | '''This class is used to view and interact with the mean image of each recording (day)''' 10 | cell_selected = Signal(int) 11 | 12 | def __init__(self, tab=None, ops=None, stat_t2p=None, f_t2p=None, colors=None, update_selection_callback=None, 13 | all_f_t2p=None, all_stat_t2p=None, all_ops=None, initial_colors=None, channel=None): 14 | """It initializes the class attributes and connects certain events to their respective handlers. It also creates the figure and the axes to display the mean image of the recording.""" 15 | self.fig, self.ax_image = plt.subplots(1, 1) 16 | self.fig.set_facecolor('black') 17 | super().__init__(self.fig) 18 | self.ops = ops 19 | self.stat_t2p = stat_t2p 20 | self.f_t2p = f_t2p 21 | self.all_fluorescence = all_f_t2p 22 | self.all_stat_t2p=all_stat_t2p 23 | self.all_ops=all_ops 24 | self.colors = colors 25 | self.initial_colors=initial_colors 26 | self.channel = channel 27 | self.all_img, self.img= self.load_all_imgs() 28 | self.selected_cell_index = None 29 | self.mpl_connect('button_press_event', self.on_mouse_press) 30 | self.update_selection_callback = update_selection_callback 31 | self.nb_cells= len(self.stat_t2p) 32 | self.plot_cells() 33 | self.initialize_interactions() 34 | 35 | def load_all_imgs(self): 36 | print('loading all images') 37 | print('channel img:', self.channel) 38 | all_img = [] 39 | img= None 40 | if self.channel == 'max_proj': 41 | print('max_proj') 42 | img= self.ops['max_proj'] 43 | Ly = self.ops['Ly'] # number of pixels in y (512) 44 | Lx = self.ops['Lx'] # number of pixels in x (512) 45 | yr = self.ops['yrange'] # first and last pixel of crop in y 46 | xr = self.ops['xrange'] # first and last pixel of crop in x 47 | range_img = np.zeros((Ly, Lx)) 48 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img 49 | img= range_img 50 | for ops in self.all_ops: 51 | img_ops= ops['max_proj'] 52 | Ly = ops['Ly'] 53 | Lx = ops['Lx'] 54 | yr = ops['yrange'] 55 | xr = ops['xrange'] 56 | range_img = np.zeros((Ly, Lx)) 57 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img_ops 58 | img_ops= range_img 59 | all_img.append(img_ops) 60 | if self.channel == 'Vcorr': 61 | print('Vcorr') 62 | img= self.ops['max_proj'] 63 | Ly = self.ops['Ly'] # number of pixels in y (512) 64 | Lx = self.ops['Lx'] # number of pixels in x (512) 65 | yr = self.ops['yrange'] # first and last pixel of crop in y 66 | xr = self.ops['xrange'] # first and last pixel of crop in x 67 | range_img = np.zeros((Ly, Lx)) 68 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img 69 | img= range_img 70 | for ops in self.all_ops: 71 | img_ops= ops['max_proj'] 72 | Ly = ops['Ly'] 73 | Lx = ops['Lx'] 74 | yr = ops['yrange'] 75 | xr = ops['xrange'] 76 | range_img = np.zeros((Ly, Lx)) 77 | range_img[yr[0]:yr[1], xr[0]:xr[1]] = img_ops 78 | img_ops= range_img 79 | all_img.append(img_ops) 80 | if self.channel == '0': 81 | #print('0') 82 | all_img = [ops['meanImg'] for ops in self.all_ops] 83 | img= self.ops['meanImg'] 84 | if self.channel == '1': 85 | #print('1') 86 | all_img = [ops['meanImg_chan2'] for ops in self.all_ops] 87 | img= self.ops['meanImg_chan2'] 88 | 89 | return all_img, img 90 | 91 | def plot_cells(self): 92 | """It plots the mean image of the recording and the contours of the cells. It also sets the axis to be invisible and the title of the plot. It uses the colors attribute to color the contours of the cells. 93 | the match_histograms function of the skimage library is used to match the histograms of the mean image of the recording and the last mean image of the recordings. . This is done to make the mean images of the recordings more comparable.""" 94 | self.ax_image.clear() 95 | start = time.time() 96 | match_mean_img=skimage.exposure.match_histograms(self.img,self.all_img[-1], channel_axis=None) 97 | self.ax_image.imshow(match_mean_img, cmap='gray') 98 | cell_count = 0 99 | for cell in range(self.nb_cells): 100 | # print(cell) 101 | bin_mask = np.zeros_like(self.img) #create a binary mask with the same shape as the mean image of the recording 102 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1 103 | color_cell=self.colors[cell] 104 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1) 105 | cell_count += 1 106 | self.ax_image.axis('off') 107 | print(f'time for plotting cells on mean image for recording : {time.time()-start}') 108 | print(f'Total cells plotted: {cell_count}') 109 | self.draw() 110 | 111 | def plot_cells_remix(self,keys): 112 | self.ax_image.clear() 113 | start = time.time() 114 | match_mean_img=skimage.exposure.match_histograms(self.img,self.all_img[-1], channel_axis=None) 115 | self.ax_image.imshow(match_mean_img, cmap='gray') 116 | cell_count = 0 117 | for cell in range(self.nb_cells): 118 | if cell in keys: 119 | continue 120 | bin_mask = np.zeros_like(self.img) #create a binary mask with the same shape as the mean image of the recording 121 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1 122 | color_cell=self.colors[cell] 123 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1) 124 | cell_count += 1 125 | self.ax_image.axis('off') 126 | print(f'time for plotting cells on mean image for recording : {time.time()-start}') 127 | print(f'Total cells plotted: {cell_count}') 128 | self.draw() 129 | 130 | 131 | 132 | def underline_cell_remix(self,colors): 133 | 134 | for cell in range(self.nb_cells): 135 | bin_mask = np.zeros_like(self.img) 136 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1 137 | color_cell=colors[cell] 138 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=1) 139 | self.draw() 140 | 141 | 142 | 143 | def underline_cell(self,selected_cell_index): 144 | """It underlines the selected cell by increasing the linewidth of the contour of the cell""" 145 | for cell in range(self.nb_cells): 146 | if cell == selected_cell_index: 147 | bin_mask = np.zeros_like(self.img) 148 | bin_mask[self.stat_t2p[cell]['ypix'], self.stat_t2p[cell]['xpix']] = 1 149 | color_cell=self.colors[cell] 150 | self.ax_image.contour(bin_mask, levels=[0.5], colors=[color_cell], linewidths=3) 151 | self.draw() 152 | 153 | def remove_previous_underline(self): 154 | """It removes the underline of the previously selected cell by decreasing the linewidth of the contour of the cell.""" 155 | for collection in self.ax_image.collections: 156 | collection.set_linewidth(1) 157 | # important, don't forget it ! (update the plot) 158 | 159 | def initialize_interactions(self): 160 | """This method is used to initialize user interactions with the mean image. It connects the scroll event to the on_scroll method and records the initial xlim and ylim of the mean image.""" 161 | self.cid_scroll = self.fig.canvas.mpl_connect('scroll_event', self.on_scroll) 162 | self.initial_xlim = self.ax_image.get_xlim() 163 | self.initial_ylim = self.ax_image.get_ylim() 164 | 165 | def on_scroll(self, event): 166 | """it allows to zoom in and out of the mean image of the recording. It is called when the mouse wheel is scrolled. It uses the base_scale to zoom in and out of the mean image of the recording. """ 167 | if event.inaxes == self.ax_image: 168 | current_xlim = self.ax_image.get_xlim() 169 | current_ylim = self.ax_image.get_ylim() 170 | base_scale = 0.9 171 | scale_factor = base_scale if event.button == 'up' else 1/base_scale 172 | # Get the coordinates of the mouse cursor in data coordinates 173 | x_data, y_data = event.xdata, event.ydata 174 | # Compute the new limits centered on the mouse cursor 175 | new_xlim = [x_data - (x_data - x) * scale_factor for x in current_xlim] 176 | new_ylim = [y_data - (y_data - y) * scale_factor for y in current_ylim] 177 | # Ensure the zoom doesn't exceed the initial image boundaries 178 | new_xlim = [max(self.initial_xlim[0], min(self.initial_xlim[1], x)) for x in new_xlim] 179 | new_ylim = [max(self.initial_ylim[1], min(self.initial_ylim[0], y)) for y in new_ylim] 180 | 181 | self.ax_image.set_xlim(new_xlim) 182 | self.ax_image.set_ylim(new_ylim) 183 | self.fig.canvas.draw_idle() 184 | 185 | def on_mouse_press(self, event): 186 | """It allows to select a cell by clicking on it. It is called when the mouse is clicked. It uses the x and y coordinates of the mouse cursor to determine if a cell is clicked. If a cell is clicked, the selected_cell_index attribute is updated and the cell_selected signal is emitted. 187 | The update_selection method of the MainWindow class is called and used to update the fluorescence and zoom plots with the selected cell. """ 188 | start = time.time() 189 | if event.inaxes == self.ax_image: 190 | x, y = event.xdata, event.ydata 191 | for j, cell_info in enumerate(self.stat_t2p): 192 | ypix = cell_info['ypix'] #ypix are the y coordinates of the pixels of the cell 193 | xpix = cell_info['xpix'] #xpix are the x coordinates of the pixels of the cell 194 | if np.any((xpix == int(x)) & (ypix == int(y))): 195 | self.selected_cell_index = j 196 | self.update_selection_callback(j) 197 | print(f"Cell selected: {j}", flush=True) 198 | break 199 | print(f'time taken for updating: {time.time()-start}') 200 | 201 | 202 | -------------------------------------------------------------------------------- /track2p/gui/t2p_wd.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QHBoxLayout, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QListWidget, QMessageBox,QListWidgetItem, QInputDialog,QCheckBox,QSizePolicy,QComboBox,QDialog 4 | from PyQt5.QtCore import Qt 5 | from track2p.t2p import run_t2p 6 | from track2p.ops.default import DefaultTrackOps 7 | from track2p.gui.custom_wd import CustomDialog 8 | class Track2pWindow(QWidget): 9 | """it is used to set the parameters of the track2p algorithm""" 10 | def __init__(self, main_wd): 11 | super(Track2pWindow,self).__init__() 12 | self.main_window = main_wd 13 | layout = QFormLayout() 14 | self.setLayout(layout) 15 | self.track_ops = DefaultTrackOps() 16 | self.saved_directory=None 17 | self.plane=None 18 | 19 | 20 | instruction1=QLabel("Import the directory containing subfolders for each session of a given subject:") 21 | self.import_recording_button = QPushButton("Import", self) 22 | self.import_recording_button.clicked.connect(self.import_path_to_recordings) 23 | layout.addRow(instruction1,self.import_recording_button) 24 | 25 | instruction2= QLabel("Imported path:") 26 | instruction2.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 27 | self.path_recording=QLabel() 28 | self.path_recording.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 29 | layout.addRow(instruction2,self.path_recording) 30 | 31 | 32 | instruction3=QLabel("Once loaded press '->' to add to the list of paths to use for track2p (in the correct order):") 33 | 34 | # TODO: Here add da driopdown menu to select the input format (e.g. suite2p, raw npy, etc.) 35 | instruction_format=QLabel("Input format:") 36 | self.format=QComboBox() 37 | self.format.addItem("suite2p") 38 | self.format.addItem("npy") 39 | self.format.setCurrentIndex(0) 40 | self.format.currentIndexChanged.connect(self.display_suite2p_options) 41 | 42 | 43 | instruction4= QLabel("Method for selecting suite2p ROIs:") 44 | field_checkbox= QVBoxLayout() 45 | self.checkbox1 = QCheckBox('manually curated', self) 46 | self.checkbox2 = QCheckBox('iscell threshold', self) 47 | self.checkbox2.stateChanged.connect(self.display_iscell) 48 | field_checkbox.addWidget(self.checkbox1) 49 | field_checkbox.addWidget(self.checkbox2) 50 | 51 | 52 | self.is_cell_thr= QLineEdit() 53 | self.is_cell_thr.setVisible(False) 54 | self.is_cell_thr.setText('0.5') 55 | self.is_cell_thr.setFixedWidth(50) 56 | 57 | 58 | file_layout = QHBoxLayout() 59 | 60 | self.computer_file_list = QListWidget(self) 61 | self.computer_file_list.setFixedHeight(200) 62 | self.move_to_computer_list = QPushButton("<-", self) 63 | self.move_to_computer_list.clicked.connect(self.move_file_to_computer_list) 64 | self.move_to_paths_list = QPushButton("->", self) 65 | self.move_to_paths_list.clicked.connect(self.move_file_to_paths_list) 66 | self.paths_list=QListWidget(self) 67 | self.paths_list.setFixedHeight(200) 68 | 69 | file_layout.addWidget(self.computer_file_list) 70 | file_layout.addWidget(self.move_to_computer_list) 71 | file_layout.addWidget(self.move_to_paths_list) 72 | file_layout.addWidget(self.paths_list) 73 | layout.addRow(instruction3, file_layout) 74 | layout.addRow(instruction_format,self.format) 75 | layout.addRow(instruction4,field_checkbox) 76 | layout.addRow("suite2p iscell threshold:",self.is_cell_thr) 77 | 78 | 79 | instruction5=QLabel("Channel to use for registration (0 : functional, 1 : anatomical (if available))") 80 | self.reg_chan= QLineEdit() 81 | self.reg_chan.setFixedWidth(50) 82 | self.reg_chan.setText('0') 83 | layout.addRow(instruction5,self.reg_chan) 84 | 85 | trsfrm_type=QLabel("Choose the type of transformation to use for registration:") 86 | self.trsfrm_type=QComboBox() 87 | self.trsfrm_type.addItem("affine") 88 | self.trsfrm_type.addItem("rigid") 89 | self.trsfrm_type.setCurrentIndex(0) 90 | layout.addRow(trsfrm_type,self.trsfrm_type) 91 | 92 | # compute_iou=QLabel("iou_dist_thr:") 93 | # self.compute_iou= QLineEdit() 94 | # self.compute_iou.setFixedWidth(50) 95 | # self.compute_iou.setText('16') 96 | # layout.addRow(compute_iou,self.compute_iou) 97 | 98 | thr_method=QLabel("Thresholding method for filtering IoU histogram:") 99 | self.thr_method=QComboBox() 100 | self.thr_method.addItem("min") 101 | self.thr_method.addItem("otsu") 102 | self.thr_method.setCurrentIndex(1) 103 | layout.addRow(thr_method,self.thr_method) 104 | 105 | instruction6=QLabel("Import the directory where outputs will be saved (a 'track2p' sub-folder will be created):") 106 | self.t2p_path_button = QPushButton("Import", self) 107 | self.t2p_path_button.clicked.connect(self.save_directoy) 108 | layout.addRow(instruction6,self.t2p_path_button) 109 | 110 | instruction7= QLabel("Path to access the output 'track2p' folder:") 111 | instruction7.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 112 | self.save_path=QLabel() 113 | self.save_path.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 114 | layout.addRow(instruction7,self.save_path) 115 | 116 | instruction8= QLabel("Save the outputs in suite2p format (containing cells tracked on all days):") 117 | self.checkbox3 = QCheckBox(self) 118 | layout.addRow(instruction8,self.checkbox3) 119 | 120 | 121 | self.run_button = QPushButton("Run", self) 122 | self.run_button.clicked.connect(self.run) 123 | layout.addRow("Run the algorithm:", self.run_button) 124 | 125 | terminal_intruction=QLabel("To monitor progress see outputs in the terminal where the GUI was launched from.") 126 | layout.addRow(terminal_intruction) 127 | 128 | 129 | def display_iscell(self,state): 130 | if state == Qt.Checked: 131 | self.is_cell_thr.setVisible(True) 132 | else: 133 | self.is_cell_thr.setVisible(False) 134 | 135 | def display_suite2p_options(self): 136 | if self.format.currentText() == "suite2p": 137 | self.checkbox1.setVisible(True) 138 | self.checkbox2.setVisible(True) 139 | self.reg_chan.setVisible(True) 140 | self.is_cell_thr.setVisible(self.checkbox2.isChecked()) 141 | else: 142 | self.checkbox1.setVisible(False) 143 | self.checkbox2.setVisible(False) 144 | self.reg_chan.setVisible(False) 145 | self.is_cell_thr.setVisible(False) 146 | 147 | def run(self): 148 | 149 | stored_all_ds_path = [] 150 | for i in range(self.paths_list.count()): 151 | item=self.paths_list.item(i).data(Qt.UserRole) 152 | item_universel=item.replace("\\", "/") 153 | stored_all_ds_path.append(item_universel) 154 | self.track_ops.all_ds_path= stored_all_ds_path 155 | save_path=self.saved_directory 156 | save_path=save_path.replace("\\", "/") 157 | self.track_ops.save_path = save_path 158 | self.track_ops.input_format = self.format.currentText() 159 | self.track_ops.reg_chan=int(self.reg_chan.text()) 160 | self.track_ops.transform_type=self.trsfrm_type.currentText() 161 | # self.track_ops.iou_dist_thr=int(self.compute_iou.text()) 162 | self.track_ops.thr_method=self.thr_method.currentText() 163 | print("transformation type:", self.track_ops.transform_type) 164 | print("iou_dist_thr:", self.track_ops.iou_dist_thr) 165 | print("thr_method:", self.track_ops.thr_method) 166 | if self.checkbox1.isChecked(): 167 | self.track_ops.iscell_thr=None 168 | if self.checkbox2.isChecked(): 169 | self.track_ops.iscell_thr=float(self.is_cell_thr.text()) 170 | if self.checkbox3.isChecked(): 171 | self.track_ops.save_in_s2p_format=True 172 | print("All parameters have been recorded ! The track2p algorithm is running...") 173 | run_t2p(self.track_ops) 174 | self.open_track2p_in_gui() 175 | 176 | 177 | def open_track2p_in_gui(self): 178 | reply = QMessageBox.question(self, "", "Run completed successfully!\nDo you want to launch the gui?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No) 179 | 180 | if reply == QMessageBox.Yes: 181 | #print("Opening GUI...") 182 | self.dialog = CustomDialog(self.main_window, self.saved_directory, self.reg_chan.text()) 183 | self.dialog.exec_() 184 | if reply == QMessageBox.No: 185 | pass 186 | 187 | def save_directoy(self): 188 | saved_directory= QFileDialog.getExistingDirectory(self, "Select Directory") 189 | if saved_directory: 190 | self.saved_directory=saved_directory 191 | self.save_path.setText(f'{self.saved_directory}') 192 | 193 | def import_path_to_recordings(self): 194 | directory = QFileDialog.getExistingDirectory(self, "Select Directory") 195 | self.saved_directory=directory 196 | self.save_path.setText(f'{self.saved_directory}') 197 | 198 | if directory: 199 | self.path_recording.setText(f'{directory}') 200 | self.computer_file_list.clear() 201 | files= os.listdir(directory) 202 | for file in sorted(files): 203 | full_path = os.path.join(directory, file) 204 | item=QListWidgetItem(file) 205 | item.setData(Qt.UserRole, full_path) 206 | self.computer_file_list.addItem(item) 207 | 208 | 209 | def move_file_to_paths_list(self): 210 | selected_items = self.computer_file_list.selectedItems() 211 | for item in selected_items: 212 | self.computer_file_list.takeItem(self.computer_file_list.row(item)) 213 | self.paths_list.addItem(item) 214 | 215 | 216 | def move_file_to_computer_list(self): 217 | selected_items = self.paths_list.selectedItems() 218 | for item in selected_items: 219 | self.paths_list.takeItem(self.paths_list.row(item)) 220 | self.computer_file_list.addItem(item) 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /track2p/t2p.py: -------------------------------------------------------------------------------- 1 | from track2p.ops.default import DefaultTrackOps 2 | from types import SimpleNamespace 3 | 4 | from track2p.io.s2p_loaders import load_all_imgs, check_nplanes, load_all_ds_stat_iscell, load_all_ds_mean_img, load_all_ds_centroids 5 | from track2p.io.savers import npy_to_s2p, save_track_ops, save_all_pl_match_mat 6 | 7 | from track2p.register.loop import run_reg_loop, reg_all_ds_all_roi 8 | from track2p.register.utils import get_all_ds_img_for_reg, get_all_ref_nonref_inters 9 | 10 | from track2p.plot.progress import plot_all_planes 11 | from track2p.plot.output import plot_reg_img_output, plot_thr_met_hist, plot_n_matched_roi, plot_roi_reg_output, plot_roi_match_multiplane, plot_allroi_match_multiplane 12 | 13 | from track2p.match.loop import get_all_ds_assign, get_all_pl_match_mat 14 | import numpy as np 15 | import os 16 | import scipy as spicy 17 | import pandas as pd 18 | 19 | 20 | def run_t2p(track_ops): 21 | 22 | # 1) initialise save paths for figures and matched neurons output 23 | track_ops.init_save_paths() 24 | 25 | # 2) Load data 26 | check_nplanes(track_ops) 27 | 28 | if track_ops.input_format == 'npy': 29 | print('Converting npy data to track2p-compatible format...') 30 | npy_to_s2p(track_ops) 31 | 32 | all_ds_avg_ch1, all_ds_avg_ch2 = load_all_imgs(track_ops) 33 | 34 | # 3) Plot available planes for registration 35 | plot_all_planes(all_ds_avg_ch1, track_ops) 36 | if track_ops.nchannels==2: 37 | plot_all_planes(all_ds_avg_ch2, track_ops, ch='anatomical') 38 | 39 | # 4) do the actual registration based on chosen channel 40 | all_ds_ref_img, all_ds_mov_img = get_all_ds_img_for_reg(all_ds_avg_ch1, all_ds_avg_ch2, track_ops) 41 | 42 | all_ds_mov_img_reg, all_ds_reg_params = run_reg_loop(all_ds_ref_img, all_ds_mov_img, track_ops) 43 | 44 | plot_reg_img_output(track_ops) 45 | 46 | 47 | # 5) apply computed transorm to all ROIs 48 | all_ds_all_roi_ref, all_ds_all_roi_mov, all_ds_all_roi_reg, all_ds_roi_counter = reg_all_ds_all_roi(all_ds_reg_params, track_ops) 49 | 50 | 51 | # 6) optional: generate 'yellow intersection' plot (this is only needed for plotting below) 52 | all_ds_ref_reg_inters = get_all_ref_nonref_inters(all_ds_all_roi_ref, all_ds_all_roi_reg, track_ops) 53 | 54 | all_ds_ref_mov_inters = get_all_ref_nonref_inters(all_ds_all_roi_ref, all_ds_all_roi_mov, track_ops) 55 | 56 | 57 | track_ops.all_ds_ref_mov_inters = all_ds_ref_mov_inters 58 | track_ops.all_ds_ref_reg_inters = all_ds_ref_reg_inters 59 | 60 | if track_ops.show_roi_reg_output: 61 | plot_roi_reg_output(track_ops) 62 | 63 | 64 | # 7) get optimal assignments for all pairs of recordings (first to last) 65 | all_ds_assign, all_ds_assign_thr, all_ds_thr_met, all_ds_thr = get_all_ds_assign(track_ops, all_ds_all_roi_ref, all_ds_all_roi_reg) 66 | plot_thr_met_hist(all_ds_thr_met, all_ds_thr, track_ops) 67 | plot_n_matched_roi(all_ds_thr_met, all_ds_thr, track_ops) 68 | 69 | 70 | # 8) get match matrices for all pairs of recordings (first to last) 71 | all_pl_match_mat = get_all_pl_match_mat(all_ds_all_roi_ref, all_ds_assign_thr, track_ops) 72 | 73 | 74 | # 9) save results 75 | save_track_ops(track_ops) 76 | 77 | 78 | save_all_pl_match_mat(all_pl_match_mat, track_ops) 79 | 80 | print('Generating suite2p indices') 81 | generate_suite2p_indices(track_ops) 82 | 83 | 84 | 85 | # 10) save in suite2p format 86 | if track_ops.save_in_s2p_format: 87 | print('Saving in suite2p format...') 88 | save_in_s2p_format(track_ops) 89 | 90 | # 11) plot results 91 | print('Finished with algorithm!\n\nGenerating plots (this can take some time)...\n\n') 92 | all_ds_stat_iscell = load_all_ds_stat_iscell(track_ops) 93 | all_ds_centroids = load_all_ds_centroids(all_ds_stat_iscell, track_ops) 94 | all_ds_mean_img = load_all_ds_mean_img(track_ops) 95 | if track_ops.nchannels==2: 96 | all_ds_mean_img_ch2 = load_all_ds_mean_img(track_ops, ch=2) 97 | 98 | 99 | plot_roi_match_multiplane(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, track_ops, win_size=track_ops.win_size) 100 | plot_allroi_match_multiplane(all_ds_mean_img, all_pl_match_mat, track_ops) 101 | if track_ops.nchannels==2: 102 | plot_roi_match_multiplane(all_ds_mean_img_ch2, all_ds_centroids, all_pl_match_mat, track_ops, win_size=track_ops.win_size, ch=2) 103 | plot_allroi_match_multiplane(all_ds_mean_img_ch2, all_pl_match_mat, track_ops, ch=2) 104 | 105 | 106 | 107 | print('\n\n\nDone!\n\n\n') 108 | 109 | 110 | 111 | def generate_suite2p_indices (track_ops): 112 | 113 | 114 | for plane in range(track_ops.nplanes): 115 | 116 | t2p_match_mat = np.load(os.path.join(track_ops.save_path, f"plane{plane}_match_mat.npy"), allow_pickle=True) 117 | 118 | all_iscell = [] 119 | 120 | for ds_path in track_ops.all_ds_path: 121 | iscell = np.load(os.path.normpath(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy')), allow_pickle=True) 122 | all_iscell.append(iscell) 123 | 124 | true_indices = [] 125 | 126 | for line in t2p_match_mat: 127 | indexes = [] 128 | for day, index_match in enumerate(line): 129 | if index_match is None: 130 | indexes.append(None) 131 | else: 132 | iscell = all_iscell[day] 133 | if track_ops.iscell_thr is None: 134 | indices_lignes_1 = np.where(iscell[:, 0] == 1)[0] # ROIs considérés comme cellules 135 | else: 136 | indices_lignes_1 = np.where(iscell[:, 1] > track_ops.iscell_thr)[0] # ROIs avec proba > threshold 137 | 138 | true_index = indices_lignes_1[index_match] 139 | indexes.append(true_index) 140 | 141 | true_indices.append(indexes) 142 | 143 | true_indices = np.array([[int(x) if x is not None else None for x in row] for row in true_indices]) 144 | true_indices_nan= np.array([[float(x) if x is not None else np.nan for x in row] for row in true_indices]) 145 | 146 | np.save(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.npy"), true_indices) 147 | np.save(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices_nan.npy"), true_indices_nan) 148 | spicy.io.savemat(os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.mat"), {'data': true_indices_nan}) 149 | 150 | column_names = [os.path.basename(ds_path) for ds_path in track_ops.all_ds_path] 151 | csv_path = os.path.join(track_ops.save_path, f"plane{plane}_suite2p_indices.csv") 152 | df=pd.DataFrame(true_indices, columns=column_names) 153 | df.to_csv(csv_path, index=False, sep=';', na_rep='NaN') 154 | 155 | print(true_indices.dtype) 156 | print(true_indices_nan.dtype) 157 | 158 | 159 | 160 | 161 | 162 | 163 | def save_in_s2p_format(track_ops): 164 | 165 | for ds_path in track_ops.all_ds_path: 166 | # check how many subfolders starting with plane* in suite2p folder 167 | n_planes = len([name for name in os.listdir(ds_path + '/suite2p') if name.startswith('plane')]) 168 | print(f'Found {n_planes} planes in {ds_path}') 169 | 170 | folderpath=track_ops.save_path 171 | track_ops_dict = np.load(os.path.join(folderpath, "track_ops.npy"), allow_pickle=True).item() 172 | track_ops = SimpleNamespace(**track_ops_dict) 173 | iscell_thr = track_ops.iscell_thr 174 | 175 | 176 | for j in range(track_ops.nplanes): 177 | 178 | all_f_t2p= [] 179 | all_ops = [] 180 | all_stat_t2p = [] 181 | all_iscell_t2p = [] 182 | fneu_iscell_t2p= [] 183 | spks_iscell_t2p= [] 184 | 185 | fneu_chan2_iscell_t2p = [] 186 | f_chan2_iscell_t2p = [] 187 | redcell_iscell_t2p = [] 188 | 189 | 190 | print(f'nplanes{j}') 191 | t2p_match_mat = np.load(os.path.join(folderpath,f'plane{str(j)}_match_mat.npy'), allow_pickle=True) 192 | t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :] 193 | for (i, ds_path) in enumerate(track_ops.all_ds_path): 194 | ops = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'ops.npy'), allow_pickle=True).item() 195 | stat = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'stat.npy'), allow_pickle=True) 196 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'F.npy'), allow_pickle=True) 197 | fneu= np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'Fneu.npy'), allow_pickle=True) 198 | spks= np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'spks.npy'), allow_pickle=True) 199 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'iscell.npy'), allow_pickle=True) 200 | if track_ops.nchannels==2: 201 | f_chan2=np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'F_chan2.npy'), allow_pickle=True) 202 | fneu_chan2 = np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'Fneu_chan2.npy'), allow_pickle=True) 203 | redcell=np.load(os.path.join(ds_path, 'suite2p', f'plane{str(j)}', 'redcell.npy'), allow_pickle=True) 204 | 205 | if track_ops.iscell_thr==None: 206 | stat_iscell = stat[iscell[:, 0] == 1] 207 | f_iscell = f[iscell[:, 0] == 1, :] 208 | fneu_iscell = fneu[iscell[:, 0] == 1, :] 209 | spks_iscell = spks[iscell[:, 0] == 1, :] 210 | is_cell = iscell[iscell[:, 0] == 1, :] 211 | if track_ops.nchannels==2: 212 | f_chan2_iscell = f_chan2[iscell[:, 0] == 1, :] 213 | fneu_chan2_iscell = fneu_chan2[iscell[:, 0] == 1, :] 214 | redcell_iscell = redcell[iscell[:, 0] == 1] 215 | else: 216 | stat_iscell = stat[iscell[:, 1] > iscell_thr] 217 | f_iscell = f[iscell[:, 1] > iscell_thr, :] 218 | fneu_iscell = fneu[iscell[:, 1] > iscell_thr, :] 219 | spks_iscell = spks[iscell[:, 1] > iscell_thr, :] 220 | is_cell = iscell[iscell[:, 1] > iscell_thr, :] 221 | if track_ops.nchannels==2: 222 | f_chan2_iscell = f_chan2[iscell[:, 1] > track_ops.iscell_thr, :] 223 | fneu_chan2_iscell = fneu_chan2[iscell[:, 1] > track_ops.iscell_thr, :] 224 | redcell_iscell = redcell[iscell[:, 1] > track_ops.iscell_thr] 225 | 226 | stat_t2p = stat_iscell[t2p_match_mat_allday[:, i].astype(int)] 227 | f_t2p = f_iscell[t2p_match_mat_allday[:, i].astype(int), :] 228 | fneu_t2p = fneu_iscell[t2p_match_mat_allday[:, i].astype(int), :] 229 | spks_t2p = spks_iscell[t2p_match_mat_allday[:, i].astype(int), :] 230 | iscell_t2p = is_cell[t2p_match_mat_allday[:, i].astype(int), :] 231 | if track_ops.nchannels==2: 232 | fneu_chan2_t2p = fneu_chan2_iscell[t2p_match_mat_allday[:, i].astype(int), :] 233 | f_chan2_t2p = f_chan2_iscell[t2p_match_mat_allday[:, i].astype(int), :] 234 | redcell_t2p = redcell_iscell[t2p_match_mat_allday[:, i].astype(int)] 235 | 236 | all_stat_t2p.append(stat_t2p) 237 | all_f_t2p.append(f_t2p) 238 | all_ops.append(ops) 239 | all_iscell_t2p.append(iscell_t2p) 240 | fneu_iscell_t2p.append(fneu_t2p) 241 | spks_iscell_t2p.append(spks_t2p) 242 | if track_ops.nchannels==2: 243 | fneu_chan2_iscell_t2p.append(fneu_chan2_t2p) 244 | f_chan2_iscell_t2p.append(f_chan2_t2p) 245 | redcell_iscell_t2p.append(redcell_t2p) 246 | 247 | 248 | output_folderpath=os.path.join(folderpath, 'matched_suite2p') 249 | last_elements = [os.path.basename(path) for path in track_ops.all_ds_path] 250 | print(last_elements) 251 | 252 | # Save each element of each list to a .npy file 253 | for i in range(len(track_ops.all_ds_path)): 254 | stat_t2p, f_t2p, ops, iscell_t2p, fneu_t2p, spks_t2p = all_stat_t2p[i], all_f_t2p[i], all_ops[i], all_iscell_t2p[i], fneu_iscell_t2p[i], spks_iscell_t2p[i] 255 | subfolder_path = os.path.join(output_folderpath, last_elements[i]) 256 | if not os.path.exists(subfolder_path): 257 | os.makedirs(subfolder_path) 258 | inner_folderpath = os.path.join(subfolder_path, 'suite2p') 259 | if not os.path.exists(inner_folderpath): 260 | os.makedirs(inner_folderpath) 261 | plane_folderpath = os.path.join(inner_folderpath, f'plane{j}') 262 | if not os.path.exists(plane_folderpath): 263 | os.makedirs(plane_folderpath) 264 | 265 | 266 | np.save(os.path.join(plane_folderpath,f"stat.npy"), stat_t2p) 267 | np.save(os.path.join(plane_folderpath, f"F.npy"), f_t2p) 268 | np.save(os.path.join(plane_folderpath, f"ops.npy"), ops) 269 | np.save(os.path.join(plane_folderpath, f"iscell.npy"), iscell_t2p) 270 | np.save(os.path.join(plane_folderpath, f"Fneu.npy"), fneu_t2p) 271 | np.save(os.path.join(plane_folderpath,f"spks.npy"), spks_t2p) 272 | if track_ops.nchannels==2: 273 | for i, (redcell_t2p, f_chan2_t2p, fneu_chan2_t2p) in enumerate(zip(redcell_iscell_t2p, f_chan2_iscell_t2p, fneu_chan2_iscell_t2p)): 274 | np.save(os.path.join(plane_folderpath, f"F_chan2.npy"), f_chan2_t2p) 275 | np.save(os.path.join(plane_folderpath, f"Fneu_chan2.npy"), fneu_chan2_t2p) 276 | np.save(os.path.join(plane_folderpath, f"redcell.npy"), redcell_t2p) -------------------------------------------------------------------------------- /notebooks/demo_t2p_ouputs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Analysing output of t2p\n", 8 | "This is a short demo explaining how to explore the output of track2p and use the matched neurons/traces for custom downstream analysis." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "The example here is for a 1 plane recording with simultaneous videography (given dataset is jm032)." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# imports\n", 25 | "import os\n", 26 | "from types import SimpleNamespace\n", 27 | "\n", 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "from scipy.stats import zscore\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "### Step by step guide (more detailed explanations below):\n", 39 | "\n", 40 | "Each point from this list matches one section of this notebook\n", 41 | "\n", 42 | "1) Load the output of track2p\n", 43 | "2) Find cells that are present in all recordings ('matched cells')\n", 44 | "3) Load the data from one example dataset and visualise it\n", 45 | "4) Load the activity of the matched cells\n", 46 | "5) Visualise the activity of matched cells" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "### 1) Load the output of track2p\n", 54 | "We will load the `.npy` files: `t2p_output_path/track2p/plane#_match_mat.npy` and `t2p_output_path/track2p/track_ops.npy`. These are the matrix of cell matches for all days and the settings respectively. For more info see the repo readme and documentation.\n", 55 | "\n", 56 | "Note: In this demo a single-plane recording is used, but it can be modified easily for multiplane compatility (just repeat the same procedure while looping through planes)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# this is the directory that contains a /track2p folder that is output by running the track2p algorithm\n", 66 | "t2p_save_path = '/Users/manonmantez/Desktop/el' # (change this based on your data)\n", 67 | "plane = 'plane0' # which plane to process (the example dataset is single-plane)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# np.load() the match matrix (plane0_match_mat.npy)\n", 77 | "t2p_match_mat = np.load(os.path.join(t2p_save_path, 'track2p', f'{plane}_match_mat.npy'), allow_pickle=True)\n", 78 | "\n", 79 | "# np.load() settings (this contains suite2p paths etc.) (track_ops.npy)\n", 80 | "track_ops_dict = np.load(os.path.join(t2p_save_path, 'track2p', 'track_ops.npy'), allow_pickle=True).item()\n", 81 | "track_ops = SimpleNamespace(**track_ops_dict) # create dummy object from the track_ops dictionary" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "### 2) Find cells that are present in all recordings ('matched cells')\n", 89 | "\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "Now from this matrix get the matches that are present on all days:\n", 97 | "\n", 98 | "- A matrix (`plane#_match_mat.npy`) containing the indices of matched neurons across the session for a given plane (`#` is the index of the plane). Since matching is done from first day to last, some neurons will not be sucessfully tracked after one or a few days. In this case the matrix contains `None` values. To get neurons tracked across all days only take the rows of the matrices containing no `None` values. \n", 99 | "\n", 100 | "Note: of course we can use cells that are not present on all days, but for now this is the intended use case for downstream analysis." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# get the rows that do not contain any Nones (if track2p doesnt find a match for a cell across two consecutive days it will append a None) -> cells with no Nones are cells matched across all days\n", 110 | "t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat==None, axis=1), :]\n", 111 | "\n", 112 | "print(f'Shape of match matrix for cells present on all days: {t2p_match_mat_allday.shape} (cells, days)')" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "### 3) Load the data from one example dataset and visualise it\n", 120 | "\n", 121 | "Note: The track_ops.npy ('settings file') contains all the paths to suite2p folders used when running track2p (see cell below)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "print('Datasets used for t2p:\\n')\n", 131 | "for ds_path in track_ops.all_ds_path:\n", 132 | " print(ds_path)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "Now just to test if the paths work we can try to look at data of one of the recordings (in the case below we use the last one). For this part it is important to know a bit about how the suite2p structures the outputs: https://suite2p.readthedocs.io/en/latest/outputs.html (the important things will be the `ops.npy`, `stat.npy`, `iscell.npy` and the `F.npy`). There are also separate tutorials and demos for this so we won't go into so much detail." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# lets take the last dataset\n", 149 | "last_ds_path = track_ops.all_ds_path[-1]\n", 150 | "print(f'We will look at the dataset saved at: {last_ds_path}')" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# load the three files\n", 160 | "last_ops = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'ops.npy'), allow_pickle=True).item()\n", 161 | "last_f = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'F.npy'), allow_pickle=True)\n", 162 | "iscell = np.load(os.path.join(last_ds_path, 'suite2p', plane, 'iscell.npy'), allow_pickle=True)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# we filter the traces based on suite2p's iscell probability (note: it is crucial to use the same probability as in the track2p settings to keep the correct indexing of matches)\n", 172 | "iscell_thr = track_ops.iscell_thr\n", 173 | "\n", 174 | "print(f'The iscell threshold used when running track2p was: {iscell_thr}')\n", 175 | "\n", 176 | "if track_ops.iscell_thr==None:\n", 177 | " last_f_iscell = last_f[iscell[:, 0] == 1, :]\n", 178 | "\n", 179 | "else:\n", 180 | " last_f_iscell = last_f[iscell[:, 1] > iscell_thr, :]" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# now first plot the mean image of the movie (it is saved in ops.npy, for more info see the suite2p outputs documentation)\n", 190 | "plt.imshow(last_ops['meanImg'], cmap='gray')\n", 191 | "plt.axis('off')\n", 192 | "plt.title('Mean image')\n", 193 | "plt.show()\n", 194 | "\n", 195 | "plt.figure(figsize=(10, 1))\n", 196 | "nonmatch_nrn_idx = 0\n", 197 | "plt.plot(last_f[nonmatch_nrn_idx, :])\n", 198 | "plt.xlabel('Frame')\n", 199 | "plt.ylabel('F')\n", 200 | "plt.title(f'Example trace (nrn_idx: {nonmatch_nrn_idx})')\n", 201 | "plt.show()\n", 202 | "\n", 203 | "plt.figure(figsize=(10, 3))\n", 204 | "plt.imshow(zscore(last_f_iscell, axis=1), aspect='auto', cmap='Greys', vmin=0, vmax=1.96)\n", 205 | "plt.xlabel('Frame')\n", 206 | "plt.ylabel('ROI')\n", 207 | "plt.title('Raster plot')\n", 208 | "plt.show()\n" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "## 4) Load the activity of the matched cells\n", 216 | "\n", 217 | "Now that we know how to look at data in one recording we will use the output from track2p to look at activity of the same cells across all datasets." 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "To do this we need to loop through all datasets and:\n", 225 | "- load the files described above\n", 226 | "- filter `stat.npy` and `fluo.npy` by the track2p iscell threshold (classical suite2p)\n", 227 | "- filter `stat.npy` and `fluo.npy` by the appropriate indices from the matrix of neurons matched on all days (additional filtering step after track2p)\n", 228 | "\n", 229 | "This will produce a nice data structure where the indices of cells are matched within the stat and fluo objects. Sorting the object in this way allows for very straightforward extraction of matched data (see cells below)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "iscell_thr = track_ops.iscell_thr # use the same threshold as when running the algo (to be consistent with indexing)\n", 239 | "\n", 240 | "all_stat_t2p = []\n", 241 | "all_f_t2p = []\n", 242 | "all_ops = [] # ops dont change\n", 243 | "\n", 244 | "for (i, ds_path) in enumerate(track_ops.all_ds_path):\n", 245 | " ops = np.load(os.path.join(ds_path, 'suite2p', plane, 'ops.npy'), allow_pickle=True).item()\n", 246 | " stat = np.load(os.path.join(ds_path, 'suite2p', plane, 'stat.npy'), allow_pickle=True)\n", 247 | " f = np.load(os.path.join(ds_path, 'suite2p', plane, 'F.npy'), allow_pickle=True)\n", 248 | " iscell = np.load(os.path.join(ds_path, 'suite2p', plane, 'iscell.npy'), allow_pickle=True)\n", 249 | " \n", 250 | " \n", 251 | " if track_ops.iscell_thr==None:\n", 252 | " stat_iscell = stat[iscell[:, 0] == 1]\n", 253 | " f_iscell = f[iscell[:, 0] == 1, :]\n", 254 | "\n", 255 | " else:\n", 256 | " stat_iscell = stat[iscell[:, 1] > iscell_thr]\n", 257 | " f_iscell = f[iscell[:, 1] > iscell_thr, :]\n", 258 | " \n", 259 | " \n", 260 | " stat_t2p = stat_iscell[t2p_match_mat_allday[:,i].astype(int)]\n", 261 | " f_t2p = f_iscell[t2p_match_mat_allday[:,i].astype(int), :]\n", 262 | "\n", 263 | " all_stat_t2p.append(stat_t2p)\n", 264 | " all_f_t2p.append(f_t2p)\n", 265 | " all_ops.append(ops)\n", 266 | "\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "### 5) Visualise the ROIs and the activity of (a) matched cell(s)\n", 274 | "\n", 275 | "\n", 276 | "This example shows how to extract the information of a ROI from all_stat. We first index by the day to get stat_t2p from all_stat2p (this is the sorted stat object for that day). We can then get the roi information by indexing stat_t2p by the index of the cell match (because of resorting we use the same index across days)." 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "wind = 24\n", 286 | "nrn_idx = 0\n", 287 | "\n", 288 | "for i in range(len(track_ops.all_ds_path)):\n", 289 | " mean_img = all_ops[i]['meanImg']\n", 290 | " stat_t2p = all_stat_t2p[i]\n", 291 | " median_coord = stat_t2p[nrn_idx]['med']\n", 292 | "\n", 293 | " plt.figure(figsize=(1.5,1.5))\n", 294 | " plt.imshow(mean_img[int(median_coord[0])-wind:int(median_coord[0])+wind, int(median_coord[1])-wind:int(median_coord[1])+wind], cmap='gray') # plot a short window around the ROI centroid\n", 295 | " plt.scatter(wind, wind)\n", 296 | " plt.axis('off')\n", 297 | " plt.show()" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "# first plot the trace of cell c for all days\n", 307 | "nrn_idx = 0 # the activity of the ROI visualised above on all days\n", 308 | "\n", 309 | "for i in range(len(track_ops.all_ds_path)):\n", 310 | " plt.figure(figsize=(10, 1)) # make a wide figure\n", 311 | " plt.plot(all_f_t2p[i][nrn_idx, :])\n", 312 | " plt.xlabel('Frame')\n", 313 | " plt.ylabel('F')\n", 314 | " plt.show()\n" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "Now to visualise the rasters its a simple exercise, since they are already sorted in a way that the rows represent the same cell across days we don't need to do anything other than simply looping through all_f_t2p and plotting each element as we did before." 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "for i in range(len(track_ops.all_ds_path)):\n", 331 | " plt.figure(figsize=(10, 3)) # make a wide figure\n", 332 | " f_plot = zscore(all_f_t2p[i], axis=1)\n", 333 | " plt.imshow(f_plot, aspect='auto', cmap='Greys', vmin=0, vmax=1.96)\n", 334 | " plt.xlabel('Frame') " 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": {}, 340 | "source": [ 341 | "### The End!\n", 342 | "\n", 343 | "Congrats! Hopefully this notebook was a clear and useful way of showing how to interact with the track2p outputs.\n", 344 | "\n", 345 | "From here on custom analysis pipelines can very easily be applied (for example looking at stability of assemblies, representational drift etc etc). \n", 346 | "\n", 347 | "The most straightforward way of doing this is to just run an already implemented pipeline on the data loaded as shown here. Alternatively the loaded match indices can be used to look at already-processed data as a way of post-hoc matching.\n", 348 | "\n", 349 | "Thanks and have fun with analysis :)" 350 | ] 351 | } 352 | ], 353 | "metadata": { 354 | "kernelspec": { 355 | "display_name": "track2p", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.9.19" 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 2 374 | } 375 | -------------------------------------------------------------------------------- /track2p/plot/output.py: -------------------------------------------------------------------------------- 1 | # contains code for generating plots based on track_ops object after the pipeline is run (this will be saved as an npy file) 2 | # here all functions just take the track_ops object as input 3 | import os 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from skimage.exposure import match_histograms 9 | from track2p.plot.utils import make_rgb_img, saturate_perc, get_all_wind_mean_img 10 | from track2p.io.loaders import load_stat_ds_plane, get_all_roi_array_from_stat 11 | 12 | def plot_reg_img_output(track_ops): 13 | # make a plot where on the top its all the images and the bottom is the overlays before and after registration 14 | nplanes = track_ops.nplanes 15 | n_row = nplanes + 2*nplanes # number of plays + 2 overlays per plane 16 | n_col = len(track_ops.all_ds_path) # number of datasets 17 | figsize = (10/3 * n_col, 10/3 * n_row) 18 | fig, axs = plt.subplots(n_row, n_col, figsize=figsize, dpi=300) 19 | 20 | # first populate first (n_planes) rows with images 21 | for i in range(nplanes): 22 | axs[i, 0].set_ylabel(f'plane{i}\nchan{track_ops.reg_chan}', rotation=0, size='large', ha='right', va='center', labelpad=20) 23 | for j in range(len(track_ops.all_ds_path)): 24 | img = track_ops.all_ds_avg_ch1[j][i] if track_ops.reg_chan==0 else track_ops.all_ds_avg_ch2[j][i] 25 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, 99)) 26 | 27 | 28 | for i in range(len(track_ops.all_ds_path)-1): # last one won't have overlay 29 | for j in range(nplanes): 30 | # get subplot indices 31 | row_nonreg = nplanes + 2*j # first shift for number of initial rows, then shift by 2 for each plane (before and after registration) 32 | row_reg = nplanes + 2*j + 1 # the row after the nonreg 33 | 34 | # get images 35 | ref_img = track_ops.all_ds_ref_img[i][j] # get ref image for this pair 36 | mov_img = track_ops.all_ds_mov_img[i][j] # get mov image for this pair 37 | mov_img_ref = track_ops.all_ds_mov_img_reg[i][j] # get mov image after registration for this pair 38 | 39 | # match histograms to reference 40 | mov_img = match_histograms(mov_img, ref_img) 41 | mov_img_ref = match_histograms(mov_img_ref, ref_img) 42 | 43 | # assemble and saturate the overlays 44 | img_rgb = make_rgb_img(ref_img, mov_img) 45 | img_rgb_reg = make_rgb_img(ref_img, mov_img_ref) 46 | img_rgb = saturate_perc(img_rgb, sat_perc=track_ops.sat_perc) 47 | img_rgb_reg = saturate_perc(img_rgb_reg, sat_perc=track_ops.sat_perc) 48 | 49 | # plot the overlays 50 | axs[row_nonreg, i].imshow(img_rgb) 51 | axs[row_reg, i].imshow(img_rgb_reg) 52 | 53 | if i == 0: 54 | for k in range(nplanes): 55 | axs[nplanes + 2*k,i].set_ylabel(f'plane{k}\nchan{track_ops.reg_chan}\nnon-reg', rotation=0, size='large', ha='right', va='center', labelpad=20) 56 | axs[nplanes + 2*k+1,i].set_ylabel(f'plane{k}\nchan{track_ops.reg_chan}\nreg', rotation=0, size='large', ha='right', va='center', labelpad=20) 57 | 58 | # loop through all subplots and remove ticks and labels and spines 59 | 60 | for ax in axs.flat: 61 | ax.set(xticks=[], yticks=[]) 62 | ax.spines['top'].set_visible(False) 63 | ax.spines['right'].set_visible(False) 64 | ax.spines['left'].set_visible(False) 65 | ax.spines['bottom'].set_visible(False) 66 | 67 | # add arrows and dashed lines etc. 68 | 69 | axs[2*nplanes, 0].annotate('', xy=(0, 1.1), xytext=(n_col + 0.5, 1.1), 70 | xycoords='axes fraction', textcoords='axes fraction', 71 | arrowprops=dict(arrowstyle='-', linestyle='dashed', color='grey'), 72 | annotation_clip=False) 73 | 74 | for i in range(len(track_ops.all_ds_path)-1): 75 | axs[nplanes-1,i].annotate('', xy=(0.5, -0.17), xytext=(0.5, -0.02), 76 | xycoords='axes fraction', textcoords='axes fraction', 77 | arrowprops=dict(facecolor=(1,0,0), edgecolor=(1,0,0), shrink=0.05), 78 | annotation_clip=False) 79 | axs[nplanes-1,i+1].annotate('', xy=(-0.6, -0.17), xytext=(0.5, -0.02), 80 | xycoords='axes fraction', textcoords='axes fraction', 81 | arrowprops=dict(facecolor=(0,1,0), edgecolor=(0,1,0), shrink=0.05), 82 | annotation_clip=False) 83 | 84 | 85 | # save figure into the output path 86 | fig.savefig(os.path.join(track_ops.save_path_fig, 'reg_img_output.png'), bbox_inches='tight', dpi=200) 87 | plt.close(fig) 88 | 89 | def plot_roi_reg_output(track_ops): 90 | # make a plot where on the top its all the images and the bottom is the overlays before and after registration 91 | nplanes = track_ops.nplanes 92 | n_row = nplanes + 2*nplanes # number of plays + 2 overlays per plane 93 | n_col = len(track_ops.all_ds_path) # number of datasets 94 | figsize = (10/3 * n_col, 10/3 * n_row) 95 | # figsize = (10 * n_col, 10 * n_row) 96 | 97 | fig, axs = plt.subplots(n_row, n_col, figsize=figsize, dpi=300) 98 | 99 | # first populate first (n_planes) rows with images 100 | for i in range(nplanes): 101 | axs[i, 0].set_ylabel(f'plane{i}\nchan{track_ops.reg_chan}', rotation=0, size='large', ha='right', va='center', labelpad=20) 102 | for j in range(len(track_ops.all_ds_path)): 103 | img = track_ops.all_ds_avg_ch1[j][i] 104 | axs[i, j].imshow(img, cmap='gray', vmin=0, vmax=np.percentile(img, 99), alpha=0.5) 105 | # take the rois of reference unless it is last on, then take the rois of the mov 106 | all_roi = track_ops.all_ds_all_roi_array_ref[j][i] if j < len(track_ops.all_ds_path)-1 else track_ops.all_ds_all_roi_array_mov[j-1][i] 107 | for k in range(all_roi.shape[2]): 108 | axs[i, j].contour(all_roi[:,:,k], colors='C0', linewidths=0.3) 109 | 110 | # now populate the next (n_planes) rows with the overlays before registration 111 | for i in range(len(track_ops.all_ds_path)-1): 112 | print(f'Plotting contours for dataset {i}/{len(track_ops.all_ds_path)-1}') 113 | for j in range(track_ops.nplanes): 114 | 115 | row_nonreg = nplanes + 2*j # first shift for number of initial rows, then shift by 2 for each plane (before and after registration) 116 | row_reg = nplanes + 2*j + 1 # the row after the nonreg 117 | 118 | all_roi_array_ref = track_ops.all_ds_all_roi_array_ref[i][j] 119 | 120 | all_roi_array_mov = track_ops.all_ds_all_roi_array_mov[i][j] 121 | ref_mov_inters = track_ops.all_ds_ref_mov_inters[i][j] 122 | 123 | axs[row_nonreg,i].imshow(ref_mov_inters) 124 | for k in range(all_roi_array_ref.shape[2]): 125 | axs[row_nonreg,i].contour(all_roi_array_ref[:,:,k], colors='r', linewidths=0.3) 126 | for k in range(all_roi_array_mov.shape[2]): 127 | axs[row_nonreg,i].contour(all_roi_array_mov[:,:,k], colors='g', linewidths=0.3) 128 | 129 | # TODO: for reg row 130 | all_roi_array_reg = track_ops.all_ds_all_roi_array_reg[i][j] 131 | ref_reg_inters = track_ops.all_ds_ref_reg_inters[i][j] 132 | 133 | axs[row_reg,i].imshow(ref_reg_inters) 134 | for k in range(all_roi_array_ref.shape[2]): 135 | axs[row_reg,i].contour(all_roi_array_ref[:,:,k], colors='r', linewidths=0.3) 136 | for k in range(all_roi_array_reg.shape[2]): 137 | axs[row_reg,i].contour(all_roi_array_reg[:,:,k], colors='g', linewidths=0.3) 138 | 139 | if i == 0: 140 | for k in range(nplanes): 141 | axs[nplanes + 2*k,i].set_ylabel(f'plane{k}\nnon-reg', rotation=0, size='large', ha='right', va='center', labelpad=20) 142 | axs[nplanes + 2*k+1,i].set_ylabel(f'plane{k}\nreg', rotation=0, size='large', ha='right', va='center', labelpad=20) 143 | 144 | 145 | for ax in axs.flat: 146 | ax.set(xticks=[], yticks=[]) 147 | ax.spines['top'].set_visible(False) 148 | ax.spines['right'].set_visible(False) 149 | ax.spines['left'].set_visible(False) 150 | ax.spines['bottom'].set_visible(False) 151 | 152 | # add arrows and dashed lines etc. 153 | 154 | axs[2*nplanes, 0].annotate('', xy=(0, 1.1), xytext=(n_col + 0.5, 1.1), 155 | xycoords='axes fraction', textcoords='axes fraction', 156 | arrowprops=dict(arrowstyle='-', linestyle='dashed', color='grey'), 157 | annotation_clip=False) 158 | 159 | for i in range(len(track_ops.all_ds_path)-1): 160 | axs[nplanes-1,i].annotate('', xy=(0.5, -0.17), xytext=(0.5, -0.02), 161 | xycoords='axes fraction', textcoords='axes fraction', 162 | arrowprops=dict(facecolor='r', edgecolor='r', shrink=0.05), 163 | annotation_clip=False) 164 | axs[nplanes-1,i+1].annotate('', xy=(-0.6, -0.17), xytext=(0.5, -0.02), 165 | xycoords='axes fraction', textcoords='axes fraction', 166 | arrowprops=dict(facecolor='g', edgecolor='g', shrink=0.05), 167 | annotation_clip=False) 168 | 169 | 170 | # save figure into the output path 171 | fig.savefig(track_ops.save_path_fig + 'reg_roi_output.png', bbox_inches='tight', dpi=200) 172 | plt.close(fig) 173 | 174 | 175 | def plot_thr_met_hist(all_ds_thr_met, all_ds_thr, track_ops): 176 | fig, axs = plt.subplots(track_ops.nplanes, len(all_ds_thr_met), figsize=(6*len(all_ds_thr_met), 6*track_ops.nplanes), sharey=True, sharex=True) 177 | for i in range(len(all_ds_thr_met)): 178 | for j in range(track_ops.nplanes): 179 | axs = np.array([axs]) if type(axs) is not np.ndarray else axs 180 | 181 | this_ax = axs[i] if track_ops.nplanes==1 or len(track_ops.all_ds_path)==2 else axs[j][i] 182 | 183 | n_reg_roi = len(all_ds_thr_met[i][j]) 184 | n_abovethr_roi = np.sum(all_ds_thr_met[i][j]>all_ds_thr[i][j]) 185 | this_ax.hist(all_ds_thr_met[i][j], bins=20) 186 | this_ax.axvline(all_ds_thr[i][j], color='grey', linestyle='--') 187 | # label the line with 'otsu threshold' 188 | this_ax.text(all_ds_thr[i][j]+0.02, this_ax.get_ylim()[1]*0.9, f'{track_ops.thr_method} thr.: {all_ds_thr[i][j]:.2f}') 189 | this_ax.set_title(f'ds{i} (ref) to ds{i+1} (reg); matched {n_abovethr_roi}/{n_reg_roi} ({n_abovethr_roi/n_reg_roi*100:.1f}%)') 190 | 191 | if i==0: 192 | this_ax.set_ylabel(f'ROI count (plane{j})') 193 | 194 | # remove the top and right spines 195 | for a in axs.flatten(): 196 | a.spines['top'].set_visible(False) 197 | a.spines['right'].set_visible(False) 198 | 199 | 200 | plt.tight_layout() 201 | plt.savefig(os.path.join(track_ops.save_path_fig, 'thr_met_hist.png'), dpi=200) 202 | plt.close(fig) 203 | 204 | 205 | 206 | def plot_n_matched_roi(all_ds_thr_met, all_ds_thr, track_ops): 207 | 208 | fig, axs = plt.subplots(track_ops.nplanes, 2, figsize=(8, 4*track_ops.nplanes), sharey='col', sharex=True) 209 | 210 | for i in range(track_ops.nplanes): 211 | all_n_reg_roi = [len(all_ds_thr_met[j][i]) for j in range(len(all_ds_thr_met))] 212 | all_n_abovethr_roi = [np.sum(all_ds_thr_met[j][i]>all_ds_thr[j][i]) for j in range(len(all_ds_thr_met))] 213 | all_prop_abovethr = [n_abovethr_roi/n_reg_roi for n_abovethr_roi, n_reg_roi in zip(all_n_abovethr_roi, all_n_reg_roi)] 214 | 215 | ax0 = axs[i, 0] if track_ops.nplanes>1 else axs[0] 216 | 217 | ax0.bar(np.arange(len(all_n_reg_roi)), all_n_reg_roi, color='grey', label='total') 218 | ax0.bar(np.arange(len(all_n_abovethr_roi)), all_n_abovethr_roi, color='C0', label='above threshold') 219 | 220 | ax0.set_ylabel(f'ROI count (plane {i})') 221 | ax0.set_xticks(np.arange(len(all_n_reg_roi))) 222 | ax0.set_xticklabels([f'r{i}-{i+1}' for i in range(len(all_n_reg_roi))]) 223 | ax0.legend(loc='lower right') 224 | ax0.spines['top'].set_visible(False) 225 | ax0.spines['right'].set_visible(False) 226 | 227 | ax1 = axs[i, 1] if track_ops.nplanes>1 else axs[1] 228 | 229 | ax1.plot(np.arange(len(all_n_reg_roi)), all_prop_abovethr, color='C0', marker='o') 230 | ax1.set_ylabel(f'ROI prop. (plane {i})') 231 | ax1.set_xticks(np.arange(len(all_n_reg_roi))) 232 | ax1.set_xticklabels([f'r{i}-{i+1}' for i in range(len(all_n_reg_roi))]) 233 | ax1.set_ylim(0, 1) 234 | ax1.spines['top'].set_visible(False) 235 | ax1.spines['right'].set_visible(False) 236 | 237 | fig.savefig(os.path.join(track_ops.save_path_fig, 'n_prop_matched_roi.png'), dpi=200, bbox_inches='tight') 238 | plt.close(fig) 239 | 240 | def plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids, track_ops, plane_idx=0, win_size=64, k=0, n=None, ch=1): 241 | 242 | if n is None: 243 | n = len(neuron_ids) 244 | 245 | nrows = len(neuron_ids) 246 | ncols = len(all_ds_mean_img) 247 | fig, axs = plt.subplots(nrows, ncols, figsize=(2*ncols, 2*nrows), dpi=50) 248 | 249 | for (i, nrn_id) in enumerate(neuron_ids): 250 | # check if neuron is not matched (if ith row of all_pl_match_mat is all None) 251 | if any(all_pl_match_mat[plane_idx][nrn_id,:]==None): 252 | continue 253 | 254 | # get all wind_mean_img (small window around centroid) 255 | all_wind_mean_img = get_all_wind_mean_img(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, nrn_id, plane_idx=plane_idx, win_size=win_size) 256 | 257 | if i == 0: 258 | fig_ref = all_wind_mean_img[0] # reference image for matching histograms along whole image 259 | else: 260 | all_wind_mean_img[0] = match_histograms(all_wind_mean_img[0], fig_ref) # if not first ROI of whole image then match to first ROI of whole image 261 | 262 | for j in range(len(all_ds_mean_img)): 263 | wind_mean_img = all_wind_mean_img[j] 264 | ref_img = all_wind_mean_img[0] 265 | matched = match_histograms(wind_mean_img, ref_img) 266 | axs[i, j].imshow(matched, cmap='gray') 267 | # scatter middle pixel of the window 268 | axs[i, j].scatter(win_size/2, win_size/2, color='C0') 269 | if i==0: 270 | axs[i, j].set_title(f'ds {j} (pl {plane_idx})') 271 | if j==0: 272 | axs[i, j].set_ylabel(f'ROI {nrn_id} ({i+k}/{n})') 273 | 274 | # remove axes labels 275 | for ax in axs.flat: 276 | ax.set_xticks([]) 277 | ax.set_yticks([]) 278 | 279 | plt.tight_layout() 280 | if ch == 1: 281 | plt.savefig(os.path.join(track_ops.save_path_fig, f'roi_match_plane{plane_idx}_idx{k}-{k+len(neuron_ids)}.png'), dpi=50) 282 | elif ch == 2: 283 | plt.savefig(os.path.join(track_ops.save_path_fig, f'roi_match_chan2_plane{plane_idx}_idx{k}-{k+len(neuron_ids)}.png'), dpi=50) 284 | plt.close(fig) 285 | 286 | def plot_roi_match_multiplane(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, track_ops, win_size=48, ch=1): 287 | 288 | for i in range(track_ops.nplanes): 289 | pl_neuron_ids = np.arange(all_pl_match_mat[i].shape[0]) 290 | for j in range(len(track_ops.save_path)): 291 | neuron_ids = pl_neuron_ids[~np.any(all_pl_match_mat[i]==None, axis=1)] 292 | 293 | if len(neuron_ids) < 100: 294 | plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids, track_ops, plane_idx=i, win_size=win_size, ch=ch) 295 | else: 296 | # plot in batches of 100 297 | for k in range(0, len(neuron_ids), 100): 298 | plot_roi_match(all_ds_mean_img, all_ds_centroids, all_pl_match_mat, neuron_ids[k:k+100], track_ops, plane_idx=i, win_size=win_size, k=k, n=len(neuron_ids), ch=ch) 299 | 300 | 301 | def plot_allroi_match_multiplane(all_ds_mean_img, all_pl_match_mat, track_ops, ch=1): 302 | 303 | 304 | fig, axs = plt.subplots(track_ops.nplanes, len(track_ops.all_ds_path), figsize=(4*len(track_ops.all_ds_path), 4*track_ops.nplanes), dpi=300) 305 | 306 | for ds_idx, ds_path in enumerate(track_ops.all_ds_path): 307 | for plane_idx in range(track_ops.nplanes): 308 | # saturating the reference image 309 | ref_mean_img = all_ds_mean_img[0][plane_idx] 310 | perc = np.percentile(ref_mean_img, track_ops.sat_perc) 311 | ref_mean_img[ref_mean_img>perc] = perc 312 | 313 | ax = axs[ds_idx] if track_ops.nplanes==1 else axs[plane_idx, ds_idx] 314 | mean_img = all_ds_mean_img[ds_idx][plane_idx] 315 | mean_img_matched = match_histograms(mean_img, ref_mean_img) 316 | ax.imshow(mean_img_matched, cmap='gray') 317 | ax.set_title(f'ds{ds_idx} (plane{plane_idx})') 318 | 319 | # add contours 320 | for plane_idx in range(track_ops.nplanes): 321 | pl_neuron_ids = np.arange(all_pl_match_mat[plane_idx].shape[0]) 322 | neuron_ids = pl_neuron_ids[~np.any(all_pl_match_mat[plane_idx]==None, axis=1)] 323 | neuron_colors = [np.random.rand(3) for i in range(len(neuron_ids))] 324 | 325 | for ds_idx, _ in enumerate(track_ops.all_ds_path): 326 | 327 | stat_ds_plane, _ = load_stat_ds_plane(track_ops.all_ds_path[ds_idx], track_ops, plane_idx=plane_idx) 328 | all_roi_array = get_all_roi_array_from_stat(stat_ds_plane, track_ops) 329 | 330 | ax = axs[ds_idx] if track_ops.nplanes==1 else axs[plane_idx, ds_idx] 331 | 332 | for (i, neuron_idx) in enumerate(neuron_ids): 333 | idx_orig = all_pl_match_mat[plane_idx][neuron_idx, ds_idx] # neurons index in the original recording 334 | cont = stat_ds_plane 335 | 336 | cont_plot = ax.contour(all_roi_array[:,:,idx_orig], linewidths=0.5) 337 | 338 | for collection in cont_plot.collections: 339 | collection.set_edgecolor(neuron_colors[i]) # RGB value for red 340 | 341 | left_axs = axs[0] if track_ops.nplanes==1 else axs[plane_idx,0] 342 | left_axs.set_ylabel(f'plane{plane_idx} (n={len(neuron_ids)})') 343 | 344 | # remove axes elements 345 | for ax in axs.flat: 346 | ax.set_xticks([]) 347 | ax.set_yticks([]) 348 | ax.set_xticklabels([]) 349 | ax.set_yticklabels([]) 350 | 351 | plt.tight_layout() 352 | if ch == 1: 353 | plt.savefig(os.path.join(track_ops.save_path_fig, f'all_roi_match.png'), dpi=300) 354 | else: 355 | plt.savefig(os.path.join(track_ops.save_path_fig, f'all_roi_match_chan2.png'), dpi=300) 356 | plt.close(fig) 357 | -------------------------------------------------------------------------------- /track2p/gui/raster_wd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from PyQt5.QtWidgets import QVBoxLayout, QWidget, QPushButton, QFileDialog, QLineEdit, QLabel, QFormLayout, QCheckBox, QComboBox,QGraphicsView,QGraphicsScene,QSplitter,QGroupBox 4 | from PyQt5.QtCore import Qt 5 | from PyQt5.QtGui import QPixmap 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | from scipy.stats import zscore 10 | from sklearn.decomposition import PCA 11 | from openTSNE import TSNE 12 | import matplotlib.pyplot as plt 13 | import copy 14 | 15 | 16 | 17 | class RasterWindow(QWidget): 18 | #QWidget is the parent class 19 | def __init__(self, mainWindow ): 20 | super(RasterWindow,self).__init__() 21 | self.main_window = mainWindow 22 | self.raster_type=None 23 | self.bin_size = None 24 | self.filename=None 25 | self.all_f_t2p_preproc=None 26 | self.vmin_value=None 27 | self.vmax_value=None 28 | 29 | 30 | #Create the right-hand side of the window 31 | layout = QFormLayout() 32 | 33 | self.checkboxes=[] 34 | label_checkbox= QLabel("Choose the sorting method:") 35 | field_checkbox= QVBoxLayout() 36 | self.checkbox1 = QCheckBox('without sorting', self) 37 | self.checkbox2 = QCheckBox('sorting by PCA', self) 38 | self.checkbox3 = QCheckBox('sorting by PCA on given day', self) 39 | self.checkbox4 = QCheckBox('sorting by tSNE', self) 40 | self.checkbox5 = QCheckBox('sorting by tSNE on given day', self) 41 | self.checkboxes.append(self.checkbox1) 42 | self.checkboxes.append(self.checkbox2) 43 | self.checkboxes.append(self.checkbox3) 44 | self.checkboxes.append(self.checkbox4) 45 | self.checkboxes.append(self.checkbox5) 46 | 47 | self.checkbox3.stateChanged.connect(self.update_day_choice) 48 | self.checkbox5.stateChanged.connect(self.update_day_choice) 49 | 50 | self.day_choice= QComboBox(self) 51 | self.day_choice.addItem('Choose recording index (for sorting on given day)') 52 | field_checkbox.addWidget(self.checkbox1) 53 | field_checkbox.addWidget(self.checkbox2) 54 | field_checkbox.addWidget(self.checkbox3) 55 | field_checkbox.addWidget(self.checkbox4) 56 | field_checkbox.addWidget(self.checkbox5) 57 | field_checkbox.addWidget(self.day_choice) 58 | layout.addRow(label_checkbox,field_checkbox) 59 | 60 | for checkbox in self.checkboxes: 61 | checkbox.stateChanged.connect(self.handle_checkbox_state) 62 | 63 | self.combined_checkbox=QCheckBox('Combined (put neurons of all planes together) ', self) 64 | self.combined_checkbox.stateChanged.connect(self.concatenate_rasters) 65 | layout.addRow(" ",self.combined_checkbox) 66 | 67 | label_check=QLabel("Advanced options:") 68 | self.check=QCheckBox(self) 69 | self.check.stateChanged.connect(self.display_advanced_options) 70 | 71 | layout.addRow(label_check,self.check) 72 | self.advanced_options_group = QGroupBox() 73 | self.advanced_options_group.setVisible(False) 74 | group_layout = QFormLayout() 75 | self.bin= QLineEdit() 76 | self.bin.setText('1') 77 | self.bin.setFixedWidth(50) 78 | group_layout.addRow("Averaging bin size:", self.bin) 79 | self.vmin= QLineEdit() 80 | self.vmin.setText('0') 81 | self.vmin.setFixedWidth(50) 82 | group_layout.addRow("vmin:", self.vmin) 83 | self.vmax= QLineEdit() 84 | self.vmax.setText('1.96') 85 | self.vmax.setFixedWidth(50) 86 | group_layout.addRow("vmax:", self.vmax) 87 | self.advanced_options_group.setLayout(group_layout) 88 | layout.addRow('',self.advanced_options_group) 89 | 90 | 91 | # Initially hide the widgets 92 | self.bin.setVisible(False) 93 | self.vmin.setVisible(False) 94 | self.vmax.setVisible(False) 95 | 96 | label_run= QLabel("Run the analysis:") 97 | field_run=QPushButton("Run") 98 | field_run.clicked.connect(self.run) 99 | layout.addRow(label_run,field_run) 100 | 101 | label_save= QLabel("Save the figure:") 102 | field_save = QPushButton("Save") 103 | field_save.clicked.connect(self.get_output_save_path) 104 | layout.addRow(label_save,field_save) 105 | 106 | label_output_path= QLabel("The figure has been saved here:") 107 | self.field_output_path=QLabel() 108 | layout.addRow(label_output_path,self.field_output_path) 109 | 110 | self.view=QGraphicsView(self) #QGraphicsView is a subclass of QWidget 111 | 112 | splitter = QSplitter(Qt.Horizontal) 113 | right_widget = QWidget() 114 | right_widget.setLayout(layout) 115 | splitter.addWidget(right_widget) 116 | splitter.addWidget(self.view) 117 | 118 | main_layout = QVBoxLayout(self) 119 | main_layout.addWidget(splitter) 120 | 121 | self.setLayout(main_layout) 122 | 123 | ############################################################################################################################################################################ 124 | 125 | def update_day_choice(self): 126 | if self.day_choice.count() ==1: 127 | self.all_stat_t2p=self.main_window.central_widget.data_management.all_stat_t2p 128 | if self.checkbox3.isChecked() or self.checkbox5.isChecked(): 129 | print(self.day_choice.count()) 130 | self.day_choice.clear() 131 | for i in range(len(self.all_stat_t2p)): 132 | self.day_choice.addItem(str(i + 1)) 133 | print('ComboBox updated') 134 | 135 | def handle_checkbox_state(self): 136 | sender = self.sender() 137 | if sender.isChecked(): 138 | for checkbox in self.checkboxes: 139 | if checkbox != sender: 140 | checkbox.setChecked(False) 141 | 142 | def run(self): 143 | if self.combined_checkbox.isChecked(): 144 | self.all_f_t2p= self.concatenate_rasters() 145 | else: 146 | self.all_f_t2p=self.main_window.central_widget.data_management.all_f_t2p 147 | self.plane=self.main_window.central_widget.data_management.plane 148 | self.preprocessing() 149 | self.get_checkbox_choice() 150 | 151 | 152 | def display_advanced_options(self,state): 153 | if state == Qt.Checked: 154 | self.advanced_options_group.setVisible(True) 155 | self.bin.setVisible(True) 156 | self.vmin.setVisible(True) 157 | self.vmax.setVisible(True) 158 | else: 159 | self.advanced_options_group.setVisible(False) 160 | self.bin.setVisible(False) 161 | self.vmin.setVisible(False) 162 | self.vmax.setVisible(False) 163 | 164 | def fit_pca_1d(self,data): 165 | print('fitting 1d-PCA...') 166 | pca=PCA(n_components=1) 167 | pca.fit(data) 168 | embedding = pca.components_.T 169 | return embedding 170 | 171 | def fit_tsne_1d(self,data): 172 | print('fitting 1d-tSNE...') 173 | tsne = TSNE( 174 | n_components=1, 175 | perplexity=30, 176 | initialization="pca", 177 | metric="euclidean", 178 | n_jobs=8, 179 | random_state=3 180 | ) 181 | 182 | tsne_emb = tsne.fit(data.T) 183 | return tsne_emb 184 | 185 | 186 | def get_output_save_path(self): 187 | save_path, _ = QFileDialog.getSaveFileName(self, "Select Save Path") 188 | if save_path: 189 | self.save_path = save_path 190 | print(f'Selected save path: {self.save_path}') 191 | plt.savefig(self.save_path) 192 | 193 | def get_checkbox_choice(self): 194 | vmin=float(self.vmin.text()) 195 | vmax=float(self.vmax.text()) 196 | if self.checkbox1.isChecked(): 197 | self.raster_type='without_sorting' 198 | self.plot_track2p_rasters(self.all_f_t2p_preproc, bin_size=self.bin_size, vmin=vmin, vmax=vmax) 199 | if self.checkbox2.isChecked(): 200 | self.raster_type='sorting_by_PCA' 201 | all_pca_emb_1d = [] 202 | for f in tqdm(self.all_f_t2p_preproc): 203 | print(f.shape) 204 | pca_emb_1d = self.fit_pca_1d(f.T) 205 | print(pca_emb_1d.shape) 206 | all_pca_emb_1d.append(pca_emb_1d) 207 | all_f_t2p_sorted = [] 208 | for (i, f) in enumerate(self.all_f_t2p_preproc): 209 | sort_inds = np.argsort(np.array(all_pca_emb_1d[i]).squeeze()) 210 | f_sorted = f[sort_inds, :].squeeze() 211 | all_f_t2p_sorted.append(f_sorted) 212 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters 213 | if self.checkbox3.isChecked(): 214 | self.raster_type='sorting_by_PCA_and_by_day_' + str(self.day_choice.currentText()) 215 | all_pca_emb_1d = [] 216 | for f in tqdm(self.all_f_t2p_preproc): 217 | print(f.shape) 218 | pca_emb_1d = self.fit_pca_1d(f.T) 219 | print(pca_emb_1d.shape) 220 | all_pca_emb_1d.append(pca_emb_1d) 221 | all_f_t2p_sorted = [] 222 | for (i, f) in enumerate(self.all_f_t2p_preproc): 223 | user_i= int(self.day_choice.currentText()) - 1 224 | sort_inds = np.argsort(np.array(all_pca_emb_1d[user_i]).squeeze()) 225 | f_sorted = f[sort_inds, :].squeeze() 226 | all_f_t2p_sorted.append(f_sorted) 227 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) 228 | if self.checkbox4.isChecked(): 229 | self.raster_type='sorting_by_tSNE' 230 | all_tsne_emb_1d = [] 231 | for f in tqdm(self.all_f_t2p_preproc): 232 | print(f.shape) 233 | tsne_emb_1d = self.fit_tsne_1d(f.T) 234 | all_tsne_emb_1d.append(tsne_emb_1d) 235 | all_f_t2p_sorted = [] 236 | for (i, f) in enumerate(self.all_f_t2p_preproc): 237 | sort_inds = np.argsort(np.array(all_tsne_emb_1d[i]).squeeze()) 238 | f_sorted = f[sort_inds, :].squeeze() 239 | all_f_t2p_sorted.append(f_sorted) 240 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters 241 | if self.checkbox5.isChecked(): 242 | self.raster_type='sorting_tSNE_and_by_day_' + str(self.day_choice.currentText()) 243 | all_tsne_emb_1d = [] 244 | for f in tqdm(self.all_f_t2p_preproc): 245 | print(f.shape) 246 | tsne_emb_1d = self.fit_tsne_1d(f.T) 247 | all_tsne_emb_1d.append(tsne_emb_1d) 248 | all_f_t2p_sorted = [] 249 | for (i, f) in enumerate(self.all_f_t2p_preproc): 250 | user_i= int(self.day_choice.currentText()) - 1 251 | sort_inds = np.argsort(np.array(all_tsne_emb_1d[user_i]).squeeze()) 252 | f_sorted = f[sort_inds, :].squeeze() 253 | all_f_t2p_sorted.append(f_sorted) 254 | self.plot_track2p_rasters(all_f_t2p_sorted, bin_size=self.bin_size, vmin=vmin, vmax=vmax) # plot the rasters 255 | 256 | 257 | def preprocessing(self): 258 | bin_data = True 259 | self.bin_size = int(self.bin.text()) #number of frames to average (1 = no averging) 260 | print(f'bin_size: {self.bin_size}') 261 | rem_zero_rows = True 262 | if bin_data: 263 | all_f_t2p_original = copy.deepcopy(self.all_f_t2p) 264 | self.all_f_t2p_preproc = [np.mean(f.reshape(f.shape[0], -1, self.bin_size), axis=2) for f in all_f_t2p_original] 265 | # renormalize 266 | self.all_f_t2p_preproc =self.zscore_all_f_t2p(self.all_f_t2p_preproc) 267 | if rem_zero_rows: 268 | # get zero rows in any of the datasets 269 | zero_rows = np.any([np.sum(np.isnan(f), axis=1) for f in self.all_f_t2p_preproc], axis=0) 270 | print(f'Number of zero rows: {np.sum(zero_rows)}') 271 | self.all_f_t2p_preproc= [f[~zero_rows, :] for f in self.all_f_t2p_preproc] 272 | 273 | 274 | def zscore(self, f, axis=1): 275 | '''this method calculates the z-score for each element in the input array f along the specified axis, ignoring NaN values.''' 276 | return (f - np.nanmean(f, axis=axis, keepdims=True)) / np.nanstd(f, axis=axis, keepdims=True) 277 | 278 | def zscore_all_f_t2p(self, all_f_t2p): 279 | return [zscore(f, axis=1) for f in all_f_t2p] #axis=1 means that we are normalizing each neuron's activity (each line of the matrix f) 280 | 281 | 282 | def plot_track2p_rasters(self, all_f_t2p, bin_size=1, vmin=None, vmax=None): 283 | 284 | fig, ax = plt.subplots(len(all_f_t2p), 1, figsize=(6, len(all_f_t2p)*1), dpi=150) #rows, columns, size, resolution 285 | 286 | for i, f in enumerate(all_f_t2p): 287 | ax[i].imshow(zscore(f, axis=1), aspect='auto', cmap='Greys', vmin=vmin, vmax=vmax) 288 | # only first and last yticks 289 | ax[i].set_yticks([0, f.shape[0]-1]) 290 | ax[i].set_yticklabels([0, f.shape[0]]) 291 | # move the text of the yticklabels a bit: the first one down and the last one up 292 | ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0, va='center', ha='right') 293 | 294 | # only first and last xticks on y axis in minutes 295 | if i == len(all_f_t2p)-1: 296 | ax[i].set_xlabel('Frames') 297 | ax[i].set_xticks([0, f.shape[1]/2, f.shape[1]-1]) 298 | ax[i].set_xticklabels([0, int((f.shape[1]*bin_size)/2), int(f.shape[1]*bin_size)]) 299 | 300 | else: 301 | ax[i].set_xticks([]) 302 | ax[i].set_xticklabels([]) 303 | 304 | plt.subplots_adjust(bottom=0.2) 305 | buf = io.BytesIO() 306 | plt.savefig(buf, format='png') 307 | buf.seek(0) 308 | pixmap = QPixmap() 309 | pixmap.loadFromData(buf.getvalue()) 310 | scene=QGraphicsScene() 311 | scene.addPixmap(pixmap) 312 | self.view.setScene(scene) 313 | print('Done') 314 | 315 | 316 | 317 | 318 | def concatenate_rasters(self): 319 | 320 | track_ops=self.main_window.central_widget.data_management.track_ops 321 | t2p_folder_path= os.path.dirname(track_ops.all_ds_path[0]) 322 | print(t2p_folder_path) 323 | if track_ops.nplanes > 1: 324 | self.results_by_plane = {} 325 | print(track_ops.nplanes) 326 | for plane in range (track_ops.nplanes): 327 | print(plane) 328 | if plane == self.main_window.central_widget.data_management.plane: 329 | print('plane is equal to the current plane, skipping to the next plane') 330 | self.results_by_plane[plane]={ 331 | 'all_ft2p': self.main_window.central_widget.data_management.all_f_t2p, 332 | 'all_fneu2p': self.main_window.central_widget.data_management.all_fneu 333 | } 334 | continue 335 | print('plane is not equal to the current plane') 336 | t2p_match_mat = np.load(os.path.join(t2p_folder_path,"track2p" ,f"plane{plane}_match_mat.npy"), allow_pickle=True) 337 | t2p_match_mat_allday = t2p_match_mat[~np.any(t2p_match_mat == None, axis=1), :] 338 | trace_type=self.main_window.central_widget.data_management.trace_type #common to all planes 339 | print(f"Processing plane {plane}") 340 | self.process_plane(plane,track_ops,t2p_match_mat_allday,trace_type) 341 | for plane, data in self.results_by_plane .items(): 342 | print(f"Plane {plane}:") 343 | print(f" all_ft2p: {len(data['all_ft2p'])}") 344 | print(f" {len(data['all_ft2p'][0])}") 345 | 346 | # Initialiser une liste pour stocker les éléments concaténés 347 | concatenated_elements = [] 348 | num_elements = len(self.results_by_plane[0]['all_ft2p']) 349 | print(f"Number of elements: {num_elements}") 350 | # Itérer sur les indices des éléments 351 | for i in range(num_elements): 352 | elements_to_concatenate = [] 353 | for plane in range(track_ops.nplanes): 354 | if 'all_ft2p' in self.results_by_plane[plane] and isinstance(self.results_by_plane[plane]['all_ft2p'], list): 355 | # Récupérer l'élément i de 'all_ft2p' pour le plan actuel 356 | element = self.results_by_plane[plane]['all_ft2p'][i] 357 | elements_to_concatenate.append(element) 358 | else: 359 | print("La clé 'all_ft2p' n'existe pas ou n'est pas une liste.") 360 | if elements_to_concatenate: 361 | concatenated_element = np.vstack(elements_to_concatenate) 362 | concatenated_elements.append(concatenated_element) 363 | 364 | # Afficher les éléments concaténés 365 | for idx, concatenated_element in enumerate(concatenated_elements): 366 | print(f"Concatenated element {idx}:") 367 | #print(concatenated_element) 368 | print(concatenated_element.shape) 369 | 370 | return concatenated_elements 371 | 372 | 373 | 374 | def process_plane(self, plane, track_ops, t2p_match_mat_allday, trace_type): 375 | all_ft2p=[] 376 | all_fneu2p=None 377 | for (i, ds_path) in enumerate(track_ops.all_ds_path): 378 | iscell = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'iscell.npy'), allow_pickle=True) 379 | if trace_type == 'F' : 380 | print('F trace') 381 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True) 382 | if trace_type == 'spks': 383 | print('spks trace') 384 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'spks.npy'), allow_pickle=True) 385 | if trace_type == 'dF/F0': 386 | print('dF/F0 trace') 387 | if all_fneu2p is None: 388 | all_fneu2p= [] 389 | f = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'F.npy'), allow_pickle=True) 390 | fneu = np.load(os.path.join(ds_path, 'suite2p', f'plane{plane}', 'Fneu.npy'), allow_pickle=True) 391 | if track_ops.iscell_thr is None: 392 | fneu_iscell = fneu[iscell[:, 0] == 1, :] 393 | else: 394 | fneu_iscell = fneu[iscell[:, 1] > track_ops.iscell_thr, :] 395 | fneu_t2p= fneu_iscell[t2p_match_mat_allday[:, i].astype(int), :] 396 | all_fneu2p.append(fneu_t2p) 397 | if track_ops.iscell_thr is None: 398 | f_iscell = f[iscell[:, 0] == 1, :] 399 | else: 400 | f_iscell = f[iscell[:, 1] > track_ops.iscell_thr, :] 401 | 402 | f_t2p = f_iscell[t2p_match_mat_allday[:, i].astype(int), :] 403 | all_ft2p.append(f_t2p) 404 | 405 | self.results_by_plane[plane]={ 406 | 'all_ft2p': all_ft2p, 407 | 'all_fneu2p': all_fneu2p 408 | } 409 | --------------------------------------------------------------------------------