├── .gitmodules ├── Bonsai_tracking_script ├── HDtrackingcoverageRG.bonsai ├── HDtrackingcoverageRG_save_video.bonsai └── HDtrackingcoverageRG_save_video_hist.bonsai ├── ExampleData ├── position.pkl └── spatial_firing.pkl ├── Gerlei_et_at_2020_where_to_find_data_and_code.xlsx ├── LICENSE ├── OpenEphys.py ├── OverallAnalysis ├── analyze_field_correlations.py ├── analyze_hd_from_whole_session.py ├── analyze_speed.py ├── basic_info_included_data.py ├── compare_directional_firing_over_days.py ├── compare_field_detectors.py ├── compare_first_and_second_half_cell.py ├── compare_grid_and_conjunctive_fields.py ├── compare_hd_with_expected_hd.py ├── compare_sampling_of_mice_and_rats.py ├── compare_shape_of_fields.py ├── compare_shuffled_from_first_and_second_halves.py ├── compare_shuffled_from_first_and_second_halves_fields.py ├── describe_dataset.py ├── directional_cell_definition_example_figure.py ├── example_get_firing_rates_for_half_fields.py ├── example_hd_histograms_binning.py ├── false_positives.py ├── field_analysis_two_sample_watson.py ├── folder_path_settings.py ├── grid_analysis_other_labs │ ├── analyze_sargolini_data.py │ └── firing_maps.py ├── load_data_frames.py ├── model_prediction_schematics_simulated.py ├── open_field_firing_maps_processed_data.py ├── overall_params.py ├── pattern_of_directions_across_nodes.py ├── pattern_of_field_shapes.py ├── plot_example_classic_field_polar_plots.py ├── plot_example_firing_fields.py ├── plot_example_polar_hists_first_and_second_half.py ├── plot_hd_tuning_vs_shuffled.py ├── plot_hd_tuning_vs_shuffled_fields.py ├── shuffle_cell_analysis.py ├── shuffle_cell_analysis_heading.py ├── shuffle_field_analysis.py ├── shuffle_field_analysis_all_animals.py ├── shuffle_field_analysis_all_animals_heading.py ├── shuffle_field_analysis_heading.py ├── simulated_data │ └── analyze_simulated_grid_cells.py ├── tuning_bias_vs_speed.py └── tuning_bias_vs_trajectory_bias.py ├── PostSorting ├── compare_first_and_second_half.py ├── compare_rate_maps.py ├── curation.py ├── load_firing_data.py ├── load_snippet_data.py ├── make_plots.py ├── open_field_firing_fields.py ├── open_field_firing_maps.py ├── open_field_grid_cells.py ├── open_field_head_direction.py ├── open_field_heading_direction.py ├── open_field_make_plots.py ├── open_field_spatial_data.py ├── open_field_spatial_firing.py ├── open_field_sync_data.py ├── parameters.py ├── post_process_sorted_data.py ├── process_fields.r ├── speed.py └── temporal_firing.py ├── README.md ├── SimulationCode ├── Model_cell_characterization.py ├── Mods │ ├── hcn.mod │ ├── km.mod │ ├── kv.mod │ └── na.mod ├── defining_peak_current.py ├── grid_run.py ├── hocfile.hoc └── import_behavior.py ├── SimulationComparisons ├── 6cd.R ├── 7d.R ├── grid_cell_models │ ├── burgess │ │ └── main.m │ ├── giocomo │ │ └── main.m │ ├── guanella │ │ └── main.m │ └── pastoll │ │ ├── prepare_trajectory.py │ │ └── submit_simulation.py ├── results_gridsize.csv ├── results_gridspacing.csv ├── results_sampling.csv ├── results_sampling_correlation_cell.csv ├── results_sampling_shuffled_cell.csv ├── results_simulations_long.csv ├── s14.R └── s15.R ├── array_utility.py ├── data_frame_utility.py ├── example_spatial_analysis.py ├── file_utility.py ├── math_utility.py ├── mdaio.py ├── open_ephys_IO.py ├── plot_utility.py └── tests └── unit ├── PostSorting ├── test_load_firing_data.py ├── test_open_field_head_direction.py ├── test_open_field_heading_direction.py ├── test_open_field_light_data.py └── test_post_process_sorted_data.py ├── test_array_utility.py └── test_control_sorting_analysis.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SimulationComparisons/grid_cell_models/pastoll/model"] 2 | path = SimulationComparisons/grid_cell_models/pastoll/model 3 | url = git@github.com:ModelDBRepository/150031.git 4 | -------------------------------------------------------------------------------- /ExampleData/position.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattNolanLab/grid_cell_analysis/3ffbeebaf6a7fd7a0980ce7aee1a555e090c730e/ExampleData/position.pkl -------------------------------------------------------------------------------- /ExampleData/spatial_firing.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattNolanLab/grid_cell_analysis/3ffbeebaf6a7fd7a0980ce7aee1a555e090c730e/ExampleData/spatial_firing.pkl -------------------------------------------------------------------------------- /Gerlei_et_at_2020_where_to_find_data_and_code.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattNolanLab/grid_cell_analysis/3ffbeebaf6a7fd7a0980ce7aee1a555e090c730e/Gerlei_et_at_2020_where_to_find_data_and_code.xlsx -------------------------------------------------------------------------------- /OverallAnalysis/analyze_field_correlations.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pylab as plt 3 | import plot_utility 4 | 5 | path = '' 6 | 7 | 8 | def plot_correlation_coef_hist(correlation_coefs, save_path, y_axis_label='Number of fields'): 9 | fig = plt.figure() 10 | ax = fig.add_subplot(1, 1, 1) # specify (nrows, ncols, axnum) 11 | fig, ax = plot_utility.style_plot(ax) 12 | ax.hist(correlation_coefs, color='navy') 13 | ax.xaxis.set_tick_params(labelsize=20) 14 | ax.yaxis.set_tick_params(labelsize=20) 15 | plt.xlabel('Correlation coefficient', fontsize=30) 16 | plt.ylabel(y_axis_label, fontsize=30) 17 | plt.xlim(-1, 1) 18 | plt.axvline(x=0, color='red', linewidth=5) 19 | plt.gcf().subplots_adjust(bottom=0.15) 20 | plt.savefig(save_path) 21 | plt.close() 22 | 23 | plt.cla() 24 | fig = plt.figure() 25 | ax = fig.add_subplot(1, 1, 1) # specify (nrows, ncols, axnum) 26 | fig, ax = plot_utility.style_plot(ax) 27 | # ax.hist(correlation_coefs, color='navy') 28 | plot_utility.plot_cumulative_histogram(correlation_coefs, ax) 29 | ax.xaxis.set_tick_params(labelsize=20) 30 | ax.yaxis.set_tick_params(labelsize=20) 31 | plt.xlabel('Correlation coefficient', fontsize=25) 32 | plt.ylabel(y_axis_label, fontsize=25) 33 | plt.xlim(-1, 1) 34 | plt.axvline(x=0, color='red', linewidth=5) 35 | plt.gcf().subplots_adjust(bottom=0.15) 36 | plt.savefig(save_path + 'cumulative.png') 37 | plt.close() 38 | 39 | 40 | fields = pd.read_excel(path) 41 | print(fields.head()) 42 | significant = (fields.p_value < 0.001) 43 | correlation_coefs = fields[significant]['correlation coef'].values 44 | save_path = path + 'correlation_coef_hist.png' 45 | plot_correlation_coef_hist(correlation_coefs, save_path) 46 | 47 | 48 | grid_cells = fields['cell type'] == 'grid' 49 | hd_cells = fields['cell type'] == 'hd' 50 | conjunctive_cells = fields['cell type'] == 'conjunctive' 51 | not_classified = fields['cell type'] == 'na' 52 | fields[grid_cells & significant]['correlation coef'].std() 53 | 54 | grid_coeffs = fields[grid_cells & significant]['correlation coef'].values 55 | save_path = path + 'correlation_coef_hist_grid.png' 56 | plot_correlation_coef_hist(grid_coeffs, save_path) 57 | 58 | grid_coeffs = fields[hd_cells & significant]['correlation coef'].values 59 | save_path = path + 'correlation_coef_hist_hd.png' 60 | plot_correlation_coef_hist(grid_coeffs, save_path) 61 | 62 | grid_coeffs = fields[not_classified & significant]['correlation coef'].values 63 | save_path = path + 'correlation_coef_hist_nc.png' 64 | plot_correlation_coef_hist(grid_coeffs, save_path) 65 | 66 | grid_coeffs = fields[conjunctive_cells & significant]['correlation coef'].values 67 | save_path = path + 'correlation_coef_hist_conj.png' 68 | plot_correlation_coef_hist(grid_coeffs, save_path) 69 | 70 | 71 | def main(): 72 | pass 73 | 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /OverallAnalysis/basic_info_included_data.py: -------------------------------------------------------------------------------- 1 | import data_frame_utility 2 | import numpy as np 3 | import os 4 | import OverallAnalysis.folder_path_settings 5 | import OverallAnalysis.shuffle_field_analysis 6 | import OverallAnalysis.compare_shuffled_from_first_and_second_halves_fields 7 | import OverallAnalysis.false_positives 8 | import pandas as pd 9 | import PostSorting.parameters 10 | 11 | import scipy 12 | 13 | 14 | local_path = OverallAnalysis.folder_path_settings.get_local_path() 15 | analysis_path = local_path + '/basic_info_included_data/' 16 | 17 | prm = PostSorting.parameters.Parameters() 18 | prm.set_pixel_ratio(440) 19 | prm.set_sampling_rate(30000) 20 | 21 | 22 | def add_cell_types_to_data_frame(spatial_firing): 23 | cell_type = [] 24 | for index, cell in spatial_firing.iterrows(): 25 | if cell.hd_score >= 0.5 and cell.grid_score >= 0.4: 26 | cell_type.append('conjunctive') 27 | elif cell.hd_score >= 0.5: 28 | cell_type.append('hd') 29 | elif cell.grid_score >= 0.4: 30 | cell_type.append('grid') 31 | else: 32 | cell_type.append('na') 33 | 34 | spatial_firing['cell type'] = cell_type 35 | 36 | return spatial_firing 37 | 38 | 39 | def add_combined_id_to_df(spatial_firing): 40 | animal_ids = [session_id.split('_')[0] for session_id in spatial_firing.session_id.values] 41 | spatial_firing['animal'] = animal_ids 42 | 43 | dates = [session_id.split('_')[1] for session_id in spatial_firing.session_id.values] 44 | 45 | cluster = spatial_firing.cluster_id.values 46 | combined_ids = [] 47 | for cell in range(len(spatial_firing)): 48 | id = animal_ids[cell] + '-' + dates[cell] + '-Cluster-' + str(cluster[cell]) 49 | combined_ids.append(id) 50 | spatial_firing['false_positive_id'] = combined_ids 51 | return spatial_firing 52 | 53 | 54 | def tag_false_positives(spatial_firing): 55 | list_of_false_positives = OverallAnalysis.false_positives.get_list_of_false_positives(analysis_path + 'false_positives_all.txt') 56 | spatial_firing = add_combined_id_to_df(spatial_firing) 57 | spatial_firing['false_positive'] = spatial_firing['false_positive_id'].isin(list_of_false_positives) 58 | return spatial_firing 59 | 60 | 61 | def get_time_spent_and_num_spikes(df_grid): 62 | print('Avg length of recording:') 63 | print((df_grid.number_of_spikes / df_grid.mean_firing_rate / 60).mean()) 64 | print('sd') 65 | print((df_grid.number_of_spikes / df_grid.mean_firing_rate / 60).std()) 66 | 67 | print('Avg number of spikes:') 68 | print(df_grid.number_of_spikes.mean()) 69 | print('sd') 70 | print(df_grid.number_of_spikes.std()) 71 | 72 | 73 | def print_basic_info(df, animal): 74 | if animal == 'mouse': 75 | df = tag_false_positives(df) 76 | else: 77 | df['false_positive'] = False 78 | 79 | good_cells = df.false_positive == False 80 | df_good_cells = df[good_cells] 81 | df = add_cell_types_to_data_frame(df_good_cells) 82 | grid_cells = df['cell type'] == 'grid' 83 | df_grid = df[grid_cells] 84 | conj_cells = df['cell type'] == 'conjunctive' 85 | df_conj = df[conj_cells] 86 | print('Number of grid cells:') 87 | print(len(df_grid)) 88 | print('Number of conjunctive cells:') 89 | print(len(df_conj)) 90 | 91 | get_time_spent_and_num_spikes(df_grid) 92 | 93 | 94 | animals_with_grid_cells = df_grid.animal.unique() 95 | animals_with_conj_cells = df_conj.animal.unique() 96 | 97 | animals_with_grid_or_conj = np.unique(np.concatenate((animals_with_grid_cells, animals_with_conj_cells), axis=0)) 98 | print('Number of animals:') 99 | print(len(animals_with_grid_or_conj)) 100 | 101 | included_cells = df_good_cells[df_good_cells.animal.isin(animals_with_grid_or_conj)] 102 | print('Number of included cells:') 103 | print(len(included_cells)) 104 | 105 | print('Number_of_included sessions:') 106 | print(len(included_cells.session_id.unique())) 107 | 108 | print('Number of recording days per animal:') 109 | print(included_cells.groupby('animal').session_id.nunique()) 110 | print('mean') 111 | print(included_cells.groupby('animal').session_id.nunique().mean()) 112 | print('std') 113 | print(included_cells.groupby('animal').session_id.nunique().std()) 114 | print('min') 115 | print(included_cells.groupby('animal').session_id.nunique().min()) 116 | print('max') 117 | print(included_cells.groupby('animal').session_id.nunique().max()) 118 | 119 | 120 | def main(): 121 | mouse_df_path = analysis_path + 'all_mice_df.pkl' 122 | mouse_df = pd.read_pickle(mouse_df_path) 123 | rat_df_path = analysis_path + 'all_rats_df.pkl' 124 | rat_df = pd.read_pickle(rat_df_path) 125 | 126 | print('mouse') 127 | print_basic_info(mouse_df, 'mouse') 128 | print('rat') 129 | print_basic_info(rat_df, 'rat') 130 | 131 | 132 | if __name__ == '__main__': 133 | main() -------------------------------------------------------------------------------- /OverallAnalysis/compare_directional_firing_over_days.py: -------------------------------------------------------------------------------- 1 | import data_frame_utility 2 | import os 3 | import OverallAnalysis.folder_path_settings 4 | import OverallAnalysis.shuffle_field_analysis 5 | import pandas as pd 6 | import PostSorting.parameters 7 | 8 | 9 | local_path = OverallAnalysis.folder_path_settings.get_local_path() 10 | analysis_path = local_path + '/compare_directional_firing_over_days/' 11 | 12 | prm = PostSorting.parameters.Parameters() 13 | prm.set_pixel_ratio(440) 14 | 15 | 16 | def get_shuffled_field_data(spatial_firing, position_data, shuffle_type='distributive', sampling_rate_video=50): 17 | field_df = data_frame_utility.get_field_data_frame(spatial_firing, position_data) 18 | field_df = OverallAnalysis.shuffle_field_analysis.add_rate_map_values_to_field_df_session(spatial_firing, field_df) 19 | field_df = OverallAnalysis.shuffle_field_analysis.shuffle_field_data(field_df, analysis_path, number_of_bins=20, 20 | number_of_times_to_shuffle=1000, shuffle_type=shuffle_type) 21 | field_df = OverallAnalysis.shuffle_field_analysis.analyze_shuffled_data(field_df, analysis_path, sampling_rate_video, 22 | number_of_bins=20, shuffle_type=shuffle_type) 23 | return field_df 24 | 25 | 26 | def process_data(): 27 | # load shuffled field data 28 | if os.path.exists(analysis_path + 'DataFrames_1/fields.pkl'): 29 | shuffled_fields_1 = pd.read_pickle(analysis_path + 'DataFrames_1/fields.pkl') 30 | else: 31 | day1_firing = pd.read_pickle(analysis_path + 'DataFrames_1/spatial_firing.pkl') 32 | day1_position = pd.read_pickle(analysis_path + 'DataFrames_1/position.pkl') 33 | shuffled_fields_1 = get_shuffled_field_data(day1_firing, day1_position) 34 | shuffled_fields_1.to_pickle(analysis_path + 'DataFrames_1/fields.pkl') 35 | 36 | if os.path.exists(analysis_path + 'DataFrames_2/fields.pkl'): 37 | shuffled_fields_2 = pd.read_pickle(analysis_path + 'DataFrames_2/fields.pkl') 38 | else: 39 | day2_firing = pd.read_pickle(analysis_path + 'DataFrames_2/spatial_firing.pkl') 40 | day2_position = pd.read_pickle(analysis_path + 'DataFrames_2/position.pkl') 41 | # shuffle field analysis 42 | shuffled_fields_2 = get_shuffled_field_data(day2_firing, day2_position) 43 | shuffled_fields_2.to_pickle(analysis_path + 'DataFrames_2/fields.pkl') 44 | print('I shuffled data from both days.') 45 | 46 | day_1_field_1 = shuffled_fields_1[(shuffled_fields_1.cluster_id == 27) & (shuffled_fields_1.field_id == 0)] 47 | day_1_field_2 = shuffled_fields_1[(shuffled_fields_1.cluster_id == 27) & (shuffled_fields_1.field_id == 1)] 48 | 49 | day_2_field_1 = shuffled_fields_2[(shuffled_fields_2.cluster_id == 20) & (shuffled_fields_2.field_id == 0)] 50 | day_2_field_2 = shuffled_fields_2[(shuffled_fields_2.cluster_id == 20) & (shuffled_fields_2.field_id == 1)] 51 | 52 | print('day 1 field 2') 53 | print(day_1_field_2.number_of_different_bins_bh) 54 | print('day 2 field 2') 55 | print(day_2_field_2.number_of_different_bins_bh) 56 | 57 | print('day 1 field 1') 58 | print(day_1_field_1.number_of_different_bins_bh) 59 | print('day 2 field 1') 60 | print(day_2_field_1.number_of_different_bins_bh) 61 | 62 | 63 | def main(): 64 | process_data() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /OverallAnalysis/compare_field_detectors.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pandas as pd 4 | import PostSorting.open_field_firing_field_detection 5 | import PostSorting.open_field_firing_fields 6 | import PostSorting.open_field_make_plots 7 | 8 | import matplotlib.pylab as plt 9 | 10 | 11 | # load data 12 | def load_data_frame(path): 13 | spike_data_frame_path = path + '/spatial_firing.pkl' 14 | spike_data = pd.read_pickle(spike_data_frame_path) 15 | return spike_data 16 | 17 | 18 | def plot_fields_on_rate_map(save_path, rate_map, fields, name): 19 | number_of_firing_fields = len(fields) 20 | colors = PostSorting.open_field_make_plots.generate_colors(number_of_firing_fields) 21 | of_figure = plt.figure() 22 | of_plot = of_figure.add_subplot(1, 1, 1) 23 | of_plot.axis('off') 24 | of_plot.imshow(rate_map) 25 | for field_id, field in enumerate(fields[0]): 26 | of_plot = PostSorting.open_field_make_plots.mark_firing_field_with_scatter(field, of_plot, colors, field_id) 27 | plt.savefig(save_path + '/detection_results_' + name + '.png') 28 | plt.close() 29 | 30 | 31 | def call_detector_original(spike_data, cluster, save_path): 32 | firing_fields, max_firing_rates = PostSorting.open_field_firing_fields.analyze_fields_in_cluster(spike_data, cluster, threshold=20) 33 | plot_fields_on_rate_map(save_path, spike_data.firing_maps[cluster], firing_fields, '01Hz_20') 34 | 35 | 36 | def call_detector_modified_params(spike_data, cluster, save_path): 37 | firing_fields, max_firing_rates = PostSorting.open_field_firing_fields.analyze_fields_in_cluster(spike_data, cluster, threshold=30) 38 | plot_fields_on_rate_map(save_path, spike_data.firing_maps[cluster], firing_fields, '01Hz_30') 39 | 40 | firing_fields, max_firing_rates = PostSorting.open_field_firing_fields.analyze_fields_in_cluster(spike_data, cluster, threshold=35) 41 | plot_fields_on_rate_map(save_path, spike_data.firing_maps[cluster], firing_fields, '01Hz_35') 42 | 43 | firing_fields, max_firing_rates = PostSorting.open_field_firing_fields.analyze_fields_in_cluster(spike_data, cluster, threshold=40) 44 | plot_fields_on_rate_map(save_path, spike_data.firing_maps[cluster], firing_fields, '01Hz_40') 45 | 46 | firing_fields, max_firing_rates = PostSorting.open_field_firing_fields.analyze_fields_in_cluster(spike_data, cluster, threshold=50) 47 | plot_fields_on_rate_map(save_path, spike_data.firing_maps[cluster], firing_fields, '01Hz_50') 48 | 49 | 50 | def call_detector_gauss(spike_data, cluster, save_path): 51 | PostSorting.open_field_firing_field_detection.detect_firing_fields(spike_data, cluster, save_path) 52 | 53 | 54 | def get_clusters_to_analyse(path): 55 | cluster_ids_path = path + '/cluster.txt' 56 | cluster_ids = [] 57 | with open(cluster_ids_path) as cluster_id_file: 58 | for line in cluster_id_file: 59 | cluster_ids.append(int(line) - 1) 60 | return cluster_ids 61 | 62 | 63 | def compare_field_detection_methods(): 64 | folder_path = 'C:/Users/s1466507/Documents/Ephys/field_detection_test' 65 | for name in glob.glob(folder_path + '/*'): 66 | if os.path.isdir(name): 67 | spike_data = load_data_frame(name) 68 | cluster_ids = get_clusters_to_analyse(name) 69 | for cluster in range(len(cluster_ids)): 70 | call_detector_gauss(spike_data, cluster_ids[cluster], name) 71 | call_detector_original(spike_data, cluster_ids[cluster], name) 72 | call_detector_modified_params(spike_data, cluster_ids[cluster], name) 73 | 74 | 75 | def main(): 76 | print('-------------------------------------------------------------') 77 | print('-------------------------------------------------------------') 78 | print('I will compare different field detection methods on a test dataset.') 79 | 80 | compare_field_detection_methods() 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /OverallAnalysis/compare_first_and_second_half_cell.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import plot_utility 4 | import matplotlib.pylab as plt 5 | import numpy as np 6 | import OverallAnalysis.false_positives 7 | import OverallAnalysis.folder_path_settings 8 | import OverallAnalysis.analyze_field_correlations 9 | import os 10 | import scipy.stats 11 | 12 | import rpy2.robjects as ro 13 | from rpy2.robjects.packages import importr 14 | 15 | 16 | local_path_mouse = OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/all_mice_df.pkl' 17 | local_path_rat = OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/all_rats_df.pkl' 18 | local_path_simulated = OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/all_simulated_df.pkl' 19 | path_to_data = OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/' 20 | save_output_path = OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/' 21 | server_path_mouse = OverallAnalysis.folder_path_settings.get_server_path_mouse() 22 | server_path_rat = OverallAnalysis.folder_path_settings.get_server_path_rat() 23 | server_path_simulated = OverallAnalysis.folder_path_settings.get_server_path_simulated() 24 | 25 | 26 | def add_cell_types_to_data_frame(cells): 27 | cell_type = [] 28 | for index, field in cells.iterrows(): 29 | if field.hd_score >= 0.5 and field.grid_score >= 0.4: 30 | cell_type.append('conjunctive') 31 | elif field.hd_score >= 0.5: 32 | cell_type.append('hd') 33 | elif field.grid_score >= 0.4: 34 | cell_type.append('grid') 35 | else: 36 | cell_type.append('na') 37 | 38 | cells['cell type'] = cell_type 39 | 40 | return cells 41 | 42 | 43 | def tag_false_positives(all_cells, animal): 44 | if animal == 'mouse': 45 | false_positives_path = path_to_data + 'false_positives_all.txt' 46 | list_of_false_positives = OverallAnalysis.false_positives.get_list_of_false_positives(false_positives_path) 47 | all_cells = add_combined_id_to_df(all_cells) 48 | all_cells['false_positive'] = all_cells['false_positive_id'].isin(list_of_false_positives) 49 | else: 50 | all_cells['false_positive'] = np.full(len(all_cells), False) 51 | return all_cells 52 | 53 | 54 | def load_spatial_firing(output_path, server_path, animal, spike_sorter='', df_path='/DataFrames'): 55 | if os.path.exists(output_path): 56 | spatial_firing = pd.read_pickle(output_path) 57 | return spatial_firing 58 | spatial_firing_data = pd.DataFrame() 59 | for recording_folder in glob.glob(server_path + '*'): 60 | os.path.isdir(recording_folder) 61 | data_frame_path = recording_folder + spike_sorter + df_path + '/spatial_firing.pkl' 62 | position_data_path = recording_folder + spike_sorter + df_path + '/position.pkl' 63 | if os.path.exists(data_frame_path): 64 | print('I found a firing data frame.') 65 | spatial_firing = pd.read_pickle(data_frame_path) 66 | if 'hd_correlation_first_vs_second_half' in spatial_firing: 67 | if animal == 'rat': 68 | spatial_firing = spatial_firing[['session_id', 'cell_id', 'cluster_id', 'firing_times', 69 | 'number_of_spikes', 'hd', 'speed', 'mean_firing_rate', 70 | 'hd_spike_histogram', 'max_firing_rate_hd', 'preferred_HD', 71 | 'grid_spacing', 'field_size', 'grid_score', 'hd_score', 'firing_fields']].copy() 72 | if animal == 'mouse': 73 | spatial_firing = spatial_firing[['session_id', 'cluster_id', 'tetrode', 'firing_times', 74 | 'number_of_spikes', 'hd', 'speed', 'mean_firing_rate', 75 | 'hd_spike_histogram', 'max_firing_rate_hd', 'preferred_HD', 76 | 'grid_spacing', 'field_size', 'grid_score', 'hd_score', 77 | 'firing_fields', 'hd_correlation_first_vs_second_half', 'hd_correlation_first_vs_second_half_p']].copy() 78 | if animal == 'simulated': 79 | spatial_firing = spatial_firing[['session_id', 'cluster_id', 'firing_times', 80 | 'hd', 'hd_spike_histogram', 'max_firing_rate_hd', 'preferred_HD', 81 | 'grid_spacing', 'field_size', 'grid_score', 'hd_score', 'firing_fields']].copy() 82 | downsample = True 83 | 84 | spatial_firing_data = spatial_firing_data.append(spatial_firing) 85 | 86 | spatial_firing_data.to_pickle(output_path) 87 | return spatial_firing_data 88 | 89 | 90 | def add_combined_id_to_df(df_all_mice): 91 | animal_ids = [session_id.split('_')[0] for session_id in df_all_mice.session_id.values] 92 | dates = [session_id.split('_')[1] for session_id in df_all_mice.session_id.values] 93 | tetrode = df_all_mice.tetrode.values 94 | cluster = df_all_mice.cluster_id.values 95 | 96 | combined_ids = [] 97 | for cell in range(len(df_all_mice)): 98 | id = animal_ids[cell] + '-' + dates[cell] + '-Tetrode-' + str(tetrode[cell]) + '-Cluster-' + str(cluster[cell]) 99 | combined_ids.append(id) 100 | df_all_mice['false_positive_id'] = combined_ids 101 | return df_all_mice 102 | 103 | 104 | def save_corr_coef_in_csv(good_grid_coef, good_grid_cells_p): 105 | correlation_data = pd.DataFrame() 106 | correlation_data['R'] = good_grid_coef 107 | correlation_data['p'] = good_grid_cells_p 108 | correlation_data.to_csv(OverallAnalysis.folder_path_settings.get_local_path() + '/correlation_cell/whole_cell_correlations.csv') 109 | 110 | 111 | def correlation_between_first_and_second_halves_of_session(df_all_animals, animal='mouse'): 112 | good_cluster = df_all_animals.false_positive == False 113 | grid_cell = df_all_animals['cell type'] == 'grid' 114 | 115 | is_hd_cell = df_all_animals.hd_score >= 0.5 116 | print('number of grid: ' + str(len(df_all_animals[grid_cell]))) 117 | print('number of conj cells: ' + str(len(df_all_animals[grid_cell & is_hd_cell]))) 118 | 119 | print('mean and sd pearson r of correlation between first and second half for grid cells') 120 | print(df_all_animals.hd_correlation_first_vs_second_half[good_cluster & grid_cell].mean()) 121 | print(df_all_animals.hd_correlation_first_vs_second_half[good_cluster & grid_cell].std()) 122 | 123 | print('% of significant correlation values for grid cells: ') 124 | good_grid_coef = df_all_animals.hd_correlation_first_vs_second_half[good_cluster & grid_cell] 125 | good_grid_cells_p = df_all_animals.hd_correlation_first_vs_second_half_p[good_cluster & grid_cell] 126 | number_of_significant_ps = (good_grid_cells_p < 0.01).sum() 127 | all_ps = len(good_grid_cells_p) 128 | proportion = number_of_significant_ps / all_ps * 100 129 | print(proportion) 130 | save_corr_coef_in_csv(good_grid_coef, good_grid_cells_p) 131 | t, p = scipy.stats.wilcoxon(df_all_animals.hd_correlation_first_vs_second_half[good_cluster & grid_cell]) 132 | print('Wilcoxon p value is ' + str(p) + ' T is ' + str(t)) 133 | 134 | OverallAnalysis.analyze_field_correlations.plot_correlation_coef_hist(df_all_animals.hd_correlation_first_vs_second_half[good_cluster & grid_cell], save_output_path + 'correlation_hd_session_' + animal + '.png', y_axis_label='Cumulative probability') 135 | 136 | 137 | def process_data(animal): 138 | print('-------------------------------------------------------------') 139 | if animal == 'mouse': 140 | spike_sorter = '/MountainSort' 141 | local_path_animal = local_path_mouse 142 | server_path_animal = server_path_mouse 143 | df_path = '/DataFrames' 144 | elif animal == 'rat': 145 | spike_sorter = '' 146 | local_path_animal = local_path_rat 147 | server_path_animal = server_path_rat 148 | df_path = '/DataFrames' 149 | else: 150 | spike_sorter = '' 151 | local_path_animal = local_path_simulated 152 | server_path_animal = server_path_simulated 153 | df_path = '' 154 | 155 | all_cells = load_spatial_firing(local_path_animal, server_path_animal, animal, spike_sorter, df_path=df_path) 156 | all_cells = tag_false_positives(all_cells, animal) 157 | all_cells = add_cell_types_to_data_frame(all_cells) 158 | if animal == 'mouse': 159 | correlation_between_first_and_second_halves_of_session(all_cells) 160 | 161 | 162 | def main(): 163 | process_data('mouse') 164 | process_data('rat') 165 | process_data('simulated') 166 | 167 | 168 | if __name__ == '__main__': 169 | main() -------------------------------------------------------------------------------- /OverallAnalysis/compare_sampling_of_mice_and_rats.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab as plt 2 | import OverallAnalysis.folder_path_settings 3 | import pandas as pd 4 | import PostSorting.parameters 5 | prm = PostSorting.parameters.Parameters() 6 | prm.set_pixel_ratio(440) 7 | 8 | from scipy import stats 9 | 10 | local_path = OverallAnalysis.folder_path_settings.get_local_path() + '/compare_sampling/' 11 | 12 | 13 | def load_data(): 14 | mouse_data = pd.read_csv(local_path + "mouse_.csv") 15 | rat_data = pd.read_csv(local_path + "rat_.csv") 16 | return mouse_data, rat_data 17 | 18 | 19 | def compare_sampling(mouse, rat): 20 | print(mouse.head()) 21 | print(rat.head()) 22 | print('mouse') 23 | print('avg time spent in field') 24 | print((mouse.time_spent_in_field / 30).mean()) 25 | print((mouse.time_spent_in_field / 30).std()) 26 | print('number of spikes') 27 | print(mouse.number_of_spikes_in_field.mean()) 28 | print(mouse.number_of_spikes_in_field.std()) 29 | print('number of fields: ' + str(len(mouse))) 30 | 31 | print('rat') 32 | print('avg time spent in field') 33 | print((rat.time_spent_in_field / 50).mean()) 34 | print((rat.time_spent_in_field / 50).std()) 35 | print('number of spikes') 36 | print(rat.number_of_spikes_in_field.mean()) 37 | print(rat.number_of_spikes_in_field.std()) 38 | print('number of fields: ' + str(len(rat))) 39 | 40 | fig = plt.figure() 41 | ax = fig.add_subplot(111) 42 | ax.set_xlabel('Time spent in field') 43 | ax.set_ylabel('Number of spikes in field') 44 | plt.scatter(mouse.time_spent_in_field / 30, mouse.number_of_spikes_in_field, color='navy', alpha=0.6) 45 | plt.scatter(rat.time_spent_in_field / 50, rat.number_of_spikes_in_field, color='lime', alpha=0.6) 46 | plt.savefig(local_path + 'sampling_in_mice_vs_rats.png') 47 | plt.close() 48 | 49 | fig = plt.figure() 50 | ax = fig.add_subplot(111) 51 | ax.set_xlabel('Time spent in field (s)') 52 | ax.set_ylabel('Number of fields') 53 | plt.hist(mouse.time_spent_in_field / 30, color='navy', alpha=0.6) 54 | plt.hist(rat.time_spent_in_field / 50, color='lime', alpha=0.6) 55 | plt.savefig(local_path + 'time_spent_in_field.png') 56 | plt.close() 57 | 58 | fig = plt.figure() 59 | ax = fig.add_subplot(111) 60 | ax.set_xlabel('Number_of_spikes') 61 | ax.set_ylabel('Number of fields') 62 | plt.hist(mouse.number_of_spikes_in_field, color='navy', alpha=0.6) 63 | plt.hist(rat.number_of_spikes_in_field, color='lime', alpha=0.6) 64 | plt.savefig(local_path + 'number_of_spikes_in_field.png') 65 | plt.close() 66 | 67 | print('time spent in field comparison:') 68 | d, p = stats.ks_2samp(mouse.time_spent_in_field, rat.time_spent_in_field) 69 | print(d) 70 | print(p) 71 | 72 | print('number of spikes comparison:') 73 | d, p = stats.ks_2samp(mouse.number_of_spikes_in_field, rat.number_of_spikes_in_field) 74 | print(d) 75 | print(p) 76 | 77 | 78 | def main(): 79 | mouse, rat = load_data() 80 | compare_sampling(mouse, rat) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /OverallAnalysis/example_get_firing_rates_for_half_fields.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | analysis_path = '/grid_fields/analysis/' 5 | analysis_path = analysis_path + 'firing_rates_for_plots/' 6 | 7 | 8 | first_path = analysis_path + 'first/' 9 | second_path = analysis_path + 'second/' 10 | all_path = analysis_path 11 | 12 | spatial_first = pd.read_pickle(first_path + 'spatial_firing.pkl') 13 | spatial_second = pd.read_pickle(second_path + 'spatial_firing.pkl') 14 | spatial_all = pd.read_pickle(all_path + 'spatial_firing.pkl') 15 | 16 | cell_first = spatial_first.iloc[0] 17 | cell_second = spatial_second.iloc[0] 18 | all = spatial_all.iloc[0] 19 | 20 | print(' ') 21 | print('fields in whole session') 22 | for field in range(len(cell_first.firing_fields_hd_cluster)): 23 | hd_hist_cluster = all.firing_fields_hd_cluster[field] 24 | hd_hist_session = np.array(all.firing_fields_hd_session[field]) 25 | hd_hist = hd_hist_cluster / hd_hist_session / 1000 26 | max_firing_rate = np.max(hd_hist[~np.isnan(hd_hist)].flatten()) 27 | print(max_firing_rate) 28 | 29 | 30 | print(' ') 31 | print('fields in first half') 32 | for field in range(len(cell_first.firing_fields_hd_cluster)): 33 | hd_hist_cluster = cell_first.firing_fields_hd_cluster[field] 34 | hd_hist_session = cell_first.firing_fields_hd_session[field] / 30000 35 | hd_hist = hd_hist_cluster / hd_hist_session / 1000 36 | max_firing_rate = np.max(hd_hist[~np.isnan(hd_hist)].flatten()) 37 | print(max_firing_rate) 38 | 39 | 40 | print('fields in second half') 41 | for field in range(len(cell_second.firing_fields_hd_cluster)): 42 | hd_hist_cluster = cell_second.firing_fields_hd_cluster[field] 43 | hd_hist_session = cell_second.firing_fields_hd_session[field] / 30000 44 | hd_hist = hd_hist_cluster / hd_hist_session / 1000 45 | max_firing_rate = np.max(hd_hist[~np.isnan(hd_hist)].flatten()) 46 | print(max_firing_rate) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /OverallAnalysis/example_hd_histograms_binning.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab as plt 2 | import numpy as np 3 | import OverallAnalysis.folder_path_settings 4 | import pandas as pd 5 | import plot_utility 6 | import PostSorting.open_field_head_direction 7 | import PostSorting.open_field_make_plots 8 | 9 | 10 | local_path = OverallAnalysis.folder_path_settings.get_local_path() + '/example_hd_histograms/' 11 | server_path_mouse = OverallAnalysis.folder_path_settings.get_server_path_mouse() 12 | server_path_rat = OverallAnalysis.folder_path_settings.get_server_path_rat() 13 | server_path_simulated = OverallAnalysis.folder_path_settings.get_server_path_simulated() 14 | 15 | 16 | def plot_polar_head_direction_histogram(spike_hist, hd_hist, id, save_path): 17 | print('I will make the polar HD plots now.') 18 | 19 | hd_polar_fig = plt.figure() 20 | # hd_polar_fig.set_size_inches(5, 5, forward=True) 21 | ax = hd_polar_fig.add_subplot(1, 1, 1) # specify (nrows, ncols, axnum) 22 | hd_hist_cluster = spike_hist 23 | theta = np.linspace(0, 2*np.pi, 361) # x axis 24 | ax = plt.subplot(1, 1, 1, polar=True) 25 | ax = plot_utility.style_polar_plot(ax) 26 | ax.plot(theta[:-1], hd_hist_cluster, color='red', linewidth=2) 27 | ax.plot(theta[:-1], hd_hist*(max(hd_hist_cluster)/max(hd_hist)), color='black', linewidth=2) 28 | # plt.tight_layout() 29 | max_firing_rate = np.max(hd_hist_cluster.flatten()) 30 | plt.title(str(round(max_firing_rate, 2)) + 'Hz', y=1.08) 31 | # + '\nKuiper p: ' + str(spatial_firing.hd_p[cluster]) 32 | # plt.title('max fr: ' + str(round(spatial_firing.max_firing_rate_hd[cluster], 2)) + ' Hz' + ', preferred HD: ' + str(round(spatial_firing.preferred_HD[cluster][0], 0)) + ', hd score: ' + str(round(spatial_firing.hd_score[cluster], 2)), y=1.08, fontsize=12) 33 | plt.savefig(save_path + '/' + id + '_hd_polar_' + '.png', dpi=300) 34 | plt.close() 35 | 36 | 37 | def plot_example_hd_histograms(): 38 | position = pd.read_pickle(local_path + 'position.pkl') 39 | hd_pos = np.array(position.hd) 40 | hd_pos = (hd_pos + 180) * np.pi / 180 41 | 42 | spatial_firing = pd.read_pickle(local_path + 'spatial_firing.pkl') 43 | hd = np.array(spatial_firing.hd.iloc[0]) 44 | hd = (hd + 180) * np.pi / 180 45 | hd_spike_histogram_23 = PostSorting.open_field_head_direction.get_hd_histogram(hd, window_size=23) 46 | hd_spike_histogram_10 = PostSorting.open_field_head_direction.get_hd_histogram(hd, window_size=10) 47 | hd_spike_histogram_20 = PostSorting.open_field_head_direction.get_hd_histogram(hd, window_size=20) 48 | hd_spike_histogram_30 = PostSorting.open_field_head_direction.get_hd_histogram(hd, window_size=30) 49 | hd_spike_histogram_40 = PostSorting.open_field_head_direction.get_hd_histogram(hd, window_size=40) 50 | 51 | hd_spike_histogram_23_pos = PostSorting.open_field_head_direction.get_hd_histogram(hd_pos, window_size=23) / 30000 # 30000 is the ephys sampling rate for the mouse data 52 | hd_spike_histogram_10_pos = PostSorting.open_field_head_direction.get_hd_histogram(hd_pos, window_size=10) / 30000 53 | hd_spike_histogram_20_pos = PostSorting.open_field_head_direction.get_hd_histogram(hd_pos, window_size=20) / 30000 54 | hd_spike_histogram_30_pos = PostSorting.open_field_head_direction.get_hd_histogram(hd_pos, window_size=30) / 30000 55 | hd_spike_histogram_40_pos = PostSorting.open_field_head_direction.get_hd_histogram(hd_pos, window_size=40) / 30000 56 | 57 | hd_spike_histogram_10_norm = hd_spike_histogram_10 / hd_spike_histogram_10_pos / 1000 58 | hd_spike_histogram_20_norm = hd_spike_histogram_20 / hd_spike_histogram_20_pos / 1000 59 | hd_spike_histogram_30_norm = hd_spike_histogram_30 / hd_spike_histogram_30_pos / 1000 60 | hd_spike_histogram_40_norm = hd_spike_histogram_40 / hd_spike_histogram_40_pos / 1000 61 | 62 | print('max rate:') 63 | max_firing_rate = np.max(hd_spike_histogram_10_norm.flatten()) 64 | print(max_firing_rate) 65 | max_firing_rate = np.max(hd_spike_histogram_20_norm.flatten()) 66 | print(max_firing_rate) 67 | max_firing_rate = np.max(hd_spike_histogram_30_norm.flatten()) 68 | print(max_firing_rate) 69 | max_firing_rate = np.max(hd_spike_histogram_40_norm.flatten()) 70 | print(max_firing_rate) 71 | 72 | plot_polar_head_direction_histogram(hd_spike_histogram_10_norm, hd_spike_histogram_10_pos, str(10), local_path) 73 | plot_polar_head_direction_histogram(hd_spike_histogram_20_norm, hd_spike_histogram_20_pos, str(20), local_path) 74 | plot_polar_head_direction_histogram(hd_spike_histogram_30_norm, hd_spike_histogram_30_pos, str(30), local_path) 75 | plot_polar_head_direction_histogram(hd_spike_histogram_40_norm, hd_spike_histogram_40_pos, str(40), local_path) 76 | 77 | 78 | def main(): 79 | plot_example_hd_histograms() 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /OverallAnalysis/false_positives.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_list_of_false_positives(false_positives_path): 5 | if os.path.isfile(false_positives_path) is True: 6 | if os.stat(false_positives_path).st_size == 0: 7 | os.remove(false_positives_path) 8 | false_positive_reader = open(false_positives_path, 'r') 9 | false_positives = false_positive_reader.readlines() 10 | false_positive_clusters = list([x.strip() for x in false_positives]) 11 | false_positive_clusters_stripped = (str.strip, false_positive_clusters) 12 | return false_positive_clusters_stripped[1] 13 | 14 | 15 | def tag_false_positives(spike_df, false_positives_path): 16 | false_positives_list = get_list_of_false_positives(false_positives_path) 17 | spike_df['false_positive'] = spike_df['fig_name_id'].isin(false_positives_list) 18 | return spike_df 19 | 20 | 21 | def add_figure_name_id(spike_df): 22 | # todo change order in data and put - current format M9-10/04/2018-Tetrode-1-Cluster-4 23 | figure_name_ids = spike_df['animal'] + '-' + spike_df['day'].apply(str) + '-Tetrode-' + spike_df['tetrode'].apply(str) + '-Cluster-' + spike_df['cluster'].apply(str) 24 | spike_df['fig_name_id'] = figure_name_ids 25 | return spike_df 26 | 27 | 28 | def get_accepted_clusters(spike_data_frame, false_positives_path): 29 | spike_data_frame = add_figure_name_id(spike_data_frame) 30 | spike_data_frame = tag_false_positives(spike_data_frame, false_positives_path) 31 | not_false_positive = spike_data_frame['false_positive'] == 0 32 | good_cluster = spike_data_frame['goodcluster'] == 1 33 | accepted_clusters = spike_data_frame[good_cluster & not_false_positive] 34 | return accepted_clusters 35 | 36 | 37 | def get_false_positives(spike_data_frame, false_positives_path): 38 | spike_data_frame = add_figure_name_id(spike_data_frame) 39 | spike_data_frame = tag_false_positives(spike_data_frame, false_positives_path) 40 | false_positive = spike_data_frame['false_positive'] == 1 41 | good_cluster = spike_data_frame['goodcluster'] == 1 42 | false_positives = spike_data_frame[good_cluster & false_positive] 43 | return false_positives 44 | -------------------------------------------------------------------------------- /OverallAnalysis/folder_path_settings.py: -------------------------------------------------------------------------------- 1 | server_path_mouse = '/path/to/raw/ephys/data/on/server/' 2 | server_path_rat = '/path/to/raw/ephys/data/on/server/' 3 | server_path_simulated = '/path/to/raw/ephys/data/on/server/' 4 | analysis_path = '/grid_fields/' 5 | local_test_recording_path = '/local/path/to/test/recording/open/ephys/format/' 6 | 7 | 8 | def get_server_path_mouse(): 9 | return server_path_mouse 10 | 11 | 12 | def get_server_path_rat(): 13 | return server_path_rat 14 | 15 | 16 | def get_server_path_simulated(): 17 | return server_path_simulated 18 | 19 | 20 | def get_local_path(): 21 | return analysis_path 22 | 23 | 24 | def get_local_test_recording_path(): 25 | return local_test_recording_path 26 | 27 | 28 | -------------------------------------------------------------------------------- /OverallAnalysis/grid_analysis_other_labs/firing_maps.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import multiprocessing 3 | import matplotlib.pylab as plt 4 | import pandas as pd 5 | from numba import jit 6 | import numpy as np 7 | import math 8 | import time 9 | 10 | 11 | def get_dwell(spatial_data, prm): 12 | min_dwell_distance_cm = 5 # from point to determine min dwell time 13 | dt_position_ms = spatial_data.time_seconds.diff().mean() * 1000 # sampling interval in position data 14 | min_dwell_time_ms = 3 * dt_position_ms # this is about 100 ms 15 | min_dwell_time = round(min_dwell_time_ms / dt_position_ms) 16 | return min_dwell_time, min_dwell_distance_cm 17 | 18 | 19 | def get_bin_size(prm): 20 | bin_size_cm = 2.5 21 | return bin_size_cm 22 | 23 | 24 | def get_number_of_bins(spatial_data, prm): 25 | bin_size = get_bin_size(prm) 26 | length_of_arena_x = spatial_data.position_x[~np.isnan(spatial_data.position_x)].max() 27 | length_of_arena_y = spatial_data.position_y[~np.isnan(spatial_data.position_y)].max() 28 | number_of_bins_x = math.ceil(length_of_arena_x / bin_size) 29 | number_of_bins_y = math.ceil(length_of_arena_y / bin_size) 30 | return number_of_bins_x, number_of_bins_y 31 | 32 | 33 | @jit 34 | def gaussian_kernel(kernx): 35 | kerny = np.exp(np.power(kernx, 2)/2 * (-1)) 36 | return kerny 37 | 38 | 39 | def calculate_firing_rate_for_cluster_parallel(cluster, smooth, firing_data_spatial, positions_x, positions_y, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms): 40 | print('Started another cluster') 41 | print(cluster) 42 | cluster_index = firing_data_spatial.cluster_id.values[cluster] - 1 43 | cluster_firings = pd.DataFrame({'position_x': firing_data_spatial.position_x[cluster_index], 'position_y': firing_data_spatial.position_y[cluster_index]}) 44 | spike_positions_x = cluster_firings.position_x.values 45 | spike_positions_y = cluster_firings.position_y.values 46 | firing_rate_map = np.zeros((number_of_bins_x, number_of_bins_y)) 47 | for x in range(number_of_bins_x): 48 | for y in range(number_of_bins_y): 49 | px = x * bin_size_pixels + (bin_size_pixels / 2) 50 | py = y * bin_size_pixels + (bin_size_pixels / 2) 51 | spike_distances = np.sqrt(np.power(px - spike_positions_x, 2) + np.power(py - spike_positions_y, 2)) 52 | spike_distances = spike_distances[~np.isnan(spike_distances)] 53 | occupancy_distances = np.sqrt(np.power((px - positions_x), 2) + np.power((py - positions_y), 2)) 54 | occupancy_distances = occupancy_distances[~np.isnan(occupancy_distances)] 55 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_pixels)[0]) 56 | 57 | if bin_occupancy >= min_dwell: 58 | firing_rate_map[x, y] = sum(gaussian_kernel(spike_distances/smooth)) / (sum(gaussian_kernel(occupancy_distances/smooth)) * (dt_position_ms/1000)) 59 | 60 | else: 61 | firing_rate_map[x, y] = 0 62 | #firing_rate_map = np.rot90(firing_rate_map) 63 | return firing_rate_map 64 | 65 | 66 | def get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm): 67 | print('I will calculate firing rate maps now.') 68 | dt_position_ms = spatial_data.time_seconds.diff().mean()*1000 69 | min_dwell, min_dwell_distance_pixels = get_dwell(spatial_data, prm) 70 | smooth = 5 # / 100 * prm.get_pixel_ratio() 71 | bin_size_pixels = get_bin_size(prm) 72 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 73 | num_cores = multiprocessing.cpu_count() 74 | clusters = range(len(firing_data_spatial)) 75 | time_start = time.time() 76 | firing_rate_maps = Parallel(n_jobs=num_cores, max_nbytes=None)(delayed(calculate_firing_rate_for_cluster_parallel)(cluster, smooth, firing_data_spatial, spatial_data.position_x.values, spatial_data.position_y.values, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms) for cluster in clusters) 77 | time_end = time.time() 78 | print('Making the rate maps took:') 79 | time_diff = time_end - time_start 80 | print(time_diff) 81 | firing_data_spatial['firing_maps'] = firing_rate_maps 82 | 83 | return firing_data_spatial 84 | 85 | 86 | def get_position_heatmap(spatial_data, prm): 87 | min_dwell, min_dwell_distance_cm = get_dwell(spatial_data, prm) 88 | min_dwell_distance_cm = 5 89 | bin_size_cm = get_bin_size(prm) 90 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 91 | 92 | position_heat_map = np.zeros((number_of_bins_x, number_of_bins_y)) 93 | 94 | # find value for each bin for heatmap 95 | for x in range(number_of_bins_x): 96 | for y in range(number_of_bins_y): 97 | px = x * bin_size_cm + (bin_size_cm / 2) 98 | py = y * bin_size_cm + (bin_size_cm / 2) 99 | 100 | occupancy_distances = np.sqrt(np.power((px - spatial_data.position_x.values), 2) + np.power((py - spatial_data.position_y.values), 2)) 101 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_cm)[0]) 102 | 103 | if bin_occupancy >= min_dwell: 104 | position_heat_map[x, y] = bin_occupancy 105 | else: 106 | position_heat_map[x, y] = None 107 | return position_heat_map 108 | 109 | 110 | # this is the firing rate in the bin with the highest rate 111 | def find_maximum_firing_rate(spatial_firing): 112 | max_firing_rates = [] 113 | for cluster in range(len(spatial_firing)): 114 | cluster = spatial_firing.cluster_id.values[cluster] - 1 115 | firing_rate_map = spatial_firing.firing_maps[cluster] 116 | max_firing_rate = np.max(firing_rate_map.flatten()) 117 | max_firing_rates.append(max_firing_rate) 118 | spatial_firing['max_firing_rate'] = max_firing_rates 119 | return spatial_firing 120 | 121 | 122 | def make_firing_field_maps(spatial_data, firing_data_spatial, prm): 123 | position_heat_map = get_position_heatmap(spatial_data, prm) 124 | firing_data_spatial = get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm) 125 | #position_heat_map = np.rot90(position_heat_map) # to rotate map to be like matlab plots 126 | firing_data_spatial = find_maximum_firing_rate(firing_data_spatial) 127 | return position_heat_map, firing_data_spatial -------------------------------------------------------------------------------- /OverallAnalysis/model_prediction_schematics_simulated.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pylab as plt 3 | import OverallAnalysis.folder_path_settings 4 | import PostSorting.open_field_head_direction 5 | import PostSorting.open_field_make_plots 6 | 7 | 8 | analysis_path = OverallAnalysis.folder_path_settings.get_local_path() + 'prediction_schematics/' 9 | 10 | 11 | def get_smooth_hist_and_plot(distribution, name): 12 | # get smooth hd hist 13 | smooth_hist = PostSorting.open_field_head_direction.get_hd_histogram(distribution, window_size=23) 14 | PostSorting.open_field_make_plots.plot_single_polar_hd_hist(smooth_hist, 0, analysis_path + name, color1='navy', title='') 15 | 16 | 17 | def plot_distributions(): 18 | print('multimodal') 19 | multimodal1 = np.random.vonmises(0, 6, 1000000) 20 | multimodal4 = np.random.vonmises(1.2, 6, 1000000) 21 | multimodal2 = np.random.vonmises(2, 6, 1000000) 22 | multimodal5 = np.random.vonmises(2.9, 7, 1000000) 23 | multimodal3 = np.random.vonmises(4, 6, 1000000) 24 | multimodal6 = np.random.vonmises(5.1, 6, 1000000) 25 | 26 | multimodal = np.concatenate((multimodal1, multimodal2, multimodal3, multimodal4, multimodal5, multimodal6)) 27 | multimodal += np.pi 28 | get_smooth_hist_and_plot(multimodal, 'multimodal') 29 | 30 | uniform = np.random.uniform(0, 2 * np.pi, 1000000) 31 | print('uniform') 32 | get_smooth_hist_and_plot(uniform, 'uniform') 33 | print('unimodal') 34 | unimodal = np.random.vonmises(0.5, 4, 1000000) + np.pi 35 | get_smooth_hist_and_plot(unimodal, 'unimodal') 36 | 37 | 38 | def main(): 39 | plot_distributions() 40 | 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /OverallAnalysis/open_field_firing_maps_processed_data.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import os 3 | import multiprocessing 4 | import matplotlib.pylab as plt 5 | import pandas as pd 6 | from numba import jit 7 | import numpy as np 8 | import math 9 | import time 10 | 11 | 12 | def get_dwell(spatial_data, prm): 13 | min_dwell_distance_cm = 5 # from point to determine min dwell time 14 | min_dwell_distance_pixels = min_dwell_distance_cm / 100 * prm.get_pixel_ratio() 15 | 16 | dt_position_ms = spatial_data.synced_time.diff().mean()*1000 # average sampling interval in position data (ms) 17 | min_dwell_time_ms = 3 * dt_position_ms # this is about 100 ms 18 | min_dwell = round(min_dwell_time_ms/dt_position_ms) 19 | return min_dwell, min_dwell_distance_pixels 20 | 21 | 22 | def get_bin_size(prm): 23 | bin_size_cm = 2.5 24 | bin_size_pixels = bin_size_cm / 100 * prm.get_pixel_ratio() 25 | return bin_size_pixels 26 | 27 | 28 | def get_number_of_bins(spatial_data, prm): 29 | bin_size_pixels = get_bin_size(prm) 30 | length_of_arena_x = spatial_data.position_x_pixels[~np.isnan(spatial_data.position_x_pixels)].max() 31 | length_of_arena_y = spatial_data.position_y_pixels[~np.isnan(spatial_data.position_y_pixels)].max() 32 | number_of_bins_x = math.ceil(length_of_arena_x / bin_size_pixels) 33 | number_of_bins_y = math.ceil(length_of_arena_y / bin_size_pixels) 34 | return number_of_bins_x, number_of_bins_y 35 | 36 | 37 | @jit 38 | def gaussian_kernel(kernx): 39 | kerny = np.exp(np.power(kernx, 2)/2 * (-1)) 40 | return kerny 41 | 42 | 43 | def calculate_firing_rate_for_cluster_parallel(cluster, smooth, firing_data_spatial, positions_x, positions_y, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms): 44 | print('Started another cluster') 45 | print(cluster) 46 | cluster_index = firing_data_spatial.cluster_id.values[cluster] - 1 47 | cluster_index = 0 48 | cluster_firings = pd.DataFrame({'position_x': firing_data_spatial.position_x_pixels[cluster_index], 'position_y': firing_data_spatial.position_y_pixels[cluster_index]}) 49 | spike_positions_x = cluster_firings.position_x.values 50 | spike_positions_y = cluster_firings.position_y.values 51 | firing_rate_map = np.zeros((number_of_bins_x, number_of_bins_y)) 52 | for x in range(number_of_bins_x): 53 | for y in range(number_of_bins_y): 54 | px = x * bin_size_pixels + (bin_size_pixels / 2) 55 | py = y * bin_size_pixels + (bin_size_pixels / 2) 56 | spike_distances = np.sqrt(np.power(px - spike_positions_x, 2) + np.power(py - spike_positions_y, 2)) 57 | spike_distances = spike_distances[~np.isnan(spike_distances)] 58 | occupancy_distances = np.sqrt(np.power((px - positions_x), 2) + np.power((py - positions_y), 2)) 59 | occupancy_distances = occupancy_distances[~np.isnan(occupancy_distances)] 60 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_pixels)[0]) 61 | 62 | if bin_occupancy >= min_dwell: 63 | firing_rate_map[x, y] = sum(gaussian_kernel(spike_distances/smooth)) / (sum(gaussian_kernel(occupancy_distances/smooth)) * (dt_position_ms/1000)) 64 | 65 | else: 66 | firing_rate_map[x, y] = 0 67 | #firing_rate_map = np.rot90(firing_rate_map) 68 | return firing_rate_map 69 | 70 | 71 | def get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm): 72 | print('I will calculate firing rate maps now.') 73 | dt_position_ms = spatial_data.synced_time.diff().mean()*1000 74 | min_dwell, min_dwell_distance_pixels = get_dwell(spatial_data, prm) 75 | smooth = 5 / 100 * prm.get_pixel_ratio() 76 | bin_size_pixels = get_bin_size(prm) 77 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 78 | clusters = range(len(firing_data_spatial)) 79 | num_cores = int(os.environ['HEATMAP_CONCURRENCY']) if os.environ.get('HEATMAP_CONCURRENCY') else multiprocessing.cpu_count() 80 | time_start = time.time() 81 | firing_rate_maps = Parallel(n_jobs=num_cores)(delayed(calculate_firing_rate_for_cluster_parallel)(cluster, smooth, firing_data_spatial, spatial_data.position_x_pixels.values, spatial_data.position_y_pixels.values, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms) for cluster in clusters) 82 | time_end = time.time() 83 | print('Making the rate maps took:') 84 | time_diff = time_end - time_start 85 | print(time_diff) 86 | firing_data_spatial['firing_maps'] = firing_rate_maps 87 | 88 | return firing_data_spatial 89 | 90 | 91 | def get_position_heatmap(spatial_data, prm): 92 | min_dwell, min_dwell_distance_cm = get_dwell(spatial_data, prm) 93 | bin_size_cm = get_bin_size(prm) 94 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 95 | 96 | position_heat_map = np.zeros((number_of_bins_x, number_of_bins_y)) 97 | 98 | # find value for each bin for heatmap 99 | for x in range(number_of_bins_x): 100 | for y in range(number_of_bins_y): 101 | px = x * bin_size_cm + (bin_size_cm / 2) 102 | py = y * bin_size_cm + (bin_size_cm / 2) 103 | 104 | occupancy_distances = np.sqrt(np.power((px - spatial_data.position_x_pixels.values), 2) + np.power((py - spatial_data.position_y_pixels.values), 2)) 105 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_cm)[0]) 106 | 107 | if bin_occupancy >= min_dwell: 108 | position_heat_map[x, y] = bin_occupancy 109 | else: 110 | position_heat_map[x, y] = None 111 | return position_heat_map 112 | 113 | 114 | # this is the firing rate in the bin with the highest rate 115 | def find_maximum_firing_rate(spatial_firing): 116 | max_firing_rates = [] 117 | for cluster in range(len(spatial_firing)): 118 | cluster = spatial_firing.cluster_id.values[cluster] - 1 119 | cluster = 0 120 | firing_rate_map = spatial_firing.firing_maps[cluster] 121 | max_firing_rate = np.max(firing_rate_map.flatten()) 122 | max_firing_rates.append(max_firing_rate) 123 | spatial_firing['max_firing_rate'] = max_firing_rates 124 | return spatial_firing 125 | 126 | 127 | def make_firing_field_maps(spatial_data, firing_data_spatial, prm): 128 | position_heat_map = get_position_heatmap(spatial_data, prm) 129 | firing_data_spatial = get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm) 130 | #position_heat_map = np.rot90(position_heat_map) # to rotate map to be like matlab plots 131 | firing_data_spatial = find_maximum_firing_rate(firing_data_spatial) 132 | return position_heat_map, firing_data_spatial -------------------------------------------------------------------------------- /OverallAnalysis/overall_params.py: -------------------------------------------------------------------------------- 1 | class OverallParameters: 2 | 3 | isolation = 0 4 | noise_overlap = 0 5 | snr = 0 6 | 7 | path_to_data = '' 8 | save_output_path = '' 9 | false_positives_path_all = '' 10 | false_positives_path_separate = '' 11 | 12 | def __init__(self): 13 | return 14 | 15 | def get_isolation(self): 16 | return OverallParameters.isolation 17 | 18 | def set_isolation(self, isolation_th): 19 | OverallParameters.isolation = isolation_th 20 | 21 | def get_noise_overlap(self): 22 | return OverallParameters.noise_overlap 23 | 24 | def set_noise_overlap(self, noise_overlap_th): 25 | OverallParameters.noise_overlap = noise_overlap_th 26 | 27 | def get_snr(self): 28 | return OverallParameters.snr 29 | 30 | def set_snr(self, signal_to_noise_ratio): 31 | OverallParameters.snr = signal_to_noise_ratio 32 | 33 | def get_path_to_data(self): 34 | return OverallParameters.path_to_data 35 | 36 | def set_path_to_data(self, path): 37 | OverallParameters.path_to_data = path 38 | 39 | def get_save_output_path(self): 40 | return OverallParameters.save_output_path 41 | 42 | def set_save_output_path(self, path): 43 | OverallParameters.save_output_path = path 44 | 45 | def get_false_positives_path_all(self): 46 | return OverallParameters.false_positives_path_all 47 | 48 | def set_false_positives_path_all(self, path): 49 | OverallParameters.false_positives_path_all = path 50 | 51 | def get_false_positives_path_separate(self): 52 | return OverallParameters.false_positives_path_separate 53 | 54 | def set_false_positives_path_separate(self, path): 55 | OverallParameters.false_positives_path_separate = path -------------------------------------------------------------------------------- /OverallAnalysis/pattern_of_field_shapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | 5 | import OverallAnalysis.folder_path_settings 6 | import PostSorting.open_field_make_plots 7 | 8 | 9 | local_path = OverallAnalysis.folder_path_settings.get_local_path() + '/pattern_of_shapes/' 10 | 11 | 12 | # select accepted fields based on list of fields that were correctly identified by field detector 13 | def tag_accepted_fields_mouse(field_data, accepted_fields): 14 | unique_id = field_data.session_id + '_' + field_data.cluster_id.apply(str) + '_' + (field_data.field_id + 1).apply(str) 15 | field_data['unique_id'] = unique_id 16 | unique_id = accepted_fields['Session ID'] + '_' + accepted_fields['Cell'].apply(str) + '_' + accepted_fields['field'].apply(str) 17 | accepted_fields['unique_id'] = unique_id 18 | field_data['unique_cell_id'] = field_data.session_id + '_' + field_data.cluster_id.apply(str) 19 | field_data['accepted_field'] = field_data.unique_id.isin(accepted_fields.unique_id) 20 | return field_data 21 | 22 | 23 | def plot_average_field_in_region(field_data, x1, x2, y1, y2, tag): 24 | sum_of_fields = np.zeros(360) 25 | sum_of_fields_norm = np.zeros(360) 26 | number_of_fields = 0 27 | for index, field in field_data.iterrows(): 28 | classic_hd_hist = field.hd_hist_spikes / field.hd_hist_session 29 | field_indices = field.indices_rate_map 30 | x = (field_indices[:, 0] * 2.5).mean() # convert to cm 31 | y = (field_indices[:, 1] * 2.5).mean() 32 | if (x >= x1) & (x < x2) & (y >= y1) & (y < y2): 33 | field_hist = np.nan_to_num(classic_hd_hist) 34 | sum_of_fields += np.nan_to_num(field_hist) 35 | number_of_fields += 1 36 | normalized_hist = field_hist / np.nanmax(field_hist) 37 | sum_of_fields_norm += normalized_hist 38 | 39 | avg = sum_of_fields / number_of_fields 40 | print('Number of fields in ' + tag) 41 | print(number_of_fields) 42 | if number_of_fields > 0: 43 | save_path = local_path + 'smooth_histograms/' + tag + 'not_normalized' 44 | PostSorting.open_field_make_plots.plot_single_polar_hd_hist(avg, 0, 45 | save_path, color1='red', title='') 46 | 47 | avg_norm = sum_of_fields_norm / number_of_fields 48 | save_path = local_path + 'smooth_histograms/' + tag + 'normalized' 49 | PostSorting.open_field_make_plots.plot_single_polar_hd_hist(avg_norm, 0, 50 | save_path, color1='red', title='') 51 | 52 | 53 | def plot_all_fields(field_data): 54 | if not os.path.isdir(local_path + 'smooth_histograms/'): 55 | os.mkdir(local_path + 'smooth_histograms/') 56 | 57 | plot_average_field_in_region(field_data, x1=0, x2=33, y1=0, y2=33, tag='region_1') 58 | plot_average_field_in_region(field_data, x1=33, x2=66, y1=0, y2=33, tag='region_2') 59 | plot_average_field_in_region(field_data, x1=66, x2=100, y1=0, y2=33, tag='region_3') 60 | 61 | plot_average_field_in_region(field_data, x1=0, x2=33, y1=33, y2=66, tag='region_4') 62 | plot_average_field_in_region(field_data, x1=33, x2=66, y1=33, y2=66, tag='region_5') 63 | plot_average_field_in_region(field_data, x1=66, x2=100, y1=33, y2=66, tag='region_6') 64 | 65 | plot_average_field_in_region(field_data, x1=0, x2=33, y1=66, y2=101, tag='region_7') 66 | plot_average_field_in_region(field_data, x1=33, x2=66, y1=66, y2=101, tag='region_8') 67 | plot_average_field_in_region(field_data, x1=66, x2=100, y1=66, y2=101, tag='region_9') 68 | 69 | for index, field in field_data.iterrows(): 70 | save_path = local_path + 'smooth_histograms/' + field.session_id + str(field.cluster_id) + str(field.field_id) + '_' 71 | field_indices = field.indices_rate_map 72 | 73 | d1 = (field_indices[:, 0] * 2.5).mean() # convert to cm 74 | d2 = (field_indices[:, 1] * 2.5).mean() 75 | classic_hd_hist = field.hd_hist_spikes / field.hd_hist_session 76 | PostSorting.open_field_make_plots.plot_single_polar_hd_hist(classic_hd_hist, 0, save_path + str(round(d1,1)) + '_' + str(round(d2,1)), color1='navy', title='') 77 | 78 | print('I made smooth plots for all fields.') 79 | 80 | 81 | def process_data(): 82 | fields = pd.read_pickle(local_path + 'mice.pkl') 83 | accepted_fields = pd.read_excel(local_path + 'list_of_accepted_fields.xlsx') 84 | fields = tag_accepted_fields_mouse(fields, accepted_fields) 85 | accepted = fields.accepted_field == True 86 | hd = fields.hd_score >= 0.5 87 | grid = fields.grid_score >= 0.4 88 | grid_cells = np.logical_and(grid, np.logical_not(hd)) 89 | plot_all_fields(fields[grid_cells & accepted]) 90 | 91 | 92 | def main(): 93 | process_data() 94 | 95 | 96 | if __name__ == '__main__': 97 | main() -------------------------------------------------------------------------------- /OverallAnalysis/plot_example_classic_field_polar_plots.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pandas as pd 4 | import plot_utility 5 | import matplotlib.pylab as plt 6 | import OverallAnalysis.folder_path_settings 7 | import PostSorting.open_field_head_direction 8 | import PostSorting.open_field_make_plots 9 | 10 | 11 | local_path = OverallAnalysis.folder_path_settings.get_local_path() 12 | analysis_path = local_path + '/plot_hd_tuning_vs_shuffled_fields/' 13 | output_path = local_path + '/example_fields_classic/' 14 | 15 | 16 | # plot polar hd histograms without needing the whole df as an input 17 | def plot_single_polar_hd_hist(hist_1, cluster, save_path, color1='lime', title=''): 18 | hd_polar_fig = plt.figure() 19 | hd_polar_fig.set_size_inches(5, 5, forward=True) 20 | ax = hd_polar_fig.add_subplot(1, 1, 1) # specify (nrows, ncols, axnum) 21 | theta = np.linspace(0, 2*np.pi, 361) # x axis 22 | ax = plt.subplot(1, 1, 1, polar=True) 23 | ax = plot_utility.style_polar_plot(ax) 24 | plt.xticks([]) 25 | plt.yticks([]) 26 | plt.ylim(0, np.nanmax(hist_1) * 1.4) 27 | # plt.xticks([math.radians(0), math.radians(90), math.radians(180), math.radians(270)]) 28 | ax.plot(theta[:-1], hist_1, color=color1, linewidth=12) 29 | plt.title(title) 30 | # ax.plot(theta[:-1], hist_2 * (max(hist_1) / max(hist_2)), color='navy', linewidth=2) 31 | plt.tight_layout() 32 | plt.savefig(save_path + '_hd_polar_' + cluster + '.png', dpi=300, bbox_inches="tight") 33 | # plt.savefig(save_path + '_hd_polar_' + str(cluster + 1) + '.pdf', bbox_inches="tight") 34 | plt.close() 35 | 36 | 37 | def make_example_plots_mouse(): 38 | mouse_df_path = analysis_path + 'shuffled_field_data_all_mice.pkl' 39 | mouse_df = pd.read_pickle(mouse_df_path) 40 | session_id = 'M12_2018-04-10_14-22-14_of' 41 | example_session = mouse_df.session_id == session_id 42 | example_cell = mouse_df[example_session] 43 | colors = PostSorting.open_field_make_plots.generate_colors(len(example_cell)) 44 | for index, field in example_cell.iterrows(): 45 | hd_session = field.hd_in_field_session 46 | hd_session_hist = PostSorting.open_field_head_direction.get_hd_histogram(hd_session) 47 | hd_spikes = field.hd_in_field_spikes 48 | hd_spikes_hist = PostSorting.open_field_head_direction.get_hd_histogram(hd_spikes) 49 | hist = hd_spikes_hist / hd_session_hist 50 | plot_single_polar_hd_hist(hist, 'mouse_' + str(field.field_id), output_path, color1=colors[field.field_id], title='') 51 | 52 | 53 | def make_example_plots_rat(): 54 | rat_df_path = analysis_path + 'shuffled_field_data_all_rats.pkl' 55 | rat_df = pd.read_pickle(rat_df_path) 56 | session_id = '11207-06070501+02' 57 | example_session = rat_df.session_id == session_id 58 | example_cell = rat_df[example_session] 59 | example_cluster = example_cell.cluster_id == 2 60 | example_cell = example_cell[example_cluster] 61 | colors = PostSorting.open_field_make_plots.generate_colors(len(example_cell)) 62 | for index, field in example_cell.iterrows(): 63 | hd_session = field.hd_in_field_session 64 | hd_session_hist = PostSorting.open_field_head_direction.get_hd_histogram(hd_session) 65 | hd_spikes = field.hd_in_field_spikes 66 | hd_spikes_hist = PostSorting.open_field_head_direction.get_hd_histogram(hd_spikes) 67 | hist = hd_spikes_hist / hd_session_hist 68 | plot_single_polar_hd_hist(hist, 'rat_' + str(field.field_id), output_path, color1=colors[field.field_id], title='') 69 | 70 | 71 | def main(): 72 | """ 73 | Make example classic hd plots for example fields (Figure 4a) 74 | """ 75 | make_example_plots_rat() 76 | make_example_plots_mouse() 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /OverallAnalysis/plot_example_firing_fields.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import PostSorting.parameters 3 | import PostSorting.open_field_make_plots 4 | 5 | 6 | analysis_path = '/example_firing_fields/' 7 | 8 | 9 | def plot_hd_in_fields_of_example_cell(prm): 10 | spatial_firing = pd.read_pickle(analysis_path + 'spatial_firing.pkl') 11 | spatial_data = pd.read_pickle(analysis_path + 'position.pkl') 12 | PostSorting.open_field_make_plots.plot_hd_for_firing_fields(spatial_firing, spatial_data, prm) 13 | 14 | 15 | def main(): 16 | prm = PostSorting.parameters.Parameters() 17 | prm.set_output_path(analysis_path) 18 | prm.set_sampling_rate(30000) 19 | plot_hd_in_fields_of_example_cell(prm) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /OverallAnalysis/plot_example_polar_hists_first_and_second_half.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import OverallAnalysis.false_positives 3 | import PostSorting.open_field_make_plots 4 | import PostSorting.open_field_head_direction 5 | 6 | save_output_path = '/watson_two_test_cells/' 7 | local_path = 'watson_two_test_cells/all_mice_df.pkl' 8 | false_positives_path = '/watson_two_test_cells/false_positives_all.txt' 9 | 10 | 11 | def load_data_frame(path): 12 | # this is the output of load_df.py which read all dfs from a folder and saved selected columns into a combined df 13 | df = pd.read_pickle(path) 14 | return df 15 | 16 | 17 | def add_combined_id_to_df(df_all_mice): 18 | animal_ids = [session_id.split('_')[0] for session_id in df_all_mice.session_id.values] 19 | dates = [session_id.split('_')[1] for session_id in df_all_mice.session_id.values] 20 | tetrode = df_all_mice.tetrode.values 21 | cluster = df_all_mice.cluster_id.values 22 | 23 | combined_ids = [] 24 | for cell in range(len(df_all_mice)): 25 | id = animal_ids[cell] + '-' + dates[cell] + '-Tetrode-' + str(tetrode[cell]) + '-Cluster-' + str(cluster[cell]) 26 | combined_ids.append(id) 27 | df_all_mice['false_positive_id'] = combined_ids 28 | return df_all_mice 29 | 30 | 31 | def plot_hd_for_example_cells(name, cluster): 32 | path = save_output_path + name 33 | spatial_firing = pd.read_pickle(path + '/spatial_firing.pkl') 34 | position = pd.read_pickle(path + '/position.pkl') 35 | hd_hist_cluster = spatial_firing.hd_spike_histogram[cluster] 36 | hd_position = position.hd 37 | hd_hist_position = PostSorting.open_field_head_direction.get_hd_histogram(hd_position) 38 | hd_hist_position = hd_hist_position * max(hd_hist_cluster)/max(hd_hist_position) 39 | 40 | PostSorting.open_field_make_plots.plot_polar_hd_hist(hd_hist_cluster, hd_hist_cluster, cluster, path + '/' + str(cluster), color1='red', color2='red') 41 | PostSorting.open_field_make_plots.plot_polar_hd_hist(spatial_firing.iloc[0].hd_hist_first_half, spatial_firing.iloc[0].hd_hist_second_half, 1, save_output_path + 'first_vs_second_half' + name) 42 | 43 | 44 | def main(): 45 | print('-------------------------------------------------------------') 46 | print('-------------------------------------------------------------') 47 | df_all_mice = load_data_frame(local_path) 48 | list_of_false_positives = OverallAnalysis.false_positives.get_list_of_false_positives(false_positives_path) 49 | df_all_mice = add_combined_id_to_df(df_all_mice) 50 | df_all_mice['false_positive'] = df_all_mice['false_positive_id'].isin(list_of_false_positives) 51 | 52 | good_cluster = df_all_mice.false_positive == False 53 | 54 | plot_hd_for_example_cells('M13_2018-05-14_09-37-33_of', 6) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /OverallAnalysis/plot_hd_tuning_vs_shuffled.py: -------------------------------------------------------------------------------- 1 | import data_frame_utility 2 | import matplotlib.pylab as plt 3 | import numpy as np 4 | import os 5 | import OverallAnalysis.folder_path_settings 6 | import OverallAnalysis.shuffle_cell_analysis 7 | import OverallAnalysis.compare_shuffled_from_first_and_second_halves_fields 8 | import OverallAnalysis.false_positives 9 | import pandas as pd 10 | import PostSorting.parameters 11 | import plot_utility 12 | 13 | import scipy 14 | import scipy.stats 15 | 16 | 17 | local_path = OverallAnalysis.folder_path_settings.get_local_path() 18 | analysis_path = local_path + '/plot_hd_tuning_vs_shuffled/' 19 | 20 | prm = PostSorting.parameters.Parameters() 21 | prm.set_pixel_ratio(440) 22 | prm.set_sampling_rate(30000) 23 | 24 | 25 | def add_cell_types_to_data_frame(spatial_firing): 26 | cell_type = [] 27 | for index, cell in spatial_firing.iterrows(): 28 | if cell.hd_score >= 0.5 and cell.grid_score >= 0.4: 29 | cell_type.append('conjunctive') 30 | elif cell.hd_score >= 0.5: 31 | cell_type.append('hd') 32 | elif cell.grid_score >= 0.4: 33 | cell_type.append('grid') 34 | else: 35 | cell_type.append('na') 36 | 37 | spatial_firing['cell type'] = cell_type 38 | 39 | return spatial_firing 40 | 41 | 42 | def add_combined_id_to_df(spatial_firing): 43 | animal_ids = [session_id.split('_')[0] for session_id in spatial_firing.session_id.values] 44 | spatial_firing['animal'] = animal_ids 45 | 46 | dates = [session_id.split('_')[1] for session_id in spatial_firing.session_id.values] 47 | 48 | cluster = spatial_firing.cluster_id.values 49 | combined_ids = [] 50 | for cell in range(len(spatial_firing)): 51 | id = animal_ids[cell] + '-' + dates[cell] + '-Cluster-' + str(cluster[cell]) 52 | combined_ids.append(id) 53 | spatial_firing['false_positive_id'] = combined_ids 54 | return spatial_firing 55 | 56 | 57 | def tag_false_positives(spatial_firing): 58 | list_of_false_positives = OverallAnalysis.false_positives.get_list_of_false_positives(analysis_path + 'false_positives_all.txt') 59 | spatial_firing = add_combined_id_to_df(spatial_firing) 60 | spatial_firing['false_positive'] = spatial_firing['false_positive_id'].isin(list_of_false_positives) 61 | return spatial_firing 62 | 63 | 64 | def plot_bar_chart_for_cells_percentile_error_bar(spatial_firing, path, animal, shuffle_type='occupancy'): 65 | counter = 0 66 | for index, cell in spatial_firing.iterrows(): 67 | mean = np.append(cell['shuffled_means'], cell['shuffled_means'][0]) 68 | percentile_95 = np.append(cell['error_bar_95'], cell['error_bar_95'][0]) 69 | percentile_5 = np.append(cell['error_bar_5'], cell['error_bar_5'][0]) 70 | shuffled_histograms_hz = cell['shuffled_histograms_hz'] 71 | max_rate = np.round(cell.hd_histogram_real_data_hz.max(), 2) 72 | x_pos = np.linspace(0, 2*np.pi, shuffled_histograms_hz.shape[1] + 1) 73 | 74 | significant_bins_to_mark = np.where(cell.p_values_corrected_bars_bh < 0.05) # indices 75 | significant_bins_to_mark = x_pos[significant_bins_to_mark[0]] 76 | y_value_markers = [max_rate + 0.5] * len(significant_bins_to_mark) 77 | 78 | ax = plt.subplot(1, 1, 1, polar=True) 79 | ax = plot_utility.style_polar_plot(ax) 80 | x_labels = ["0", "", "", "", "", "90", "", "", "", "", "180", "", "", "", "", "270", "", "", "", ""] 81 | plt.xticks(x_pos, x_labels) 82 | ax.fill_between(x_pos, mean - percentile_5, percentile_95 + mean, color='grey', alpha=0.4) 83 | ax.plot(x_pos, mean, color='grey', linewidth=5, alpha=0.7) 84 | observed_data = np.append(cell.hd_histogram_real_data_hz, cell.hd_histogram_real_data_hz[0]) 85 | ax.plot(x_pos, observed_data, color='navy', linewidth=5) 86 | plt.title('\n' + str(max_rate) + ' Hz', fontsize=20, y=1.08) 87 | if (cell.p_values_corrected_bars_bh < 0.05).sum() > 0: 88 | ax.scatter(significant_bins_to_mark, y_value_markers, c='red', marker='*', zorder=3, s=100) 89 | plt.subplots_adjust(top=0.85) 90 | plt.savefig(analysis_path + animal + '_' + shuffle_type + '/' + str(counter) + str(cell['session_id']) + str(cell['cluster_id']) + '_percentile_polar_' + str(cell.percentile_value) + '.png') 91 | plt.close() 92 | counter += 1 93 | 94 | 95 | def get_number_of_directional_cells(cells, tag='grid'): 96 | percentiles_no_correction = [] 97 | percentiles_correction = [] 98 | for index, cell in cells.iterrows(): 99 | percentile = scipy.stats.percentileofscore(cell.number_of_different_bins_shuffled, cell.number_of_different_bins) 100 | percentiles_no_correction.append(percentile) 101 | 102 | percentile = scipy.stats.percentileofscore(cell.number_of_different_bins_shuffled_corrected_p, cell.number_of_different_bins_bh) 103 | percentiles_correction.append(percentile) 104 | 105 | cells['percentile_value'] = percentiles_correction 106 | print(tag) 107 | print('Number of fields: ' + str(len(cells))) 108 | print('Number of directional cells [without correction]: ') 109 | print(np.sum(np.array(percentiles_no_correction) > 95)) 110 | cells['directional_no_correction'] = np.array(percentiles_no_correction) > 95 111 | 112 | print('Number of directional cells [with BH correction]: ') 113 | print(np.sum(np.array(percentiles_correction) > 95)) 114 | cells['directional_correction'] = np.array(percentiles_correction) > 95 115 | cells.to_pickle(local_path + tag + 'cells.pkl') 116 | 117 | 118 | def plot_hd_vs_shuffled(): 119 | mouse_df_path = analysis_path + 'all_mice_df.pkl' 120 | mouse_df = pd.read_pickle(mouse_df_path) 121 | df = tag_false_positives(mouse_df) 122 | good_cells = df.false_positive == False 123 | df_good_cells = df[good_cells] 124 | df = add_cell_types_to_data_frame(df_good_cells) 125 | grid_cells = df['cell type'] == 'grid' 126 | df_grid = df[grid_cells] 127 | print('mouse') 128 | get_number_of_directional_cells(df_grid, tag='grid') 129 | plot_bar_chart_for_cells_percentile_error_bar(df_grid, analysis_path, 'mouse', shuffle_type='distributive') 130 | 131 | 132 | def main(): 133 | plot_hd_vs_shuffled() 134 | 135 | 136 | if __name__ == '__main__': 137 | main() -------------------------------------------------------------------------------- /OverallAnalysis/simulated_data/analyze_simulated_grid_cells.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import PostSorting.make_plots 6 | import PostSorting.open_field_make_plots 7 | import PostSorting.open_field_firing_fields 8 | import PostSorting.open_field_head_direction 9 | import PostSorting.open_field_grid_cells 10 | import PostSorting.open_field_spatial_data 11 | import PostSorting.parameters 12 | import PostSorting.open_field_spatial_data 13 | import OverallAnalysis.grid_analysis_other_labs.firing_maps 14 | 15 | import matplotlib.pylab as plt 16 | 17 | 18 | 19 | prm = PostSorting.parameters.Parameters() 20 | prm.set_pixel_ratio(100) # data is in cm already 21 | prm.set_sampling_rate(1000) 22 | 23 | 24 | # load data frames and reorganize to be similar to real data to make it easier to rerun analyses 25 | def organize_data(analysis_path): 26 | spatial_data_path = analysis_path + 'v_spatial_data' 27 | spatial_data = pd.read_pickle(spatial_data_path) 28 | position_data = pd.DataFrame() 29 | position_data['synced_time'] = spatial_data.synced_time.iloc[0] 30 | position_data['time_seconds'] = spatial_data.synced_time.iloc[0] 31 | position_data['position_x'] = spatial_data.position_x.iloc[0] 32 | position_data['position_y'] = spatial_data.position_y.iloc[0] 33 | position_data['position_x_pixels'] = spatial_data.position_x.iloc[0] 34 | position_data['position_y_pixels'] = spatial_data.position_y.iloc[0] 35 | position_data['hd'] = spatial_data.hd.iloc[0] - 180 36 | for name in glob.glob(analysis_path + '*'): 37 | if os.path.exists(name) and os.path.isdir(name) is False and name != spatial_data_path: 38 | if not os.path.isdir(name + '_simulated'): 39 | cell = pd.read_pickle(name) 40 | id_count = 1 41 | cell['session_id'] = 'simulated' 42 | cell['cluster_id'] = id_count 43 | cell['animal'] = 'simulated' 44 | os.mkdir(name + '_simulated') 45 | position_data.to_pickle(name + '_simulated/position.pkl') 46 | cell.to_pickle(name + '_simulated/spatial_firing.pkl') 47 | id_count += 1 48 | 49 | 50 | def get_rate_maps(position_data, firing_data): 51 | position_heat_map, spatial_firing = OverallAnalysis.grid_analysis_other_labs.firing_maps.make_firing_field_maps(position_data, firing_data, prm) 52 | return position_heat_map, spatial_firing 53 | 54 | 55 | def make_plots(position_data, spatial_firing, position_heat_map, hd_histogram, prm): 56 | # PostSorting.make_plots.plot_spike_histogram(spatial_firing, prm) 57 | # PostSorting.make_plots.plot_firing_rate_vs_speed(spatial_firing, position_data, prm) 58 | # PostSorting.make_plots.plot_autocorrelograms(spatial_firing, prm) 59 | PostSorting.open_field_make_plots.plot_spikes_on_trajectory(position_data, spatial_firing, prm) 60 | PostSorting.open_field_make_plots.plot_coverage(position_heat_map, prm) 61 | PostSorting.open_field_make_plots.plot_firing_rate_maps(spatial_firing, prm) 62 | PostSorting.open_field_make_plots.plot_rate_map_autocorrelogram(spatial_firing, prm) 63 | try: 64 | PostSorting.open_field_make_plots.plot_hd(spatial_firing, position_data, prm) 65 | except: 66 | print('I did not manage to plot 2d hd scatter.') 67 | PostSorting.open_field_make_plots.plot_polar_head_direction_histogram(hd_histogram, spatial_firing, prm) 68 | PostSorting.open_field_make_plots.plot_hd_for_firing_fields(spatial_firing, position_data, prm) 69 | # PostSorting.open_field_make_plots.plot_spikes_on_firing_fields(spatial_firing, prm) 70 | try: 71 | PostSorting.open_field_make_plots.make_combined_figure(prm, spatial_firing) 72 | except: 73 | print('I did not manage to make combined plots.') 74 | 75 | 76 | def process_data(analysis_path): 77 | organize_data(analysis_path) 78 | for name in glob.glob(analysis_path + '*'): 79 | if os.path.isdir(name): 80 | if os.path.exists(name + '/spatial_firing.pkl'): 81 | print(name) 82 | prm.set_file_path(name) 83 | prm.set_output_path(name) 84 | position = pd.read_pickle(name + '/position.pkl') 85 | # process position data - add hd etc 86 | spatial_firing = pd.read_pickle(name + '/spatial_firing.pkl') 87 | 88 | hd = [item for sublist in spatial_firing.hd[0] for item in sublist] 89 | spatial_firing['hd'] = [np.array(hd) - 180] 90 | #if len(spatial_firing.hd) == 1: 91 | # spatial_firing['hd'] = np.array(spatial_firing.hd) 92 | spatial_firing['position_x_pixels'] = spatial_firing.position_x 93 | spatial_firing['position_y_pixels'] = spatial_firing.position_y 94 | 95 | prm.set_sampling_rate(1000000) # this is to make the histograms similar to the real data 96 | hd_histogram, spatial_firing = PostSorting.open_field_head_direction.process_hd_data(spatial_firing, position, prm) 97 | 98 | # if 'firing_maps' not in spatial_firing: 99 | position_heat_map, spatial_firing = get_rate_maps(position, spatial_firing) 100 | 101 | spatial_firing = PostSorting.open_field_grid_cells.process_grid_data(spatial_firing) 102 | spatial_firing = PostSorting.open_field_firing_fields.analyze_firing_fields(spatial_firing, position, prm) 103 | spatial_firing.to_pickle(name + '/spatial_firing.pkl') 104 | make_plots(position, spatial_firing, position_heat_map, hd_histogram, prm) 105 | 106 | 107 | def main(): 108 | analysis_path = '/grid_fields/simulated_data/ventral_narrow/' 109 | process_data(analysis_path) 110 | analysis_path = 'grid_fields/simulated_data/control_narrow/' 111 | process_data(analysis_path) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /OverallAnalysis/tuning_bias_vs_trajectory_bias.py: -------------------------------------------------------------------------------- 1 | ''' 2 | I will quantify how similar the trajectory hd distribution is to a uniform distribution (1 sample watson test) 3 | and then correlate the results of this to the number of significant bins from the distributive shuffled analysis 4 | ''' 5 | 6 | 7 | import matplotlib.pylab as plt 8 | import numpy as np 9 | import OverallAnalysis.folder_path_settings 10 | import OverallAnalysis.shuffle_field_analysis_all_animals 11 | import pandas as pd 12 | import rpy2.robjects as ro 13 | from rpy2.robjects.packages import importr 14 | import scipy.stats 15 | # utils = importr('utils') 16 | # utils.install_packages('circular') 17 | from scipy.stats import linregress 18 | 19 | 20 | analysis_path = OverallAnalysis.folder_path_settings.get_local_path() + '/tuning_bias_vs_trajectory_bias/' 21 | local_path_to_shuffled_field_data_mice = analysis_path + 'shuffled_field_data_all_mice.pkl' 22 | local_path_to_shuffled_field_data_rats = analysis_path + 'shuffled_field_data_all_rats.pkl' 23 | 24 | 25 | # run 2 sample watson test and put it in df 26 | def run_one_sample_watson_test(hd_session): 27 | circular = importr("circular") 28 | watson_test = circular.watson_test 29 | hd_session = ro.FloatVector(hd_session) 30 | stat = watson_test(hd_session) 31 | return stat[0][0] # this is the part of the return r object that is the stat 32 | 33 | 34 | def compare_trajectory_hd_to_uniform_dist(fields): 35 | hd = fields.hd_in_field_session 36 | stats_values = [] 37 | for field in hd: 38 | stat = run_one_sample_watson_test(field) 39 | stats_values.append(stat) 40 | fields['watson_stat'] = stats_values 41 | return fields 42 | 43 | 44 | def plot_results(grid_fields, animal): 45 | number_of_significantly_directional_bins = grid_fields.number_of_different_bins_bh 46 | watson_stats = grid_fields.watson_stat 47 | plt.figure() 48 | plt.scatter(watson_stats, number_of_significantly_directional_bins) 49 | plt.xlabel('Bias in trajectory', fontsize=18) 50 | plt.ylabel('Number of directional bins', fontsize=18) 51 | plt.savefig(analysis_path + 'number_of_significantly_directional_bins_vs_watson_stats' + animal + '.png') 52 | 53 | 54 | def add_percentiles(fields): 55 | percentiles_correction = [] 56 | for index, field in fields.iterrows(): 57 | percentile = scipy.stats.percentileofscore(field.number_of_different_bins_shuffled_corrected_p, field.number_of_different_bins_bh) 58 | percentiles_correction.append(percentile) 59 | fields['directional_percentile'] = percentiles_correction 60 | return fields 61 | 62 | 63 | def check_if_they_correlate(fields): 64 | print('compare tuning and trajectory bias:') 65 | trajectory_bias = fields.watson_stat 66 | tuning = fields.number_of_different_bins_bh 67 | slope, intercept, r_value, p_value, std_err = linregress(trajectory_bias, tuning) 68 | print("slope: %f intercept: %f p_value %f" % (slope, intercept, p_value)) 69 | 70 | 71 | def process_data(animal): 72 | print(animal) 73 | if animal == 'mouse': 74 | local_path_to_field_data = local_path_to_shuffled_field_data_mice 75 | accepted_fields = pd.read_excel(analysis_path + 'list_of_accepted_fields.xlsx') 76 | shuffled_field_data = pd.read_pickle(local_path_to_field_data) 77 | shuffled_field_data = OverallAnalysis.shuffle_field_analysis_all_animals.tag_accepted_fields_mouse(shuffled_field_data, accepted_fields) 78 | 79 | else: 80 | local_path_to_field_data = local_path_to_shuffled_field_data_rats 81 | accepted_fields = pd.read_excel(analysis_path + 'included_fields_detector2_sargolini.xlsx') 82 | shuffled_field_data = pd.read_pickle(local_path_to_field_data) 83 | shuffled_field_data = OverallAnalysis.shuffle_field_analysis_all_animals.tag_accepted_fields_rat(shuffled_field_data, accepted_fields) 84 | 85 | grid = shuffled_field_data.grid_score >= 0.4 86 | hd = shuffled_field_data.hd_score >= 0.5 87 | grid_cells = np.logical_and(grid, np.logical_not(hd)) 88 | accepted_field = shuffled_field_data.accepted_field == True 89 | grid_fields = shuffled_field_data[grid_cells & accepted_field] 90 | grid_fields = compare_trajectory_hd_to_uniform_dist(grid_fields) 91 | grid_fields = add_percentiles(grid_fields) 92 | 93 | plot_results(grid_fields, animal) 94 | check_if_they_correlate(grid_fields) 95 | 96 | 97 | def main(): 98 | process_data('mouse') 99 | process_data('rat') 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | 106 | 107 | -------------------------------------------------------------------------------- /PostSorting/curation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | 5 | 6 | def load_curation_metrics(spike_data_frame, prm): 7 | isolations = [] 8 | noise_overlaps = [] 9 | signal_to_noise_ratios = [] 10 | peak_amplitudes = [] 11 | sorter_name = prm.get_sorter_name() 12 | path_to_metrics = prm.get_local_recording_folder_path() + '/Electrophysiology' + sorter_name + '/cluster_metrics.json' 13 | if not os.path.exists(path_to_metrics): 14 | print('I did not find the curation results.') 15 | 16 | for filename in os.listdir(prm.get_ms_tmp_path() + 'prvbucket/_mountainprocess/'): 17 | if filename.startswith('output_metrics_out'): 18 | print(filename) 19 | path_to_metrics = prm.get_ms_tmp_path() + '/prvbucket/_mountainprocess/' + filename 20 | 21 | if os.path.exists(path_to_metrics): 22 | with open(path_to_metrics) as metrics_file: 23 | cluster_metrics = json.load(metrics_file) 24 | metrics_file.close() 25 | for cluster in range(len(spike_data_frame)): 26 | isolation = cluster_metrics["clusters"][cluster]["metrics"]["isolation"] 27 | noise_overlap = cluster_metrics["clusters"][cluster]["metrics"]["noise_overlap"] 28 | peak_snr = cluster_metrics["clusters"][cluster]["metrics"]["peak_snr"] 29 | peak_amp = cluster_metrics["clusters"][cluster]["metrics"]["peak_amp"] 30 | 31 | isolations.append(isolation) 32 | noise_overlaps.append(noise_overlap) 33 | signal_to_noise_ratios.append(peak_snr) 34 | peak_amplitudes.append(peak_amp) 35 | 36 | spike_data_frame['isolation'] = isolations 37 | spike_data_frame['noise_overlap'] = noise_overlaps 38 | spike_data_frame['peak_snr'] = signal_to_noise_ratios 39 | spike_data_frame['peak_amp'] = peak_amplitudes 40 | return spike_data_frame 41 | 42 | 43 | def curate_data(spike_data_frame, prm): 44 | if 'isolation' in spike_data_frame: 45 | noisy_cluster = pd.DataFrame() 46 | noisy_cluster['this is empty'] = 'Noisy clusters were not reloaded. Sort again if you need them.' 47 | return spike_data_frame, noisy_cluster 48 | spike_data_frame = load_curation_metrics(spike_data_frame, prm) 49 | isolation_threshold = 0.9 50 | noise_overlap_threshold = 0.05 51 | peak_snr_threshold = 1 52 | firing_rate_threshold = 0.5 53 | 54 | isolated_cluster = spike_data_frame['isolation'] > isolation_threshold 55 | low_noise_cluster = spike_data_frame['noise_overlap'] < noise_overlap_threshold 56 | high_peak_snr = spike_data_frame['peak_snr'] > peak_snr_threshold 57 | high_mean_firing_rate = spike_data_frame['mean_firing_rate'] > firing_rate_threshold 58 | 59 | good_cluster = spike_data_frame[isolated_cluster & low_noise_cluster & high_peak_snr & high_mean_firing_rate].copy() 60 | noisy_cluster = spike_data_frame.loc[~spike_data_frame.index.isin(list(good_cluster.index))] 61 | 62 | return good_cluster, noisy_cluster 63 | 64 | 65 | -------------------------------------------------------------------------------- /PostSorting/load_firing_data.py: -------------------------------------------------------------------------------- 1 | import mdaio 2 | import numpy as np 3 | import os 4 | from pathlib import Path 5 | import pandas as pd 6 | import PreClustering.dead_channels 7 | import data_frame_utility 8 | 9 | 10 | def get_firing_info(file_path, prm): 11 | firing_times_path = file_path + '/Electrophysiology' + prm.get_sorter_name() + '/firings.mda' 12 | units_list = None 13 | firing_info = None 14 | if os.path.exists(firing_times_path): 15 | firing_info = mdaio.readmda(firing_times_path) 16 | units_list = np.unique(firing_info[2]) 17 | else: 18 | print('I could not find the MountainSort output [firing.mda] file. I will check if the data was sorted earlier.') 19 | spatial_firing_path = file_path + '/MountainSort/DataFrames/spatial_firing.pkl' 20 | if os.path.exists(spatial_firing_path): 21 | spatial_firing = pd.read_pickle(spatial_firing_path) 22 | os.mknod(file_path + '/sorted_data_exists.txt') 23 | return units_list, firing_info, spatial_firing 24 | else: 25 | print('There are no sorting results available for this recording.') 26 | return units_list, firing_info, False 27 | 28 | 29 | # if the recording has dead channels, detected channels need to be shifted to get read channel ids 30 | def correct_detected_ch_for_dead_channels(dead_channels, primary_channels): 31 | for dead_channel in dead_channels: 32 | indices_to_add_to = np.where(primary_channels >= dead_channel) 33 | primary_channels[indices_to_add_to] += 1 34 | return primary_channels 35 | 36 | 37 | def correct_for_dead_channels(primary_channels, prm): 38 | PreClustering.dead_channels.get_dead_channel_ids(prm) 39 | dead_channels = prm.get_dead_channels() 40 | if len(dead_channels) != 0: 41 | dead_channels = list(map(int, dead_channels[0])) 42 | primary_channels = correct_detected_ch_for_dead_channels(dead_channels, primary_channels) 43 | return primary_channels 44 | 45 | 46 | def process_firing_times(recording_to_process, session_type, prm): 47 | session_id = recording_to_process.split('/')[-1] 48 | units_list, firing_info, spatial_firing = get_firing_info(recording_to_process, prm) 49 | if isinstance(spatial_firing, pd.DataFrame): 50 | firing_data = spatial_firing[['session_id', 'cluster_id', 'tetrode', 'primary_channel', 'firing_times', 'firing_times_opto', 'isolation', 'noise_overlap', 'peak_snr', 'mean_firing_rate', 'random_snippets', 'position_x', 'position_y', 'hd', 'position_x_pixels', 'position_y_pixels', 'speed']].copy() 51 | return firing_data 52 | cluster_ids = firing_info[2] 53 | firing_times = firing_info[1] 54 | primary_channel = firing_info[0] 55 | primary_channel = correct_for_dead_channels(primary_channel, prm) 56 | if session_type == 'openfield' and prm.get_opto_tagging_start_index() is not None: 57 | firing_data = data_frame_utility.df_empty(['session_id', 'cluster_id', 'tetrode', 'primary_channel', 'firing_times', 'firing_times_opto'], dtypes=[str, np.uint8, np.uint8, np.uint8, np.uint64, np.uint64]) 58 | for cluster in units_list: 59 | cluster_firings_all = firing_times[cluster_ids == cluster] 60 | cluster_firings = np.take(cluster_firings_all, np.where(cluster_firings_all < prm.get_opto_tagging_start_index())[0]) 61 | cluster_firings_opto = np.take(cluster_firings_all, np.where(cluster_firings_all >= prm.get_opto_tagging_start_index())[0]) 62 | channel_detected = primary_channel[cluster_ids == cluster][0] 63 | tetrode = int((channel_detected-1)/4 + 1) 64 | ch = int((channel_detected - 1) % 4 + 1) 65 | firing_data = firing_data.append({ 66 | "session_id": session_id, 67 | "cluster_id": int(cluster), 68 | "tetrode": tetrode, 69 | "primary_channel": ch, 70 | "firing_times": cluster_firings, 71 | "firing_times_opto": cluster_firings_opto 72 | }, ignore_index=True) 73 | else: 74 | firing_data = data_frame_utility.df_empty(['session_id', 'cluster_id', 'tetrode', 'primary_channel', 'firing_times', 'trial_number', 'trial_type'], dtypes=[str, np.uint8, np.uint8, np.uint8, np.uint64, np.uint8, np.uint16]) 75 | for cluster in units_list: 76 | cluster_firings = firing_times[cluster_ids == cluster] 77 | channel_detected = primary_channel[cluster_ids == cluster][0] 78 | tetrode = int((channel_detected-1)/4 + 1) 79 | ch = int((channel_detected - 1) % 4 + 1) 80 | firing_data = firing_data.append({ 81 | "session_id": session_id, 82 | "cluster_id": int(cluster), 83 | "tetrode": tetrode, 84 | "primary_channel": ch, 85 | "firing_times": cluster_firings 86 | }, ignore_index=True) 87 | return firing_data 88 | 89 | 90 | def create_firing_data_frame(recording_to_process, session_type, prm): 91 | spike_data = None 92 | spike_data = process_firing_times(recording_to_process, session_type, prm) 93 | return spike_data 94 | 95 | -------------------------------------------------------------------------------- /PostSorting/load_snippet_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mdaio 3 | import numpy as np 4 | import PreClustering.dead_channels 5 | import matplotlib.pylab as plt 6 | 7 | 8 | def extract_random_snippets(filtered_data, firing_times, tetrode, number_of_snippets, prm): 9 | dead_channels = prm.get_dead_channels() 10 | if len(dead_channels) != 0: 11 | for dead_ch in range(len(dead_channels[0])): 12 | to_insert = np.zeros(len(filtered_data[0])) 13 | filtered_data = np.insert(filtered_data, int(dead_channels[0][dead_ch]) - 1, to_insert, 0) 14 | random_indices = np.ceil(np.random.uniform(16, len(firing_times)-16, number_of_snippets)).astype(int) 15 | snippets = np.zeros((4, 30, number_of_snippets)) 16 | 17 | channels = [(tetrode-1)*4, (tetrode-1)*4 + 1, (tetrode-1)*4 + 2, (tetrode-1)*4 + 3] 18 | 19 | for index, event in enumerate(random_indices): 20 | snippets_indices = np.arange(firing_times[event]-10, firing_times[event]+20, 1).astype(int) 21 | snippets[:, :, index] = filtered_data[channels[0]:channels[3]+1, snippets_indices] 22 | # plt.plot(snippets[3,:,:]) # example ch plot 23 | return snippets 24 | 25 | def extract_all_snippets(filtered_data, firing_times, tetrode, prm): 26 | dead_channels = prm.get_dead_channels() 27 | if len(dead_channels) != 0: 28 | for dead_ch in range(len(dead_channels[0])): 29 | to_insert = np.zeros(len(filtered_data[0])) 30 | filtered_data = np.insert(filtered_data, int(dead_channels[0][dead_ch]) - 1, to_insert, 0) 31 | 32 | all_indices = np.arange(16, len(firing_times)-16) 33 | snippets = np.zeros((4, 30, len(all_indices))) 34 | 35 | channels = [(tetrode-1)*4, (tetrode-1)*4 + 1, (tetrode-1)*4 + 2, (tetrode-1)*4 + 3] 36 | 37 | for index, event in enumerate(all_indices): 38 | snippets_indices = np.arange(firing_times[event]-10, firing_times[event]+20, 1).astype(int) 39 | snippets[:, :, index] = filtered_data[channels[0]:channels[3]+1, snippets_indices] 40 | # plt.plot(snippets[3,:,:]) # example ch plot 41 | return snippets 42 | 43 | 44 | def get_snippets(firing_data, prm, random_snippets=True): 45 | if 'random_snippets' in firing_data: 46 | return firing_data 47 | print('I will get some random snippets now for each cluster.') 48 | file_path = prm.get_local_recording_folder_path() 49 | filtered_data_path = [] 50 | 51 | filtered_data_path = file_path + '/Electrophysiology' + prm.get_sorter_name() + '/filt.mda' 52 | 53 | snippets_all_clusters = [] 54 | if os.path.exists(filtered_data_path): 55 | filtered_data = mdaio.readmda(filtered_data_path) 56 | for cluster in range(len(firing_data)): 57 | cluster = firing_data.cluster_id.values[cluster] - 1 58 | firing_times = firing_data.firing_times[cluster] 59 | 60 | if random_snippets is True: 61 | snippets = extract_random_snippets(filtered_data, firing_times, firing_data.tetrode[cluster], 50, prm) 62 | else: 63 | snippets = extract_all_snippets(filtered_data, firing_times, firing_data.tetrode[cluster], prm) 64 | snippets_all_clusters.append(snippets) 65 | 66 | if random_snippets is True: 67 | firing_data['random_snippets'] = snippets_all_clusters 68 | else: 69 | firing_data['all_snippets'] = snippets_all_clusters 70 | #plt.plot(firing_data.random_snippets[4][3,:,:]) 71 | return firing_data -------------------------------------------------------------------------------- /PostSorting/open_field_firing_maps.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import os 3 | import multiprocessing 4 | import matplotlib.pylab as plt 5 | import pandas as pd 6 | from numba import jit 7 | import numpy as np 8 | import math 9 | import time 10 | 11 | 12 | def get_dwell(spatial_data, prm): 13 | min_dwell_distance_cm = 5 # from point to determine min dwell time 14 | min_dwell_distance_pixels = min_dwell_distance_cm / 100 * prm.get_pixel_ratio() 15 | 16 | dt_position_ms = spatial_data.synced_time.diff().mean()*1000 # average sampling interval in position data (ms) 17 | min_dwell_time_ms = 3 * dt_position_ms # this is about 100 ms 18 | min_dwell = round(min_dwell_time_ms/dt_position_ms) 19 | return min_dwell, min_dwell_distance_pixels 20 | 21 | 22 | def get_bin_size(prm): 23 | bin_size_cm = 2.5 24 | bin_size_pixels = bin_size_cm / 100 * prm.get_pixel_ratio() 25 | return bin_size_pixels 26 | 27 | 28 | def get_number_of_bins(spatial_data, prm): 29 | bin_size_pixels = get_bin_size(prm) 30 | length_of_arena_x = spatial_data.position_x_pixels[~np.isnan(spatial_data.position_x_pixels)].max() 31 | length_of_arena_y = spatial_data.position_y_pixels[~np.isnan(spatial_data.position_y_pixels)].max() 32 | number_of_bins_x = math.ceil(length_of_arena_x / bin_size_pixels) 33 | number_of_bins_y = math.ceil(length_of_arena_y / bin_size_pixels) 34 | return number_of_bins_x, number_of_bins_y 35 | 36 | 37 | @jit 38 | def gaussian_kernel(kernx): 39 | kerny = np.exp(np.power(kernx, 2)/2 * (-1)) 40 | return kerny 41 | 42 | 43 | def calculate_firing_rate_for_cluster_parallel(cluster, smooth, firing_data_spatial, positions_x, positions_y, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms): 44 | print('Started another cluster') 45 | print(cluster) 46 | cluster_index = firing_data_spatial.cluster_id.values[cluster] - 1 47 | cluster_firings = pd.DataFrame({'position_x': firing_data_spatial.position_x_pixels[cluster_index], 'position_y': firing_data_spatial.position_y_pixels[cluster_index]}) 48 | spike_positions_x = cluster_firings.position_x.values 49 | spike_positions_y = cluster_firings.position_y.values 50 | firing_rate_map = np.zeros((number_of_bins_x, number_of_bins_y)) 51 | for x in range(number_of_bins_x): 52 | for y in range(number_of_bins_y): 53 | px = x * bin_size_pixels + (bin_size_pixels / 2) 54 | py = y * bin_size_pixels + (bin_size_pixels / 2) 55 | spike_distances = np.sqrt(np.power(px - spike_positions_x, 2) + np.power(py - spike_positions_y, 2)) 56 | spike_distances = spike_distances[~np.isnan(spike_distances)] 57 | occupancy_distances = np.sqrt(np.power((px - positions_x), 2) + np.power((py - positions_y), 2)) 58 | occupancy_distances = occupancy_distances[~np.isnan(occupancy_distances)] 59 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_pixels)[0]) 60 | 61 | if bin_occupancy >= min_dwell: 62 | firing_rate_map[x, y] = sum(gaussian_kernel(spike_distances/smooth)) / (sum(gaussian_kernel(occupancy_distances/smooth)) * (dt_position_ms/1000)) 63 | 64 | else: 65 | firing_rate_map[x, y] = 0 66 | #firing_rate_map = np.rot90(firing_rate_map) 67 | return firing_rate_map 68 | 69 | 70 | def get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm): 71 | print('I will calculate firing rate maps now.') 72 | dt_position_ms = spatial_data.synced_time.diff().mean()*1000 73 | min_dwell, min_dwell_distance_pixels = get_dwell(spatial_data, prm) 74 | smooth = 5 / 100 * prm.get_pixel_ratio() 75 | bin_size_pixels = get_bin_size(prm) 76 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 77 | clusters = range(len(firing_data_spatial)) 78 | num_cores = int(os.environ['HEATMAP_CONCURRENCY']) if os.environ.get('HEATMAP_CONCURRENCY') else multiprocessing.cpu_count() 79 | time_start = time.time() 80 | firing_rate_maps = Parallel(n_jobs=num_cores)(delayed(calculate_firing_rate_for_cluster_parallel)(cluster, smooth, firing_data_spatial, spatial_data.position_x_pixels.values, spatial_data.position_y_pixels.values, number_of_bins_x, number_of_bins_y, bin_size_pixels, min_dwell, min_dwell_distance_pixels, dt_position_ms) for cluster in clusters) 81 | time_end = time.time() 82 | print('Making the rate maps took:') 83 | time_diff = time_end - time_start 84 | print(time_diff) 85 | firing_data_spatial['firing_maps'] = firing_rate_maps 86 | 87 | return firing_data_spatial 88 | 89 | 90 | def get_position_heatmap_fixed_bins(spatial_data, number_of_bins_x, number_of_bins_y, bin_size_cm, min_dwell_distance_cm, min_dwell): 91 | position_heat_map = np.zeros((number_of_bins_x, number_of_bins_y)) 92 | for x in range(number_of_bins_x): 93 | for y in range(number_of_bins_y): 94 | px = x * bin_size_cm + (bin_size_cm / 2) 95 | py = y * bin_size_cm + (bin_size_cm / 2) 96 | 97 | occupancy_distances = np.sqrt(np.power((px - spatial_data.position_x_pixels.values), 2) + np.power((py - spatial_data.position_y_pixels.values), 2)) 98 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_cm)[0]) 99 | 100 | if bin_occupancy >= min_dwell: 101 | position_heat_map[x, y] = bin_occupancy 102 | else: 103 | position_heat_map[x, y] = None 104 | return position_heat_map 105 | 106 | 107 | def get_position_heatmap(spatial_data, prm): 108 | min_dwell, min_dwell_distance_cm = get_dwell(spatial_data, prm) 109 | bin_size_cm = get_bin_size(prm) 110 | number_of_bins_x, number_of_bins_y = get_number_of_bins(spatial_data, prm) 111 | 112 | position_heat_map = np.zeros((number_of_bins_x, number_of_bins_y)) 113 | 114 | # find value for each bin for heatmap 115 | for x in range(number_of_bins_x): 116 | for y in range(number_of_bins_y): 117 | px = x * bin_size_cm + (bin_size_cm / 2) 118 | py = y * bin_size_cm + (bin_size_cm / 2) 119 | 120 | occupancy_distances = np.sqrt(np.power((px - spatial_data.position_x_pixels.values), 2) + np.power((py - spatial_data.position_y_pixels.values), 2)) 121 | bin_occupancy = len(np.where(occupancy_distances < min_dwell_distance_cm)[0]) 122 | 123 | if bin_occupancy >= min_dwell: 124 | position_heat_map[x, y] = bin_occupancy 125 | else: 126 | position_heat_map[x, y] = None 127 | return position_heat_map 128 | 129 | 130 | # this is the firing rate in the bin with the highest rate 131 | def find_maximum_firing_rate(spatial_firing): 132 | max_firing_rates = [] 133 | for cluster in range(len(spatial_firing)): 134 | cluster = spatial_firing.cluster_id.values[cluster] - 1 135 | firing_rate_map = spatial_firing.firing_maps[cluster] 136 | max_firing_rate = np.max(firing_rate_map.flatten()) 137 | max_firing_rates.append(max_firing_rate) 138 | spatial_firing['max_firing_rate'] = max_firing_rates 139 | return spatial_firing 140 | 141 | 142 | def make_firing_field_maps(spatial_data, firing_data_spatial, prm): 143 | position_heat_map = get_position_heatmap(spatial_data, prm) 144 | firing_data_spatial = get_spike_heatmap_parallel(spatial_data, firing_data_spatial, prm) 145 | #position_heat_map = np.rot90(position_heat_map) # to rotate map to be like matlab plots 146 | firing_data_spatial = find_maximum_firing_rate(firing_data_spatial) 147 | return position_heat_map, firing_data_spatial -------------------------------------------------------------------------------- /PostSorting/open_field_spatial_firing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def calculate_corresponding_indices(spike_data, spatial_data, avg_sampling_rate_open_ephys=30000): 5 | avg_sampling_rate_bonsai = float(1 / spatial_data['synced_time'].diff().mean()) 6 | sampling_rate_rate = avg_sampling_rate_open_ephys/avg_sampling_rate_bonsai 7 | spike_data['bonsai_indices'] = spike_data.firing_times/sampling_rate_rate 8 | return spike_data 9 | 10 | 11 | def find_firing_location_indices(spike_data, spatial_data): 12 | spike_data = calculate_corresponding_indices(spike_data, spatial_data) 13 | spatial_firing = pd.DataFrame(columns=['position_x', 'position_x_pixels', 'position_y', 'position_y_pixels', 'hd', 'speed']) 14 | for cluster in range(len(spike_data)): 15 | cluster = spike_data.cluster_id.values[cluster] - 1 16 | bonsai_indices_cluster = spike_data.bonsai_indices[cluster] 17 | bonsai_indices_cluster_round = bonsai_indices_cluster.round(0) 18 | spatial_firing = spatial_firing.append({ 19 | "position_x": list(spatial_data.position_x[bonsai_indices_cluster_round]), 20 | "position_x_pixels": list(spatial_data.position_x_pixels[bonsai_indices_cluster_round]), 21 | "position_y": list(spatial_data.position_y[bonsai_indices_cluster_round]), 22 | "position_y_pixels": list(spatial_data.position_y_pixels[bonsai_indices_cluster_round]), 23 | "hd": list(spatial_data.hd[bonsai_indices_cluster_round]), 24 | "speed": list(spatial_data.speed[bonsai_indices_cluster_round]) 25 | }, ignore_index=True) 26 | spike_data['position_x'] = spatial_firing.position_x.values 27 | spike_data['position_x_pixels'] = spatial_firing.position_x_pixels.values 28 | spike_data['position_y'] = spatial_firing.position_y.values 29 | spike_data['position_y_pixels'] = spatial_firing.position_y_pixels.values 30 | spike_data['hd'] = spatial_firing.hd.values 31 | spike_data['speed'] = spatial_firing.speed.values 32 | return spike_data 33 | 34 | 35 | def add_firing_locations(spike_data, spatial_data): 36 | spike_data = find_firing_location_indices(spike_data, spatial_data) 37 | spike_data = spike_data.drop(['bonsai_indices'], axis=1) 38 | return spike_data 39 | 40 | 41 | def process_spatial_firing(spike_data, spatial_data): 42 | if 'position_x' in spike_data: 43 | return spike_data 44 | spatial_spike_data = add_firing_locations(spike_data, spatial_data) 45 | return spatial_spike_data 46 | -------------------------------------------------------------------------------- /PostSorting/parameters.py: -------------------------------------------------------------------------------- 1 | class Parameters: 2 | 3 | is_ubuntu = True 4 | is_windows = False 5 | is_stable = False 6 | is_interleaved_opto = False 7 | delete_two_min = False 8 | first_half_only = False 9 | second_half_only = False 10 | pixel_ratio = None 11 | opto_channel = '' 12 | sync_channel = '' 13 | sampling_rate = 0 14 | opto_tagging_start_index = None 15 | sampling_rate_rate = 0 16 | local_recording_folder_path = '' 17 | file_path = [] 18 | output_path = [] 19 | ms_tmp_path = [] 20 | total_length_sampling_points = 0 21 | dead_channels = [] 22 | sorter_name = [] 23 | 24 | # vr parameters 25 | first_trial_channel = '' # vr 26 | second_trial_channel = '' # vr 27 | movement_channel = '' # vr 28 | stop_threshold = 10.7 # vr 29 | track_length = 200 # vr 30 | cue_conditioned_goal = False 31 | 32 | def __init__(self): 33 | return 34 | 35 | def get_is_stable(self): 36 | return Parameters.is_stable 37 | 38 | def set_is_stable(self, is_stbl): 39 | Parameters.is_stable = is_stbl 40 | 41 | def get_sorter_name(self): 42 | return Parameters.sorter_name 43 | 44 | def set_sorter_name(self, name): 45 | Parameters.sorter_name = name 46 | 47 | def get_first_half_only(self): 48 | return Parameters.first_half_only 49 | 50 | def set_first_half_only(self, is_first): 51 | Parameters.first_half_only = is_first 52 | 53 | def get_second_half_only(self): 54 | return Parameters.second_half_only 55 | 56 | def set_second_half_only(self, is_second): 57 | Parameters.second_half_only = is_second 58 | 59 | def get_is_ubuntu(self): 60 | return Parameters.is_ubuntu 61 | 62 | def set_is_ubuntu(self, is_ub): 63 | Parameters.is_ubuntu = is_ub 64 | 65 | def get_is_windows(self): 66 | return Parameters.is_windows 67 | 68 | def set_is_windows(self, is_win): 69 | Parameters.is_windows = is_win 70 | 71 | def get_pixel_ratio(self): 72 | return Parameters.pixel_ratio 73 | 74 | def set_pixel_ratio(self, pr): 75 | Parameters.pixel_ratio = pr 76 | 77 | def get_opto_channel(self): 78 | return Parameters.opto_channel 79 | 80 | def set_opto_channel(self, opto_ch): 81 | Parameters.opto_channel = opto_ch 82 | 83 | def get_sync_channel(self): 84 | return Parameters.sync_channel 85 | 86 | def set_sync_channel(self, sync_ch): 87 | Parameters.sync_channel = sync_ch 88 | 89 | def get_sampling_rate(self): 90 | return Parameters.sampling_rate 91 | 92 | def set_sampling_rate(self, sr): 93 | Parameters.sampling_rate = sr 94 | 95 | def get_opto_tagging_start_index(self): 96 | return Parameters.opto_tagging_start_index 97 | 98 | def set_opto_tagging_start_index(self, opto_start): 99 | Parameters.opto_tagging_start_index = opto_start 100 | 101 | def get_sampling_rate_rate(self): 102 | return Parameters.sampling_rate_rate 103 | 104 | def set_sampling_rate_rate(self, sr): 105 | Parameters.sampling_rate_rate = sr 106 | 107 | def get_local_recording_folder_path(self): 108 | return Parameters.local_recording_folder_path 109 | 110 | def set_local_recording_folder_path(self, path): 111 | Parameters.local_recording_folder_path = path 112 | 113 | def get_filepath(self): 114 | return Parameters.file_path 115 | 116 | def set_file_path(self, path): 117 | Parameters.file_path = path 118 | 119 | def get_output_path(self): 120 | return Parameters.output_path 121 | 122 | def set_output_path(self, path): 123 | Parameters.output_path = path 124 | 125 | def get_ms_tmp_path(self): 126 | return Parameters.ms_tmp_path 127 | 128 | def set_ms_tmp_path(self, path): 129 | Parameters.ms_tmp_path = path 130 | 131 | def get_total_length_sampling_points(self): 132 | return Parameters.total_length_sampling_points 133 | 134 | def set_total_length_sampling_points(self, length): 135 | Parameters.total_length_sampling_points = length 136 | 137 | def get_dead_channels(self): 138 | return Parameters.dead_channels 139 | 140 | def set_dead_channels(d_ch = [], *args): 141 | dead_ch = [] 142 | for dead_chan in args: 143 | dead_ch.append(dead_chan) 144 | 145 | Parameters.dead_channels = dead_ch 146 | 147 | def get_dead_channel_path(self): 148 | return Parameters.dead_channel_path 149 | 150 | def set_dead_channel_path(self, dead_ch): 151 | Parameters.dead_channel_path = dead_ch 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /PostSorting/process_fields.r: -------------------------------------------------------------------------------- 1 | #! /usr/bin/Rscript 2 | args = commandArgs(trailingOnly=TRUE) 3 | .libPaths("/home/nolanlab/R/x86_64-pc-linux-gnu-library/3.4") 4 | library(tidyverse) 5 | library(circular) 6 | 7 | 8 | setwd(args[1]) 9 | 10 | # load data 11 | read_plus <- function(flnm) { 12 | read_csv(flnm, col_names = FALSE) %>% 13 | mutate(filename = flnm) 14 | } 15 | 16 | c <-list.files(pattern = "*cluster.csv", 17 | full.names = T) %>% 18 | map_dfr(~read_plus(.)) %>% 19 | group_by(filename) %>% 20 | nest(filename, X1, .key = "c") 21 | 22 | 23 | s <-list.files(pattern = "*session.csv", 24 | full.names = T) %>% 25 | map_dfr(~read_plus(.)) %>% 26 | group_by(filename) %>% 27 | nest(filename, X1, .key = "s") 28 | 29 | cs <- bind_cols(c, s) 30 | 31 | # Test for uniformity of data using Kuiper’s test. First performs test on all cluster data then on all session data. Makes new columns for test statistic and p value. 32 | 33 | cs <- cs %>% 34 | mutate(c_circ = map(c,circular)) %>% 35 | mutate(c_kuiper = map(c_circ, kuiper.test)) %>% 36 | mutate(c_kuiper_ts = map_dbl(c_kuiper, ~.$statistic)) %>% 37 | mutate(s_circ = map(s,circular)) %>% 38 | mutate(s_kuiper = map(s_circ, kuiper.test)) %>% 39 | mutate(s_kuiper_ts = map_dbl(s_kuiper, ~.$statistic)) 40 | 41 | # Test for uniformity of data using Watson’s test 42 | 43 | cs <- cs %>% 44 | mutate(w_c = map(c_circ, watson.test)) %>% 45 | mutate(w_c_ts = map_dbl(w_c, ~.$statistic)) %>% 46 | mutate(w_s = map(s_circ, watson.test)) %>% 47 | mutate(w_s_ts = map_dbl(w_s, ~.$statistic)) 48 | 49 | # Compare the two distributions using Watson’s two sample test. 50 | 51 | cs <- cs %>% 52 | mutate(w2st = map2(c_circ, s_circ, watson.two.test)) %>% 53 | mutate(w2st_ts = map_dbl(w2st, ~.$statistic)) 54 | 55 | # Make table with test statistics 56 | 57 | table <- tibble(cluster = cs$filename, session = cs$filename1, Kuiper_Cluster = cs$c_kuiper_ts, Kuiper_Session = cs$s_kuiper_ts, Watson_Cluster = cs$w_c_ts, Watson_Session = cs$w_s_ts, Watson_two_sample = cs$w2st_ts) 58 | 59 | knitr::kable(table) 60 | 61 | # save table as csv 62 | write_csv(table, "circular_out.csv") 63 | 64 | # close open files 65 | closeAllConnections() 66 | -------------------------------------------------------------------------------- /PostSorting/speed.py: -------------------------------------------------------------------------------- 1 | import array_utility 2 | import matplotlib.pylab as plt 3 | import numpy as np 4 | import pandas as pd 5 | import plot_utility 6 | import scipy.ndimage 7 | import scipy.stats 8 | 9 | from typing import Tuple 10 | import OverallAnalysis.analyze_speed 11 | 12 | 13 | ''' 14 | 15 | The speed score is a measure of the correlation between the firing rate of the neuron and the running speed of the 16 | animal. The firing times of the neuron are binned at the same sampling rate as the position data (speed). The resulting 17 | temporal firing histogram is then smoothed with a Gaussian (standard deviation ~250ms). Speed and temporal firing rate 18 | are correlated (Pearson correlation) to obtain the speed score. 19 | 20 | Based on: Gois & Tort, 2018, Cell Reports 25, 1872–1884 21 | 22 | 23 | position : data frame that contains the speed of the animal as a column ('speed'). 24 | spatial_firing : data frame that contains the firing times ('firing_times') 25 | sigma : standard deviation for Gaussian filter (sigma = 250 / video_sampling) 26 | sampling_rate_conversion : sampling rate of ephys data relative to seconds. If the firing times are in seconds then this 27 | should be 1. 28 | 29 | ''' 30 | 31 | 32 | def calculate_speed_score(position: pd.DataFrame, spatial_firing: pd.DataFrame, gauss_sd: float, sampling_rate_conversion: int) -> pd.DataFrame: 33 | avg_sampling_rate_video = float(1 / position['synced_time'].diff().mean()) 34 | sigma = gauss_sd / avg_sampling_rate_video 35 | speed = scipy.ndimage.filters.gaussian_filter(position.speed, sigma) 36 | speed_scores = [] 37 | speed_score_ps = [] 38 | for index, cell in spatial_firing.iterrows(): 39 | firing_times = cell.firing_times 40 | firing_hist, edges = np.histogram(firing_times, bins=len(speed), range=(0, max(position.synced_time) * sampling_rate_conversion)) 41 | smooth_hist = scipy.ndimage.filters.gaussian_filter(firing_hist.astype(float), sigma) 42 | speed, smooth_hist = array_utility.remove_nans_from_both_arrays(speed, smooth_hist) 43 | speed_score, p = scipy.stats.pearsonr(speed, smooth_hist) 44 | speed_scores.append(speed_score) 45 | speed_score_ps.append(p) 46 | spatial_firing['speed_score'] = speed_scores 47 | spatial_firing['speed_score_p_values'] = speed_score_ps 48 | 49 | return spatial_firing 50 | 51 | 52 | ''' 53 | Calculate median, 25th and 75th percentile of firing rate (y) at given speed (x) values. Speed is binned into 6 cm/s 54 | overlapping bins with a 2 cm/s step size. 55 | 56 | Based on: Gois & Tort, 2018, Cell Reports 25, 1872–1884 57 | ''' 58 | 59 | 60 | def calculate_median_for_scatter_binned(x: np.ndarray, y: np.ndarray) -> 'Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]': 61 | bin_size = 6 62 | step_size = 2 63 | number_of_bins = int((max(x) - min(x)) / 2) 64 | 65 | median_x = [] 66 | median_y = [] 67 | percentile_25 = [] 68 | percentile_75 = [] 69 | for x_bin in range(number_of_bins): 70 | median_x.append(x_bin * step_size + bin_size/2) 71 | data_in_bin = np.take(y, np.where((x_bin * step_size < x) & (x < x_bin * step_size + bin_size))) 72 | if len(data_in_bin[0]) > 0: 73 | med_y = np.median(data_in_bin) 74 | median_y.append(med_y) 75 | percentile_25.append(np.percentile(data_in_bin, 25)) 76 | percentile_75.append(np.percentile(data_in_bin, 75)) 77 | else: 78 | median_y.append(0) 79 | percentile_25.append(0) 80 | percentile_75.append(0) 81 | 82 | return np.array(median_x), np.array(median_y), np.array(percentile_25), np.array(percentile_75) 83 | 84 | 85 | # plot grid cells only 86 | def plot_speed_vs_firing_rate_grid(position: pd.DataFrame, spatial_firing: pd.DataFrame, sampling_rate_conversion: int, video_sampling_rate: int, save_path: str) -> None: 87 | sigma = 250 / video_sampling_rate 88 | speed = scipy.ndimage.filters.gaussian_filter(position.speed, sigma) 89 | spatial_firing = OverallAnalysis.analyze_speed.add_cell_types_to_data_frame(spatial_firing) 90 | for index, cell in spatial_firing.iterrows(): 91 | if cell['cell type'] == 'grid': 92 | firing_times = cell.firing_times 93 | firing_hist, edges = np.histogram(firing_times, bins=len(speed), range=(0, max(position.synced_time) * sampling_rate_conversion)) 94 | firing_hist *= video_sampling_rate 95 | smooth_hist = scipy.ndimage.filters.gaussian_filter(firing_hist.astype(float), sigma) 96 | speed, smooth_hist = array_utility.remove_nans_from_both_arrays(speed, smooth_hist) 97 | median_x, median_y, percentile_25, percentile_75 = calculate_median_for_scatter_binned(speed, smooth_hist) 98 | plt.cla() 99 | fig, ax = plt.subplots() 100 | ax = plot_utility.format_bar_chart(ax, 'Speed (cm/s)', 'Firing rate (Hz)') 101 | plt.scatter(speed[::10], smooth_hist[::10], color='gray', alpha=0.7) 102 | plt.plot(median_x, percentile_25, color='black', linewidth=5) 103 | plt.plot(median_x, percentile_75, color='black', linewidth=5) 104 | plt.scatter(median_x, median_y, color='black', s=100) 105 | plt.title('speed score: ' + str(np.round(cell.speed_score, 4))) 106 | plt.xlim(0, 50) 107 | plt.ylim(0, None) 108 | plt.savefig(save_path + cell.session_id + str(cell.cluster_id) + '_speed.png') 109 | plt.close() 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /PostSorting/temporal_firing.py: -------------------------------------------------------------------------------- 1 | # calculate number of spikes and mean firing rate for each cluster and add to spatial firing df 2 | def add_temporal_firing_properties_to_df(spatial_firing, prm): 3 | total_number_of_spikes_per_cluster = [] 4 | mean_firing_rates = [] 5 | for cluster in range(len(spatial_firing)): 6 | cluster = spatial_firing.cluster_id.values[cluster] - 1 7 | firing_times = spatial_firing.firing_times[cluster] 8 | total_number_of_spikes = len(firing_times) 9 | total_length_of_recordings = prm.get_total_length_sampling_points() # this does not include opto 10 | mean_firing_rate = total_number_of_spikes / total_length_of_recordings 11 | 12 | total_number_of_spikes_per_cluster.append(total_number_of_spikes) 13 | mean_firing_rates.append(mean_firing_rate) 14 | 15 | spatial_firing['number_of_spikes'] = total_number_of_spikes_per_cluster 16 | spatial_firing['mean_firing_rate'] = mean_firing_rates 17 | 18 | return spatial_firing 19 | -------------------------------------------------------------------------------- /SimulationCode/Mods/hcn.mod: -------------------------------------------------------------------------------- 1 | COMMENT 2 | Schmidt-Hieber C, Häusser M (2013) Cellular mechanisms of spatial navigation in the medial entorhinal cortex. Nat Neurosci 16:325-31 3 | 4 | 17/07/2012 5 | (c) 2012, C. Schmidt-Hieber, University College London 6 | Based on an initial version by Chris Burgess 07/2011 7 | 8 | Kinetics based on: 9 | E. Fransen, A. A. Alonso, C. T. Dickson, J. Magistretti, M. E. Hasselmo 10 | Ionic mechanisms in the generation of subthreshold oscillations and 11 | action potential clustering in entorhinal layer II stellate neurons. 12 | Hippocampus 14, 368 (2004). 13 | 14 | ENDCOMMENT 15 | 16 | 17 | NEURON { 18 | SUFFIX ih 19 | NONSPECIFIC_CURRENT i 20 | RANGE i, gslow, gfast, gslowbar, gfastbar 21 | GLOBAL ehcn, taufn, taufdo, taufdd, taufro, taufrd 22 | GLOBAL tausn, tausdo, tausdd, tausro, tausrd 23 | GLOBAL mifo, mifd, mife, miso, misd, mise 24 | } 25 | 26 | UNITS { 27 | (mV) = (millivolt) 28 | (S) = (siemens) 29 | (mA) = (milliamp) 30 | } 31 | 32 | PARAMETER { 33 | gfastbar = 9.8e-5 (S/cm2) 34 | gslowbar = 5.3e-5 (S/cm2) 35 | ehcn = -20 (mV) 36 | taufn = 0.51 (ms) : original: .51 parameters for tau_fast 37 | taufdo = 1.7 (mV) 38 | taufdd = 10 (mV) 39 | taufro = 340 (mV) 40 | taufrd = 52 (mV) 41 | tausn = 5.6 (ms) : parameters for tau_slow 42 | tausdo = 17 (mV) 43 | tausdd = 14 (mV) 44 | tausro = 260 (mV) 45 | tausrd = 43 (mV) 46 | mifo = 74.2 (mV) : parameters for steady state m_fast 47 | mifd = 9.78 (mV) 48 | mife = 1.36 49 | miso = 2.83 (mV) : parameters for steady state m_slow 50 | misd = 15.9 (mV) 51 | mise = 58.5 52 | } 53 | 54 | ASSIGNED { 55 | v (mV) 56 | gslow (S/cm2) 57 | gfast (S/cm2) 58 | i (mA/cm2) 59 | alphaf (/ms) : alpha_fast 60 | betaf (/ms) : beta_fast 61 | alphas (/ms) : alpha_slow 62 | betas (/ms) : beta_slow 63 | } 64 | 65 | INITIAL { 66 | : assume steady state 67 | settables(v) 68 | mf = alphaf/(alphaf+betaf) 69 | ms = alphas/(alphas+betas) 70 | } 71 | 72 | BREAKPOINT { 73 | SOLVE states METHOD cnexp 74 | gfast = gfastbar*mf 75 | gslow = gslowbar*ms 76 | i = (gfast+gslow)*(v-ehcn) 77 | } 78 | 79 | STATE { 80 | mf ms 81 | } 82 | 83 | DERIVATIVE states { 84 | settables(v) 85 | mf' = alphaf*(1-mf) - betaf*mf 86 | ms' = alphas*(1-ms) - betas*ms 87 | } 88 | 89 | PROCEDURE settables(v (mV)) { 90 | LOCAL mif, mis, tauf, taus 91 | TABLE alphaf, betaf, alphas, betas FROM -100 TO 100 WITH 200 92 | 93 | tauf = taufn/( exp( (v-taufdo)/taufdd ) + exp( -(v+taufro)/taufrd ) ) 94 | taus = tausn/( exp( (v-tausdo)/tausdd ) + exp( -(v+tausro)/tausrd ) ) 95 | mif = 1/pow( 1 + exp( (v+mifo)/mifd ), mife ) 96 | mis = 1/pow( 1 + exp( (v+miso)/misd ), mise ) 97 | 98 | alphaf = mif/tauf 99 | alphas = mis/taus 100 | betaf = (1-mif)/tauf 101 | betas = (1-mis)/taus 102 | } -------------------------------------------------------------------------------- /SimulationCode/Mods/km.mod: -------------------------------------------------------------------------------- 1 | COMMENT 2 | km.mod 3 | 4 | Mainen ZF, Sejnowski TJ (1996) Influence of dendritic structure on firing pattern in model neocortical neurons. Nature 382:363-6 5 | 6 | Potassium channel, Hodgkin-Huxley style kinetics 7 | Based on I-M (muscarinic K channel) 8 | Slow, noninactivating 9 | 10 | Author: Zach Mainen, Salk Institute, 1995, zach@salk.edu 11 | 12 | 26 Ago 2002 Modification of original channel to allow 13 | variable time step and to correct an initialization error. 14 | Done by Michael Hines(michael.hines@yale.e) and 15 | Ruggero Scorcioni(rscorcio@gmu.edu) at EU Advance Course 16 | in Computational Neuroscience. Obidos, Portugal 17 | 18 | 20110202 made threadsafe by Ted Carnevale 19 | 20120514 fixed singularity in PROCEDURE rates 20 | ENDCOMMENT 21 | 22 | NEURON { 23 | THREADSAFE 24 | SUFFIX km 25 | USEION k READ ek WRITE ik 26 | RANGE n, gk, gbar 27 | RANGE ninf, ntau 28 | GLOBAL Ra, Rb 29 | GLOBAL q10, temp, tadj, vmin, vmax 30 | } 31 | 32 | UNITS { 33 | (mA) = (milliamp) 34 | (mV) = (millivolt) 35 | (pS) = (picosiemens) 36 | (um) = (micron) 37 | } 38 | 39 | PARAMETER { 40 | gbar (mho/cm2) 41 | 42 | tha = -30 (mV) : v 1/2 for inf 43 | qa = 9 (mV) : inf slope 44 | 45 | Ra = 0.001 (/ms) : max act rate (slow) 46 | Rb = 0.001 (/ms) : max deact rate (slow) 47 | 48 | : dt (ms) 49 | temp = 23 (degC) : original temp 50 | q10 = 2.3 : temperature sensitivity 51 | 52 | vmin = -120 (mV) 53 | vmax = 100 (mV) 54 | } 55 | 56 | 57 | ASSIGNED { 58 | v (mV) 59 | celsius (degC) 60 | a (/ms) 61 | b (/ms) 62 | ik (mA/cm2) 63 | gk (pS/um2) 64 | ek (mV) 65 | ninf 66 | ntau (ms) 67 | tadj 68 | } 69 | 70 | 71 | STATE { n } 72 | 73 | INITIAL { 74 | tadj = q10^((celsius - temp)/(10 (degC))) : make all threads calculate tadj at initialization 75 | 76 | trates(v) 77 | n = ninf 78 | } 79 | 80 | BREAKPOINT { 81 | SOLVE states METHOD cnexp 82 | gk = tadj*gbar*n 83 | ik = (1e-4) * gk * (v - ek) 84 | } 85 | 86 | LOCAL nexp 87 | 88 | DERIVATIVE states { :Computes state variable n 89 | trates(v) : at the current v and dt. 90 | n' = (ninf-n)/ntau 91 | 92 | } 93 | 94 | PROCEDURE trates(v (mV)) { :Computes rate and other constants at current v. 95 | :Call once from HOC to initialize inf at resting v. 96 | TABLE ninf, ntau 97 | DEPEND celsius, temp, Ra, Rb, tha, qa 98 | FROM vmin TO vmax WITH 199 99 | 100 | rates(v): not consistently executed from here if usetable_hh == 1 101 | 102 | : tinc = -dt * tadj 103 | : nexp = 1 - exp(tinc/ntau) 104 | } 105 | 106 | UNITSOFF 107 | PROCEDURE rates(v (mV)) { :Computes rate and other constants at current v. 108 | :Call once from HOC to initialize inf at resting v. 109 | 110 | : singular when v = tha 111 | : a = Ra * (v - tha) / (1 - exp(-(v - tha)/qa)) 112 | : a = Ra * qa*((v - tha)/qa) / (1 - exp(-(v - tha)/qa)) 113 | : a = Ra * qa*(-(v - tha)/qa) / (exp(-(v - tha)/qa) - 1) 114 | a = Ra * qa * efun(-(v - tha)/qa) 115 | 116 | : singular when v = tha 117 | : b = -Rb * (v - tha) / (1 - exp((v - tha)/qa)) 118 | : b = -Rb * qa*((v - tha)/qa) / (1 - exp((v - tha)/qa)) 119 | : b = Rb * qa*((v - tha)/qa) / (exp((v - tha)/qa) - 1) 120 | b = Rb * qa * efun((v - tha)/qa) 121 | 122 | tadj = q10^((celsius - temp)/10) 123 | ntau = 1/tadj/(a+b) 124 | ninf = a/(a+b) 125 | } 126 | UNITSON 127 | 128 | FUNCTION efun(z) { 129 | if (fabs(z) < 1e-4) { 130 | efun = 1 - z/2 131 | }else{ 132 | efun = z/(exp(z) - 1) 133 | } 134 | } -------------------------------------------------------------------------------- /SimulationCode/Mods/kv.mod: -------------------------------------------------------------------------------- 1 | COMMENT 2 | kv.mod 3 | 4 | Mainen ZF, Sejnowski TJ (1996) Influence of dendritic structure on firing pattern in model neocortical neurons. Nature 382:363-6 5 | 6 | Potassium channel, Hodgkin-Huxley style kinetics 7 | Kinetic rates based roughly on Sah et al. and Hamill et al. (1991) 8 | 9 | Author: Zach Mainen, Salk Institute, 1995, zach@salk.edu 10 | 11 | 26 Ago 2002 Modification of original channel to allow 12 | variable time step and to correct an initialization error. 13 | Done by Michael Hines(michael.hines@yale.e) and 14 | Ruggero Scorcioni(rscorcio@gmu.edu) at EU Advance Course 15 | in Computational Neuroscience. Obidos, Portugal 16 | 17 | 20110202 made threadsafe by Ted Carnevale 18 | 20120514 fixed singularity in PROCEDURE rates 19 | 20 | ENDCOMMENT 21 | 22 | 23 | NEURON { 24 | THREADSAFE 25 | SUFFIX kv 26 | USEION k READ ek WRITE ik 27 | RANGE n, gk, gbar 28 | RANGE ninf, ntau 29 | GLOBAL Ra, Rb 30 | GLOBAL q10, temp, tadj, vmin, vmax 31 | } 32 | 33 | UNITS { 34 | (mA) = (milliamp) 35 | (mV) = (millivolt) 36 | (pS) = (picosiemens) 37 | (um) = (micron) 38 | } 39 | 40 | PARAMETER { 41 | gbar (mho/cm2) 42 | 43 | tha = 25 (mV) : v 1/2 for inf 44 | qa = 9 (mV) : inf slope 45 | 46 | Ra = 0.02 (/ms) : max act rate 47 | Rb = 0.002 (/ms) : max deact rate 48 | 49 | : dt (ms) 50 | temp = 23 (degC) : original temp 51 | q10 = 2.3 : temperature sensitivity 52 | 53 | vmin = -120 (mV) 54 | vmax = 100 (mV) 55 | } 56 | 57 | 58 | ASSIGNED { 59 | v (mV) 60 | celsius (degC) 61 | a (/ms) 62 | b (/ms) 63 | ik (mA/cm2) 64 | gk (pS/um2) 65 | ek (mV) 66 | ninf 67 | ntau (ms) 68 | tadj 69 | } 70 | 71 | 72 | STATE { n } 73 | 74 | INITIAL { 75 | tadj = q10^((celsius - temp)/(10 (degC))) : make all threads calculate tadj at initialization 76 | 77 | trates(v) 78 | n = ninf 79 | } 80 | 81 | BREAKPOINT { 82 | SOLVE states METHOD cnexp 83 | gk = tadj*gbar*n 84 | ik = (1e-4) * gk * (v - ek) 85 | } 86 | 87 | DERIVATIVE states { :Computes state variable n 88 | trates(v) : at the current v and dt. 89 | n' = (ninf-n)/ntau 90 | } 91 | 92 | PROCEDURE trates(v (mV)) { :Computes rate and other constants at current v. 93 | :Call once from HOC to initialize inf at resting v. 94 | TABLE ninf, ntau 95 | DEPEND celsius, temp, Ra, Rb, tha, qa 96 | FROM vmin TO vmax WITH 199 97 | 98 | rates(v): not consistently executed from here if usetable_hh == 1 99 | 100 | : tinc = -dt * tadj 101 | : nexp = 1 - exp(tinc/ntau) 102 | } 103 | 104 | UNITSOFF 105 | PROCEDURE rates(v (mV)) { :Computes rate and other constants at current v. 106 | :Call once from HOC to initialize inf at resting v. 107 | 108 | : singular when v = tha 109 | : a = Ra * (v - tha) / (1 - exp(-(v - tha)/qa)) 110 | : a = Ra * qa*((v - tha)/qa) / (1 - exp(-(v - tha)/qa)) 111 | : a = Ra * qa*(-(v - tha)/qa) / (exp(-(v - tha)/qa) - 1) 112 | a = Ra * qa * efun(-(v - tha)/qa) 113 | 114 | : singular when v = tha 115 | : b = -Rb * (v - tha) / (1 - exp((v - tha)/qa)) 116 | : b = -Rb * qa*((v - tha)/qa) / (1 - exp((v - tha)/qa)) 117 | : b = Rb * qa*((v - tha)/qa) / (exp((v - tha)/qa) - 1) 118 | b = Rb * qa * efun((v - tha)/qa) 119 | 120 | tadj = q10^((celsius - temp)/10) 121 | ntau = 1/tadj/(a+b) 122 | ninf = a/(a+b) 123 | } 124 | UNITSON 125 | 126 | FUNCTION efun(z) { 127 | if (fabs(z) < 1e-4) { 128 | efun = 1 - z/2 129 | }else{ 130 | efun = z/(exp(z) - 1) 131 | } 132 | } -------------------------------------------------------------------------------- /SimulationCode/Mods/na.mod: -------------------------------------------------------------------------------- 1 | COMMENT 2 | na.mod 3 | 4 | Mainen ZF, Sejnowski TJ (1996) Influence of dendritic structure on firing pattern in model neocortical neurons. Nature 382:363-6 5 | 6 | Sodium channel, Hodgkin-Huxley style kinetics. 7 | 8 | Kinetics were fit to data from Huguenard et al. (1988) and Hamill et 9 | al. (1991) 10 | 11 | qi is not well constrained by the data, since there are no points 12 | between -80 and -55. So this was fixed at 5 while the thi1,thi2,Rg,Rd 13 | were optimized using a simplex least square proc 14 | 15 | voltage dependencies are shifted approximately from the best 16 | fit to give higher threshold 17 | 18 | Author: Zach Mainen, Salk Institute, 1994, zach@salk.edu 19 | 20 | 26 Ago 2002 Modification of original channel to allow 21 | variable time step and to correct an initialization error. 22 | Done by Michael Hines(michael.hines@yale.e) and 23 | Ruggero Scorcioni(rscorcio@gmu.edu) at EU Advance Course 24 | in Computational Neuroscience. Obidos, Portugal 25 | 26 | 11 Jan 2007 Fixed glitch in trap where (v/th) was where (v-th)/q is. 27 | (thanks Ronald van Elburg!) 28 | 29 | 20110202 made threadsafe by Ted Carnevale 30 | 20120514 replaced vtrap0 with efun, which is a better approximation 31 | in the vicinity of a singularity 32 | 33 | ENDCOMMENT 34 | 35 | NEURON { 36 | THREADSAFE 37 | SUFFIX na 38 | USEION na READ ena WRITE ina 39 | RANGE m, h, gna, gbar 40 | GLOBAL tha, thi1, thi2, qa, qi, qinf, thinf 41 | RANGE minf, hinf, mtau, htau 42 | GLOBAL Ra, Rb, Rd, Rg 43 | GLOBAL q10, temp, tadj, vmin, vmax, vshift 44 | } 45 | 46 | UNITS { 47 | (mA) = (milliamp) 48 | (mV) = (millivolt) 49 | (pS) = (picosiemens) 50 | (um) = (micron) 51 | } 52 | 53 | PARAMETER { 54 | gbar = 1000 (pS/um2) : 0.12 mho/cm2 55 | vshift = -10 (mV) : voltage shift (affects all) 56 | 57 | tha = -35 (mV) : v 1/2 for act (-42) 58 | qa = 9 (mV) : act slope 59 | Ra = 0.182 (/ms) : open (v) 60 | Rb = 0.124 (/ms) : close (v) 61 | 62 | thi1 = -50 (mV) : v 1/2 for inact 63 | thi2 = -75 (mV) : v 1/2 for inact 64 | qi = 5 (mV) : inact tau slope 65 | thinf = -65 (mV) : inact inf slope 66 | qinf = 6.2 (mV) : inact inf slope 67 | Rg = 0.0091 (/ms) : inact (v) 68 | Rd = 0.024 (/ms) : inact recov (v) 69 | 70 | temp = 23 (degC) : original temp 71 | q10 = 2.3 : temperature sensitivity 72 | 73 | : dt (ms) 74 | vmin = -120 (mV) 75 | vmax = 100 (mV) 76 | } 77 | 78 | ASSIGNED { 79 | v (mV) 80 | celsius (degC) 81 | ina (mA/cm2) 82 | gna (pS/um2) 83 | ena (mV) 84 | minf hinf 85 | mtau (ms) htau (ms) 86 | tadj 87 | } 88 | 89 | STATE { m h } 90 | 91 | INITIAL { 92 | tadj = q10^((celsius - temp)/(10 (degC))) : make all threads calculate tadj at initialization 93 | 94 | trates(v+vshift) 95 | m = minf 96 | h = hinf 97 | } 98 | 99 | BREAKPOINT { 100 | SOLVE states METHOD cnexp 101 | gna = tadj*gbar*m*m*m*h 102 | ina = (1e-4) * gna * (v - ena) 103 | } 104 | 105 | : LOCAL mexp, hexp 106 | 107 | DERIVATIVE states { :Computes state variables m, h, and n 108 | trates(v+vshift) : at the current v and dt. 109 | m' = (minf-m)/mtau 110 | h' = (hinf-h)/htau 111 | } 112 | 113 | PROCEDURE trates(v (mV)) { 114 | TABLE minf, hinf, mtau, htau 115 | DEPEND celsius, temp, Ra, Rb, Rd, Rg, tha, thi1, thi2, qa, qi, qinf 116 | FROM vmin TO vmax WITH 199 117 | 118 | rates(v): not consistently executed from here if usetable == 1 119 | 120 | : tinc = -dt * tadj 121 | 122 | : mexp = 1 - exp(tinc/mtau) 123 | : hexp = 1 - exp(tinc/htau) 124 | } 125 | 126 | 127 | : efun() is a better approx than trap0 in vicinity of singularity-- 128 | 129 | UNITSOFF 130 | PROCEDURE rates(vm (mV)) { 131 | LOCAL a, b 132 | 133 | : a = trap0(vm,tha,Ra,qa) 134 | a = Ra * qa * efun((tha - vm)/qa) 135 | 136 | : b = trap0(-vm,-tha,Rb,qa) 137 | b = Rb * qa * efun((vm - tha)/qa) 138 | 139 | tadj = q10^((celsius - temp)/10) 140 | 141 | mtau = 1/tadj/(a+b) 142 | minf = a/(a+b) 143 | 144 | :"h" inactivation 145 | 146 | : a = trap0(vm,thi1,Rd,qi) 147 | a = Rd * qi * efun((thi1 - vm)/qi) 148 | 149 | : b = trap0(-vm,-thi2,Rg,qi) 150 | b = Rg * qi * efun((vm - thi2)/qi) 151 | 152 | htau = 1/tadj/(a+b) 153 | hinf = 1/(1+exp((vm-thinf)/qinf)) 154 | } 155 | UNITSON 156 | 157 | COMMENT 158 | FUNCTION trap0(v,th,a,q) { 159 | if (fabs((v-th)/q) > 1e-6) { 160 | trap0 = a * (v - th) / (1 - exp(-(v - th)/q)) 161 | } else { 162 | trap0 = a * q 163 | } 164 | } 165 | ENDCOMMENT 166 | 167 | FUNCTION efun(z) { 168 | if (fabs(z) < 1e-6) { 169 | efun = 1 - z/2 170 | }else{ 171 | efun = z/(exp(z) - 1) 172 | } 173 | } -------------------------------------------------------------------------------- /SimulationCode/defining_peak_current.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import random 4 | import math 5 | from neuron import h, gui 6 | from scipy.stats import multivariate_normal 7 | from mpl_toolkits.mplot3d import Axes3D 8 | import h5py 9 | from os import listdir 10 | from os.path import isfile 11 | import csv 12 | import glob 13 | import os 14 | import pandas as pd 15 | from scipy.misc import electrocardiogram 16 | from scipy.signal import find_peaks 17 | from scipy.ndimage import gaussian_filter 18 | import cmocean 19 | import pickle 20 | from skimage import measure 21 | from skimage.transform import rotate 22 | from sympy import Eq, Symbol, solve 23 | #Morphology 24 | h.xopen("./nolan/hocfile.hoc") 25 | 26 | 27 | def generate_cells(n): 28 | single_cell = [] 29 | for i in np.arange(n): 30 | cell=h.Cell1() 31 | single_cell.append(cell) 32 | return single_cell 33 | 34 | def generate_attachments(single_cell): 35 | num_dend_basal=0 36 | for i in single_cell[0].basal: 37 | num_dend_basal=num_dend_basal+1 38 | 39 | shape=(num_dend_basal) 40 | length_basal=np.zeros(shape) 41 | j=0 42 | for i in single_cell[0].basal: 43 | length_basal[j]=i.L 44 | j=j+1 45 | 46 | n_dend_attach=round(num_dend_basal/10) 47 | total_length=sum(length_basal) 48 | prob_dend=length_basal/total_length 49 | attachment_dend=np.random.choice(num_dend_basal, n_dend_attach, replace=False, p=prob_dend) 50 | 51 | return attachment_dend, n_dend_attach 52 | 53 | def generate_stimulus(single_cell, attachment_dend, times, weight, n_dend_attach): 54 | times=h.Vector(times) #eventvec move from seconds to ms 55 | single_cell[0].tlist.append(times) 56 | inputs=h.VecStim() #presyn 57 | inputs.play(times) 58 | single_cell[0].vslist.append(inputs) 59 | 60 | for i in np.arange(n_dend_attach): #create synapses 61 | synapse = h.ExpSyn(single_cell[0].dend[int(attachment_dend[i])](0.5)) 62 | synapse.tau=2 63 | single_cell[0].synlist.append(synapse) 64 | ncstim = h.NetCon(inputs, synapse) 65 | ncstim.weight[0]=weight 66 | ncstim.threshold=-32.4 67 | ncstim.delay=1 68 | single_cell[0].nclist.append(ncstim) 69 | 70 | return single_cell 71 | 72 | 73 | def set_recording_vectors(cell): 74 | soma_v_vec=h.Vector() 75 | t_vec=h.Vector() 76 | soma_v_vec.record(cell.soma(0.5)._ref_v) 77 | t_vec.record(h._ref_t) 78 | return soma_v_vec, t_vec 79 | 80 | 81 | def simulate (): 82 | h.tstop=7000 83 | h.run() 84 | 85 | 86 | def plot(single_cell): 87 | soma_v_vec, t_vec = set_recording_vectors(single_cell[0]) 88 | simulate() #1 min60000 max(t_array)*1000 89 | 90 | return t_vec, soma_v_vec 91 | 92 | 93 | def parameters(single_cell): 94 | h.celsius=37 95 | 96 | single_cell[0].soma.ena=45 97 | single_cell[0].soma.ek=-85 98 | 99 | #Sodium and potassium 100 | single_cell[0].soma.gbar_kv = 10000#potassium fast 101 | single_cell[0].soma.gbar_km = 15 #potassium slow 102 | single_cell[0].soma.gbar_na=72000 #sodium 103 | 104 | single_cell[0].soma.g_pas = 0.000033 105 | single_cell[0].soma.e_pas = -65 #v_init 106 | 107 | j=0 108 | for i in single_cell[0].axonal: 109 | single_cell[0].axon[j].gbar_kv =500 110 | single_cell[0].axon[j].gbar_km = 7 111 | single_cell[0].axon[j].gbar_na=1000 112 | single_cell[0].axon[j].ena=45 113 | single_cell[0].axon[j].ek=-85 114 | single_cell[0].axon[j].g_pas= 0.00001 115 | single_cell[0].axon[j].e_pas=-65 116 | j=j+1 117 | 118 | 119 | j=0 120 | for i in single_cell[0].basal: 121 | single_cell[0].dend[j].g_pas=0.00001 122 | single_cell[0].dend[j].e_pas=-65 123 | j=j+1 124 | return single_cell 125 | 126 | 127 | def create_test_cell(): 128 | single_cell=generate_cells(1) 129 | single_cell=parameters(single_cell) 130 | for i in np.arange(n_hd): 131 | times=random.sample(range(5000,6000), 12) 132 | times=np.sort(times, axis=0) 133 | attachment_dend, n_dend_attach=generate_attachments(single_cell) 134 | single_cell=generate_stimulus(single_cell, attachment_dend, times, weight, n_dend_attach) 135 | 136 | #plot 137 | t_vec, soma_v_vec=plot(single_cell) 138 | 139 | return t_vec, soma_v_vec 140 | 141 | 142 | def define_peak_current(n_hd_o): 143 | #maximal input current tested is adjusted manually for each input number such that firing rate reaches plateau 144 | h.cvode.active(1) 145 | inp=np.linspace(0, 0.000322, 50) 146 | tstop=7000 147 | np.random.seed(n_hd_o) 148 | random.seed(n_hd_o) 149 | 150 | shape=(len(n_hd_o),len(inp),20) 151 | l5b_all=np.zeros(shape) 152 | 153 | for j in np.arange(len(n_hd_o)): 154 | n_hd=n_hd_o[j] 155 | for i in np.arange(len(inp)): 156 | weight=inp[i] 157 | for m in np.arange(20): 158 | t_vec, soma_v_vec=create_test_cell() 159 | 160 | soma_array=soma_v_vec.to_python() 161 | peaks, _=find_peaks(soma_array,height=20) 162 | peak_times = [t_vec[i] for i in peaks] #ms 163 | l5b_all[j,i,m]=len(np.asarray(peak_times).astype(int)) 164 | 165 | mean_rates=np.mean(l5b_all, axis=2) 166 | variance_rates=np.var(l5b_all, axis=2) 167 | x=inp 168 | 169 | #plot and fit 170 | a=19 171 | b=28 172 | plt.errorbar(x[a:b], mean_rates[0][a:b], np.sqrt(variance_rates[0][a:b])) 173 | z=np.polyfit(x[a:b], mean_rates[0][a:b], 2) 174 | p = np.poly1d(z) 175 | plt.plot(x[a:b], p(x)[a:b], linewidth=3, color='k') #black spline fit 176 | xfm = Symbol('x') 177 | eqn = Eq(p[0]+p[1]*xfm+p[2]*xfm**2, 12) 178 | sol=solve(eqn) 179 | print('peak_current='+str(sol)) 180 | 181 | #run and print peak_current 182 | n_hd_o=30 #run for 30 input cells 183 | define_peak_current(n_hd_o) 184 | 185 | -------------------------------------------------------------------------------- /SimulationCode/import_behavior.py: -------------------------------------------------------------------------------- 1 | #importing M5-0313, processed through PostSorting/open_field_spatial_data.py 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import random 6 | import math 7 | import csv 8 | import pandas as pd 9 | import pickle 10 | 11 | def get_data(t, h, location_x, location_y, res_t): 12 | max_t = max(t) 13 | t_array=np.arange(0,max_t,res_t) #1 s bins, 0-1, 1-2... 14 | shape=(t_array.shape[0],2) 15 | positions=np.zeros(shape) 16 | shape=(t_array.shape[0]) 17 | hd=np.zeros(shape) 18 | 19 | for i in np.arange(t_array.shape[0]): 20 | positions[i,0]=np.mean(location_x[(t>=t_array[i]) & (t<(t_array[i]+res_t))]) 21 | positions[i,1]=np.mean(location_y[(t>=t_array[i]) & (t<(t_array[i]+res_t))]) 22 | hd[i]=np.mean(h[(t>=t_array[i]) & (t<(t_array[i]+res_t))]) 23 | 24 | df = pd.DataFrame(positions[:,0]).interpolate() 25 | positions[:,0]=df.values[:,0] 26 | df2 = pd.DataFrame(positions[:,1]).interpolate() 27 | positions[:,1]=df2.values[:,0] 28 | 29 | df3 = pd.DataFrame(hd).interpolate() 30 | hd=df3.values 31 | hd=hd+180 32 | 33 | x=max(positions[:,0]) 34 | y=max(positions[:,1]) 35 | 36 | return positions, x, y, hd, t_array 37 | 38 | #import dataframe 39 | all_dat=pd.read_pickle("./trajectory_all_mice.pkl") 40 | dat1=all_dat[112:113] 41 | location_x_p=dat1.position_x 42 | location_x=location_x_p[~np.isnan(location_x_p)] 43 | location_y=dat1.position_y[~np.isnan(location_x_p)] 44 | h=dat1.hd[~np.isnan(location_x_p)] 45 | t_pre=dat1.synced_time[~np.isnan(location_x_p)]#seconds 46 | t=t_pre-t_pre[0] 47 | res_t=0.001 48 | res=1 #cm 49 | 50 | #segment into 1 ms bins 51 | positions, x, y, hd, t_array=get_data(t, h, location_x, location_y, res_t) 52 | 53 | #save output 54 | f = open('behaviour', 'wb') 55 | pickle.dump([positions, hd, t_array], f) 56 | f.close() 57 | 58 | -------------------------------------------------------------------------------- /SimulationComparisons/6cd.R: -------------------------------------------------------------------------------- 1 | library(knitr) 2 | library(tidyverse) 3 | library(scales) 4 | library(glue) 5 | library(ggstatsplot) 6 | 7 | options(scipen=10000) 8 | 9 | theme_set(theme_classic()) 10 | 11 | simulation_data_long <- read_csv('results_simulations_long.csv') 12 | 13 | simulation_data_long_mt <- simulation_data_long %>% 14 | group_by(name) %>% 15 | mutate(frac_directional = sum(directional_correction) / n()) %>% 16 | ungroup() 17 | 18 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Burgess', 'Giocomo', 'Guanella', 'Pastoll')] = 'Previous models' 19 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Non-uniform conjunctive', '0.25 Non-uniform conjunctive', 'Uniform conjunctive', '0.25 Uniform conjunctive')] = 'New conjunctive cell input models' 20 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Experimental data (mouse)', 'Experimental data (rat)')] = 'Experimental data' 21 | 22 | simulation_data_long$type = factor(simulation_data_long_mt$type, levels = c('Previous models', 'New conjunctive cell input models', 'Experimental data')) 23 | 24 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Non-uniform conjunctive'] = 'Non-uniform (high gₑₓ)' 25 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Uniform conjunctive'] = 'Uniform (high gₑₓ)' 26 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Uniform conjunctive'] = 'Uniform (low gₑₓ)' 27 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Non-uniform conjunctive'] = 'Non-uniform (low gₑₓ)' 28 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (rat)'] = 'Rat' 29 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (mouse)'] = 'Mouse' 30 | 31 | 32 | simulation_data_long_mt$name = factor(simulation_data_long_mt$name, levels=c( 33 | 'Burgess', 34 | 'Giocomo', 35 | 'Guanella', 36 | 'Pastoll', 37 | 'Uniform (low gₑₓ)', 38 | 'Non-uniform (low gₑₓ)', 39 | 'Uniform (high gₑₓ)', 40 | 'Non-uniform (high gₑₓ)', 41 | 'Rat', 42 | 'Mouse' 43 | )) 44 | 45 | 46 | 47 | p1 <- simulation_data_long_mt %>% 48 | filter(type %in% c('Previous models', 'Experimental data')) %>% 49 | ggplot(aes(name, number_of_different_bins_bh)) + 50 | geom_boxplot(aes(), outlier.shape = NA) + 51 | geom_jitter(alpha = 0.3, height = 0) + 52 | theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust=1)) + 53 | scale_fill_distiller(palette = 'Reds', direction = 0) + 54 | labs(x = '', y = 'Significant bins per field', fill = '') + 55 | scale_y_continuous(limits = c(0,20)) + 56 | facet_grid(~ type, scale = 'free_x', space = 'free') + 57 | theme(strip.background = element_blank(), 58 | strip.text.y = element_blank(), text = element_text(size = 24), strip.text.x = element_text(size = 13), plot.margin = margin(10, 10, 10, 15)) 59 | 60 | ggsave("6c.png", p1) 61 | 62 | p2 <- simulation_data_long_mt %>% 63 | filter(type %in% c('Previous models', 'Experimental data')) %>% 64 | group_by(name) %>% 65 | filter(row_number() == 1) %>% 66 | ggplot(aes(name, frac_directional)) + 67 | geom_bar(stat = 'identity') + 68 | theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust=1)) + 69 | scale_fill_distiller(palette = 'Reds', direction = 0) + 70 | labs(x = '', y = 'Proportion of directional fields', fill = '') + 71 | scale_y_continuous(limits = c(0,1)) + 72 | facet_grid(~ type, scale = 'free_x', space = 'free') + 73 | theme(strip.background = element_blank(), 74 | strip.text.y = element_blank(), text = element_text(size = 24), strip.text.x = element_text(size = 13), plot.margin = margin(11, 11, 11, 15)) 75 | 76 | ggsave("6d.png", p2) 77 | 78 | 79 | p <- ggbetweenstats( 80 | data = simulation_data_long_mt %>% filter(type %in% c('Previous models', 'Experimental data')), 81 | x = name, 82 | y = number_of_different_bins_bh, 83 | type = "parametric", 84 | k = 9, 85 | outlier.tagging = FALSE, 86 | outlier.label.color = "darkgreen", 87 | pairwise.comparisons = TRUE, 88 | ) + theme(axis.text.x = element_text(angle = 45, hjust = 1)) 89 | pb <- ggplot_build(p) 90 | pb$plot$plot_env$df_pairwise 91 | 92 | write_csv(pb$plot$plot_env$df_pairwise, '6cd_comparisons.csv') 93 | 94 | -------------------------------------------------------------------------------- /SimulationComparisons/7d.R: -------------------------------------------------------------------------------- 1 | library(knitr) 2 | library(tidyverse) 3 | library(scales) 4 | library(glue) 5 | library(ggstatsplot) 6 | 7 | options(scipen=10000) 8 | 9 | theme_set(theme_classic()) 10 | 11 | simulation_data_long <- read_csv('results_simulations_long.csv') 12 | 13 | simulation_data_long_mt <- simulation_data_long %>% 14 | group_by(name) %>% 15 | mutate(frac_directional = sum(directional_correction) / n()) %>% 16 | ungroup() 17 | 18 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Burgess', 'Giocomo', 'Guanella', 'Pastoll')] = 'Previous models' 19 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Non-uniform conjunctive', '0.25 Non-uniform conjunctive', 'Uniform conjunctive', '0.25 Uniform conjunctive')] = 'New conjunctive cell input models' 20 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Experimental data (mouse)', 'Experimental data (rat)')] = 'Experimental data' 21 | 22 | simulation_data_long$type = factor(simulation_data_long_mt$type, levels = c('Previous models', 'New conjunctive cell input models', 'Experimental data')) 23 | 24 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Non-uniform conjunctive'] = 'Non-uniform (high gₑₓ)' 25 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Uniform conjunctive'] = 'Uniform (high gₑₓ)' 26 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Uniform conjunctive'] = 'Uniform (low gₑₓ)' 27 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Non-uniform conjunctive'] = 'Non-uniform (low gₑₓ)' 28 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (rat)'] = 'Rat' 29 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (mouse)'] = 'Mouse' 30 | 31 | 32 | simulation_data_long_mt$name = factor(simulation_data_long_mt$name, levels=c( 33 | 'Burgess', 34 | 'Giocomo', 35 | 'Guanella', 36 | 'Pastoll', 37 | 'Uniform (low gₑₓ)', 38 | 'Non-uniform (low gₑₓ)', 39 | 'Uniform (high gₑₓ)', 40 | 'Non-uniform (high gₑₓ)', 41 | 'Rat', 42 | 'Mouse' 43 | )) 44 | 45 | 46 | simulation_data_long_mt <- simulation_data_long %>% 47 | group_by(name) %>% 48 | mutate(frac_directional = sum(directional_correction) / n()) %>% 49 | ungroup() 50 | 51 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Burgess', 'Giocomo', 'Guanella', 'Pastoll')] = 'Previous models' 52 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Non-uniform conjunctive', '0.25 Non-uniform conjunctive', 'Uniform conjunctive', '0.25 Uniform conjunctive')] = 'New conjunctive cell input models' 53 | simulation_data_long_mt$type[simulation_data_long_mt$name %in% c('Experimental data (mouse)', 'Experimental data (rat)')] = 'Experimental data' 54 | 55 | simulation_data_long$type = factor(simulation_data_long_mt$type, levels = c('Previous models', 'New conjunctive cell input models', 'Experimental data')) 56 | 57 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Non-uniform conjunctive'] = 'Non-uniform (high gₑₓ)' 58 | simulation_data_long_mt$name[simulation_data_long_mt$name == '0.25 Uniform conjunctive'] = 'Uniform (high gₑₓ)' 59 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Uniform conjunctive'] = 'Uniform (low gₑₓ)' 60 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Non-uniform conjunctive'] = 'Non-uniform (low gₑₓ)' 61 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (rat)'] = 'Rat' 62 | simulation_data_long_mt$name[simulation_data_long_mt$name == 'Experimental data (mouse)'] = 'Mouse' 63 | 64 | 65 | simulation_data_long_mt$name = factor(simulation_data_long_mt$name, levels=c( 66 | 'Burgess', 67 | 'Giocomo', 68 | 'Guanella', 69 | 'Pastoll', 70 | 'Uniform (low gₑₓ)', 71 | 'Non-uniform (low gₑₓ)', 72 | 'Uniform (high gₑₓ)', 73 | 'Non-uniform (high gₑₓ)', 74 | 'Rat', 75 | 'Mouse' 76 | )) 77 | 78 | p3 <- simulation_data_long_mt %>% 79 | filter(type %in% c('New conjunctive cell input models', 'Experimental data')) %>% 80 | ggplot(aes(name, number_of_different_bins_bh)) + 81 | geom_boxplot(aes(), outlier.shape = NA) + 82 | geom_jitter(alpha = 0.3, height = 0) + 83 | theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust=1)) + 84 | scale_fill_distiller(palette = 'Reds', direction = 0) + 85 | labs(x = '', y = 'Significant bins per field', fill = '') + 86 | scale_y_continuous(limits = c(0,20)) + 87 | facet_grid(~ type, scale = 'free_x', space = 'free') + 88 | theme(strip.background = element_blank(), 89 | strip.text.y = element_blank(), text = element_text(size = 21), strip.text.x = element_text(size = 13), plot.margin = margin(10, 10, 10, 15)) 90 | 91 | ggsave('7d_left.png', p3) 92 | 93 | 94 | p4 <- simulation_data_long_mt %>% 95 | filter(type %in% c('New conjunctive cell input models', 'Experimental data')) %>% 96 | group_by(name) %>% 97 | filter(row_number() == 1) %>% 98 | ggplot(aes(name, frac_directional)) + 99 | geom_bar(stat = 'identity') + 100 | theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust=1)) + 101 | scale_fill_distiller(palette = 'Reds', direction = 0) + 102 | labs(x = '', y = 'Proportion of directional fields', fill = '') + 103 | scale_y_continuous(limits = c(0,1)) + 104 | facet_grid(~ type, scale = 'free_x', space = 'free') + 105 | theme(strip.background = element_blank(), 106 | strip.text.y = element_blank(), text = element_text(size = 21), strip.text.x = element_text(size = 13), plot.margin = margin(10, 10, 10, 15)) 107 | 108 | ggsave('7d_right.png', p4) 109 | 110 | 111 | p <- ggbetweenstats( 112 | data = simulation_data_long_mt %>% filter(type %in% c('New conjunctive cell input models', 'Experimental data')), 113 | x = name, 114 | y = number_of_different_bins_bh, 115 | type = "parametric", 116 | k = 9, 117 | outlier.tagging = FALSE, 118 | outlier.label.color = "darkgreen", 119 | pairwise.comparisons = TRUE, 120 | ) + theme(axis.text.x = element_text(angle = 45, hjust = 1)) 121 | pb <- ggplot_build(p) 122 | pb$plot$plot_env$df_pairwise 123 | 124 | write_csv(pb$plot$plot_env$df_pairwise, '7d_comparisons.csv') -------------------------------------------------------------------------------- /SimulationComparisons/grid_cell_models/burgess/main.m: -------------------------------------------------------------------------------- 1 | % Burgess, Barry, O'Keefe 2007's abstract oscillatory interference model 2 | % Adapted from Zilli 2012 3 | % This code is released into the public domain. Not for use in skynet. 4 | 5 | basePhases = 0:6 6 | spikeTimes = cell(length(basePhases), 1); 7 | spikeCoords = cell(length(basePhases), 1); 8 | actualHdAtSpike = cell(length(basePhases), 1); 9 | 10 | for basePhase=basePhases 11 | watchCellIndex = basePhase + 1; 12 | 13 | livePlot = 000; 14 | 15 | % if =0, just give constant velocity. if =1, load trajectory from disk 16 | useRealTrajectory = 1; 17 | constantVelocity = 1*[.5; 0*0.5]; % m/s 18 | 19 | %% Simulation parameters 20 | dt = .02; % time step, s 21 | simdur = 1293; % total simulation time, s 22 | tind = 1; % time step number for indexing 23 | t = 0; % simulation time variable, s 24 | x = 0; % position, m 25 | y = 0; % position, m 26 | 27 | %% Model parameters 28 | ncells = 1; 29 | % Basline maintains a fixed frequency 30 | baseFreq = 6; % Hz 31 | % Directional preference of each dendrite (this also sets the number of dendrites) 32 | dirPreferences = [0 2*pi/3 4*pi/3]; 33 | % Scaling factor relating speed to oscillator frequencies 34 | % NB paper uses 0.05*2pi rad/cm [=(rad/s)/(cm/s)]. But we do the conversion to rad later, 35 | % leaving 0.05 Hz/(cm/s) = 5 Hz/(m/s) which produces very tight field spacing. For cosmetic 36 | % purposes for the trajectory we use here, we'll use beta = 2. 37 | beta = 2; % Hz/(m/s) 38 | spikeThreshold = 1.8; 39 | 40 | 41 | %% History variables 42 | speed = zeros(1,ceil(simdur/dt)); 43 | curDir = zeros(1,ceil(simdur/dt)); 44 | vhist = zeros(1,ceil(simdur/dt)); 45 | fhist = zeros(1,ceil(simdur/dt)); 46 | 47 | %% Firing field plot variables 48 | nSpatialBins = 60; 49 | minx = 0; maxx = 1.1; % m 50 | miny = 0; maxy = 1.1; % m 51 | occupancy = zeros(nSpatialBins); 52 | spikes = zeros(nSpatialBins); 53 | 54 | spikePhases = []; 55 | 56 | %% Initial conditions 57 | % Oscillators will start at phase 0: 58 | dendritePhases = zeros(1,length(dirPreferences)); % rad 59 | %basePhase = 0; % rad 60 | 61 | %% Make optional figure of sheet of activity 62 | if livePlot 63 | h = figure('color','w','name','Activity of one cell'); 64 | if useRealTrajectory 65 | set(h,'position',[520 378 1044 420]) 66 | end 67 | drawnow 68 | end 69 | 70 | %% Possibly load trajectory from disk 71 | if useRealTrajectory 72 | load ../trajectory_data.mat; 73 | % interpolate down to simulation time step 74 | pos(1:2,:) = pos(1:2,:)/100; % cm to m 75 | vels = [diff(pos(1,:)); diff(pos(2,:))]/dt; % m/s 76 | x = pos(1,1); % m 77 | y = pos(2,1); % m 78 | actualHd = pos(4, :); % real hd for reference 79 | end 80 | 81 | %% !! Main simulation loop 82 | fprintf('Simulation starting. Press ctrl+c to end...\n') 83 | while t0); 132 | 133 | % Save for later 134 | fhist(tind) = f; 135 | 136 | % Save firing field information 137 | if f>spikeThreshold 138 | spikeTimes{watchCellIndex}= [spikeTimes{watchCellIndex}; t]; 139 | spikeCoords{watchCellIndex} = [spikeCoords{watchCellIndex}; x(tind) y(tind)]; 140 | spikePhases = [spikePhases; basePhase]; 141 | actualHdAtSpike{watchCellIndex} = [actualHdAtSpike{watchCellIndex}; actualHd(tind)]; 142 | end 143 | if useRealTrajectory 144 | xindex = round((x(tind)-minx)/(maxx-minx)*nSpatialBins)+1; 145 | yindex = round((y(tind)-miny)/(maxy-miny)*nSpatialBins)+1; 146 | occupancy(yindex,xindex) = occupancy(yindex,xindex) + dt; 147 | spikes(yindex,xindex) = spikes(yindex,xindex) + double(f>spikeThreshold); 148 | end 149 | 150 | if livePlot>0 && (livePlot==1 || mod(tind,livePlot)==1) 151 | if ~useRealTrajectory 152 | figure(h); 153 | subplot(121); 154 | plot(fhist(1:tind)); 155 | title('Activity'); 156 | xlabel('Time (s)') 157 | axis square 158 | set(gca,'ydir','normal') 159 | title(sprintf('t = %.1f s',t)) 160 | subplot(122); 161 | plot(x(1:tind),y(1:tind)) 162 | hold on; 163 | if ~isempty(spikeCoords) 164 | cmap = jet; 165 | cmap = [cmap((end/2+1):end,:); cmap(1:end/2,:)]; 166 | phaseInds = mod(spikePhases,2*pi)*(length(cmap)-1)/2/pi; 167 | pointColors = cmap(ceil(phaseInds)+1,:); 168 | 169 | scatter3(spikeCoords(:,1), ... 170 | spikeCoords(:,2), ... 171 | zeros(size(spikeCoords(:,1))), ... 172 | 30*ones(size(spikeCoords(:,1))), ... 173 | pointColors, ... 174 | 'o','filled'); 175 | end 176 | axis square 177 | title({'Trajectory (blue) and',... 178 | 'spikes (colored by theta phase',... 179 | 'blues before baseline peak, reds after)'}) 180 | drawnow 181 | else 182 | figure(h); 183 | subplot(131); 184 | plot((0:tind-1)*dt,fhist(1:tind)); 185 | hold on; 186 | plot([0 tind-1]*dt,[spikeThreshold spikeThreshold],'r') 187 | title('Activity (blue) and threshold (red)'); 188 | xlabel('Time (s)') 189 | axis square 190 | set(gca,'ydir','normal') 191 | subplot(132); 192 | imagesc(spikes./occupancy); 193 | axis square 194 | set(gca,'ydir','normal') 195 | title({'Rate map',sprintf('t = %.1f s',t)}) 196 | subplot(133); 197 | plot(x(1:tind),y(1:tind)) 198 | hold on; 199 | if ~isempty(spikeCoords) 200 | cmap = jet; 201 | cmap = [cmap((end/2+1):end,:); cmap(1:end/2,:)]; 202 | phaseInds = mod(spikePhases,2*pi)*(length(cmap)-1)/2/pi; 203 | pointColors = cmap(ceil(phaseInds)+1,:); 204 | 205 | scatter3(spikeCoords(:,1), ... 206 | spikeCoords(:,2), ... 207 | zeros(size(spikeCoords(:,1))), ... 208 | 30*ones(size(spikeCoords(:,1))), ... 209 | pointColors, ... 210 | 'o','filled'); 211 | end 212 | axis square 213 | title({'Trajectory (blue) and',... 214 | 'spikes (colored by theta phase',... 215 | 'blues before baseline peak, reds after)'}) 216 | drawnow 217 | end 218 | end 219 | end 220 | end 221 | 222 | save('results.mat') 223 | disp('done'); 224 | -------------------------------------------------------------------------------- /SimulationComparisons/grid_cell_models/giocomo/main.m: -------------------------------------------------------------------------------- 1 | % Giocomo, Zilli, Fransen, and Hasselmo 2007's temporal interference model 2 | % Adapted from Zilli 2012 3 | % This code is released into the public domain. Not for use in skynet. 4 | 5 | 6 | % if >0, plots the sheet of activity during the simulation on every livePlot'th step 7 | 8 | basePhases = 0:6 9 | spikeCoords = cell(length(basePhases), 1); 10 | actualHdAtSpike = cell(length(basePhases), 1); 11 | spikeTimes = cell(length(basePhases), 1); 12 | 13 | for basePhase=basePhases 14 | watchCell = basePhase + 1; 15 | 16 | livePlot = 000; 17 | 18 | % if =0, just give constant velocity. if =1, load trajectory from disk 19 | useRealTrajectory = 1; 20 | constantVelocity = 1*[.5; 0*0.5]; % m/s 21 | 22 | %% Simulation parameters 23 | dt = .02; % time step, s 24 | simdur = 1293; % total simulation time, s % 25 | tind = 1; % time step number for indexing 26 | t = 0; % simulation time variable, s 27 | x = 0; % position, m 28 | y = 0; % position, m 29 | 30 | %% Model parameters 31 | ncells = 1; 32 | % Basline maintains a fixed frequency 33 | baseFreq = 6.42; % dorsal, Hz 34 | % baseFreq = 4.23; % ventral, Hz 35 | % Directional preference of each dendrite (this also sets the number of dendrites) 36 | dirPreferences = [0 2*pi/3 4*pi/3]; 37 | % Scaling factor relating speed to oscillator frequencies 38 | beta = 0.385; % Hz/(m/s) 39 | spikeThreshold = 1.8; 40 | 41 | 42 | %% History variables 43 | speed = zeros(1,ceil(simdur/dt)); 44 | curDir = zeros(1,ceil(simdur/dt)); 45 | vhist = zeros(1,ceil(simdur/dt)); 46 | fhist = zeros(1,ceil(simdur/dt)); 47 | 48 | %% Firing field plot variables 49 | nSpatialBins = 60; 50 | minx = 0; maxx = 1.1; % m 51 | miny = 0; maxy = 1.1; % m 52 | occupancy = zeros(nSpatialBins); 53 | spikes = zeros(nSpatialBins); 54 | 55 | %spikeTimes = []; 56 | %spikeCoords = []; 57 | spikePhases = []; 58 | %actualHdAtSpike = []; 59 | 60 | %% Initial conditions 61 | % Oscillators will start at phase 0: 62 | dendritePhases = zeros(1,length(dirPreferences)); % rad 63 | %basePhase = 0; % rad 64 | disp(basePhase) 65 | 66 | %% Make optional figure of sheet of activity 67 | if livePlot 68 | h = figure('color','w','name','Activity of one cell'); 69 | if useRealTrajectory 70 | set(h,'position',[520 378 1044 420]) 71 | end 72 | drawnow 73 | end 74 | 75 | %% Possibly load trajectory from disk 76 | if useRealTrajectory 77 | load ../trajectory_data.mat; 78 | % interpolate down to simulation time step 79 | %pos = [interp1(pos(3,:),pos(1,:),0:dt:pos(3,end)); 80 | % interp1(pos(3,:),pos(2,:),0:dt:pos(3,end)); 81 | % interp1(pos(3,:),pos(3,:),0:dt:pos(3,end)); 82 | % interp1(pos(3,:),pos(4,:),0:dt:pos(3,end))]; % hd reference 83 | pos(1:2,:) = pos(1:2,:)/100; % cm to m 84 | vels = [diff(pos(1,:)); diff(pos(2,:))]/dt; % m/s 85 | x = pos(1,1); % m 86 | y = pos(2,1); % m 87 | actualHd = pos(4, :); % real hd for reference 88 | end 89 | 90 | %% !! Main simulation loop 91 | fprintf('Simulation starting. Press ctrl+c to end...\n') 92 | while t0); 129 | 130 | % Save for later 131 | fhist(tind) = f; 132 | 133 | % Save firing field information 134 | if f>spikeThreshold 135 | spikeTimes{watchCell} = [spikeTimes{watchCell}; t]; 136 | spikeCoords{watchCell} = [spikeCoords{watchCell}; x(tind) y(tind)]; 137 | spikePhases = [spikePhases; basePhase]; 138 | actualHdAtSpike{watchCell} = [actualHdAtSpike{watchCell}; actualHd(tind)]; 139 | end 140 | if useRealTrajectory 141 | xindex = round((x(tind)-minx)/(maxx-minx)*nSpatialBins)+1; 142 | yindex = round((y(tind)-miny)/(maxy-miny)*nSpatialBins)+1; 143 | occupancy(yindex,xindex) = occupancy(yindex,xindex) + dt; 144 | spikes(yindex,xindex) = spikes(yindex,xindex) + double(f>spikeThreshold); 145 | end 146 | 147 | if livePlot>0 && (livePlot==1 || mod(tind,livePlot)==1) 148 | if ~useRealTrajectory 149 | figure(h); 150 | subplot(121); 151 | plot(fhist(1:tind)); 152 | title('Activity'); 153 | xlabel('Time (s)') 154 | axis square 155 | set(gca,'ydir','normal') 156 | title(sprintf('t = %.1f s',t)) 157 | subplot(122); 158 | plot(x(1:tind),y(1:tind)) 159 | hold on; 160 | if ~isempty(spikeCoords) 161 | cmap = jet; 162 | cmap = [cmap((end/2+1):end,:); cmap(1:end/2,:)]; 163 | phaseInds = mod(spikePhases,2*pi)*(length(cmap)-1)/2/pi; 164 | pointColors = cmap(ceil(phaseInds)+1,:); 165 | 166 | scatter3(spikeCoords(:,1), ... 167 | spikeCoords(:,2), ... 168 | zeros(size(spikeCoords(:,1))), ... 169 | 30*ones(size(spikeCoords(:,1))), ... 170 | pointColors, ... 171 | 'o','filled'); 172 | end 173 | axis square 174 | title({'Trajectory (blue) and',... 175 | 'spikes (colored by theta phase',... 176 | 'blues before baseline peak, reds after)'}) 177 | drawnow 178 | else 179 | figure(h); 180 | subplot(131); 181 | plot((0:tind-1)*dt,fhist(1:tind)); 182 | hold on; 183 | plot([0 tind-1]*dt,[spikeThreshold spikeThreshold],'r') 184 | title('Activity (blue) and threshold (red)'); 185 | xlabel('Time (s)') 186 | axis square 187 | set(gca,'ydir','normal') 188 | subplot(132); 189 | imagesc(spikes./occupancy); 190 | axis square 191 | set(gca,'ydir','normal') 192 | title({'Rate map',sprintf('t = %.1f s',t)}) 193 | subplot(133); 194 | plot(x(1:tind),y(1:tind)) 195 | hold on; 196 | if ~isempty(spikeCoords) 197 | cmap = jet; 198 | cmap = [cmap((end/2+1):end,:); cmap(1:end/2,:)]; 199 | phaseInds = mod(spikePhases,2*pi)*(length(cmap)-1)/2/pi; 200 | pointColors = cmap(ceil(phaseInds)+1,:); 201 | 202 | scatter3(spikeCoords(:,1), ... 203 | spikeCoords(:,2), ... 204 | zeros(size(spikeCoords(:,1))), ... 205 | 30*ones(size(spikeCoords(:,1))), ... 206 | pointColors, ... 207 | 'o','filled'); 208 | end 209 | axis square 210 | title({'Trajectory (blue) and',... 211 | 'spikes (colored by theta phase',... 212 | 'blues before baseline peak, reds after)'}) 213 | drawnow 214 | end 215 | end 216 | end 217 | end 218 | 219 | save('results.mat') 220 | disp('done'); 221 | -------------------------------------------------------------------------------- /SimulationComparisons/grid_cell_models/pastoll/prepare_trajectory.py: -------------------------------------------------------------------------------- 1 | # Prepares a trajectory data file for use in the pastoll 2013 model 2 | 3 | import scipy.io 4 | import numpy as np 5 | 6 | pos = scipy.io.loadmat('../trajectory_data.mat')['pos'] 7 | 8 | t = pos[2, :] 9 | x = pos[0, :] - 50 10 | y = pos[1, :] - 50 11 | 12 | traj = { 13 | 'pos_timeStamps': t.T, 14 | 'pos_x': x.T, 15 | 'pos_y': y.T, 16 | 'dt': np.array(0.02) 17 | } 18 | 19 | 20 | scipy.io.savemat('./model/grid_cell_model/kg_trajectory.mat', traj) 21 | -------------------------------------------------------------------------------- /SimulationComparisons/grid_cell_models/pastoll/submit_simulation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Note: might need to add the model directory to PYTHONPATH, 3 | # or just move this file to the grid_cell_model dir 4 | 5 | 6 | from default_params import defaultParameters 7 | from common import GenericSubmitter, ArgumentCreator 8 | 9 | import logging as lg 10 | lg.basicConfig(level=lg.DEBUG) 11 | 12 | 13 | parameters = defaultParameters 14 | 15 | parameters['time'] = 1293e3 # ms 16 | parameters['ndumps'] = 1 17 | 18 | parameters['placeT'] = 10e3 # ms 19 | 20 | parameters['stateMonDur'] = 10e3 21 | 22 | parameters['bumpCurrentSlope'] = 1.175 # pA/(cm/s), !! this will depend on prefDirC !! 23 | parameters['gridSep'] = 60 # cm, grid field inter-peak distance 24 | 25 | parameters['ratVelFName'] = './kg_trajectory.mat' 26 | 27 | startJobNum = 0 28 | numRepeat = 1 29 | 30 | # Workstation parameters 31 | programName = 'python simulation_basic_grids.py' 32 | blocking = True 33 | 34 | ac = ArgumentCreator(parameters) 35 | submitter = GenericSubmitter(ac, programName, blocking=blocking) 36 | submitter.submitAll(startJobNum, numRepeat, dry_run=False) 37 | -------------------------------------------------------------------------------- /SimulationComparisons/results_sampling_shuffled_cell.csv: -------------------------------------------------------------------------------- 1 | ,trial,length_minutes,sig_bins 2 | 0,3.0,80.0,11.0 3 | 1,13.0,80.0,11.0 4 | 2,11.0,80.0,14.0 5 | 3,18.0,80.0,11.0 6 | 4,4.0,80.0,12.0 7 | 5,2.0,80.0,12.0 8 | 6,19.0,80.0,13.0 9 | 7,16.0,80.0,13.0 10 | 8,10.0,80.0,17.0 11 | 9,15.0,80.0,9.0 12 | 10,8.0,80.0,9.0 13 | 11,12.0,80.0,13.0 14 | 12,9.0,80.0,11.0 15 | 13,5.0,80.0,12.0 16 | 14,14.0,80.0,15.0 17 | 15,17.0,7.0,3.0 18 | 16,12.0,7.0,1.0 19 | 17,19.0,7.0,7.0 20 | 18,2.0,7.0,4.0 21 | 19,13.0,7.0,2.0 22 | 20,4.0,7.0,1.0 23 | 21,14.0,7.0,1.0 24 | 22,7.0,7.0,1.0 25 | 23,9.0,7.0,4.0 26 | 24,8.0,7.0,1.0 27 | 25,3.0,7.0,3.0 28 | 26,14.0,500.0,17.0 29 | 27,3.0,500.0,17.0 30 | 28,8.0,500.0,14.0 31 | 29,6.0,500.0,18.0 32 | 30,2.0,500.0,15.0 33 | 31,5.0,500.0,19.0 34 | 32,17.0,500.0,13.0 35 | 33,12.0,6.0,1.0 36 | 34,2.0,6.0,4.0 37 | 35,8.0,6.0,0.0 38 | 36,13.0,6.0,2.0 39 | 37,0.0,6.0,3.0 40 | 38,6.0,6.0,3.0 41 | 39,3.0,6.0,1.0 42 | 40,7.0,6.0,1.0 43 | 41,5.0,6.0,2.0 44 | 42,15.0,6.0,3.0 45 | 43,4.0,3.0,0.0 46 | 44,9.0,3.0, 47 | 45,14.0,3.0,0.0 48 | 46,18.0,3.0,4.0 49 | 47,8.0,3.0,0.0 50 | 48,17.0,3.0,0.0 51 | 49,6.0,3.0,3.0 52 | 50,3.0,3.0,0.0 53 | 51,2.0,3.0, 54 | 52,12.0,3.0,1.0 55 | 53,5.0,3.0,0.0 56 | 54,19.0,3.0,4.0 57 | 55,7.0,3.0,1.0 58 | 56,15.0,3.0,2.0 59 | 57,8.0,400.0,15.0 60 | 58,11.0,400.0,17.0 61 | 59,5.0,400.0,18.0 62 | 60,9.0,400.0,18.0 63 | 61,4.0,400.0,14.0 64 | 62,14.0,400.0,17.0 65 | 63,2.0,400.0,15.0 66 | 64,3.0,400.0,17.0 67 | 65,15.0,400.0,14.0 68 | 66,13.0,400.0,19.0 69 | 67,7.0,400.0,12.0 70 | 68,10.0,400.0,14.0 71 | 69,6.0,400.0,17.0 72 | 70,14.0,60.0,14.0 73 | 71,3.0,60.0,11.0 74 | 72,12.0,60.0,13.0 75 | 73,4.0,60.0,11.0 76 | 74,6.0,60.0,13.0 77 | 75,16.0,60.0,8.0 78 | 76,2.0,60.0,11.0 79 | 77,9.0,60.0,9.0 80 | 78,13.0,60.0,10.0 81 | 79,19.0,60.0,10.0 82 | 80,8.0,60.0,13.0 83 | 81,0.0,60.0,12.0 84 | 82,5.0,60.0,11.0 85 | 83,5.0,5.0,1.0 86 | 84,13.0,5.0,0.0 87 | 85,12.0,5.0,1.0 88 | 86,6.0,5.0,2.0 89 | 87,7.0,5.0,1.0 90 | 88,9.0,5.0,2.0 91 | 89,3.0,5.0,0.0 92 | 90,8.0,5.0,0.0 93 | 91,2.0,5.0,4.0 94 | 92,10.0,5.0,5.0 95 | 93,15.0,5.0,4.0 96 | 94,0.0,5.0,4.0 97 | 95,8.0,2.0,0.0 98 | 96,6.0,2.0,0.0 99 | 97,4.0,2.0, 100 | 98,2.0,2.0,1.0 101 | 99,3.0,2.0,0.0 102 | 100,14.0,2.0,0.0 103 | 101,7.0,2.0,0.0 104 | 102,5.0,2.0,0.0 105 | 103,10.0,2.0,1.0 106 | 104,12.0,2.0,0.0 107 | 105,4.0,300.0,13.0 108 | 106,19.0,300.0,14.0 109 | 107,2.0,300.0,13.0 110 | 108,0.0,300.0,13.0 111 | 109,15.0,300.0,13.0 112 | 110,9.0,300.0,16.0 113 | 111,8.0,300.0,13.0 114 | 112,3.0,300.0,16.0 115 | 113,4.0,55.0,11.0 116 | 114,13.0,55.0,10.0 117 | 115,12.0,55.0,13.0 118 | 116,5.0,55.0,10.0 119 | 117,14.0,55.0,11.0 120 | 118,9.0,55.0,9.0 121 | 119,3.0,55.0,10.0 122 | 120,8.0,55.0,10.0 123 | 121,2.0,55.0,11.0 124 | 122,12.0,50.0,12.0 125 | 123,11.0,50.0,10.0 126 | 124,0.0,50.0,12.0 127 | 125,17.0,50.0,6.0 128 | 126,2.0,50.0,12.0 129 | 127,9.0,50.0,9.0 130 | 128,3.0,50.0,10.0 131 | 129,8.0,50.0,12.0 132 | 130,13.0,50.0,9.0 133 | 131,5.0,50.0,10.0 134 | 132,19.0,50.0,8.0 135 | 133,7.0,50.0,8.0 136 | 134,3.0,4.0,0.0 137 | 135,2.0,4.0,2.0 138 | 136,13.0,4.0,1.0 139 | 137,14.0,4.0,0.0 140 | 138,19.0,4.0,5.0 141 | 139,9.0,4.0,2.0 142 | 140,6.0,4.0,3.0 143 | 141,5.0,4.0,1.0 144 | 142,12.0,4.0,1.0 145 | 143,4.0,1.0, 146 | 144,8.0,1.0, 147 | 145,13.0,1.0, 148 | 146,16.0,1.0, 149 | 147,17.0,1.0, 150 | 148,18.0,1.0, 151 | 149,19.0,1.0, 152 | 150,11.0,1.0, 153 | 151,14.0,1.0,0.0 154 | 152,1.0,1.0, 155 | 153,12.0,1.0, 156 | 154,10.0,1.0, 157 | 155,15.0,1.0, 158 | 156,2.0,1.0, 159 | 157,3.0,1.0,1.0 160 | 158,9.0,1.0, 161 | 159,7.0,200.0,9.0 162 | 160,6.0,200.0,16.0 163 | 161,19.0,200.0,11.0 164 | 162,3.0,200.0,16.0 165 | 163,8.0,200.0,13.0 166 | 164,13.0,200.0,13.0 167 | 165,2.0,200.0,12.0 168 | 166,4.0,200.0,11.0 169 | 167,9.0,200.0,17.0 170 | 168,18.0,200.0,12.0 171 | 169,11.0,200.0,15.0 172 | 170,13.0,45.0,10.0 173 | 171,3.0,45.0,10.0 174 | 172,2.0,45.0,11.0 175 | 173,15.0,45.0,8.0 176 | 174,7.0,45.0,6.0 177 | 175,8.0,45.0,9.0 178 | 176,5.0,45.0,9.0 179 | 177,19.0,45.0,8.0 180 | 178,6.0,45.0,11.0 181 | 179,9.0,40.0,8.0 182 | 180,3.0,40.0,7.0 183 | 181,7.0,40.0,7.0 184 | 182,18.0,40.0,12.0 185 | 183,2.0,40.0,11.0 186 | 184,14.0,40.0,8.0 187 | 185,8.0,40.0,10.0 188 | 186,17.0,40.0,5.0 189 | 187,5.0,40.0,9.0 190 | 188,12.0,40.0,10.0 191 | 189,8.0,150.0,10.0 192 | 190,12.0,150.0,14.0 193 | 191,7.0,150.0,8.0 194 | 192,11.0,150.0,14.0 195 | 193,5.0,150.0,12.0 196 | 194,6.0,150.0,16.0 197 | 195,19.0,150.0,13.0 198 | 196,4.0,150.0,12.0 199 | 197,15.0,150.0,13.0 200 | 198,3.0,150.0,14.0 201 | 199,14.0,150.0,18.0 202 | 200,13.0,150.0,12.0 203 | 201,8.0,100.0,8.0 204 | 202,7.0,100.0,7.0 205 | 203,12.0,100.0,15.0 206 | 204,15.0,100.0,9.0 207 | 205,6.0,100.0,15.0 208 | 206,10.0,100.0,16.0 209 | 207,2.0,100.0,12.0 210 | 208,3.0,100.0,12.0 211 | 209,9.0,100.0,15.0 212 | 210,4.0,100.0,11.0 213 | 211,7.0,35.0,4.0 214 | 212,3.0,35.0,8.0 215 | 213,5.0,35.0,8.0 216 | 214,19.0,35.0,6.0 217 | 215,8.0,35.0,8.0 218 | 216,15.0,35.0,9.0 219 | 217,2.0,35.0,10.0 220 | 218,17.0,30.0,7.0 221 | 219,7.0,30.0,5.0 222 | 220,3.0,30.0,7.0 223 | 221,0.0,30.0,12.0 224 | 222,9.0,30.0,8.0 225 | 223,19.0,30.0,9.0 226 | 224,12.0,30.0,11.0 227 | 225,5.0,30.0,9.0 228 | 226,2.0,30.0,9.0 229 | 227,11.0,25.0,10.0 230 | 228,6.0,25.0,9.0 231 | 229,5.0,25.0,7.0 232 | 230,2.0,25.0,9.0 233 | 231,7.0,25.0,4.0 234 | 232,19.0,25.0,9.0 235 | 233,9.0,25.0,7.0 236 | 234,3.0,25.0,9.0 237 | 235,14.0,25.0,6.0 238 | 236,12.0,25.0,8.0 239 | 237,2.0,20.0,9.0 240 | 238,9.0,20.0,6.0 241 | 239,11.0,20.0,10.0 242 | 240,18.0,20.0,5.0 243 | 241,14.0,20.0,8.0 244 | 242,17.0,20.0,5.0 245 | 243,5.0,20.0,7.0 246 | 244,7.0,20.0,4.0 247 | 245,15.0,20.0,7.0 248 | 246,12.0,20.0,9.0 249 | 247,8.0,20.0,5.0 250 | 248,13.0,20.0,6.0 251 | 249,10.0,20.0,13.0 252 | 250,19.0,20.0,7.0 253 | 251,16.0,20.0,5.0 254 | 252,1.0,20.0,7.0 255 | 253,4.0,20.0,6.0 256 | 254,3.0,20.0,8.0 257 | 255,3.0,9.0,3.0 258 | 256,15.0,9.0,5.0 259 | 257,14.0,9.0,3.0 260 | 258,2.0,9.0,4.0 261 | 259,8.0,9.0,1.0 262 | 260,5.0,9.0,4.0 263 | 261,9.0,9.0,3.0 264 | 262,12.0,9.0,3.0 265 | 263,18.0,9.0,5.0 266 | 264,6.0,9.0,3.0 267 | 265,14.0,250.0,18.0 268 | 266,13.0,250.0,15.0 269 | 267,5.0,250.0,15.0 270 | 268,2.0,250.0,13.0 271 | 269,7.0,250.0,9.0 272 | 270,6.0,250.0,17.0 273 | 271,3.0,250.0,16.0 274 | 272,8.0,250.0,13.0 275 | 273,12.0,250.0,16.0 276 | 274,9.0,250.0,17.0 277 | 275,11.0,15.0,7.0 278 | 276,7.0,15.0,4.0 279 | 277,9.0,15.0,6.0 280 | 278,19.0,15.0,5.0 281 | 279,0.0,15.0,8.0 282 | 280,5.0,15.0,5.0 283 | 281,2.0,15.0,9.0 284 | 282,15.0,15.0,5.0 285 | 283,8.0,15.0,4.0 286 | 284,12.0,15.0,7.0 287 | 285,17.0,15.0,2.0 288 | 286,3.0,15.0,6.0 289 | 287,18.0,10.0,5.0 290 | 288,13.0,10.0,4.0 291 | 289,10.0,10.0,10.0 292 | 290,14.0,10.0,5.0 293 | 291,8.0,10.0,1.0 294 | 292,3.0,10.0,2.0 295 | 293,5.0,10.0,4.0 296 | 294,6.0,10.0,6.0 297 | 295,11.0,8.0,1.0 298 | 296,12.0,8.0,2.0 299 | 297,5.0,8.0,5.0 300 | 298,15.0,8.0,4.0 301 | 299,8.0,8.0,1.0 302 | 300,7.0,8.0,1.0 303 | 301,3.0,8.0,2.0 304 | 302,17.0,8.0,2.0 305 | 303,4.0,8.0,2.0 306 | 304,9.0,8.0,5.0 307 | 305,15.0,120.0,10.0 308 | 306,12.0,120.0,14.0 309 | 307,3.0,120.0,13.0 310 | 308,13.0,120.0,13.0 311 | 309,10.0,120.0,14.0 312 | 310,8.0,120.0,9.0 313 | 311,4.0,120.0,10.0 314 | 312,7.0,120.0,6.0 315 | 313,7.0,600.0,15.0 316 | 314,9.0,600.0,18.0 317 | 315,6.0,600.0,17.0 318 | 316,14.0,600.0,19.0 319 | 317,11.0,600.0,17.0 320 | 318,5.0,600.0,18.0 321 | 319,12.0,600.0,20.0 322 | 320,13.0,600.0,19.0 323 | 321,0.0,600.0,16.0 324 | 322,4.0,600.0,14.0 325 | -------------------------------------------------------------------------------- /SimulationComparisons/s15.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(scales) 3 | library(glue) 4 | library(patchwork) 5 | library(ggstatsplot) 6 | library(ggpubr) 7 | 8 | options(scipen=10000) 9 | 10 | theme_set(theme_classic()) 11 | 12 | 13 | spacing_data <- read_csv('results_gridspacing.csv') %>% 14 | mutate(is_grid_cell = grid_score >= 0.4) 15 | 16 | 17 | size_data <- read_csv('results_gridsize.csv') %>% 18 | mutate(is_grid_cell = grid_score >= 0.4) 19 | 20 | size_data_mt <- size_data %>% 21 | mutate(frac_sig_bins = sig_bins / total_bins) %>% 22 | filter(is_grid_cell == TRUE, !is.na(frac_sig_bins)) 23 | 24 | size_plot <- ggscatter(size_data_mt, 25 | x = 'field_size', 26 | y = 'frac_sig_bins', 27 | add = 'reg.line', 28 | conf.int = TRUE, 29 | cor.coef = TRUE, 30 | cor.method = 'pearson', 31 | xlab = expression('Field size cm'^{'2'}), 32 | ylab = 'Proportion of directional bins', 33 | caption = '' 34 | ) 35 | 36 | size_plot 37 | 38 | 39 | 40 | spacing_data_mt <- spacing_data %>% 41 | mutate(frac_sig_bins = sig_bins / total_bins) %>% 42 | filter(is_grid_cell == TRUE, !is.na(frac_sig_bins)) 43 | 44 | spacing_plot <- ggscatter(spacing_data_mt, 45 | x = 'calculated_grid_spacing', 46 | y = 'frac_sig_bins', 47 | add = 'reg.line', 48 | conf.int = TRUE, 49 | cor.coef = TRUE, 50 | cor.method = 'pearson', 51 | xlab = 'Field spacing (cm)', 52 | ylab = 'Proportion of directional bins', 53 | caption = '' 54 | ) 55 | spacing_plot 56 | 57 | 58 | panel_plot <- size_plot + spacing_plot + plot_layout(ncol = 1) + plot_annotation(tag_levels = 'a') & theme(plot.tag= element_text(size = 16)) 59 | 60 | ggsave('s15.png', panel_plot) 61 | ggsave('s15.pdf', panel_plot) 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /array_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # https://stackoverflow.com/questions/30399534/shift-elements-in-a-numpy-array 5 | def shift(array_to_shift, n): 6 | if n >= 0: 7 | return np.concatenate((np.full(n, np.nan), array_to_shift[:-n])) 8 | else: 9 | return np.concatenate((array_to_shift[-n:], np.full(-n, np.nan))) 10 | 11 | 12 | ''' 13 | Shifts 2d array along given axis. 14 | 15 | array_to_shift : 2d array that is to be shifted 16 | n : array will be shifted by n places 17 | axis : shift along this axis (should be 0 or 1) 18 | ''' 19 | 20 | 21 | def shift_2d(array_to_shift, n, axis): 22 | shifted_array = np.zeros_like(array_to_shift) 23 | if axis == 0: # shift along x axis 24 | if n == 0: 25 | return array_to_shift 26 | if n > 0: 27 | shifted_array[:, :n] = 0 28 | shifted_array[:, n:] = array_to_shift[:, :-n] 29 | else: 30 | shifted_array[:, n:] = 0 31 | shifted_array[:, :n] = array_to_shift[:, -n:] 32 | 33 | if axis == 1: # shift along y axis 34 | if n == 0: 35 | return array_to_shift 36 | elif n > 0: 37 | shifted_array[-n:, :] = 0 38 | shifted_array[:-n, :] = array_to_shift[n:, :] 39 | else: 40 | shifted_array[:-n, :] = 0 41 | shifted_array[-n:, :] = array_to_shift[:n, :] 42 | return shifted_array 43 | 44 | 45 | def nan_helper(y): 46 | """Helper to handle indices and logical indices of NaNs. 47 | 48 | Input: 49 | - y, 1d numpy array with possible NaNs 50 | Output: 51 | - nans, logical indices of NaNs 52 | - index, a function, with signature indices= index(logical_indices), 53 | to convert logical indices of NaNs to 'equivalent' indices 54 | Example: 55 | # linear interpolation of NaNs 56 | nans, x= nan_helper(y) 57 | y[nans]= np.interp(x(nans), x(~nans), y[~nans]) 58 | """ 59 | 60 | return np.isnan(y), lambda z: z.nonzero()[0] 61 | 62 | 63 | def remove_nans_from_both_arrays(array1, array2): 64 | not_nans_in_array1 = ~np.isnan(array1) 65 | not_nans_in_array2 = ~np.isnan(array2) 66 | array1 = array1[not_nans_in_array1 & not_nans_in_array2] 67 | array2 = array2[not_nans_in_array1 & not_nans_in_array2] 68 | return array1, array2 69 | 70 | 71 | def remove_nans_and_inf_from_both_arrays(array1, array2): 72 | not_nans_in_array1 = ~np.isnan(array1) 73 | not_nans_in_array2 = ~np.isnan(array2) 74 | array1 = array1[not_nans_in_array1 & not_nans_in_array2] 75 | array2 = array2[not_nans_in_array1 & not_nans_in_array2] 76 | 77 | not_nans_in_array1 = ~np.isinf(array1) 78 | not_nans_in_array2 = ~np.isinf(array2) 79 | array1 = array1[not_nans_in_array1 & not_nans_in_array2] 80 | array2 = array2[not_nans_in_array1 & not_nans_in_array2] 81 | return array1, array2 82 | 83 | 84 | 85 | def main(): 86 | print('-------------------------------------------------------------') 87 | print('-------------------------------------------------------------') 88 | 89 | array_to_shift = np.array([[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]) 90 | n = -2 91 | axis = 1 92 | 93 | desired_result = np.array([[np.nan, np.nan, 1, 1], [np.nan, np.nan, 2, 9], [np.nan, np.nan, 3, 3], [np.nan, np.nan, 4, 4], [np.nan, np.nan, 5, 5], [np.nan, np.nan, 6, 6]]) 94 | result = shift_2d(array_to_shift, n, axis) 95 | 96 | array_to_shift2 = np.array([[[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3]], [[4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]]) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | 102 | -------------------------------------------------------------------------------- /data_frame_utility.py: -------------------------------------------------------------------------------- 1 | import OverallAnalysis.folder_path_settings 2 | import pandas as pd 3 | import PostSorting.open_field_head_direction 4 | import numpy as np 5 | 6 | 7 | # source: https://stackoverflow.com/users/48956/user48956 8 | def df_empty(columns, dtypes, index=None): 9 | assert len(columns) == len(dtypes) 10 | df = pd.DataFrame(index=index) 11 | for c, d in zip(columns, dtypes): 12 | df[c] = pd.Series(dtype=d) 13 | return df 14 | 15 | 16 | def append_field_to_data_frame(field_df, session_id, cluster_id, field_id, indices_rate_map, spike_times, number_of_spikes_in_field, position_x_spikes, position_y_spikes, hd_in_field_spikes, hd_hist_spikes, times_session, time_spent_in_field, position_x_session, position_y_session, hd_in_field_session, hd_hist_session, hd_score, grid_score, grid_spacing, field_size): 17 | field_df = field_df.append({ 18 | "session_id": session_id, 19 | "cluster_id": cluster_id, 20 | "field_id": field_id, 21 | "indices_rate_map": indices_rate_map, 22 | "spike_times": spike_times, 23 | "number_of_spikes_in_field": number_of_spikes_in_field, 24 | "position_x_spikes": position_x_spikes, 25 | "position_y_spikes": position_y_spikes, 26 | "hd_in_field_spikes": hd_in_field_spikes, 27 | "hd_hist_spikes": hd_hist_spikes, 28 | "times_session": times_session, 29 | "time_spent_in_field": time_spent_in_field, 30 | "position_x_session": position_x_session, 31 | "position_y_session": position_y_session, 32 | "hd_in_field_session": hd_in_field_session, 33 | "hd_hist_session": hd_hist_session, 34 | "hd_score": hd_score, 35 | "grid_score": grid_score, 36 | "grid_spacing": grid_spacing, 37 | "field_size": field_size 38 | }, ignore_index=True) 39 | return field_df 40 | 41 | 42 | def get_field_data_frame(spatial_firing, position_data): 43 | field_df = pd.DataFrame(columns=['session_id', 'cluster_id', 'field_id', 'indices_rate_map', 'spike_times', 'number_of_spikes_in_field', 'position_x_spikes', 'position_y_spikes', 'hd_in_field_spikes', 'hd_hist_spikes', 'times_session', 'time_spent_in_field', 'position_x_session', 'position_y_session', 'hd_in_field_session', 'hd_hist_session', 'hd_score', 'grid_score', 'grid_spacing', 'field_size']) 44 | for index, cluster in spatial_firing.iterrows(): 45 | cluster_id = spatial_firing.cluster_id[index] 46 | session_id = spatial_firing.session_id[index] 47 | number_of_firing_fields = len(spatial_firing.firing_fields[index]) 48 | if number_of_firing_fields > 0: 49 | firing_field_spike_times = spatial_firing.spike_times_in_fields[index] 50 | for field_id, field in enumerate(firing_field_spike_times): 51 | indices_rate_map = spatial_firing.firing_fields[index][field_id] 52 | mask_firing_times_in_field = np.in1d(spatial_firing.firing_times[index], field) 53 | spike_times = field 54 | number_of_spikes_in_field = len(field) 55 | position_x_spikes = np.array(spatial_firing.position_x_pixels[index])[mask_firing_times_in_field] 56 | position_y_spikes = np.array(spatial_firing.position_y_pixels[index])[mask_firing_times_in_field] 57 | hd_in_field_spikes = np.array(spatial_firing.hd[index])[mask_firing_times_in_field] 58 | hd_in_field_spikes = (np.array(hd_in_field_spikes) + 180) * np.pi / 180 59 | hd_hist_spikes = PostSorting.open_field_head_direction.get_hd_histogram(hd_in_field_spikes) 60 | 61 | times_session = spatial_firing.times_in_session_fields[index][field_id] 62 | time_spent_in_field = len(times_session) 63 | mask_times_in_field = np.in1d(position_data.synced_time, times_session) 64 | position_x_session = position_data.position_x_pixels.values[mask_times_in_field] 65 | position_y_session = position_data.position_y_pixels.values[mask_times_in_field] 66 | hd_in_field_session = position_data.hd.values[mask_times_in_field] 67 | hd_in_field_session = (np.array(hd_in_field_session) + 180) * np.pi / 180 68 | hd_hist_session = PostSorting.open_field_head_direction.get_hd_histogram(hd_in_field_session) 69 | hd_score = cluster.hd_score 70 | if 'grid_score' in spatial_firing: 71 | grid_score = cluster.grid_score 72 | grid_spacing = cluster.grid_spacing 73 | field_size = cluster.field_size 74 | else: 75 | grid_score = np.nan 76 | grid_spacing = np.nan 77 | field_size = np.nan 78 | 79 | field_df = append_field_to_data_frame(field_df, session_id, cluster_id, field_id, indices_rate_map, spike_times, number_of_spikes_in_field, position_x_spikes, position_y_spikes, hd_in_field_spikes, hd_hist_spikes, times_session, time_spent_in_field, position_x_session, position_y_session, hd_in_field_session, hd_hist_session, hd_score, grid_score, grid_spacing, field_size) 80 | return field_df 81 | 82 | 83 | def main(): 84 | spatial_firing = pd.read_pickle(OverallAnalysis.folder_path_settings.get_local_test_recording_path() + 'DataFrames/spatial_firing.pkl') 85 | position_data = pd.read_pickle(OverallAnalysis.folder_path_settings.get_local_test_recording_path() + 'DataFrames/position.pkl') 86 | get_field_data_frame(spatial_firing, position_data) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() -------------------------------------------------------------------------------- /example_spatial_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code shows a few of the analyses implemented in PostSorting on the data frames in ExampleData. 3 | Output figures will be saved in ExampleOutput. 4 | """ 5 | 6 | path_where_you_want_to_save_the_output = '' 7 | 8 | import os 9 | import pandas as pd 10 | import PostSorting.open_field_firing_maps 11 | import PostSorting.open_field_head_direction 12 | import PostSorting.make_plots 13 | import PostSorting.open_field_make_plots 14 | import PostSorting.parameters 15 | import PostSorting.speed 16 | 17 | # the prm object contains some recording parameters such as sampling rate 18 | prm = PostSorting.parameters.Parameters() 19 | # set path to output folder (for figures) 20 | prm.set_output_path(path_where_you_want_to_save_the_output + '/ExampleOutput/') 21 | prm.set_pixel_ratio(440) 22 | prm.set_sampling_rate(30000) # sampling rate of mouse ephys data 23 | 24 | 25 | def spatial_firing_analysis(): 26 | # Load data frame with spike sorted data 27 | spatial_firing = pd.read_pickle('ExampleData/spatial_firing.pkl') 28 | # Load data frame with the trajectory of the animal 29 | position = pd.read_pickle('ExampleData/position.pkl') 30 | 31 | # plot how well the animal explored the arena 32 | position_heat_map = PostSorting.open_field_firing_maps.get_position_heatmap(position, prm) 33 | PostSorting.open_field_make_plots.plot_coverage(position_heat_map, prm) 34 | 35 | # plot spikes on the trajectory of the animal 36 | PostSorting.open_field_make_plots.plot_spikes_on_trajectory(position, spatial_firing, prm) 37 | # plot firing rate of cell vs running speed 38 | PostSorting.make_plots.plot_firing_rate_vs_speed(spatial_firing, position, prm) 39 | 40 | # calculate the speed score of the cell(s) in spatial_firing and add it to the data frame as a new column 41 | sspatial_firing = PostSorting.speed.calculate_speed_score(position, spatial_firing, 250, 42 | prm.get_sampling_rate()) 43 | 44 | # make another plot to look at speed dependence 45 | PostSorting.make_plots.plot_speed_vs_firing_rate(position, spatial_firing, prm.get_sampling_rate(), 250, prm) 46 | 47 | # plot firing rate maps 48 | PostSorting.open_field_make_plots.plot_firing_rate_maps(spatial_firing, prm) 49 | 50 | # if the rate map is not in the data frame, run the rate map analyses 51 | if not 'firing_maps' in spatial_firing: 52 | position_heat_map, spatial_firing = PostSorting.open_field_firing_maps.make_firing_field_maps(position, 53 | spatial_firing, 54 | prm) 55 | # also rerun the grid cell analysis 56 | spatial_firing = PostSorting.open_field_grid_cells.process_grid_data(spatial_firing) 57 | 58 | # plot the autocorrelograms of the firing rate maps 59 | PostSorting.open_field_make_plots.plot_rate_map_autocorrelogram(spatial_firing, prm) 60 | 61 | # get the head direction histogram from the trajectory and rerun hd analyses 62 | hd_histogram, spatial_firing = PostSorting.open_field_head_direction.process_hd_data(spatial_firing, 63 | position, prm) 64 | # plot traditional polar head direction plots 65 | PostSorting.open_field_make_plots.plot_polar_head_direction_histogram(hd_histogram, spatial_firing, prm) 66 | 67 | # plot head direction in individual firing fields 68 | PostSorting.open_field_make_plots.plot_hd_for_firing_fields(spatial_firing, position, prm) 69 | 70 | 71 | def main(): 72 | if not os.path.isdir(prm.get_output_path()): # check if output folder exists 73 | os.mkdir(prm.get_output_path()) # make it if it doesn't exist 74 | spatial_firing_analysis() 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /file_utility.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | def find_the_file(file_path, pattern, type): 6 | name = None 7 | file_found = True 8 | file_name = None 9 | 10 | file_counter = 0 11 | for name in glob.glob(file_path + pattern): 12 | file_counter += 1 13 | pass 14 | 15 | if file_counter > 1: 16 | print('There are more than one ' + type + ' files in this folder. This may not be okay.') 17 | 18 | if name is not None: 19 | file_name = name.rsplit('\\', 1)[1] 20 | else: 21 | print('The '+ type + ' file(such as ' + pattern + ' )is not here, or it has an unusual name.') 22 | 23 | file_found = False 24 | 25 | return file_name, file_found 26 | 27 | 28 | def init_data_file_names(prm, beginning, end): 29 | prm.set_continuous_file_name(beginning) 30 | prm.set_continuous_file_name_end(end) 31 | 32 | 33 | def set_continuous_data_path(prm): 34 | file_path = prm.get_filepath() 35 | continuous_file_name_1 = '105_CH' 36 | continuous_file_name_end_1 = '_0' 37 | continuous_file_name_2 = '100_CH' 38 | continuous_file_name_end_2 = '' 39 | 40 | recording_path = file_path + continuous_file_name_1 + str(1) + continuous_file_name_end_1 + '.continuous' 41 | if os.path.isfile(recording_path) is True: 42 | init_data_file_names(prm, continuous_file_name_1, continuous_file_name_end_1) 43 | 44 | recording_path = file_path + continuous_file_name_2 + str(1) + continuous_file_name_end_2 + '.continuous' 45 | if os.path.isfile(recording_path) is True: 46 | init_data_file_names(prm, continuous_file_name_2, continuous_file_name_end_2) 47 | 48 | 49 | def set_dead_channel_path(prm): 50 | file_path = prm.get_filepath() 51 | dead_ch_path = file_path + "/dead_channels.txt" 52 | prm.set_dead_channel_path(dead_ch_path) 53 | 54 | 55 | def create_behaviour_folder_structure(prm): 56 | movement_path = prm.get_filepath() + 'Behaviour' 57 | prm.set_behaviour_path(movement_path) 58 | 59 | data_path = movement_path + '/Data' 60 | analysis_path = movement_path + '/Analysis' 61 | 62 | prm.set_behaviour_data_path(data_path) 63 | prm.set_behaviour_analysis_path(analysis_path) 64 | 65 | if os.path.exists(movement_path) is False: 66 | print('Behavioural data will be saved in {}.'.format(movement_path)) 67 | os.makedirs(movement_path) 68 | os.makedirs(data_path) 69 | os.makedirs(analysis_path) 70 | 71 | 72 | # main path is the folder that contains 'recordings' and 'sorting_files' 73 | def get_main_path(prm): 74 | file_path = prm.get_filepath() 75 | main_path = file_path.rsplit('/', 3)[-4] 76 | return main_path 77 | 78 | 79 | def get_raw_mda_path_all_channels(prm): 80 | raw_mda_path = prm.get_filepath() + 'Electrophysiology/' + prm.get_spike_sorter() + '/raw.mda' 81 | return raw_mda_path 82 | 83 | 84 | def get_raw_mda_path_separate_tetrodes(prm): 85 | raw_mda_path = '/data/raw.mda' 86 | return raw_mda_path 87 | 88 | 89 | def folders_for_separate_tetrodes(prm): 90 | ephys_path = prm.get_filepath() + 'Electrophysiology' 91 | 92 | spike_path = ephys_path + '/Spike_sorting' 93 | data_path = ephys_path + '/Data' 94 | sorting_t1_path_continuous = spike_path + '/t1' 95 | sorting_t2_path_continuous = spike_path + '/t2' 96 | sorting_t3_path_continuous = spike_path + '/t3' 97 | sorting_t4_path_continuous = spike_path + '/t4' 98 | 99 | mountain_data_folder_t1 = spike_path + '/t1/data' 100 | mountain_data_folder_t2 = spike_path + '/t2/data' 101 | mountain_data_folder_t3 = spike_path + '/t3/data' 102 | mountain_data_folder_t4 = spike_path + '/t4/data' 103 | 104 | if os.path.exists(ephys_path) is False: 105 | os.makedirs(ephys_path) 106 | os.makedirs(spike_path) 107 | os.makedirs(data_path) 108 | 109 | if os.path.exists(sorting_t1_path_continuous) is False: 110 | os.makedirs(sorting_t1_path_continuous) 111 | os.makedirs(sorting_t2_path_continuous) 112 | os.makedirs(sorting_t3_path_continuous) 113 | os.makedirs(sorting_t4_path_continuous) 114 | 115 | os.makedirs(mountain_data_folder_t1) 116 | os.makedirs(mountain_data_folder_t2) 117 | os.makedirs(mountain_data_folder_t3) 118 | os.makedirs(mountain_data_folder_t4) 119 | 120 | 121 | def create_ephys_folder_structure(prm): 122 | ephys_path = prm.get_filepath() + 'Electrophysiology' 123 | prm.set_ephys_path(ephys_path) 124 | data_path = ephys_path + '/' + prm.get_spike_sorter() 125 | 126 | if os.path.exists(ephys_path) is False: 127 | os.makedirs(ephys_path) 128 | if os.path.exists(data_path) is False: 129 | os.makedirs(data_path) 130 | 131 | 132 | def create_folder_structure(prm): 133 | create_behaviour_folder_structure(prm) 134 | create_ephys_folder_structure(prm) 135 | 136 | 137 | -------------------------------------------------------------------------------- /math_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cmath import rect, phase 3 | from math import radians, degrees 4 | from scipy.stats import pearsonr 5 | from scipy.stats import chi2 6 | 7 | 8 | def cart2pol(x, y): 9 | rho = np.sqrt(x**2 + y**2) 10 | phi = np.arctan2(y, x) 11 | return rho, phi 12 | 13 | 14 | def pol2cart(rho, phi): 15 | x = rho * np.cos(phi) 16 | y = rho * np.sin(phi) 17 | return x, y 18 | 19 | 20 | # source: https://rosettacode.org/wiki/Averages/Mean_angle#Python 21 | def mean_angle(deg): 22 | return degrees(phase(sum(rect(1, radians(d)) for d in deg)/len(deg))) 23 | 24 | 25 | def circ_corrcc(alpha, x): 26 | """Correlation coefficient between one circular and one linear random 27 | variable. 28 | 29 | Args: 30 | alpha: vector 31 | Sample of angles in radians 32 | 33 | x: vector 34 | Sample of linear random variable 35 | 36 | Returns: 37 | rho: float 38 | Correlation coefficient 39 | 40 | pval: float 41 | p-value 42 | 43 | Code taken from the Circular Statistics Toolbox for Matlab 44 | By Philipp Berens, 2009 45 | Python adaptation by Etienne Combrisson 46 | """ 47 | if len(alpha) is not len(x): 48 | raise ValueError('The length of alpha and x must be the same') 49 | n = len(alpha) 50 | 51 | # Compute correlation coefficent for sin and cos independently 52 | rxs = pearsonr(x, np.sin(alpha))[0] 53 | rxc = pearsonr(x, np.cos(alpha))[0] 54 | rcs = pearsonr(np.sin(alpha), np.cos(alpha))[0] 55 | 56 | # Compute angular-linear correlation (equ. 27.47) 57 | rho = np.sqrt((rxc ** 2 + rxs ** 2 - 2 * rxc * rxs * rcs) / (1 - rcs ** 2)); 58 | 59 | # Compute pvalue 60 | pval = 1 - chi2.cdf(n * rho ** 2, 2); 61 | 62 | return rho, pval 63 | 64 | 65 | def circ_r(alpha, w=None, d=0, axis=0): 66 | """Computes mean resultant vector length for circular data. 67 | 68 | Args: 69 | alpha: array 70 | Sample of angles in radians 71 | 72 | Kargs: 73 | w: array, optional, [def: None] 74 | Number of incidences in case of binned angle data 75 | 76 | d: radians, optional, [def: 0] 77 | Spacing of bin centers for binned data, if supplied 78 | correction factor is used to correct for bias in 79 | estimation of r 80 | 81 | axis: int, optional, [def: 0] 82 | Compute along this dimension 83 | 84 | Return: 85 | r: mean resultant length 86 | 87 | Code taken from the Circular Statistics Toolbox for Matlab 88 | By Philipp Berens, 2009 89 | Python adaptation by Etienne Combrisson 90 | """ 91 | # alpha = np.array(alpha) 92 | # if alpha.ndim == 1: 93 | # alpha = np.matrix(alpha) 94 | # if alpha.shape[0] is not 1: 95 | # alpha = alpha 96 | 97 | if w is None: 98 | w = np.ones(alpha.shape) 99 | elif (alpha.size is not w.size): 100 | raise ValueError("Input dimensions do not match") 101 | 102 | # Compute weighted sum of cos and sin of angles: 103 | r = np.multiply(w, np.exp(1j * alpha)).sum(axis=axis) 104 | 105 | # Obtain length: 106 | r = np.abs(r) / w.sum(axis=axis) 107 | 108 | # For data with known spacing, apply correction factor to 109 | # correct for bias in the estimation of r 110 | if d is not 0: 111 | c = d / 2 / np.sin(d / 2) 112 | r = c * r 113 | 114 | return np.array(r) 115 | 116 | 117 | def circ_rtest(alpha, w=None, d=0): 118 | """Computes Rayleigh test for non-uniformity of circular data. 119 | H0: the population is uniformly distributed around the circle 120 | HA: the populatoin is not distributed uniformly around the circle 121 | Assumption: the distribution has maximally one mode and the data is 122 | sampled from a von Mises distribution! 123 | 124 | Args: 125 | alpha: array 126 | Sample of angles in radians 127 | 128 | Kargs: 129 | w: array, optional, [def: None] 130 | Number of incidences in case of binned angle data 131 | 132 | d: radians, optional, [def: 0] 133 | Spacing of bin centers for binned data, if supplied 134 | correction factor is used to correct for bias in 135 | estimation of r 136 | 137 | Code taken from the Circular Statistics Toolbox for Matlab 138 | By Philipp Berens, 2009 139 | Python adaptation by Etienne Combrisson 140 | """ 141 | alpha = np.array(alpha) 142 | if alpha.ndim == 1: 143 | alpha = np.matrix(alpha) 144 | if alpha.shape[1] > alpha.shape[0]: 145 | alpha = alpha.T 146 | 147 | if w is None: 148 | r = circ_r(alpha) 149 | n = len(alpha) 150 | else: 151 | if len(alpha) is not len(w): 152 | raise ValueError("Input dimensions do not match") 153 | r = circ_r(alpha, w, d) 154 | n = w.sum() 155 | 156 | # Compute Rayleigh's 157 | R = n * r 158 | 159 | # Compute Rayleigh's 160 | z = (R ** 2) / n 161 | 162 | # Compute p value using approxation in Zar, p. 617 163 | pval = np.exp(np.sqrt(1 + 4 * n + 4 * (n ** 2 - R ** 2)) - (1 + 2 * n)) 164 | 165 | return np.squeeze(pval), np.squeeze(z) 166 | 167 | -------------------------------------------------------------------------------- /open_ephys_IO.py: -------------------------------------------------------------------------------- 1 | import OpenEphys 2 | import numpy as np 3 | import matplotlib.pylab as plt 4 | 5 | 6 | def delete_noise(file_path, name, waveforms, timestamps): 7 | to_delete = np.array([]) 8 | for wave in range(0, waveforms.shape[0]): 9 | if np.ndarray.max(abs(waveforms[wave, :, :])) > 0.0025: 10 | to_delete = np.append(to_delete, wave) 11 | 12 | # print('these are deleted') 13 | # print(to_delete) 14 | # print(waveforms[to_delete[0], :, 0]) 15 | 16 | for spk in range(0, to_delete.shape[0]): 17 | plt.plot(waveforms[to_delete[spk], :, 0]) 18 | 19 | plt.savefig(file_path + name + '_deleted_waves.png') 20 | 21 | waveforms = np.delete(waveforms, to_delete, axis=0) 22 | timestamps = np.delete(timestamps, to_delete) 23 | 24 | return waveforms, timestamps 25 | 26 | 27 | def get_data_spike(folder_path, file_path, name): 28 | data = OpenEphys.load(file_path) # returns a dict with data, timestamps, etc. 29 | timestamps = data['timestamps'] 30 | waveforms = data['spikes'] 31 | 32 | # print('{} waveforms were found in the spike file'.format(waveforms.shape[0])) 33 | 34 | waveforms, timestamps = delete_noise(folder_path, name, waveforms, timestamps) 35 | 36 | return waveforms, timestamps 37 | 38 | 39 | def get_data_continuous(prm, file_path): 40 | data = OpenEphys.load(file_path) 41 | signal = data['data'] 42 | signal = np.asanyarray(signal) 43 | return signal 44 | 45 | 46 | def get_events(prm, file_path): 47 | events = OpenEphys.load(file_path) 48 | return events -------------------------------------------------------------------------------- /plot_utility.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab as plt 2 | import math 3 | import numpy as np 4 | import random 5 | import PostSorting.parameters 6 | prm = PostSorting.parameters.Parameters() 7 | 8 | 9 | ''' 10 | colour functions are from https://gist.github.com/adewes/5884820 11 | ''' 12 | 13 | 14 | def style_plot(ax): 15 | ax.spines['top'].set_visible(False) 16 | ax.spines['right'].set_visible(False) 17 | ax.xaxis.set_ticks_position('bottom') 18 | ax.yaxis.set_ticks_position('left') 19 | return plt, ax 20 | 21 | 22 | def style_open_field_plot(ax): 23 | ax.spines['top'].set_visible(False) 24 | ax.spines['right'].set_visible(False) 25 | ax.spines['left'].set_visible(False) 26 | ax.spines['bottom'].set_visible(False) 27 | plt.tick_params( 28 | axis='both', # changes apply to the x-axis 29 | which='both', # both major and minor ticks are affected 30 | bottom=False, # ticks along the bottom edge are off 31 | top=False, # ticks along the top edge are off 32 | right=False, 33 | left=False, 34 | labelleft=False, 35 | labelbottom=False) # labels along the bottom edge are off 36 | 37 | ax.set_aspect('equal') 38 | return ax 39 | 40 | 41 | def style_polar_plot(ax): 42 | ax.spines['polar'].set_visible(False) 43 | ax.set_yticklabels([]) # remove yticklabels 44 | # ax.grid(None) 45 | plt.xticks([math.radians(0), math.radians(90), math.radians(180), math.radians(270)]) 46 | ax.axvline(math.radians(90), color='black', linewidth=1, alpha=0.6) 47 | ax.axvline(math.radians(180), color='black', linewidth=1, alpha=0.6) 48 | ax.axvline(math.radians(270), color='black', linewidth=1, alpha=0.6) 49 | ax.axvline(math.radians(0), color='black', linewidth=1, alpha=0.6) 50 | ax.set_theta_direction(-1) 51 | ax.set_theta_offset(np.pi/2.0) 52 | ax.xaxis.set_tick_params(labelsize=25) 53 | return ax 54 | 55 | 56 | def get_random_color(pastel_factor = 0.5): 57 | return [(x+pastel_factor)/(1.0+pastel_factor) for x in [random.uniform(0,1.0) for i in [1,2,3]]] 58 | 59 | 60 | def color_distance(c1,c2): 61 | return sum([abs(x[0]-x[1]) for x in zip(c1,c2)]) 62 | 63 | 64 | def generate_new_color(existing_colors, pastel_factor=0.5): 65 | max_distance = None 66 | best_color = None 67 | for i in range(0, 100): 68 | color = get_random_color(pastel_factor = pastel_factor) 69 | if not existing_colors: 70 | return color 71 | best_distance = min([color_distance(color, c) for c in existing_colors]) 72 | if not max_distance or best_distance > max_distance: 73 | max_distance = best_distance 74 | best_color = color 75 | return best_color 76 | 77 | 78 | 79 | def adjust_spine_thickness(ax): 80 | for axis in ['left','bottom']: 81 | ax.spines[axis].set_linewidth(1) 82 | 83 | 84 | def adjust_spines(ax,spines): 85 | for loc, spine in ax.spines.items(): 86 | if loc in spines: 87 | spine.set_position(('outward',0)) # outward by 10 points 88 | #spine.set_smart_bounds(True) 89 | else: 90 | spine.set_color('none') # don't draw spine 91 | 92 | # turn off ticks where there is no spine 93 | if 'left' in spines: 94 | ax.yaxis.set_ticks_position('left') 95 | else: 96 | # no yaxis ticks 97 | ax.yaxis.set_ticks([]) 98 | 99 | if 'bottom' in spines: 100 | ax.xaxis.set_ticks_position('bottom') 101 | else: 102 | # no xaxis ticks 103 | ax.xaxis.set_ticks([]) 104 | 105 | 106 | def get_weights_normalized_hist(array_in): 107 | weights = np.ones_like(array_in) / float(len(array_in)) 108 | return weights 109 | 110 | 111 | def format_bar_chart(ax, x_label, y_label): 112 | plt.gcf().subplots_adjust(bottom=0.2) 113 | plt.gcf().subplots_adjust(left=0.2) 114 | ax.spines['top'].set_visible(False) 115 | ax.spines['right'].set_visible(False) 116 | ax.xaxis.set_ticks_position('bottom') 117 | ax.yaxis.set_ticks_position('left') 118 | ax.set_xlabel(x_label, fontsize=25) 119 | ax.set_ylabel(y_label, fontsize=25) 120 | ax.xaxis.set_tick_params(labelsize=20) 121 | ax.yaxis.set_tick_params(labelsize=20) 122 | return ax 123 | 124 | 125 | def plot_cumulative_histogram(corr_values, ax, color='black', number_of_bins=40): 126 | plt.xlim(-1, 1) 127 | plt.yticks([0, 1]) 128 | ax = format_bar_chart(ax, 'r', 'Cumulative probability') 129 | values, base = np.histogram(corr_values, bins=number_of_bins, range=(-1, 1)) 130 | # evaluate the cumulative 131 | cumulative = np.cumsum(values / len(corr_values)) 132 | # plot the cumulative function 133 | plt.plot(base[:-1], cumulative, c=color, linewidth=5, alpha=0.6) 134 | return ax 135 | 136 | 137 | def plot_cumulative_histogram_from_zero(corr_values, ax, color='black', number_of_bins=40): 138 | plt.xlim(0, 1) 139 | plt.yticks([0, 1], fontsize=20) 140 | plt.gcf().subplots_adjust(bottom=0.2) 141 | plt.gcf().subplots_adjust(left=0.2) 142 | ax.spines['top'].set_visible(False) 143 | ax.spines['right'].set_visible(False) 144 | ax.xaxis.set_ticks_position('bottom') 145 | ax.yaxis.set_ticks_position('left') 146 | plt.xlabel('Percentile score', fontsize=25) 147 | plt.ylabel('Cumulative probability', fontsize=25) 148 | # ax.xaxis.set_tick_params(labelsize=20) 149 | # ax.yaxis.set_tick_params(labelsize=20) 150 | plt.xticks([0, 1], ["0", "100"], fontsize=20) 151 | values, base = np.histogram(corr_values, bins=number_of_bins, range=(-1, 1)) 152 | # evaluate the cumulative 153 | cumulative = np.cumsum(values / len(corr_values)) 154 | # plot the cumulative function 155 | plt.plot(base[:-1], cumulative, c=color, linewidth=5, alpha=0.6) 156 | return ax 157 | -------------------------------------------------------------------------------- /tests/unit/PostSorting/test_load_firing_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PostSorting.load_firing_data 3 | 4 | 5 | def test_correct_detected_ch_for_dead_channels(): 6 | dead_channels = [1, 2] 7 | dead_channels = list(map(int, dead_channels)) 8 | primary_channels = np.array([1, 1, 3, 6, 7, 9, 11, 3]) 9 | 10 | desired_result = [3, 3, 5, 8, 9, 11, 13, 5] 11 | result = PostSorting.load_firing_data.correct_detected_ch_for_dead_channels(dead_channels, primary_channels) 12 | 13 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 14 | -------------------------------------------------------------------------------- /tests/unit/PostSorting/test_open_field_head_direction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PostSorting.open_field_head_direction 3 | 4 | 5 | def test_get_rolling_sum(): 6 | 7 | array_in = [1, 2, 3, 4, 5, 6] 8 | window = 3 9 | 10 | desired_result = [9, 6, 9, 12, 15, 12] 11 | result = PostSorting.open_field_head_direction.get_rolling_sum(array_in, window) 12 | 13 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 14 | 15 | array_in = [3, 4, 5, 8, 11, 1, 3, 5] 16 | window = 3 17 | 18 | desired_result = [12, 12, 17, 24, 20, 15, 9, 11] 19 | result = PostSorting.open_field_head_direction.get_rolling_sum(array_in, window) 20 | 21 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 22 | 23 | array_in = [3, 4, 5, 8, 11, 1, 3, 5, 4] 24 | window = 3 25 | 26 | desired_result = [11, 12, 17, 24, 20, 15, 9, 12, 12] 27 | result = PostSorting.open_field_head_direction.get_rolling_sum(array_in, window) 28 | 29 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 30 | 31 | 32 | def test_get_rayleighscore_for_cluster(): 33 | hd_hist = np.ones(360) 34 | expected_result = 1 # uniform distribution 35 | result = PostSorting.open_field_head_direction.get_rayleigh_score_for_cluster(hd_hist) 36 | assert np.allclose(result, expected_result, rtol=1e-05, atol=1e-08) 37 | 38 | hd_hist = np.ones(20) * 300 39 | expected_result = 1 # uniform distribution (different array shape) 40 | result = PostSorting.open_field_head_direction.get_rayleigh_score_for_cluster(hd_hist) 41 | assert np.allclose(result, expected_result, rtol=1e-05, atol=1e-08) 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /tests/unit/PostSorting/test_open_field_heading_direction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PostSorting.open_field_heading_direction 3 | 4 | 5 | def test_calculate_heading_direction(): 6 | x = [0, 1, 2, 2, 1] 7 | y = [0, 1, 1, 0, 1] 8 | 9 | desired_result = [225, 225, 180, 90, 315] 10 | result = PostSorting.open_field_heading_direction.calculate_heading_direction(x, y, pad_first_value=True) 11 | 12 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 13 | 14 | 15 | desired_result = [45 + 180, 0 + 180, -90 + 180, 135 + 180] 16 | result = PostSorting.open_field_heading_direction.calculate_heading_direction(x, y, pad_first_value=False) 17 | 18 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 19 | -------------------------------------------------------------------------------- /tests/unit/PostSorting/test_open_field_light_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PostSorting.open_field_light_data 3 | import pandas as pd 4 | from pandas.util.testing import assert_frame_equal 5 | 6 | 7 | def test_make_opto_data_frame(): 8 | 9 | # pulses equally spaced and same length 10 | array_in = ([[1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21]]) 11 | desired_df = pd.DataFrame() 12 | desired_df['opto_start_times'] = [1, 9, 17] 13 | desired_df['opto_end_times'] = [5, 13, 21] 14 | 15 | result_df = PostSorting.open_field_light_data.make_opto_data_frame(array_in) 16 | assert assert_frame_equal(desired_df, result_df, check_dtype=False) is None 17 | 18 | 19 | # lengths of pulses are different 20 | array_in = ([[1, 2, 3, 10, 11, 12, 13, 14, 21, 22, 23, 24, 25, 26, 27]]) 21 | desired_df = pd.DataFrame() 22 | desired_df['opto_start_times'] = [1, 10, 21] 23 | desired_df['opto_end_times'] = [3, 14, 27] 24 | 25 | result_df = PostSorting.open_field_light_data.make_opto_data_frame(array_in) 26 | assert assert_frame_equal(desired_df, result_df, check_dtype=False) is None 27 | 28 | # spacings between pulses are different 29 | array_in = ([[1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 23, 24, 25, 26, 27]]) 30 | desired_df = pd.DataFrame() 31 | desired_df['opto_start_times'] = [1, 10, 23] 32 | desired_df['opto_end_times'] = [5, 14, 27] 33 | 34 | result_df = PostSorting.open_field_light_data.make_opto_data_frame(array_in) 35 | assert assert_frame_equal(desired_df, result_df, check_dtype=False) is None 36 | 37 | # lengths and spacing between pulses are different 38 | array_in = ([[1, 2, 3, 10, 11, 12, 13, 14, 26, 27, 28, 29, 30, 31, 32]]) 39 | desired_df['opto_start_times'] = [1, 10, 26] 40 | desired_df['opto_end_times'] = [3, 14, 32] 41 | 42 | result_df = PostSorting.open_field_light_data.make_opto_data_frame(array_in) 43 | assert assert_frame_equal(desired_df, result_df, check_dtype=False) is None 44 | 45 | # pulse start != 1 46 | array_in = ([[10, 11, 12, 13, 14, 26, 27, 28, 29, 30, 42, 43, 44, 45, 46]]) 47 | desired_df['opto_start_times'] = [10, 26, 42] 48 | desired_df['opto_end_times'] = [14, 30, 46] 49 | 50 | result_df = PostSorting.open_field_light_data.make_opto_data_frame(array_in) 51 | assert assert_frame_equal(desired_df, result_df, check_dtype=False) is None 52 | 53 | 54 | def main(): 55 | test_make_opto_data_frame() 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /tests/unit/PostSorting/test_post_process_sorted_data.py: -------------------------------------------------------------------------------- 1 | import PostSorting.post_process_sorted_data 2 | import numpy as np 3 | 4 | 5 | def test_process_running_parameter_tag(): 6 | tags = 'interleaved_opto*test1*cat' 7 | result = PostSorting.post_process_sorted_data.process_running_parameter_tag(tags) 8 | desired_result = True, True, False, False 9 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 10 | 11 | tags = 'interleaved_opto*test1*cat*pixel_ratio=555' 12 | result = PostSorting.post_process_sorted_data.process_running_parameter_tag(tags) 13 | desired_result = True, True, False, 555 14 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 15 | 16 | 17 | def main(): 18 | test_process_running_parameter_tag() 19 | 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /tests/unit/test_array_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append("..") 4 | import array_utility 5 | 6 | 7 | def test_shift(): 8 | array_to_shift = np.array([0, 1, 2, 3]) 9 | n = 2 10 | desired_result = np.array([0, 0, 0, 1]) 11 | result = array_utility.shift(array_to_shift, n) 12 | 13 | n = -2 14 | desired_result = np.array([2, 3, 0, 0]) 15 | result = array_utility.shift(array_to_shift, n) 16 | 17 | 18 | def test_shift_2d(): 19 | # shift right along x 20 | array_to_shift = np.array([[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]) 21 | n = 2 22 | axis = 0 23 | 24 | desired_result = np.array([[0, 0, 1, 1], [0, 0, 2, 2], [0, 0, 3, 3], [0, 0, 4, 4], [0, 0, 5, 5], [0, 0, 6, 6]]) 25 | result = array_utility.shift_2d(array_to_shift, n, axis) 26 | 27 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 28 | 29 | # shift left along x 30 | array_to_shift = np.array([[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]) 31 | n = -2 32 | axis = 0 33 | 34 | desired_result = np.array([[1, 1, 0, 0], [2, 9, 0, 0], [3, 3, 0, 0], [4, 4, 0, 0], [5, 5, 0, 0], [6, 6, 0, 0]]) 35 | result = array_utility.shift_2d(array_to_shift, n, axis) 36 | 37 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 38 | 39 | # shift up 40 | array_to_shift = np.array([[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]) 41 | n = 2 42 | axis = 1 43 | 44 | desired_result = np.array([[3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], [0, 0, 0, 0], [0, 0, 0, 0]]) 45 | result = array_utility.shift_2d(array_to_shift, n, axis) 46 | 47 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) 48 | 49 | # shift down 50 | array_to_shift = np.array([[1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6]]) 51 | n = -2 52 | axis = 1 53 | 54 | desired_result = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 9], [3, 3, 3, 3], [4, 4, 4, 4]]) 55 | result = array_utility.shift_2d(array_to_shift, n, axis) 56 | 57 | assert np.allclose(result, desired_result, rtol=1e-05, atol=1e-08) -------------------------------------------------------------------------------- /tests/unit/test_control_sorting_analysis.py: -------------------------------------------------------------------------------- 1 | import control_sorting_analysis 2 | import pytest 3 | 4 | 5 | class TestGetSessionType: 6 | 7 | def test_openfield_type(self, tmp_path): 8 | parameters = '''openfield 9 | JohnWick/Open_field_opto_tagging_p038/M5_2018-03-06_15-34-44_of 10 | ''' 11 | with open(tmp_path / 'parameters.txt', 'w') as f: 12 | f.write(parameters) 13 | 14 | is_vr, is_open_field = control_sorting_analysis.get_session_type(str(tmp_path)) 15 | 16 | assert is_vr == False 17 | assert is_open_field == True 18 | 19 | def test_vr_type(self, tmp_path): 20 | parameters = '''vr 21 | JohnWick/Open_field_opto_tagging_p038/M5_2018-03-06_15-34-44_of 22 | ''' 23 | with open(tmp_path / 'parameters.txt', 'w') as f: 24 | f.write(parameters) 25 | 26 | is_vr, is_open_field = control_sorting_analysis.get_session_type(str(tmp_path)) 27 | 28 | assert is_vr == True 29 | assert is_open_field == False 30 | 31 | def test_invalid_type(self, tmp_path): 32 | parameters = '''openvr 33 | JohnWick/Open_field_opto_tagging_p038/M5_2018-03-06_15-34-44_of 34 | ''' 35 | with open(tmp_path / 'parameters.txt', 'w') as f: 36 | f.write(parameters) 37 | 38 | is_vr, is_open_field = control_sorting_analysis.get_session_type(str(tmp_path)) 39 | 40 | assert is_vr == False 41 | assert is_open_field == False 42 | 43 | def test_file_is_dir(self, tmp_path): 44 | with pytest.raises(Exception): 45 | is_vr, is_open_field = control_sorting_analysis.get_session_type(str(tmp_path)) 46 | --------------------------------------------------------------------------------