├── .gitignore
├── Models_output
├── CNN1D.pickle
├── CNN2D.pickle
├── LSTM.pickle
├── SVM.pickle
└── XGBOOST.pickle
├── README.md
├── __pycache__
├── aux_fcn.cpython-39.pyc
└── rippl_AI.cpython-39.pyc
├── aux_fcn.py
├── examples_detection.ipynb
├── examples_explore
├── example_CNN1D.ipynb
├── example_CNN2D.ipynb
├── example_LSTM.ipynb
├── example_SVM.ipynb
└── example_XGBOOST.ipynb
├── examples_retraining.ipynb
├── figures
├── CNNs.png
├── LSTM.png
├── SVM.png
├── XGBoost.png
├── detection-method.png
├── manual-curation.png
├── output-probabilities.png
├── rippl-AI-logo.png
├── ripple-variability.png
└── threshold-selection.png
├── optimized_models
├── CNN1D_1_Ch8_W60_Ts16_OGmodel12
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN1D_5_Ch8_W60_Ts40_OGmodel32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_1_Ch8_W60_Ts40_OgModel
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── ENS
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256
│ ├── keras_metadata.pb
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
├── SVM_1_Ch8_W60_Ts001_Us0.05
├── SVM_2_Ch8_W60_Ts001_Us0.10
├── SVM_3_Ch8_W60_Ts002_Us0.05
├── SVM_4_Ch8_W60_Ts001_Us1.00
├── SVM_5_Ch8_W60_Ts001_Us0.50
├── SVM_6_Ch8_W60_Ts060_Us1.00
├── XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1
├── XGBOOST_2_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE5
├── XGBOOST_3_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE3
├── XGBOOST_4_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE5
├── XGBOOST_5_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE3
└── XGBOOST_6_Ch1_W60_Ts032_D7_Lr0.10_G0.00_L10_SCALE3
└── rippl_AI.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__
3 | Downloaded_data
4 |
--------------------------------------------------------------------------------
/Models_output/CNN1D.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/Models_output/CNN1D.pickle
--------------------------------------------------------------------------------
/Models_output/CNN2D.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/Models_output/CNN2D.pickle
--------------------------------------------------------------------------------
/Models_output/LSTM.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/Models_output/LSTM.pickle
--------------------------------------------------------------------------------
/Models_output/SVM.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/Models_output/SVM.pickle
--------------------------------------------------------------------------------
/Models_output/XGBOOST.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/Models_output/XGBOOST.pickle
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # rippl-AI
2 |
3 |
4 |
5 | __rippl-AI__ is an open toolbox of Artifical Intelligence (AI) resources for detection of hippocampal neurophysiological signals, in particular sharp-wave ripples (SWR). This toolbox offers multiple successful plug-and-play machine learning (ML) models from 5 different architectures (1D-CNN, 2D-CNN, LSTM, SVM and XGBoost) that are ready to use to detect SWRs in hippocampal recordings. Moreover, there is an additional package that allows easy re-training, so that models are updated to better detect particular features of your own recordings. More details in [Navas-Olive, Rubio, et al. Commun Biol 7, 211 (2024)](https://www.nature.com/articles/s42003-024-05871-w)!
6 |
7 | # Description
8 |
9 | ## Sharp-wave ripples
10 |
11 | Sharp-wave ripples (SWRs) are transient fast oscillatory events (100-250Hz) of around 50ms that appear in the hippocampus, that had been associated with memory consolidation. During SWRs, sequential firing of ensembles of neurons are _replayed_, reactivating memory traces of previously encoded experiences. SWR-related interventions can influence hippocampal-dependent cognitive function, making their detection crucial to understand underlying mechanisms. However, existing SWR identification tools mostly rely on using spectral methods, which remain suboptimal.
12 |
13 | Because of the micro-circuit properties of the hippocampus, CA1 SWRs share a common profile, consisting of a _ripple_ in the _stratum pyramidale_ (SP), and a _sharp-wave_ deflection in _stratum radiatum_ that reflects the large excitatory input that comes from CA3. Yet, SWRs can extremely differ depending on the underlying reactivated circuit. This continuous recording shows this variability:
14 |
15 | 
16 |
17 | ## Artificial intelligence architectures
18 |
19 | In this project, we take advantage of supervised machine learning approaches to train different AI architectures so they can unbiasedly learn to identify signature SWR features on raw Local Field Potential (LFP) recordings. These are the explored architectures:
20 |
21 | ### Convolutional Neural Networks (CNNs)
22 |
23 | 
24 |
25 | ### Support Vector Machine (SVM)
26 |
27 | 
28 |
29 | ### Long-Short Term Memory Recurrent Neural Networks (LSTM)
30 |
31 | 
32 |
33 | ### Extreme-Gradient Boosting (XGBoost)
34 |
35 | 
36 |
37 |
38 |
39 | # The toolbox
40 |
41 | This toolbox contains three main blocks: **detection**, **re-training** and **exploration**. These three packages can be used jointly or separatedly. We will proceed to describe each of their purpose and usage.
42 |
43 | ## Detection
44 |
45 | In previous works ([Navas-Olive, Amaducci et al, 2022](https://elifesciences.org/articles/77772)), we demonstrated that using feature-based algorithms to detect electrophysiological events, such as SWRs, had several advantages:
46 | * Performance lies within the expert range
47 | * It is more stable and less biased than spectral methods
48 | * It can detect a wider variety of SWRs
49 | * It can be used as an interpretation tool
50 | All this is available in our [cnn-ripple repository](https://github.com/PridaLab/cnn-ripple).
51 |
52 | In this toolbox, we widen the machine learning spectrum, by offering multiple plug-and-play models, from very different AI architectures: 1D-CNN, 2D-CNN, LSTM, SVM and XGBoost. We performed an exhaustive parametric search to find different `architecture` solutions (i.e. `model`s) that achieve:
53 | * **High performance**, so detections were as similar as manually labeled SWRs
54 | * **High stability**, so performance does not depend on threshold selection
55 | * **High generability**, so performance remains good on very different contexts
56 |
57 | This respository contains the best five `model`s from each of these five `architecture`s. These `model`s are already trained using mice data, and can be found in the [optimized_models/](https://github.com/PridaLab/rippl-AI/blob/main/optimized_models/) folder.
58 |
59 | The [rippl_AI](https://github.com/PridaLab/rippl-AI/blob/main/rippl_AI.py) python module contains all the necessary functions to easily use any `model` to detect SWRs. Additionally, we also provide some auxiliary functions in the [aux_fcn](https://github.com/PridaLab/rippl-AI/blob/main/aux_fcn.py) module, that contains useful code to process LFP and evaluate performance detection.
60 |
61 | Moreover, several usage examples of all functions can be found in the [examples_detection.ipynb](https://github.com/PridaLab/rippl-AI/blob/main/examples_detection.ipynb) python notebook.
62 |
63 |
64 |
65 | ### rippl_AI.predict()
66 |
67 | The python function `predict(LFP, sf, arch='CNN1D', model_number=1, channels=np.arange(8), d_sf=1250)` of the `rippl_AI` module computes the SWR probability for a give LFP.
68 |
69 | In the figure below, you can see an example of a high-density LFP recording (top) with manually labeled data (gray). The objective of these `model`s is to generate an output signal that most similarly matches the manually labeled signal. The output of the uploaded optimized models can be seen in the bottom, where outputs go from 0 (low probability of SWR) to 1 (high probability of SWR) for each LFP sample.
70 |
71 | 
72 |
73 | The `rippl_AI.predict()` input and output variables are:
74 |
75 | * Mandatory inputs:
76 | - `LFP`: LFP recorded data (`np.array`: `n_samples` x `n_channels`). Although there are no restrictions in `n_channels`, some considerations should be taken into account (see `channels`). Data does not need to be normalized, because it will be internally be z-scored (see `aux_fcn.process_LFP()`).
77 | - `sf`: sampling frequency (in Hz).
78 |
79 | * Optional inputs:
80 | - `arch`: Name of the AI architecture to use (`string`). It can be: `CNN1D`, `CNN2D`, `LSTM`, `SVM` or `XGBOOST`.
81 | - `model_number`: Number of the model to use (`integer`). There are six different models for each architecture, sorted by performance, being `1` the best, and `5` the last. `model_number=6` model can be used if single-channel data needs to be used.
82 | - `channels`: Channels to be used for detection (`np.array` or `list`: `1` x `8`). This is the most senstive parameter, because models will be looking for specific spatial features over all channels. Counting starts in `0`. The two main remarks are:
83 | * All models have been trained to look at features in the pyramidal layer (SP), so for them to work at their maximum potential, the selected channels would ideally be centered in the SP, with a postive deflection on the first channels (upper channels) and a negative deflection on the last channels (lower channels). The image above can be used as a visual reference of how to choose channels.
84 | * For all combinations of `architectures` and `model_numbers`, `channels` **has to be of size 8**. There is only one exception, for `architecture = 2D-CNN` with `models = {3, 4, 5}`, that needs to have **3 channels**.
85 | * If you are using a high-density probe, then we recommend to use equi-distant channels from the beginning to the end of the SP. For example, for Neuropixels in mice, a good set of channels would be `pyr_channel` + [-8,-6,-4,-2,0,2,4,6].
86 | * In the case of linear probes or tetrodes, there are not enough density to cover the SP with 8 channels. For that, interpolation or recorded channels can be done without compromising performance. New artificial interpolated channels will be add to the LFP wherever there is a `-1` in `channels`. For example, if `pyr_channel=11` in your linear probe, so that 10 is in _stratum oriens_ and 12 in _stratum radiatum_, then we could define `channels=[10,-1,-1,11,-1,-1,-1,12]`, where 2nd and 3rd channels will be an interpolation of SO and SP channels, and 5th to 7th an interpolation of SP and SR channels. For tetrodes, organising channels according to their spatial profile is very convenient to assure best performance. These interpolations are done using the function `aux_fcn.interpolate_channels()`.
87 | * Several examples of all these usages can be found in the [examples_detection.ipynb](https://github.com/PridaLab/rippl-AI/blob/main/examples_detection.ipynb) python notebook.
88 | - `new_model`: Other re-trained model you want to use for detection. If you have used our re-train function to adapt the optimized models to your own data (see `rippl_AI.retrain()` for more details), you can input the `new_model` here to use that model to predict your events.
89 | - `d_sf`: Desired subsampling frequency in Hz (`int`). By default all works in 1250 Hz, but can be changed if you retrain your models using `rippl_AI.retrain_model`.
90 |
91 | * Output:
92 | - `SWR_prob`: model output for every sample of the LFP (`np.array`: `n_samples` x 1). It can be interpreted as the confidence or probability of a SWR event, so values close to 0 mean that the `model` is certain that there are not SWRs, and values close to 1 that the model is very sure that there is a SWR hapenning.
93 | - `LFP_norm`: LFP data used as an input to the model (`np.array`: `n_samples` x `len(channels)`). It is undersampled to 1250Hz, z-scored, and transformed to used the channels specified in `channels`.
94 |
95 |
96 | ### rippl_AI.get_intervals()
97 |
98 | The python function `get_intervals(SWR_prob, LFP_norm=None, sf=1250, win_size=100, threshold=None, file_path=None, merge_win=0)` of the `rippl_AI` module takes the output of `rippl_AI.predict()` (i.e. the SWR probability), and identifies SWR beginnings and ends by stablishing a threshold. In the figure below, you can see how the threshold can decisevely determine what events are being detected. For example, lowering the threshold to 0.5 would have result in XGBoost correctly detecting the first SWR, and the 1D-CNN detecting the sharp-wave that has no ripple.
99 |
100 | 
101 |
102 | * Mandatory inputs:
103 | - `SWR_prob`: output of `rippl_AI.predict()`. If this is the only input, the function will display a histogram of all SWR probability values (i.e. `n_samples`), and a draggable threshold to set a threshold based on the values of this particular session. When 'Done' button is pressed, the GUI takes the value of the draggable as the threshold, and computes the beginning and ends of the events.
104 |
105 | * Optional inputs - **Setting the threshold**
106 | Depending on the inputs, different possibilities arise:
107 | - `threshold`: Threshold of predictions (`float`)
108 | - `LFP_norm`: Normalized input signal of the model (`np.array`: `n_samples` x `n_channels`). It is recommended to use `LFP_norm`.
109 | - `file_path`: Absolute path of the folder where the .txt with the predictions will be generated (`string`). Leave empty if you don't want to generate the file.
110 | - `win_size`: Length of the displayed ripples in miliseconds (`integer`). By default 100 ms.
111 | - `sf`: Sampling frequency (Hz) of `LFP_norm` (`integer`). By default 1250 Hz (i.e., sampling frequency of `LFP_norm`).
112 | - `merge_win`: Minimal length of the interval in miliseconds between predictions (`float`). If two detections are closer in time than this parameter, they will be merged together
113 |
114 | There are 4 possible use cases, depending on which parameter combination is used when calling the function.
115 | 1. `rippl_AI.get_intervals(SWR_prob)`: a histogram of the output is displayed, you drag a vertical bar to selecct your `threshold`
116 | 2. `rippl_AI.get_intervals(SWR_prob,threshold)`: no GUI is displayed, the predictions are gererated automatically
117 | 3. `rippl_AI.get_intervals(SWR_prob,LFP_norm)`: some examples of detected events are displayed next to the histogram
118 | 4. `rippl_AI.get_intervals(SWR_prob,LFP_norm,threshold)`: same case as 3, but the initial location of the bar is `threshold`
119 |
120 | Examples:
121 | - `get_intervals(SWR_prob, LFP_norm=LFP_norm, sf=sf, win_size=win_size)`: as `LFP_norm` is also added as an input, then the GUI adds up to 50 examples of SWR detections. If the 'Update' button is pressed, another 50 random detections are shown. When 'Save' button is pressed, the GUI takes the value of the draggable as the threshold. Sampling frequency `sf` (in Hz) and window size `win_size` (in milliseconds) can be used to set the window length of the displayed examples. It automatically discards false positives due to drifts, but if you want to set it off, you can set `discard_drift` to `false`. By default, it discards noises whose mean LFP is above `std_discard` times the standard deviation, which by default is 1SD. This parameter can also be changed. 
122 | - `get_intervals(SWR_prob, 'threshold', threshold)`: if a threshold is given, then it takes that threshold without displaying any GUI.
123 |
124 | * Outputs:
125 | - `predictions`: Returns the time (in seconds) of the begining and end of each vents. (`n_events` x 2)
126 |
127 |
128 | ### aux_fcn.manual_curation()
129 |
130 | The python function `aux_fcn.manual_curation(events, data, file_path, win_size=100, gt_events=None, sf=1250)` of the `aux_fcn` module allows doing a manual curation of the detected events. It displays an interactive GUI to manually select/discard the events.
131 |
132 | * Mandatory inputs:
133 | - `events`: array with events begining and end times in seconds (`2`,`n_det`).
134 | - `data`: normalized array with the input data (`n,n_channels`)
135 | - `file_path`: absolute path of the folder where the .txt with the curated predictions will be saved (`str`).
136 | - `win_size`: length of the displayed ripples in miliseconds (`int`)
137 | - `gt_events`: ground truth events beginning and end times in seconds (`2`,`n_gt_events`)
138 | - `sf`: sampling frequency (Hz) of the data/model output (`int`). Change if different than 1250 Hz.
139 |
140 | * Output: It always writes the curated events begin and end times in file_path.
141 | - curated_ids: boolean array with `True` for events that have been selected, and `False` for events that had been discarded (`#events`,)
142 |
143 | 
144 |
145 | Use cases:
146 | 1. If no GT events are provided, a the detected events will be provided, you can select which ones you want to keep (highligted in green) and which ones to discard (in red)
147 | 2. If GT events are provided, true positive detections (TP) will be displayed in green. If for any reason you want to discard correct detections, they will be displayed in yellow
148 |
149 |
150 | ### aux_fcn.plot_all_events()
151 |
152 | The python function `aux_fcn.plot_all_events(t_events, lfp, sf, win=0.1, title='', savefig='')` of the `aux_fcn` module plots all events in a single plot. It can be used as a fast summary/check after detection and/or curation.
153 |
154 | * Mandatory inputs:
155 | - `events`: numpy array of size (`#events`, `1`) with all times of events
156 | - `lfp`: formated lfp with all channels
157 | - `sf`: sampling frequency of `lfp`
158 |
159 | * Optional inputs:
160 | - `win`: window size at each side of the center of the ripple (`float`)
161 | - `title`: if provided, displays this title (`string`)
162 | - `savefig`: if provided, saves the image in the savefig directory (`string`).Full name required: e.g. images/session1_events.png
163 |
164 |
165 | ### aux_fcn.process_LFP()
166 |
167 | The python function `process_LFP(FP, sf, d_sf, channels)` of the `aux_fcn` module processes the LFP before it is input to the algorithm. It downsamples LFP to `d_sf`, and normalizes each channel separately by z-scoring them.
168 |
169 | * Mandatory inputs:
170 | - `LFP`: LFP recorded data (`np.array`: `n_samples` x `n_channels`).
171 | - `sf`: sampling frequency (in Hz).
172 | - `d_sf`: Desired subsampling frequency in Hz (`int`). By default all works in 1250 Hz, but can be changed if you retrain your models using `rippl_AI.retrain_model`.
173 | - `channels`: channel to which compute the undersampling and z-score normalization. Counting starts in `0`. If `channels` contains any `-1`, interpolation will be also applied. See `channels` of rippl_AI.predict(), or `aux_fcn.interpolate_channels()` for more information.
174 |
175 | * Output:
176 | - `LFP_norm`: normalized LFP (`np.array`: `n_samples` x `len(channels)`). It is undersampled to 1250Hz, z-scored, and transformed to used the channels specified in `channels`.
177 |
178 |
179 | ### aux_fcn.interpolate_channels()
180 |
181 | The python function `interpolate_channels(LFP, channels)` of the `aux_fcn` module allows creating more intermediate channels using interpolation.
182 |
183 | Because these models best performed using a richer spatial profile, all combinations of `architectures` and `model_numbers` **work with 8 channels**. There is only one exception, for `architecture = 2D-CNN` with `models = {3, 4, 5}`, that needs to have **3 channels**. However, some times it's not possible to get such number of channels in the pyramidal layer, like when using linear probes (only 2 oe 3 channels fit in the pyramidal layer) or tetrodes (there are 4 recording channels). For this, we developed this interpolation function, that creates new channels between any pair of your recording channels. Using this approach, we can successfully use the already built algorithms with an equally high performance.
184 |
185 | * Mandatory inputs:
186 | - `LFP`: LFP recorded data (`np.array`: `n_samples` x `n_channels`).
187 | - `channels`: list of channels over which to make interpolations (`np.array` or `list`: 1 x `# channels needed by the model` - 8 in most cases). Interpolated channels will be created in the positions of the `-1` elements of the list. Examples:
188 | - Let's say we have only 4 channels, so `LFP` is `n_samples` x 4. We can interpolate to get 8 functional channels. We will interpolate 1 channel between the first two, another one between 2nd and 3rd, and two more interpolated channels between the last two:
189 | ```
190 | # Define channels
191 | channels_interpolation = [0,-1,1,-1,2,-1,-1,3]
192 |
193 | # Make interpolation
194 | LFP_interpolated = aux_fcn.interpolate_channels(LFP, channels_interpolation)
195 | ```
196 | - Let's say we have 8 channels, but channels 2 and 5 are dead. Then we want to interpolate them to get 8 fuctional channels:
197 | ```
198 | # Define channels
199 | channels_interpolation = [0,1,-1,3,4,-1,6,7,8]
200 |
201 | # Make interpolation
202 | LFP_interpolated = aux_fcn.interpolate_channels(LFP, channels_interpolation)
203 | ```
204 | - More usage examples can be found in the [examples_detection.ipynb](https://github.com/PridaLab/rippl-AI/blob/main/examples_detection.ipynb) python notebook.
205 |
206 | * Output:
207 | - LFP_interpolated: Interpolated LFP (`np.array`: `n_samples` x `len(channels)`).
208 |
209 |
210 | ### aux_fcn.get_performance()
211 |
212 | The python function `get_performance(predictions, true_events, threshold=0, exclude_matched_trues=False, verbose=True)` of the `aux_fcn` module computes several performance metrics:
213 | * precision: also called *positive predictive value* is computed as (# good detections) / (# all detections)
214 | * recall: also called *sensitivity* is computed as (# good detections) / (# all ground truth events)
215 | * F1: computed as the harmonic mean between precision and recall, is a conservative and fair measure of performance. If any of precision or recall is low, F1 will be low. F1=1 only happens if detected events exactly match ground truth events.
216 |
217 | Therefore, this function can be used only when some ground truth (i.e. events that we are considering the _truth_) is given. In order to check if a true event has been predicted, it computes the **Intersection over Union** (IoU). This index metric measures how much two intervals *intersect* with respect of the *union* of their size. So if `pred_events = [[2,3], [6,7]]` and `true_events = [[2,4]],[8,9]]`, then we would expect that the `IoU(pred_events[0], true_events[0]) > 0`, while the rest will be zero.
218 |
219 | * Mandatory inputs:
220 | - `predictions`: detected events (`np.array`: `n_predictions` x 2). First column are beginnings of the events (in seconds), second columns are ends of events (in seconds). This should be the output of `rippl_AI.get_intervals()`.
221 | - `true_events`: ground truth events (`np.array`: `n_groundtruth` x 2). Same format as `predictions`
222 |
223 | * Optional inputs:
224 | - `threshold`: Threshold for the IoU (`bool`). By default is 0, so any intersection will be consider a match.
225 | - `exclude_matched_trues`: Boolean to determine if true events that had been already match to one prediction can be considered for other predicted events (`bool`). By default is `False`, so one true can match many predictions.
226 | - `verbose`: Print results (`bool`).
227 |
228 | * Output:
229 | - `precision`: Metric indicating the percentage of correct predictions out of total predictions
230 | - `recall`: Metric indicating the percentage of true events predicted correctly
231 | - `F1`: Metric with a measure that combines precision and recall.
232 | - `TP`: True Positives (`np.array`: `n_predictions` x 1). It indicates which `pred_event` **detected** a `true_event`, so `True` are true positives, and `False` are false negatives.
233 | - `FN`: False Negatives (`np.array`: `n_groundtruth` x 1). It indicates which `true_event` was **not detected** by `pred_event`, so `True` are false negatives, and `False` are true positives.
234 | - `IOU`: IoU matrix (`np.array`: `n_predictions` x `n_groundtruth`). This can be used to know the matching indexes between `pred_event` and `true_event`.
235 |
236 |
237 | ## Re-training
238 |
239 | Here, we provide a unique toolbox to easily re-train `model`s and adapt them to new datasets. These models have been selected because their architectural parameters are best fit to look for electrophysiological high-frequency events. So both if you are interested in finding SWRs or other electrophysiological events, these toolbox offers you the possility to skip all the parametric search and parameter tuning just by running this scripts. The advantages of the re-training module are:
240 | * Avoid starting from scratch in making **your own feature-based detection algorithm**
241 | * Easily plug-and-play to re-train already tested algorithms
242 | * **Extend detection to other events** such as pathological fast ripples or interictal spikes
243 | * **Extend detection to human** recordings
244 |
245 |
246 | ### rippl_AI.retrain_model()
247 |
248 | The python function `rippl_AI.retrain_model(train_data, train_GT, test_data, test_GT, arch, parameters=None, save_path=None, d_sf=1250, merge_win=0)` of the `rippl_AI` module re-trains the best model of a given `architecture` to re-learn the optimal features to detect the new ground truth events annotated in the ground truth events.
249 |
250 | * Mandatory inputs:
251 | - `train_data`: LFP recorded data that will be used to train the model (`np.array`: `n_samples` x `n_channels`). If several sessions needed, concatenate them to get the specified format.
252 | - `train_GT`: ground truth events corresponding to the `train_data` (`np.array`: `n_events` x 2). If several sessions were used, don't forget to readjust the times to properly refer to `train_data`.. Same format as `predictions`.
253 | - `test_data`: LFP recorded data that will be used to test the re-trained model (`list()` of `np.array`: `n_samples` x `n_channels`).
254 | - `test_GT`: ground truth events corresponding to the `test_data` (`list()` of `np.array`: `n_events` x 2). Event times refer to each element of the `test_data` list.
255 |
256 | * Optional inputs:
257 | - `arch`: Name of the AI architecture to use (`string`). It can be: `CNN1D`, `CNN2D`, `LSTM`, `SVM` or `XGBOOST`.
258 | - `parameters`: dictionary, with the parameters that will be use in each specific architecture retraining
259 | - In 'XGBOOST': not needed
260 | - In 'SVM':
261 | parameters['Undersampler proportion']. Any value between 0 and 1. This parameter eliminates
262 | samples where no ripple is present untill the desired proportion is achieved:
263 | Undersampler proportion= Positive samples/Negative samples
264 | - In 'LSTM', 'CNN1D' and 'CNN2D':
265 | parameters['Epochs']. The number of times the training data set will be used to train the model
266 | parameters['Training batch']. The number of windows that will be processed before updating the weights
267 | - `save_path`: string, path where the retrained model will be saved
268 | - `d_sf`: Desired subsampling frequency in Hz (`int`). By default all works in 1250 Hz, but this function allows using different subsampling frequencies.
269 | - `merge_win`: Minimal length of the interval in miliseconds between predictions (`float`). If two detections are closer in time than this parameter, they will be merged together
270 |
271 | Usage examples can be found in the [examples_retraining.ipynb](https://github.com/PridaLab/rippl-AI/blob/main/examples_retraining.ipynb) python notebook.
272 |
273 |
274 | ## Exploration
275 |
276 | Finally, as a further explotation of this toolbox, we also offer an exploration module, in which you can create your own model. In the [examples_explore](https://github.com/PridaLab/rippl-AI/blob/main/examples_explore/) folder, you can see how different architectures can be modified by multiple parameters to create infinite number of other models, that can be better adjusted to the need of your desired events. For example, if you are interested in lower frequency events, such as theta cycles, this exploratory module will be of utmost convenience to find an AI architecture that better adapts to the need of your research. Here, we specify the most common parameters to explore for each architecture:
277 |
278 | ### 1D-CNN
279 | * Channels: number of LFP channel
280 | * Window size: LFP window size to evaluate: LFP window size to evaluate
281 | * Kernel factor
282 | * Batch size
283 | * Number of epochs
284 |
285 | ### 2D-CNN
286 | * Channels: number of LFP channel
287 | * Window size: LFP window size to evaluate
288 |
289 | ### LSTM
290 | * Channels: number of LFP channel
291 | * Window size: LFP window size to evaluate
292 | * Bidirectionality
293 | * Number of layers
294 | * Number of units per layer
295 | * Number of epochs
296 |
297 | ### SVM
298 | * Channels: number of LFP channel
299 | * Window size: LFP window size to evaluate
300 | * Undersampling
301 |
302 | ### XGBoost
303 | * Channels: number of LFP channel
304 | * Window size: LFP window size to evaluate
305 | * Maximum tree depth
306 | * Learning rate
307 | * Gamma
308 | * Lambda regularity
309 | * Scale
310 |
311 |
312 | # Enviroment setup
313 |
314 | 1. Install miniconda, following the tutorial: https://docs.conda.io/en/latest/miniconda.html
315 | 2. Launch the anaconda console, typing anaconda promp in the windows/linux search bar.
316 | 3. In the anaconda prompt, create a conda environment (e.g. `ripple_AI_env`):
317 | ```
318 | conda create -n rippl_AI_env python=3.9.15
319 | ```
320 | 4. This will create a enviroment in your miniconda3 enviroments folder, usually: `C:\Users\\miniconda3\envs`
321 | 5. Check that the enviroment `rippl_AI_env` has been created by typing:
322 | ```
323 | conda env list
324 | ```
325 | 6. Activate the enviroment with:
326 | ```conda activate rippl_AI_env```
327 | In case you want to launch the scripts from the command prompt. If you are using Visual Studio Code, you need to select the python interpreter `rippl_AI_env`
328 | 7. Next step after activating the enviroment, is to install every necessary python package:
329 | ```
330 | conda install pip
331 | pip install tensorflow==2.11 keras==2.11 xgboost==1.6.1 imblearn numpy matplotlib pandas scipy
332 | pip install -U scikit-learn==1.1.2
333 | ```
334 | To download the lab data from figshare (not normalized, sampled with the original frequency of 30 000 Hz):
335 | ```
336 | git clone https://github.com/cognoma/figshare.git
337 | cd figshare
338 | python setup.py install
339 | ```
340 | The package versions compatible with the toolbox are:
341 |
342 | - h5py==3.11.0
343 | - imbalanced-learn==0.12.2
344 | - imblearn==0.0
345 | - ipython==8.18.1
346 | - keras==2.11.0
347 | - numpy==1.26.4
348 | - pandas==2.2.2
349 | - pip==23.3.1
350 | - python==3.9.19
351 | - scikit-learn==1.1.2
352 | - scipy==1.13.0
353 | - tensorflow==2.11.0
354 | - xgboost==1.6.1
355 |
--------------------------------------------------------------------------------
/__pycache__/aux_fcn.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/__pycache__/aux_fcn.cpython-39.pyc
--------------------------------------------------------------------------------
/__pycache__/rippl_AI.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/__pycache__/rippl_AI.cpython-39.pyc
--------------------------------------------------------------------------------
/examples_detection.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Imports"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "import os\n",
18 | "import rippl_AI\n",
19 | "import importlib\n",
20 | "importlib.reload(rippl_AI)\n",
21 | "import aux_fcn\n",
22 | "\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "import numpy as np\n",
25 | "%matplotlib qt"
26 | ]
27 | },
28 | {
29 | "attachments": {},
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "# Basic detection example\n",
34 | "In this section, a use example of the predict and detection functions are provided"
35 | ]
36 | },
37 | {
38 | "attachments": {},
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "### Download data"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "Data already exists. Moving on.\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "from figshare.figshare import Figshare\n",
60 | "fshare = Figshare()\n",
61 | "\n",
62 | "article_ids = [14959449] \n",
63 | "sess=['Dlx1'] \n",
64 | "for id,s in zip(article_ids,sess):\n",
65 | " datapath = os.path.join('Downloaded_data', f'{s}')\n",
66 | " if os.path.isdir(datapath):\n",
67 | " print(\"Data already exists. Moving on.\")\n",
68 | " else:\n",
69 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
70 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
71 | " print(\"Data downloaded!\")"
72 | ]
73 | },
74 | {
75 | "attachments": {},
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "### Data loading"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 3,
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "Channel map value: [0, 1, 2, 3, 4, 5, 6, 7]\n",
92 | "Downloaded_data\\Dlx1\\figshare_14959449/lfp_Dlx1-2021-02-12_12-46-54.dat\n",
93 | "fileStart 0\n",
94 | "fileStop 490242048\n",
95 | "nSamples 245121024\n",
96 | "nSamplesPerChannel 30640128\n",
97 | "nSamplesPerChunk 10000\n",
98 | "size data 30640128\n",
99 | "Sampling frequency: 30000\n",
100 | "Shape of the original data (30640128, 8)\n",
101 | "[[1 1 2]\n",
102 | " [2 1 2]\n",
103 | " [3 1 2]\n",
104 | " [4 1 2]\n",
105 | " [5 1 2]\n",
106 | " [6 1 2]\n",
107 | " [7 1 2]\n",
108 | " [8 1 2]]\n"
109 | ]
110 | }
111 | ],
112 | "source": [
113 | "path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')\n",
114 | "\n",
115 | "sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)\n",
116 | "channels_map = aux_fcn.load_channels_map(path)\n",
117 | "\n",
118 | "# Reformat channels into correct values\n",
119 | "channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)\n",
120 | "# Read .dat\n",
121 | "print('Channel map value: ',channels)\n",
122 | "LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)\n",
123 | "print('Sampling frequency: ', sf)\n",
124 | "print('Shape of the original data',LFP.shape)\n",
125 | "print(channels_map)\n"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "# The predict function takes care of normalizing and subsampling your data\n",
135 | "# If no architecture or model is specified, the best CNN1D will be used\n",
136 | "prob,LFP_norm=rippl_AI.predict(LFP,sf) "
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "# An interactive GUI will be displayed, choose your deired threshold\n",
146 | "det_ind=rippl_AI.get_intervals(prob,LFP_norm=LFP_norm) \n",
147 | "print(f\"{det_ind.shape[0]} events where detected\")"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "A **manual event curation GUI** is included in the toolbox. You can select and discard events. As a result of the curation a .txt file will be generated with the events' times.\n",
155 | "For a detailed description of its features, parameters and functionalities, check the comments of *manual_curation* in *aux_fcn.py*"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "aux_fcn.manual_curation(det_ind,LFP_norm,file_path='C:\\\\rippl-AI\\\\Curated_txt')"
165 | ]
166 | },
167 | {
168 | "attachments": {},
169 | "cell_type": "markdown",
170 | "metadata": {},
171 | "source": [
172 | "# Get performances after detection\n",
173 | "Every model predict, get_intervals is used automatically and the performance metric is ploted"
174 | ]
175 | },
176 | {
177 | "attachments": {},
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "### Data download"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "from figshare.figshare import Figshare\n",
191 | "fshare = Figshare()\n",
192 | "\n",
193 | "article_ids = [14959449] \n",
194 | "sess=['Dlx1'] \n",
195 | "for id,s in zip(article_ids,sess):\n",
196 | " datapath = os.path.join('Downloaded_data', f'{s}')\n",
197 | " if os.path.isdir(datapath):\n",
198 | " print(\"Data already exists. Moving on.\")\n",
199 | " else:\n",
200 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
201 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
202 | " print(\"Data downloaded!\")"
203 | ]
204 | },
205 | {
206 | "attachments": {},
207 | "cell_type": "markdown",
208 | "metadata": {},
209 | "source": [
210 | "### Data loading"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')\n",
220 | "\n",
221 | "sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)\n",
222 | "\n",
223 | "channels_map = aux_fcn.load_channels_map(path)\n",
224 | "# Now the ground truth (GT) tagged events is loaded \n",
225 | "ripples=aux_fcn.load_ripples(path)/sf\n",
226 | "# Reformat channels into correct values\n",
227 | "channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)\n",
228 | "# Read .dat\n",
229 | "print('Channel map value: ',channels)\n",
230 | "LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)\n",
231 | "print('Sampling frequency: ', sf)\n",
232 | "print('Shape of the original data',LFP.shape)"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "# Two loops going over every possible model\n",
242 | "architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']\n",
243 | "SWR_prob=[[None]*5]*5\n",
244 | "for i,architecture in enumerate(architectures):\n",
245 | " print(i,architecture)\n",
246 | " for n in range(1,6):\n",
247 | " # Make sure the selected model expected number of channels is the same as the channels array passed to the predict fcn\n",
248 | " # In this case, we are manually setting the channel array to 3 \n",
249 | " if architecture=='CNN2D' and n>=3:\n",
250 | " channels=[0,3,7]\n",
251 | " else:\n",
252 | " channels=[0,1,2,3,4,5,6,7]\n",
253 | " SWR_prob[i][n-1],_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=n,channels=channels)\n",
254 | "\n",
255 | "# SWR_prob contains the output of each model\n"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "th_arr=np.linspace(0.1,1,10)\n",
265 | "fig,axs=plt.subplots(5,5,figsize=(10,10),sharex='all',sharey='all')\n",
266 | "for i in range(5):\n",
267 | " for j in range(5):\n",
268 | " F1_arr=np.zeros(shape=(len(th_arr)))\n",
269 | " for k,th in enumerate(th_arr):\n",
270 | " det_ind=rippl_AI.get_intervals(SWR_prob[i][j],threshold=th)\n",
271 | " #print(ripples)\n",
272 | " _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)\n",
273 | " axs[i,j].plot(th_arr,F1_arr)\n",
274 | " axs[i,0].set_title(architectures[i])\n",
275 | "\n",
276 | "axs[0,0].set_xlabel('Threshold')\n",
277 | "axs[0,0].set_ylabel('F1')"
278 | ]
279 | },
280 | {
281 | "attachments": {},
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "# Detecting with less than 8 channels\n",
286 | "Detectors need several channels for optimal performance. We found out that 8 channels have enough information to assure good performance. But what happens if we don't have 8? We have seen that interpolating the missing channels also works. In this section, we will show how to use the interpolation function we have created for this purpose, inside the `aux_fcn` package."
287 | ]
288 | },
289 | {
290 | "attachments": {},
291 | "cell_type": "markdown",
292 | "metadata": {},
293 | "source": [
294 | "### Data download"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "from figshare.figshare.figshare import Figshare\n",
304 | "fshare = Figshare()\n",
305 | "\n",
306 | "article_ids = [14959449] \n",
307 | "sess=['Dlx1'] \n",
308 | "for id,s in zip(article_ids,sess):\n",
309 | " datapath = os.path.join('Downloaded_data', f'{s}')\n",
310 | " if os.path.isdir(datapath):\n",
311 | " print(\"Data already exists. Moving on.\")\n",
312 | " else:\n",
313 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
314 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
315 | " print(\"Data downloaded!\")"
316 | ]
317 | },
318 | {
319 | "attachments": {},
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "### Data load\n",
324 | "To ilustrate how 'interpolate_channels' can be used to extract the desired number of channels, we will be simulating two cases using the DLx1 session:\n",
325 | "1. We are using a recording probe that extracts 4 channels, when we need 8.\n",
326 | "2. Some channels are dead or have to much noise."
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')\n",
336 | "\n",
337 | "sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)\n",
338 | "\n",
339 | "channels_map = aux_fcn.load_channels_map(path)\n",
340 | "channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)\n",
341 | "LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=False)\n",
342 | "print('Sampling frequency: ', sf)\n",
343 | "print('Shape of the original data',LFP.shape)\n",
344 | "LFP_linear=LFP[:,[0,2,4,6]]\n",
345 | "print('Shape of the 4 channels simulated data: ',LFP_linear.shape)\n",
346 | "LFP[:,[2,5]]=0\n",
347 | "LFP_dead=LFP\n",
348 | "print('Sample of the simulated dead LFP: ',LFP_dead[0])"
349 | ]
350 | },
351 | {
352 | "attachments": {},
353 | "cell_type": "markdown",
354 | "metadata": {},
355 | "source": [
356 | "After interpolation, the data is ready to use in prediction"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": null,
362 | "metadata": {},
363 | "outputs": [],
364 | "source": [
365 | "# Define channels\n",
366 | "channels_interpolation = [0,-1,1,-1,2,-1,-1,3]\n",
367 | "\n",
368 | "# Make interpolation\n",
369 | "LFP_interpolated = aux_fcn.interpolate_channels(LFP_linear, channels_interpolation)\n",
370 | "print('Shape of the interpolated LFP: ',LFP_interpolated.shape)"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": null,
376 | "metadata": {},
377 | "outputs": [],
378 | "source": [
379 | " # Define channels\n",
380 | "channels_interpolation = [0,1,-1,3,4,-1,6,7]\n",
381 | "\n",
382 | "# Make interpolation\n",
383 | "LFP_interpolated = aux_fcn.interpolate_channels(LFP_dead, channels_interpolation)\n",
384 | "print('Value of the 1st sample of the interpolated LFP: ',LFP_interpolated[0])"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "metadata": {},
390 | "source": [
391 | "# Using an ensemble model\n",
392 | "In this section, we will show how to use an ensemble model that combines the output of the best models of each architecture. This model has better performance and more stability than the individual models. In this case, only the best ensemble model will be provided.\n",
393 | "\n",
394 | "First, the output of the 5 selected models needs to reshaped"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": null,
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "# 5 outputs are generated\n",
404 | "architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']\n",
405 | "output=[]\n",
406 | "for architecture in architectures:\n",
407 | " channels=[0,1,2,3,4,5,6,7]\n",
408 | " SWR_prob,_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=1,channels=channels)\n",
409 | " output.append(SWR_prob)\n",
410 | "ens_input=np.array(output).transpose()\n"
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {},
416 | "source": [
417 | "Generating ensemble model output"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "\n",
427 | "prob_ens=rippl_AI.predict_ens(ens_input)"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "metadata": {},
433 | "source": [
434 | "Plot performance"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": null,
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "fig,ax=plt.subplots()\n",
444 | "th_arr=np.linspace(0.1,1,10)\n",
445 | "F1_arr=np.zeros(shape=(len(th_arr)))\n",
446 | "for k,th in enumerate(th_arr):\n",
447 | " det_ind=rippl_AI.get_intervals(prob_ens,threshold=th)\n",
448 | " _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)\n",
449 | "ax.plot(th_arr,F1_arr)\n",
450 | "ax.set_title('Ensemble model')\n",
451 | "ax.set_ylim(-0.05,0.8)\n",
452 | "ax.set_xlabel('Threshold')\n",
453 | "ax.set_ylabel('F1')\n"
454 | ]
455 | }
456 | ],
457 | "metadata": {
458 | "kernelspec": {
459 | "display_name": "PublicBCG_d",
460 | "language": "python",
461 | "name": "python3"
462 | },
463 | "language_info": {
464 | "codemirror_mode": {
465 | "name": "ipython",
466 | "version": 3
467 | },
468 | "file_extension": ".py",
469 | "mimetype": "text/x-python",
470 | "name": "python",
471 | "nbconvert_exporter": "python",
472 | "pygments_lexer": "ipython3",
473 | "version": "3.9.15"
474 | },
475 | "orig_nbformat": 4
476 | },
477 | "nbformat": 4,
478 | "nbformat_minor": 2
479 | }
480 |
--------------------------------------------------------------------------------
/examples_explore/example_CNN1D.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# CNN1D parameter exploration\n",
9 | "This notebook is a template for finding the CNN1D model best suited for your needs\n",
10 | "\n",
11 | "This architecture is based in the CNN1D designed by A. Navas-Olivé (https://doi.org/10.7554/eLife.77772)."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import os\n",
21 | "import importlib\n",
22 | "import matplotlib.pyplot as plt\n",
23 | "import numpy as np\n",
24 | "import sys\n",
25 | "parent_dir=os.path.dirname(os.getcwd())\n",
26 | "sys.path.insert(0,parent_dir)\n",
27 | "import rippl_AI\n",
28 | "import aux_fcn\n",
29 | "importlib.reload(aux_fcn)\n",
30 | "importlib.reload(rippl_AI)"
31 | ]
32 | },
33 | {
34 | "attachments": {},
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "### Data download\n",
39 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "from figshare.figshare.figshare import Figshare\n",
49 | "fshare = Figshare()\n",
50 | "\n",
51 | "article_ids = [16847521,16856137,14959449,14960085] \n",
52 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
53 | "for id,s in zip(article_ids,sess):\n",
54 | " datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')\n",
55 | " if os.path.isdir(datapath):\n",
56 | " print(f\"{s} session already exists. Moving on.\")\n",
57 | " else:\n",
58 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
59 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
60 | " print(\"Data downloaded!\")"
61 | ]
62 | },
63 | {
64 | "attachments": {},
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "### Data load\n",
69 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
70 | "That is the required input for the training parser"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "# The training sessions will be appended together. Replace this cell with your own data loading\n",
80 | "train_LFPs=[]\n",
81 | "train_GTs=[]\n",
82 | "# Amigo2\n",
83 | "path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')\n",
84 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
85 | "train_LFPs.append(LFP)\n",
86 | "train_GTs.append(GT)\n",
87 | "# Som2\n",
88 | "path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')\n",
89 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
90 | "train_LFPs.append(LFP)\n",
91 | "train_GTs.append(GT)\n",
92 | "\n",
93 | "## Append all your validation sessions\n",
94 | "val_LFPs=[]\n",
95 | "val_GTs=[]\n",
96 | "# Dlx1 Validation\n",
97 | "path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')\n",
98 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
99 | "val_LFPs.append(LFP)\n",
100 | "val_GTs.append(GT)\n",
101 | "# Thy07 Validation\n",
102 | "path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')\n",
103 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
104 | "val_LFPs.append(LFP)\n",
105 | "val_GTs.append(GT)\n",
106 | "\n",
107 | "x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
108 | ]
109 | },
110 | {
111 | "attachments": {},
112 | "cell_type": "markdown",
113 | "metadata": {},
114 | "source": [
115 | "## CNN1D training parameters"
116 | ]
117 | },
118 | {
119 | "attachments": {},
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "#### Parameters:\n",
124 | "* Channels: number of channels that will be used to train the model, extracted from the data shape defined in the previous cell\n",
125 | "* Timesteps: number of samples that the model will use to produce a single output\n",
126 | "* Configuration: list with as many elements as layers in the model shaped [number of kernels layers, kernel size and stride ]. The length size and the kernel layer were matched to reduce design complexity.\n",
127 | "* Epoch: number of times the training data set is used to train the model\n",
128 | "* Training batch: number of windows that are proccessed before weight updating"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "conf= {\"timesteps\": [16], # 16, 32, 64 ...\n",
138 | " \"configuration\": [[[4,2],[2,1],[8,2],[4,1],[16,2],[8,1],[32,2]], \n",
139 | " [[4,4],[2,1],[8,2],[4,1],[16,2],[8,1],[32,2]]], \n",
140 | " \"epochs\": [1], # 1, 2, 3, 5...\n",
141 | " \"train_batch\": [2**5], # 32, 64, 128...\n",
142 | "}"
143 | ]
144 | },
145 | {
146 | "attachments": {},
147 | "cell_type": "markdown",
148 | "metadata": {},
149 | "source": [
150 | "### Training"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "# Desired sampling frequency of the models\n",
160 | "sf=1250\n",
161 | "th_arr=np.linspace(0.1,0.9,9)\n",
162 | "model_name_arr=[] # To plot in the next cell\n",
163 | "model_arr=[] # Actual model array, used in the next validation section\n",
164 | "n_channels=x_training.shape[1]\n",
165 | "timesteps_arr=conf['timesteps']\n",
166 | "\n",
167 | "config_arr=conf['configuration']\n",
168 | "epochs_arr=conf['epochs']\n",
169 | "train_batch_arr=conf['train_batch'] \n",
170 | "\n",
171 | "l_ts=len(timesteps_arr)\n",
172 | "l_conf=len(config_arr)\n",
173 | "l_epochs =len(epochs_arr)\n",
174 | "l_batch =len(train_batch_arr)\n",
175 | "n_iters=l_ts*l_conf*l_epochs*l_batch\n",
176 | "# GT is in the shape (n_events x 2), a y output signal with the same length as x is required\n",
177 | "perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]\n",
178 | "perf_test_arr=np.zeros_like(perf_train_arr)\n",
179 | "timesteps_arr_ploting=[] # Array that will be used in the validation, to be able to call the function predict\n",
180 | "\n",
181 | "print(f'{n_channels} channels will be used to train the CNN1D models')\n",
182 | "\n",
183 | "print(f'{n_iters} models will be trained')\n",
184 | "\n",
185 | "x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)\n",
186 | "\n",
187 | "y_test_or= np.zeros(shape=(len(x_test_or)))\n",
188 | "for ev in GT_test:\n",
189 | " y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
190 | "y_train_or= np.zeros(shape=(len(x_train_or)))\n",
191 | "for ev in GT_train:\n",
192 | " y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
193 | "\n",
194 | "\n",
195 | "for i_ts,timesteps in enumerate(timesteps_arr):\n",
196 | " x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps,n_channels)\n",
197 | " y_train_aux=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps)\n",
198 | " x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps,n_channels)\n",
199 | " y_test_aux=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps)\n",
200 | "\n",
201 | " y_train=np.zeros(shape=[x_train.shape[0],1])\n",
202 | " for i in range(y_train_aux.shape[0]):\n",
203 | " y_train[i]=1 if any (y_train_aux[i]==1) else 0\n",
204 | " print(\"Train Input and Output dimension\", x_train.shape,y_train.shape)\n",
205 | " \n",
206 | " y_test=np.zeros(shape=[x_test.shape[0],1])\n",
207 | " for i in range(y_test_aux.shape[0]):\n",
208 | " y_test[i]=1 if any (y_test_aux[i]==1) else 0\n",
209 | "\n",
210 | " for i_conf, configuration in enumerate(config_arr):\n",
211 | " for i_epochs,epochs in enumerate(epochs_arr):\n",
212 | " for i_batch,train_batch in enumerate(train_batch_arr):\n",
213 | " iter=((i_ts*l_conf+i_conf)*l_epochs + i_epochs)*l_batch + i_batch\n",
214 | " print(f\"\\nIteration {iter+1} out of {n_iters}\")\n",
215 | " print(f'Number of channels: {n_channels:d}, Time steps: {timesteps:d},\\nconfiguration: {configuration}\\nEpochs: {epochs:d}, Samples per batch: {train_batch:d}')\n",
216 | "\n",
217 | " model = aux_fcn.build_CNN1D(n_channels,timesteps,configuration)\n",
218 | " # Training\n",
219 | " model.fit(x_train, y_train,shuffle=False, epochs=epochs,batch_size=train_batch,validation_data=(x_test,y_test), verbose=1)\n",
220 | " model_arr.append(model)\n",
221 | " # Prediction\n",
222 | " test_signal = model.predict(x_test,verbose=1)\n",
223 | " train_signal=model.predict(x_train,verbose=1)\n",
224 | "\n",
225 | " y_train_predict=np.empty(shape=(x_train.shape[0]*timesteps,1,1))\n",
226 | " for i,window in enumerate(train_signal):\n",
227 | " y_train_predict[i*timesteps:(i+1)*timesteps]=window\n",
228 | " y_test_predict=np.empty(shape=(x_test.shape[0]*timesteps,1,1))\n",
229 | " for i,window in enumerate(test_signal):\n",
230 | " y_test_predict[i*timesteps:(i+1)*timesteps]=window\n",
231 | "\n",
232 | " ############################\n",
233 | " for i,th in enumerate(th_arr):\n",
234 | " # Test\n",
235 | " ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf\n",
236 | " perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]\n",
237 | " # Train\n",
238 | " ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf\n",
239 | " perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]\n",
240 | "\n",
241 | " # Saving the model\n",
242 | " model_name=f\"CNN1D_Ch{n_channels:d}_Ts{timesteps:03d}_C{i_conf:02d}_E{epochs:02d}_TB{train_batch:04d}\"\n",
243 | " model.save(os.path.join(parent_dir,'explore_models',model_name))\n",
244 | "\n",
245 | " model_name_arr.append(model_name)\n",
246 | " timesteps_arr_ploting.append(timesteps)"
247 | ]
248 | },
249 | {
250 | "attachments": {},
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "### Plot training results"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "# Plot training results\n",
264 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
265 | "\n",
266 | "for i in range(n_iters):\n",
267 | " axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')\n",
268 | " axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')\n",
269 | " axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')\n",
270 | " axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')\n",
271 | " axs[i,0].set_title(model_name_arr[i])\n",
272 | " axs[i,0].set_ylabel('Precision')\n",
273 | " axs[i,1].set_ylabel('F1')\n",
274 | "axs[-1,0].set_xlabel('Recall')\n",
275 | "axs[-1,1].set_xlabel('Threshold')\n",
276 | "axs[0,0].legend(['Training','Test'])\n",
277 | "plt.show()"
278 | ]
279 | },
280 | {
281 | "attachments": {},
282 | "cell_type": "markdown",
283 | "metadata": {},
284 | "source": [
285 | "### Validation"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": null,
291 | "metadata": {},
292 | "outputs": [],
293 | "source": [
294 | "# For loop iterating over the models\n",
295 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
296 | "for n_m,model in enumerate(model_arr):\n",
297 | " F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored\n",
298 | " for n_sess,LFP in enumerate(x_val_list):\n",
299 | " val_pred=rippl_AI.predict(LFP,sf=1250,arch='CNN1D',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]\n",
300 | " for i,th in enumerate(th_arr):\n",
301 | " val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf\n",
302 | " F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]\n",
303 | " \n",
304 | " axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')\n",
305 | " axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')\n",
306 | " for F1 in F1_arr:\n",
307 | " axs[n_m,1].plot(th_arr,F1)\n",
308 | " axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')\n",
309 | " axs[n_m,0].set_title(model_name_arr[n_m])\n",
310 | " axs[n_m,0].set_ylabel('Precision')\n",
311 | " axs[n_m,1].set_ylabel('F1')\n",
312 | "axs[-1,0].set_xlabel('Recall')\n",
313 | "axs[-1,1].set_xlabel('Threshold')\n",
314 | "plt.show()\n",
315 | " "
316 | ]
317 | }
318 | ],
319 | "metadata": {
320 | "kernelspec": {
321 | "display_name": "PublicBCG_d",
322 | "language": "python",
323 | "name": "python3"
324 | },
325 | "language_info": {
326 | "codemirror_mode": {
327 | "name": "ipython",
328 | "version": 3
329 | },
330 | "file_extension": ".py",
331 | "mimetype": "text/x-python",
332 | "name": "python",
333 | "nbconvert_exporter": "python",
334 | "pygments_lexer": "ipython3",
335 | "version": "3.9.15"
336 | },
337 | "orig_nbformat": 4
338 | },
339 | "nbformat": 4,
340 | "nbformat_minor": 2
341 | }
342 |
--------------------------------------------------------------------------------
/examples_explore/example_CNN2D.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# CNN2D parameter exploration\n",
9 | "This notebook is a template for finding the CNN2D model best suited for your needs \n",
10 | "\n",
11 | "This arquitecture is inspired by the UNet (https://doi.org/10.48550/arXiv.1505.04597) and YOLOR (https://doi.org/10.48550/arXiv.2105.04206)\n",
12 | "The 1st half uses convolution and MaxPooling to reduce the dimnensinality of the input, and the late half expands it"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import os\n",
22 | "import matplotlib.pyplot as plt\n",
23 | "import numpy as np\n",
24 | "import sys\n",
25 | "parent_dir=os.path.dirname(os.getcwd())\n",
26 | "sys.path.insert(0,parent_dir)\n",
27 | "import rippl_AI\n",
28 | "import aux_fcn"
29 | ]
30 | },
31 | {
32 | "attachments": {},
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "### Data download\n",
37 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "from figshare.figshare.figshare import Figshare\n",
47 | "fshare = Figshare()\n",
48 | "\n",
49 | "article_ids = [16847521,16856137,14959449,14960085] \n",
50 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
51 | "for id,s in zip(article_ids,sess):\n",
52 | " datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')\n",
53 | " if os.path.isdir(datapath):\n",
54 | " print(f\"{s} session already exists. Moving on.\")\n",
55 | " else:\n",
56 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
57 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
58 | " print(\"Data downloaded!\")"
59 | ]
60 | },
61 | {
62 | "attachments": {},
63 | "cell_type": "markdown",
64 | "metadata": {},
65 | "source": [
66 | "### Data load\n",
67 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
68 | "That is the required input for the training parser"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# The training sessions will be appended together. Replace this cell with your own data loading\n",
78 | "train_LFPs=[]\n",
79 | "train_GTs=[]\n",
80 | "# Amigo2\n",
81 | "path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')\n",
82 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
83 | "train_LFPs.append(LFP)\n",
84 | "train_GTs.append(GT)\n",
85 | "# Som2\n",
86 | "path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')\n",
87 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
88 | "train_LFPs.append(LFP)\n",
89 | "train_GTs.append(GT)\n",
90 | "\n",
91 | "## Append all your validation sessions\n",
92 | "val_LFPs=[]\n",
93 | "val_GTs=[]\n",
94 | "# Dlx1 Validation\n",
95 | "path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')\n",
96 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
97 | "val_LFPs.append(LFP)\n",
98 | "val_GTs.append(GT)\n",
99 | "# Thy07 Validation\n",
100 | "path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')\n",
101 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
102 | "val_LFPs.append(LFP)\n",
103 | "val_GTs.append(GT)\n",
104 | "\n",
105 | "x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
106 | ]
107 | },
108 | {
109 | "attachments": {},
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "## CNN2D training parameters"
114 | ]
115 | },
116 | {
117 | "attachments": {},
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "#### Parameters:\n",
122 | "* Channels: number of channels that will be used to train the model, extracted from the data shape defined in the previous cell\n",
123 | "* Timesteps: number of samples that the model will use to produce a single output\n",
124 | "* Configuration: each element configures a convolutional layer of the model, shaped like [Number of filters, 1st dimension of kernel, 2nd dimension of kernel]. After each of the convolutional layers of the 1st half of the model there is a Max pooling 2D layer that reduces the output dimensionality.\n",
125 | "WARNING: each max Pooling layer (each layer in the 1st half of the models) halves the dimensionality of the layer: this limits the number \n",
126 | "of layers of the model according to the number of channels or timesteps, the smallest of the two: 8 channels allows 4 max Pooling steps (8 -> 4 -> 2 -> 1), 3 channels only 1 (3 ->1)\n",
127 | "Please take it into consideration when designing models\n",
128 | "* Epoch: number of times the training data set is used to train the model\n",
129 | "* Training batch: number of windows that are proccessed before weight updating"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "conf= {\"timesteps\":[40], # 8, 16, 40 ...\n",
139 | " \"configuration\": [[[32,2,2],[16,2,2],[8,3,2],[16,4,1],[16,6,1],[8,8,2]], # Configuration examples:\n",
140 | " [[16,2,2],[8,2,2],[4,3,2],[8,4,1],[8,6,1],[4,8,1]]], # [[32,2,2],[16,2,2],[8,3,2],[16,4,1],[16,6,1],[8,8,2]]\n",
141 | " # [[16,2,2],[8,2,2],[4,3,2],[8,4,1],[8,6,1],[4,8,1]]\n",
142 | " # [[64,2,2],[32,2,2],[16,3,2],[32,4,1],[32,6,1],[16,8,1]]\n",
143 | " \"epochs\": [2], # 2, 10, 30 ...\n",
144 | " \"train_batch\": [2**5], # 32, 64, 128 ... \n",
145 | "}"
146 | ]
147 | },
148 | {
149 | "attachments": {},
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "### Training"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "# Desired sampling frequency of the models\n",
163 | "sf=1250\n",
164 | "th_arr=np.linspace(0.1,0.9,9)\n",
165 | "model_name_arr=[] # To plot in the next cell\n",
166 | "model_arr=[] # Actual model array, used in the next validation section\n",
167 | "n_channels=x_training.shape[1]\n",
168 | "timesteps_arr=conf['timesteps']\n",
169 | "\n",
170 | "conf_arr=conf['configuration'] \n",
171 | "epochs_arr=conf['epochs']\n",
172 | "train_batch_arr=conf['train_batch'] \n",
173 | "\n",
174 | "l_ts=len(timesteps_arr)\n",
175 | "\n",
176 | "l_conf=len(conf_arr)\n",
177 | "l_epochs =len(epochs_arr)\n",
178 | "l_batch =len(train_batch_arr)\n",
179 | "n_iters=l_ts*l_conf*l_epochs*l_batch\n",
180 | "# GT is in the shape (n_events x 2), a y output signal with the same length as x is required\n",
181 | "perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]\n",
182 | "perf_test_arr=np.zeros_like(perf_train_arr)\n",
183 | "timesteps_arr_ploting=[] # Array that will be used in the validation, to be able to call the function predict\n",
184 | "\n",
185 | "print(f'{n_channels} channels will be used to train the CNN2D models')\n",
186 | "\n",
187 | "print(f'{n_iters} models will be trained')\n",
188 | "\n",
189 | "x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)\n",
190 | "\n",
191 | "y_test_or= np.zeros(shape=(len(x_test_or)))\n",
192 | "for ev in GT_test:\n",
193 | " y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
194 | "y_train_or= np.zeros(shape=(len(x_train_or)))\n",
195 | "for ev in GT_train:\n",
196 | " y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
197 | "\n",
198 | "\n",
199 | "for i_ts,timesteps in enumerate(timesteps_arr):\n",
200 | " x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps,n_channels,1)\n",
201 | " y_train_aux=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps,1)\n",
202 | " x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps,n_channels,1)\n",
203 | " y_test_aux=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps,1)\n",
204 | "\n",
205 | " y_train=np.zeros(shape=[x_train.shape[0],1])\n",
206 | " for i in range(y_train_aux.shape[0]):\n",
207 | " y_train[i]=sum(y_train_aux[i])/timesteps\n",
208 | " \n",
209 | " y_test=np.zeros(shape=[x_test.shape[0],1])\n",
210 | " for i in range(y_test_aux.shape[0]):\n",
211 | " y_test[i]=sum(y_test_aux[i])/timesteps\n",
212 | " input_shape=x_test.shape[1:]\n",
213 | " \n",
214 | " for i_conf,conf in enumerate(conf_arr):\n",
215 | " for i_epochs,epochs in enumerate(epochs_arr):\n",
216 | " for i_batch,train_batch in enumerate(train_batch_arr):\n",
217 | " iter=((i_ts*l_conf+ i_conf)*l_epochs + i_epochs)*l_batch + i_batch\n",
218 | " print(f\"\\nIteration {iter+1} out of {n_iters}\")\n",
219 | " print(f'Number of channels: {n_channels:d}, Time steps: {timesteps:d}, Configuration: {i_conf:d}\\nEpochs: {epochs:d}, Samples per batch: {train_batch:d}')\n",
220 | "\n",
221 | " model = aux_fcn.build_CNN2D(conf, input_shape = input_shape)\n",
222 | " model.fit(x_train, y_train,shuffle=False, epochs=epochs,batch_size=train_batch,validation_data=(x_test,y_test), verbose=1)\n",
223 | " model_arr.append(model)\n",
224 | "\n",
225 | " test_signal = model.predict(x_test,verbose=1)\n",
226 | " train_signal=model.predict(x_train,verbose=1)\n",
227 | "\n",
228 | " y_train_predict=np.empty(shape=(x_train.shape[0]*timesteps,1,1))\n",
229 | " for i,window in enumerate(train_signal):\n",
230 | " y_train_predict[i*timesteps:(i+1)*timesteps]=window\n",
231 | "\n",
232 | " y_test_predict=np.empty(shape=(x_test.shape[0]*timesteps,1,1))\n",
233 | " for i,window in enumerate(test_signal):\n",
234 | " y_test_predict[i*timesteps:(i+1)*timesteps]=window\n",
235 | "\n",
236 | "\n",
237 | " ############################\n",
238 | " for i,th in enumerate(th_arr):\n",
239 | " # Test\n",
240 | " ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf\n",
241 | " perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]\n",
242 | " # Train\n",
243 | " ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf\n",
244 | " perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]\n",
245 | "\n",
246 | " # Saving the model\n",
247 | " model_name=f\"CNN2D_Ch{n_channels:d}_Ts{timesteps:03d}_C{i_conf:02d}_E{epochs:02d}_TB{train_batch:04d}\"\n",
248 | " model.save(os.path.join(parent_dir,'explore_models',model_name))\n",
249 | "\n",
250 | " model_name_arr.append(model_name)\n",
251 | " timesteps_arr_ploting.append(timesteps)"
252 | ]
253 | },
254 | {
255 | "attachments": {},
256 | "cell_type": "markdown",
257 | "metadata": {},
258 | "source": [
259 | "### Plot training results"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "metadata": {},
266 | "outputs": [],
267 | "source": [
268 | "# Plot training results\n",
269 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
270 | "\n",
271 | "for i in range(n_iters):\n",
272 | " axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')\n",
273 | " axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')\n",
274 | " axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')\n",
275 | " axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')\n",
276 | " axs[i,0].set_title(model_name_arr[i])\n",
277 | " axs[i,0].set_ylabel('Precision')\n",
278 | " axs[i,1].set_ylabel('F1')\n",
279 | "axs[-1,0].set_xlabel('Recall')\n",
280 | "axs[-1,1].set_xlabel('Threshold')\n",
281 | "axs[0,0].legend(['Training','Test'])\n",
282 | "plt.show()"
283 | ]
284 | },
285 | {
286 | "attachments": {},
287 | "cell_type": "markdown",
288 | "metadata": {},
289 | "source": [
290 | "### Validation"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "# For loop iterating over the models\n",
300 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
301 | "for n_m,model in enumerate(model_arr):\n",
302 | " F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored\n",
303 | " for n_sess,LFP in enumerate(x_val_list):\n",
304 | " val_pred=rippl_AI.predict(LFP,sf=1250,arch='CNN2D',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]\n",
305 | " for i,th in enumerate(th_arr):\n",
306 | " val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf\n",
307 | " F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]\n",
308 | " \n",
309 | " axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')\n",
310 | " axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')\n",
311 | " for F1 in F1_arr:\n",
312 | " axs[n_m,1].plot(th_arr,F1)\n",
313 | " axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')\n",
314 | " axs[n_m,0].set_title(model_name_arr[n_m])\n",
315 | " axs[n_m,0].set_ylabel('Precision')\n",
316 | " axs[n_m,1].set_ylabel('F1')\n",
317 | "axs[-1,0].set_xlabel('Recall')\n",
318 | "axs[-1,1].set_xlabel('Threshold')\n",
319 | "plt.show()\n",
320 | " "
321 | ]
322 | }
323 | ],
324 | "metadata": {
325 | "kernelspec": {
326 | "display_name": "PublicBCG_d",
327 | "language": "python",
328 | "name": "python3"
329 | },
330 | "language_info": {
331 | "codemirror_mode": {
332 | "name": "ipython",
333 | "version": 3
334 | },
335 | "file_extension": ".py",
336 | "mimetype": "text/x-python",
337 | "name": "python",
338 | "nbconvert_exporter": "python",
339 | "pygments_lexer": "ipython3",
340 | "version": "3.9.15"
341 | },
342 | "orig_nbformat": 4
343 | },
344 | "nbformat": 4,
345 | "nbformat_minor": 2
346 | }
347 |
--------------------------------------------------------------------------------
/examples_explore/example_LSTM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# LSTM parameter exploration\n",
9 | "This notebook is a template for finding the LSTM model best suited for your needs"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import numpy as np\n",
21 | "from tensorflow import keras\n",
22 | "import sys\n",
23 | "parent_dir=os.path.dirname(os.getcwd())\n",
24 | "sys.path.insert(0,parent_dir)\n",
25 | "import rippl_AI\n",
26 | "import aux_fcn"
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "### Data download\n",
35 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "from figshare.figshare.figshare import Figshare\n",
45 | "fshare = Figshare()\n",
46 | "\n",
47 | "article_ids = [16847521,16856137,14959449,14960085] \n",
48 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
49 | "for id,s in zip(article_ids,sess):\n",
50 | " datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')\n",
51 | " if os.path.isdir(datapath):\n",
52 | " print(f\"{s} session already exists. Moving on.\")\n",
53 | " else:\n",
54 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
55 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
56 | " print(\"Data downloaded!\")"
57 | ]
58 | },
59 | {
60 | "attachments": {},
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "### Data load\n",
65 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
66 | "That is the required input for the training parser"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# The training sessions will be appended together. Replace this cell with your own data loading\n",
76 | "train_LFPs=[]\n",
77 | "train_GTs=[]\n",
78 | "# Amigo2\n",
79 | "path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')\n",
80 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
81 | "train_LFPs.append(LFP)\n",
82 | "train_GTs.append(GT)\n",
83 | "# Som2\n",
84 | "path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')\n",
85 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
86 | "train_LFPs.append(LFP)\n",
87 | "train_GTs.append(GT)\n",
88 | "\n",
89 | "## Append all your validation sessions\n",
90 | "val_LFPs=[]\n",
91 | "val_GTs=[]\n",
92 | "# Dlx1 Validation\n",
93 | "path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')\n",
94 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
95 | "val_LFPs.append(LFP)\n",
96 | "val_GTs.append(GT)\n",
97 | "# Thy07 Validation\n",
98 | "path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')\n",
99 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
100 | "val_LFPs.append(LFP)\n",
101 | "val_GTs.append(GT)\n",
102 | "\n",
103 | "x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
104 | ]
105 | },
106 | {
107 | "attachments": {},
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "## LSTM training parameters"
112 | ]
113 | },
114 | {
115 | "attachments": {},
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "#### Parameters:\n",
120 | "* Channels: number of channels that will be used to train the model, extracted from the data shape defined in the previous cell\n",
121 | "* Timesteps: number of samples that the will be processed at once\n",
122 | "* Bidirectionality: if the model processes simutaneously the window forward and backwards\n",
123 | "* Layers: number of LSTM layers\n",
124 | "* Epoch: number of times the training data set is used to train the model\n",
125 | "* Training batch: number of windows that are proccessed before weight updating\n",
126 | "\n",
127 | "#\n",
128 | "LSTM contains more parameters, feel free to add your own modifications. Check the oficial documentation:\n",
129 | "https://keras.io/api/layers/recurrent_layers/lstm/"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "conf= {\"timesteps\":[15,40], # 8,16,40 ...\n",
139 | " \"bidirectional\": [0], # 0 or 1\n",
140 | " \"layers\": [1,2], # 2,3,4\n",
141 | " \"units\": [10], # 5,6,10,12...\n",
142 | " \"epochs\": [2], # 1,2,3...\n",
143 | " \"train_batch\": [2**8]} # 16,32,64 (Powers of two are recommended for computacional efficiency)"
144 | ]
145 | },
146 | {
147 | "attachments": {},
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "### Training"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "# Desired sampling frequency of the models\n",
161 | "sf=1250\n",
162 | "th_arr=np.linspace(0.1,0.9,9)\n",
163 | "model_name_arr=[] # To plot in the next cell\n",
164 | "model_arr=[] # Actual model array, used in the next validation section\n",
165 | "n_channels=x_training.shape[1]\n",
166 | "timesteps_arr=conf['timesteps']\n",
167 | "\n",
168 | "bi_arr=conf['bidirectional']\n",
169 | "layer_arr=conf['layers'] \n",
170 | "units_arr=conf['units'] \n",
171 | "epochs_arr=conf['epochs']\n",
172 | "train_batch_arr=conf['train_batch'] \n",
173 | "\n",
174 | "l_ts=len(timesteps_arr)\n",
175 | "\n",
176 | "l_bi=len(bi_arr)\n",
177 | "l_layer =len(layer_arr)\n",
178 | "l_units=len(units_arr)\n",
179 | "l_epochs =len(epochs_arr)\n",
180 | "l_batch =len(train_batch_arr)\n",
181 | "n_iters=l_ts*l_bi*l_layer*l_units*l_epochs*l_batch\n",
182 | "# GT is in the shape (n_events x 2), a y output signal with the same length as x is required\n",
183 | "perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]\n",
184 | "perf_test_arr=np.zeros_like(perf_train_arr)\n",
185 | "timesteps_arr_ploting=[] # Array that will be used in the validation, to be able to call the function predict\n",
186 | "\n",
187 | "print(f'{n_channels} channels will be used to train the LSTM models')\n",
188 | "\n",
189 | "print(f'{n_iters} models will be trained')\n",
190 | "\n",
191 | "x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)\n",
192 | "\n",
193 | "y_test_or= np.zeros(shape=(len(x_test_or)))\n",
194 | "for ev in GT_test:\n",
195 | " y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
196 | "y_train_or= np.zeros(shape=(len(x_train_or)))\n",
197 | "for ev in GT_train:\n",
198 | " y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
199 | "\n",
200 | "\n",
201 | "for i_ts,timesteps in enumerate(timesteps_arr):\n",
202 | "\n",
203 | " x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps,n_channels)\n",
204 | " y_train=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps,1)\n",
205 | " \n",
206 | " x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps,n_channels)\n",
207 | " y_test=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps,1)\n",
208 | "\n",
209 | " for i_bi,bi in enumerate(bi_arr):\n",
210 | " for i_layer,layers in enumerate(layer_arr):\n",
211 | " for i_units,units in enumerate(units_arr):\n",
212 | " for i_epochs,epochs in enumerate(epochs_arr):\n",
213 | " for i_batch,train_batch in enumerate(train_batch_arr):\n",
214 | " iter=((((i_ts*l_bi+i_bi)*l_layer+i_layer)*l_units+ i_units)*l_epochs + i_epochs)*l_batch + i_batch\n",
215 | " print(f\"\\nIteration {iter+1} out of {n_iters}\")\n",
216 | " print(f'Number of channels: {n_channels:d}, Time steps: {timesteps:d}, Bidirectional: {bi:d}\\nN of layers: {layers:d}, N of units: {units:d}, epochs: {epochs:d}, Samples per batch: {train_batch:d}')\n",
217 | "\n",
218 | " model=aux_fcn.build_LSTM(input_shape=(timesteps,n_channels),n_layers=layers,layer_size=units,bidirectional=bi)\n",
219 | " # Training\n",
220 | " model.fit(x_train, y_train, epochs=epochs,batch_size=train_batch,validation_data=(x_test,y_test), verbose=1)\n",
221 | " model_arr.append(model)\n",
222 | "\n",
223 | " # Prediction\n",
224 | " y_test_predict = model.predict(x_test,verbose=1).reshape(-1,1,1)\n",
225 | " y_train_predict= model.predict(x_train,verbose=1).reshape(-1,1,1)\n",
226 | " \n",
227 | " for i,th in enumerate(th_arr):\n",
228 | " # Test\n",
229 | " ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf\n",
230 | " perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]\n",
231 | " # Train\n",
232 | " ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf\n",
233 | " perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]\n",
234 | "\n",
235 | " # Saving the model\n",
236 | " model_name=f\"LSTM_Ch{n_channels:d}_Ts{timesteps:03d}_Bi{bi:d}_L{layers:d}_U{units:02d}_E{epochs:02d}_TB{train_batch:04d}\"\n",
237 | " model.save(os.path.join(parent_dir,'explore_models',model_name))\n",
238 | "\n",
239 | " model_name_arr.append(model_name)\n",
240 | " timesteps_arr_ploting.append(timesteps)"
241 | ]
242 | },
243 | {
244 | "attachments": {},
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "### Plot training results"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "# Plot training results\n",
258 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
259 | "\n",
260 | "for i in range(n_iters):\n",
261 | " axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')\n",
262 | " axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')\n",
263 | " axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')\n",
264 | " axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')\n",
265 | " axs[i,0].set_title(model_name_arr[i])\n",
266 | " axs[i,0].set_ylabel('Precision')\n",
267 | " axs[i,1].set_ylabel('F1')\n",
268 | "axs[-1,0].set_xlabel('Recall')\n",
269 | "axs[-1,1].set_xlabel('Threshold')\n",
270 | "axs[0,0].legend(['Training','Test'])\n",
271 | "plt.show()"
272 | ]
273 | },
274 | {
275 | "attachments": {},
276 | "cell_type": "markdown",
277 | "metadata": {},
278 | "source": [
279 | "### Validation"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "# For loop iterating over the models\n",
289 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
290 | "for n_m,model in enumerate(model_arr):\n",
291 | " F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored\n",
292 | " for n_sess,LFP in enumerate(x_val_list):\n",
293 | " val_pred=rippl_AI.predict(LFP,sf=1250,arch='LSTM',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]\n",
294 | " for i,th in enumerate(th_arr):\n",
295 | " val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf\n",
296 | " F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]\n",
297 | " \n",
298 | " axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')\n",
299 | " axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')\n",
300 | " for F1 in F1_arr:\n",
301 | " axs[n_m,1].plot(th_arr,F1)\n",
302 | " axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')\n",
303 | " axs[n_m,0].set_title(model_name_arr[n_m])\n",
304 | " axs[n_m,0].set_ylabel('Precision')\n",
305 | " axs[n_m,1].set_ylabel('F1')\n",
306 | "axs[-1,0].set_xlabel('Recall')\n",
307 | "axs[-1,1].set_xlabel('Threshold')\n",
308 | "plt.show()\n",
309 | " "
310 | ]
311 | }
312 | ],
313 | "metadata": {
314 | "kernelspec": {
315 | "display_name": "PublicBCG_d",
316 | "language": "python",
317 | "name": "python3"
318 | },
319 | "language_info": {
320 | "codemirror_mode": {
321 | "name": "ipython",
322 | "version": 3
323 | },
324 | "file_extension": ".py",
325 | "mimetype": "text/x-python",
326 | "name": "python",
327 | "nbconvert_exporter": "python",
328 | "pygments_lexer": "ipython3",
329 | "version": "3.9.15"
330 | },
331 | "orig_nbformat": 4
332 | },
333 | "nbformat": 4,
334 | "nbformat_minor": 2
335 | }
336 |
--------------------------------------------------------------------------------
/examples_explore/example_SVM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# SVM parameter exploration\n",
9 | "This notebook is a template for finding the SVM model best suited for your needs"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import numpy as np\n",
21 | "import sklearn as sk\n",
22 | "from sklearn import svm,calibration\n",
23 | "from imblearn.under_sampling import RandomUnderSampler\n",
24 | "import sys\n",
25 | "import inspect\n",
26 | "parent_dir=os.path.dirname(os.getcwd())\n",
27 | "sys.path.insert(0,parent_dir )\n",
28 | "import rippl_AI\n",
29 | "import aux_fcn\n"
30 | ]
31 | },
32 | {
33 | "attachments": {},
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "### Data download\n",
38 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "from figshare.figshare.figshare import Figshare\n",
48 | "fshare = Figshare()\n",
49 | "\n",
50 | "article_ids = [16847521,16856137,14959449,14960085] \n",
51 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
52 | "for id,s in zip(article_ids,sess):\n",
53 | " datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')\n",
54 | " if os.path.isdir(datapath):\n",
55 | " print(f\"{s} session already exists. Moving on.\")\n",
56 | " else:\n",
57 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
58 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
59 | " print(\"Data downloaded!\")"
60 | ]
61 | },
62 | {
63 | "attachments": {},
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "### Data load\n",
68 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
69 | "That is the required input for the training parser"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "# The training sessions will be appended together. Replace this cell with your own data loading\n",
79 | "train_LFPs=[]\n",
80 | "train_GTs=[]\n",
81 | "# Amigo2\n",
82 | "path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')\n",
83 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
84 | "train_LFPs.append(LFP)\n",
85 | "train_GTs.append(GT)\n",
86 | "# Som2\n",
87 | "path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')\n",
88 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
89 | "train_LFPs.append(LFP)\n",
90 | "train_GTs.append(GT)\n",
91 | "\n",
92 | "## Append all your validation sessions\n",
93 | "val_LFPs=[]\n",
94 | "val_GTs=[]\n",
95 | "# Dlx1 Validation\n",
96 | "path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')\n",
97 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
98 | "val_LFPs.append(LFP)\n",
99 | "val_GTs.append(GT)\n",
100 | "# Thy07 Validation\n",
101 | "path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')\n",
102 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
103 | "val_LFPs.append(LFP)\n",
104 | "val_GTs.append(GT)\n",
105 | "\n",
106 | "x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "## SVM training parameters"
115 | ]
116 | },
117 | {
118 | "attachments": {},
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "#### Parameters:\n",
123 | "* Channels: Number of channels that will be used to train the model, extracted from the data shape defined in the previous cell\n",
124 | "* Timesteps: Number of samples that the will be used to generate a single prediction\n",
125 | "* Undersampler proportion: roportion of True/False samples. Using all the samples demands heavy resources usage, but the most data used, the better the model generalizes"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "conf= { 'timesteps':[1,2,4], # Possible values: 1,2,4,8,16,32... \n",
135 | " 'undersampler proportion':[1]} # Possible values: 1, 0.5, 0.1 ... (0,1]"
136 | ]
137 | },
138 | {
139 | "attachments": {},
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "### Training"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "# Desired sampling frequency of the models\n",
153 | "sf=1250\n",
154 | "th_arr=np.linspace(0.1,0.9,9)\n",
155 | "model_name_arr=[] # To plot in the next cell\n",
156 | "model_arr=[] # Actual model array, used in the next validation section\n",
157 | "n_channels=x_training.shape[1]\n",
158 | "timesteps_arr=conf['timesteps']\n",
159 | "undersampler_arr=conf['undersampler proportion']\n",
160 | "l_ts=len(timesteps_arr)\n",
161 | "l_us=len(undersampler_arr)\n",
162 | "n_iters=l_ts*l_us\n",
163 | "# GT is in the shape (n_events x 2), a y output signal with the same length as x is required\n",
164 | "perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]\n",
165 | "perf_test_arr=np.zeros_like(perf_train_arr)\n",
166 | "timesteps_arr_ploting=[] # Array that will be used in the validation, to be able to call the function predict\n",
167 | "print(f'{n_channels} will be used to train the SVM models')\n",
168 | "\n",
169 | "print(f'{n_iters} models will be trained')\n",
170 | "\n",
171 | "x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)\n",
172 | "\n",
173 | "y_test_or= np.zeros(shape=(len(x_test_or)))\n",
174 | "for ev in GT_test:\n",
175 | " y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
176 | "y_train_or= np.zeros(shape=(len(x_train_or)))\n",
177 | "for ev in GT_train:\n",
178 | " y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
179 | "\n",
180 | "\n",
181 | "for i_ts,timesteps in enumerate(timesteps_arr):\n",
182 | "\n",
183 | " x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps*n_channels)\n",
184 | " y_train_aux=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps)\n",
185 | " y_train=aux_fcn.rec_signal(y_train_aux) \n",
186 | " \n",
187 | " x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps*n_channels)\n",
188 | " y_test_aux=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps)\n",
189 | " y_test=aux_fcn.rec_signal(y_test_aux)\n",
190 | "\n",
191 | "\n",
192 | " for i_us,undersampler_prop in enumerate(undersampler_arr):\n",
193 | " rus = RandomUnderSampler(sampling_strategy=undersampler_prop)\n",
194 | " x_train_us, y_train_us = rus.fit_resample(x_train, y_train)\n",
195 | " iter=i_ts*l_us+i_us\n",
196 | " print(f\"\\nIteration {iter+1} out of {n_iters}\")\n",
197 | " print(f'Time steps: {timesteps}, Undersampler proportion: {undersampler_prop}')\n",
198 | " clf = sk.calibration.CalibratedClassifierCV(svm.LinearSVC()) \n",
199 | "\n",
200 | " # Training \n",
201 | " clf.fit(x_train_us, y_train_us)\n",
202 | " model_arr.append(clf)\n",
203 | " # Prediction. One value per window\n",
204 | " test_signal = clf.predict_proba(x_test)[:,1]\n",
205 | " train_signal=clf.predict_proba(x_train)[:,1]\n",
206 | " # Not compatible with the functions that extract beginning and end times\n",
207 | " y_train_predict=np.empty(shape=(x_train.shape[0]*timesteps,1,1))\n",
208 | " for i,window in enumerate(train_signal):\n",
209 | " y_train_predict[i*timesteps:(i+1)*timesteps]=window\n",
210 | " \n",
211 | " y_test_predict=np.empty(shape=(x_test.shape[0]*timesteps,1,1))\n",
212 | " for i,window in enumerate(test_signal):\n",
213 | " y_test_predict[i*timesteps:(i+1)*timesteps]=window\n",
214 | " \n",
215 | " for i,th in enumerate(th_arr):\n",
216 | " # Test\n",
217 | " ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf\n",
218 | " perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]\n",
219 | " # Train\n",
220 | " ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf\n",
221 | " perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]\n",
222 | "\n",
223 | " # Saving the model\n",
224 | " model_name=f\"SVM_Ch{n_channels}_Ts{timesteps:03d}_Us{undersampler_prop:1.2f}\"\n",
225 | " \n",
226 | " aux_fcn.fcn_save_pickle(os.path.join(parent_dir,'explore_models',model_name),clf)\n",
227 | " model_name_arr.append(model_name)\n",
228 | " timesteps_arr_ploting.append(timesteps)"
229 | ]
230 | },
231 | {
232 | "attachments": {},
233 | "cell_type": "markdown",
234 | "metadata": {},
235 | "source": [
236 | "### Plot training results"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "# Plot training results\n",
246 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
247 | "\n",
248 | "for i in range(n_iters):\n",
249 | " axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')\n",
250 | " axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')\n",
251 | " axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')\n",
252 | " axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')\n",
253 | " axs[i,0].set_title(model_name_arr[i])\n",
254 | " axs[i,0].set_ylabel('Precision')\n",
255 | " axs[i,1].set_ylabel('F1')\n",
256 | "axs[-1,0].set_xlabel('Recall')\n",
257 | "axs[-1,1].set_xlabel('Threshold')\n",
258 | "axs[0,0].legend(['Training','Test'])\n",
259 | "plt.show()"
260 | ]
261 | },
262 | {
263 | "attachments": {},
264 | "cell_type": "markdown",
265 | "metadata": {},
266 | "source": [
267 | "### Validation"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "# For loop iterating over the models\n",
277 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
278 | "for n_m,model in enumerate(model_arr):\n",
279 | " F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored\n",
280 | " for n_sess,LFP in enumerate(x_val_list):\n",
281 | " val_pred=rippl_AI.predict(LFP,sf=1250,arch='SVM',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]\n",
282 | " for i,th in enumerate(th_arr):\n",
283 | " val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf\n",
284 | " F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]\n",
285 | " \n",
286 | " axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')\n",
287 | " axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')\n",
288 | " for F1 in F1_arr:\n",
289 | " axs[n_m,1].plot(th_arr,F1)\n",
290 | " axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')\n",
291 | " axs[n_m,0].set_title(model_name_arr[n_m])\n",
292 | " axs[n_m,0].set_ylabel('Precision')\n",
293 | " axs[n_m,1].set_ylabel('F1')\n",
294 | "axs[-1,0].set_xlabel('Recall')\n",
295 | "axs[-1,1].set_xlabel('Threshold')\n",
296 | "plt.show()\n",
297 | " "
298 | ]
299 | }
300 | ],
301 | "metadata": {
302 | "kernelspec": {
303 | "display_name": "PublicBCG_d",
304 | "language": "python",
305 | "name": "python3"
306 | },
307 | "language_info": {
308 | "codemirror_mode": {
309 | "name": "ipython",
310 | "version": 3
311 | },
312 | "file_extension": ".py",
313 | "mimetype": "text/x-python",
314 | "name": "python",
315 | "nbconvert_exporter": "python",
316 | "pygments_lexer": "ipython3",
317 | "version": "3.9.15"
318 | },
319 | "orig_nbformat": 4
320 | },
321 | "nbformat": 4,
322 | "nbformat_minor": 2
323 | }
324 |
--------------------------------------------------------------------------------
/examples_explore/example_XGBOOST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# XGBOOST parameter exploration\n",
9 | "This notebook is a template for finding the XGBOOST model best suited for your needs"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import numpy as np\n",
21 | "from xgboost import XGBClassifier\n",
22 | "import sys\n",
23 | "parent_dir=os.path.dirname(os.getcwd())\n",
24 | "sys.path.insert(0,parent_dir )\n",
25 | "import rippl_AI\n",
26 | "import aux_fcn"
27 | ]
28 | },
29 | {
30 | "attachments": {},
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "### Data download\n",
35 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "from figshare.figshare.figshare import Figshare\n",
45 | "fshare = Figshare()\n",
46 | "\n",
47 | "article_ids = [16847521,16856137,14959449,14960085] \n",
48 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
49 | "for id,s in zip(article_ids,sess):\n",
50 | " datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')\n",
51 | " if os.path.isdir(datapath):\n",
52 | " print(f\"{s} session already exists. Moving on.\")\n",
53 | " else:\n",
54 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
55 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
56 | " print(\"Data downloaded!\")"
57 | ]
58 | },
59 | {
60 | "attachments": {},
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "### Data load\n",
65 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
66 | "That is the required input for the training parser"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# The training sessions will be appended together. Replace this cell with your own data loading\n",
76 | "train_LFPs=[]\n",
77 | "train_GTs=[]\n",
78 | "# Amigo2\n",
79 | "path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')\n",
80 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
81 | "train_LFPs.append(LFP)\n",
82 | "train_GTs.append(GT)\n",
83 | "# Som2\n",
84 | "path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')\n",
85 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
86 | "train_LFPs.append(LFP)\n",
87 | "train_GTs.append(GT)\n",
88 | "\n",
89 | "## Append all your validation sessions\n",
90 | "val_LFPs=[]\n",
91 | "val_GTs=[]\n",
92 | "# Dlx1 Validation\n",
93 | "path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')\n",
94 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
95 | "val_LFPs.append(LFP)\n",
96 | "val_GTs.append(GT)\n",
97 | "# Thy07 Validation\n",
98 | "path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')\n",
99 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
100 | "val_LFPs.append(LFP)\n",
101 | "val_GTs.append(GT)\n",
102 | "\n",
103 | "x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
104 | ]
105 | },
106 | {
107 | "attachments": {},
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "## XGBOOST training parameters"
112 | ]
113 | },
114 | {
115 | "attachments": {},
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "#### Parameters:\n",
120 | "* Channels: number of channels that will be used to train the model, extracted from the data shape defined in the previous cell\n",
121 | "* Timesteps: number of samples that the will be used to generate a single prediction\n",
122 | "* Max depth: number of max layers in each tree. Too many usually causes overfitting\n",
123 | "* Learning rate: similar to a weight used to update te predictor, a high value leads to faster computations but may not reaach a optimal value\n",
124 | "* Gamma: Minimum loss reduction required to make a partition on a leaf node. The larger gamma is, the more conservative the model will be\n",
125 | "* Reg lamda: L2 regularization term of weight updating. Increasing this value makes the model more conservative\n",
126 | "* Scale pos weight: controls the balance of positive and negative weights, useful for unbalanced clasess.\n",
127 | "* Subsample: subsample ratio of the training instances. Setting it to 0.5 means that XGBoost would randomly sample half of the training data prior to growing trees. Used to prevent overfitting\n",
128 | "#\n",
129 | "XGBOOST contains many more parameters, feel free to add your own modifications. Check the oficial documentation: https://xgboost.readthedocs.io/en/stable/parameter.html#parameters-for-tree-booster"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "conf= {\"timesteps\":[16], # 2,4,8,16,20,32 ...\n",
139 | " \"max_depth\": [4, 5], # 3,4,5,6,7 ...\n",
140 | " \"learning_rate\": [0.1], # 0.2, 0.1, 0.05, 0.01 ...\n",
141 | " \"gamma\": [1], # 0, 0.25, 1 ...\n",
142 | " \"reg_lambda\": [10], # 0, 1, 10 ...\n",
143 | " \"scale_pos_weight\": [1], # 1, 3, 5...\n",
144 | " \"subsample\": [0.8]} # 0.5, 0.8, 0.9 ..."
145 | ]
146 | },
147 | {
148 | "attachments": {},
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "### Training"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "# Desired sampling frequency of the models\n",
162 | "sf=1250\n",
163 | "th_arr=np.linspace(0.1,0.9,9)\n",
164 | "model_name_arr=[] # To plot in the next cell\n",
165 | "model_arr=[] # Actual model array, used in the next validation section\n",
166 | "n_channels=x_training.shape[1]\n",
167 | "timesteps_arr=conf['timesteps']\n",
168 | "\n",
169 | "max_depth_arr=conf[\"max_depth\"] \n",
170 | "lr_arr=conf[\"learning_rate\"]\n",
171 | "gamma_arr=conf[\"gamma\"] \n",
172 | "reg_lambda_arr=conf[\"reg_lambda\"] \n",
173 | "scale_arr=conf[\"scale_pos_weight\"]\n",
174 | "subsample_arr=conf[\"subsample\"] \n",
175 | "\n",
176 | "l_ts=len(timesteps_arr)\n",
177 | "\n",
178 | "l_maxd=len(max_depth_arr)\n",
179 | "l_lr =len(lr_arr)\n",
180 | "l_g =len(gamma_arr)\n",
181 | "l_reg =len(reg_lambda_arr)\n",
182 | "l_sc =len(scale_arr)\n",
183 | "l_sub =len(subsample_arr)\n",
184 | "n_iters=l_ts*l_maxd*l_lr*l_g*l_reg*l_sc*l_sub\n",
185 | "# GT is in the shape (n_events x 2), a y output signal with the same length as x is required\n",
186 | "perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]\n",
187 | "perf_test_arr=np.zeros_like(perf_train_arr)\n",
188 | "timesteps_arr_ploting=[] # Array that will be used in the validation, to be able to call the function predict\n",
189 | "\n",
190 | "print(f'{n_channels} channels will be used to train the XGBOOST models')\n",
191 | "\n",
192 | "print(f'{n_iters} models will be trained')\n",
193 | "\n",
194 | "x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)\n",
195 | "\n",
196 | "y_test_or= np.zeros(shape=(len(x_test_or)))\n",
197 | "for ev in GT_test:\n",
198 | " y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
199 | "y_train_or= np.zeros(shape=(len(x_train_or)))\n",
200 | "for ev in GT_train:\n",
201 | " y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1\n",
202 | "\n",
203 | "\n",
204 | "for i_ts,timesteps in enumerate(timesteps_arr):\n",
205 | "\n",
206 | " x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps*n_channels)\n",
207 | " y_train_aux=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps)\n",
208 | " y_train=aux_fcn.rec_signal(y_train_aux) # If any sample of the window contains a ripple, the desired output for the shape is 1\n",
209 | " \n",
210 | " x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps*n_channels)\n",
211 | " y_test_aux=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps)\n",
212 | " y_test=aux_fcn.rec_signal(y_test_aux)\n",
213 | "\n",
214 | " for i_maxd,max_depth in enumerate(max_depth_arr):\n",
215 | " for i_lr,lr in enumerate(lr_arr):\n",
216 | " for i_g,g in enumerate(gamma_arr):\n",
217 | " for i_rg,reg_l in enumerate(reg_lambda_arr):\n",
218 | " for i_sc,scale in enumerate(scale_arr):\n",
219 | " for i_subs,subsample in enumerate(subsample_arr):\n",
220 | " iter=(((((i_ts*l_maxd+i_maxd)*l_lr+i_lr)*l_g+i_g)*l_reg+ i_rg)*l_sc + i_sc)*l_sub + i_subs\n",
221 | " print(f\"\\nIteration {iter+1} out of {n_iters}\")\n",
222 | " print(f'Number of channels: {n_channels:d}, Time steps: {timesteps:d}.\\nMax depth: {max_depth:d}, Lr: {lr:1.3f}, gamma: {g:1.2f}, reg_l: {reg_l:d}, scale: {scale:1.3f}, subsample: {subsample:0.3f}')\n",
223 | " xgb = XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
224 | " colsample_bynode=1, colsample_bytree=1, gamma=g, gpu_id=-1,\n",
225 | " importance_type='gain', interaction_constraints='',\n",
226 | " learning_rate=lr, max_delta_step=0, max_depth=max_depth,\n",
227 | " min_child_weight=1, monotone_constraints='()',\n",
228 | " n_estimators=100, n_jobs=-1, num_parallel_tree=1, random_state=0,\n",
229 | " reg_alpha=0, reg_lambda=reg_l, scale_pos_weight=scale, subsample=subsample,\n",
230 | " tree_method='exact', validate_parameters=1, verbosity=2)\n",
231 | "\n",
232 | " # Training\n",
233 | " xgb.fit(x_train, y_train,verbose=1,eval_metric=[\"logloss\"] ,eval_set = [(x_train,y_train),(x_test, y_test)])\n",
234 | " model_arr.append(xgb)\n",
235 | " # Prediction. One value for window\n",
236 | " test_signal = xgb.predict_proba(x_test)[:,1]\n",
237 | " train_signal=xgb.predict_proba(x_train)[:,1]\n",
238 | " # Not compatible with the functions that extract beginning and end times\n",
239 | " y_train_predict=np.empty(shape=(x_train.shape[0]*timesteps,1,1))\n",
240 | " for i,window in enumerate(train_signal):\n",
241 | " y_train_predict[i*timesteps:(i+1)*timesteps]=window\n",
242 | " \n",
243 | " y_test_predict=np.empty(shape=(x_test.shape[0]*timesteps,1,1))\n",
244 | " for i,window in enumerate(test_signal):\n",
245 | " y_test_predict[i*timesteps:(i+1)*timesteps]=window\n",
246 | " \n",
247 | " for i,th in enumerate(th_arr):\n",
248 | " # Test\n",
249 | " ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf\n",
250 | " perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]\n",
251 | " # Train\n",
252 | " ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf\n",
253 | " perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]\n",
254 | "\n",
255 | " # Saving the model\n",
256 | " model_name=f\"XGBOOST_Ch{n_channels:d}_Ts{timesteps:03d}_D{max_depth:d}_Lr{lr:1.3f}_G{g:1.2f}_regl{reg_l:02d}_SCALE{scale:03d}_Subs{subsample:1.3f}\"\n",
257 | " xgb.save_model(os.path.join(parent_dir,'explore_models',model_name))\n",
258 | "\n",
259 | " model_name_arr.append(model_name)\n",
260 | " timesteps_arr_ploting.append(timesteps)"
261 | ]
262 | },
263 | {
264 | "attachments": {},
265 | "cell_type": "markdown",
266 | "metadata": {},
267 | "source": [
268 | "### Plot training results"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "# Plot training results\n",
278 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
279 | "\n",
280 | "for i in range(n_iters):\n",
281 | " axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')\n",
282 | " axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')\n",
283 | " axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')\n",
284 | " axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')\n",
285 | " axs[i,0].set_title(model_name_arr[i])\n",
286 | " axs[i,0].set_ylabel('Precision')\n",
287 | " axs[i,1].set_ylabel('F1')\n",
288 | "axs[-1,0].set_xlabel('Recall')\n",
289 | "axs[-1,1].set_xlabel('Threshold')\n",
290 | "axs[0,0].legend(['Training','Test'])\n",
291 | "plt.show()"
292 | ]
293 | },
294 | {
295 | "attachments": {},
296 | "cell_type": "markdown",
297 | "metadata": {},
298 | "source": [
299 | "### Validation"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "# For loop iterating over the models\n",
309 | "fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')\n",
310 | "for n_m,model in enumerate(model_arr):\n",
311 | " F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored\n",
312 | " for n_sess,LFP in enumerate(x_val_list):\n",
313 | " val_pred=rippl_AI.predict(LFP,sf=1250,arch='XGBOOST',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]\n",
314 | " for i,th in enumerate(th_arr):\n",
315 | " val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf\n",
316 | " F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]\n",
317 | " \n",
318 | " axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')\n",
319 | " axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')\n",
320 | " for F1 in F1_arr:\n",
321 | " axs[n_m,1].plot(th_arr,F1)\n",
322 | " axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')\n",
323 | " axs[n_m,0].set_title(model_name_arr[n_m])\n",
324 | " axs[n_m,0].set_ylabel('Precision')\n",
325 | " axs[n_m,1].set_ylabel('F1')\n",
326 | "axs[-1,0].set_xlabel('Recall')\n",
327 | "axs[-1,1].set_xlabel('Threshold')\n",
328 | "plt.show()\n",
329 | " "
330 | ]
331 | }
332 | ],
333 | "metadata": {
334 | "kernelspec": {
335 | "display_name": "PublicBCG_d",
336 | "language": "python",
337 | "name": "python3"
338 | },
339 | "language_info": {
340 | "codemirror_mode": {
341 | "name": "ipython",
342 | "version": 3
343 | },
344 | "file_extension": ".py",
345 | "mimetype": "text/x-python",
346 | "name": "python",
347 | "nbconvert_exporter": "python",
348 | "pygments_lexer": "ipython3",
349 | "version": "3.9.15"
350 | },
351 | "orig_nbformat": 4
352 | },
353 | "nbformat": 4,
354 | "nbformat_minor": 2
355 | }
356 |
--------------------------------------------------------------------------------
/examples_retraining.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import rippl_AI\n",
11 | "import aux_fcn\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "attachments": {},
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "### Data download\n",
22 | "4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from figshare.figshare import Figshare\n",
32 | "fshare = Figshare()\n",
33 | "\n",
34 | "article_ids = [16847521,16856137,14959449,14960085] \n",
35 | "sess=['Amigo2','Som2','Dlx1','Thy7'] \n",
36 | "for id,s in zip(article_ids,sess):\n",
37 | " datapath = os.path.join('Downloaded_data', f'{s}')\n",
38 | " if os.path.isdir(datapath):\n",
39 | " print(f\"{s} session already exists. Moving on.\")\n",
40 | " else:\n",
41 | " print(\"Downloading data... Please wait, this might take up some time\") # Can take up to 10 minutes\n",
42 | " fshare.retrieve_files_from_article(id,directory=datapath)\n",
43 | " print(\"Data downloaded!\")"
44 | ]
45 | },
46 | {
47 | "attachments": {},
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "### Data load\n",
52 | "The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.\n",
53 | "That is the required input for the training parser"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# The training sessions will be appended together. Do the same with your training data\n",
63 | "train_LFPs=[]\n",
64 | "train_GTs=[]\n",
65 | "# Amigo2\n",
66 | "path=os.path.join('Downloaded_data','Amigo2','figshare_16847521')\n",
67 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
68 | "train_LFPs.append(LFP)\n",
69 | "train_GTs.append(GT)\n",
70 | "\n",
71 | "# Som2\n",
72 | "path=os.path.join('Downloaded_data','Som2','figshare_16856137')\n",
73 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
74 | "train_LFPs.append(LFP)\n",
75 | "train_GTs.append(GT)\n",
76 | "## Append all your validation sessions\n",
77 | "val_LFPs=[]\n",
78 | "val_GTs=[]\n",
79 | "# Dlx1 Validation\n",
80 | "path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')\n",
81 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
82 | "val_LFPs.append(LFP)\n",
83 | "val_GTs.append(GT)\n",
84 | "# Thy07 Validation\n",
85 | "path=os.path.join('Downloaded_data','Thy7','figshare_14960085')\n",
86 | "LFP,GT=aux_fcn.load_lab_data(path)\n",
87 | "val_LFPs.append(LFP)\n",
88 | "val_GTs.append(GT)\n",
89 | "\n"
90 | ]
91 | },
92 | {
93 | "attachments": {},
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "The training sessions are concatenated, the validation sessions are kept as different sessions"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "# Retraining examples for the different models"
115 | ]
116 | },
117 | {
118 | "attachments": {},
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "### XGBOOST\n",
123 | "XGBOOST does not require further parameters"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='XGBOOST',\n",
133 | " save_path=os.path.join('retrained_models','XGBOOST_retrained1'))"
134 | ]
135 | },
136 | {
137 | "attachments": {},
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "### SVM\n",
142 | "Has only one parameter: \n",
143 | "'Undersampler proportion': It controls the number of windows with negatives (no ripples) that will be used to train the model. Following the formula: Undersampler proportion= (Positive windows)/(Negative windows). 1 means the same number of poitive and negative windows. Low values can lead to overfitting."
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "params={'Unsersampler proportion': 0.1}\n",
153 | "\n",
154 | "rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='SVM',parameters=params,\n",
155 | " save_path=os.path.join('retrained_models','SVM_retrained1'))"
156 | ]
157 | },
158 | {
159 | "attachments": {},
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "### LSTM \n",
164 | "LSTM has two training parameters:\n",
165 | "'Epochs': is the number of times that the training data is fed to the model\n",
166 | "'Training batch': is the number of windows that are processed before updating the weights during training. Higher values prevent big weight oscillations."
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "params={'Epochs': 2,\n",
176 | " 'Training batch': 32}\n",
177 | "rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='LSTM',parameters=params,\n",
178 | " save_path=os.path.join('retrained_models','LSTM_retrained1'))"
179 | ]
180 | },
181 | {
182 | "attachments": {},
183 | "cell_type": "markdown",
184 | "metadata": {},
185 | "source": [
186 | "### CNN2D\n",
187 | "CNN2D share training parameters with th LSTM architecture"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "params={'Epochs': 1,\n",
197 | " 'Training batch': 64}\n",
198 | "rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='CNN2D',parameters=params,\n",
199 | " save_path=os.path.join('retrained_models','CNN2D_retrained1'))"
200 | ]
201 | },
202 | {
203 | "attachments": {},
204 | "cell_type": "markdown",
205 | "metadata": {},
206 | "source": [
207 | "### CNN1D\n",
208 | "CNN1D share training parameters with LSTM and CNN2D"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {},
215 | "outputs": [],
216 | "source": [
217 | "params={'Epochs': 2,\n",
218 | " 'Training batch': 32}\n",
219 | "rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='CNN1D',parameters=params,save_path=os.path.join('retrained_models','CNN1D_retrained1'))"
220 | ]
221 | }
222 | ],
223 | "metadata": {
224 | "kernelspec": {
225 | "display_name": "PublicBCG_d",
226 | "language": "python",
227 | "name": "python3"
228 | },
229 | "language_info": {
230 | "codemirror_mode": {
231 | "name": "ipython",
232 | "version": 3
233 | },
234 | "file_extension": ".py",
235 | "mimetype": "text/x-python",
236 | "name": "python",
237 | "nbconvert_exporter": "python",
238 | "pygments_lexer": "ipython3",
239 | "version": "3.9.15"
240 | },
241 | "orig_nbformat": 4
242 | },
243 | "nbformat": 4,
244 | "nbformat_minor": 2
245 | }
246 |
--------------------------------------------------------------------------------
/figures/CNNs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/CNNs.png
--------------------------------------------------------------------------------
/figures/LSTM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/LSTM.png
--------------------------------------------------------------------------------
/figures/SVM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/SVM.png
--------------------------------------------------------------------------------
/figures/XGBoost.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/XGBoost.png
--------------------------------------------------------------------------------
/figures/detection-method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/detection-method.png
--------------------------------------------------------------------------------
/figures/manual-curation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/manual-curation.png
--------------------------------------------------------------------------------
/figures/output-probabilities.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/output-probabilities.png
--------------------------------------------------------------------------------
/figures/rippl-AI-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/rippl-AI-logo.png
--------------------------------------------------------------------------------
/figures/ripple-variability.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/ripple-variability.png
--------------------------------------------------------------------------------
/figures/threshold-selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/figures/threshold-selection.png
--------------------------------------------------------------------------------
/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_1_Ch8_W60_Ts16_OGmodel12/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_2_Ch8_W60_Ts32_Fp1.50_E10_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_3_Ch8_W60_Ts40_Fp1.50_E02_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_4_Ch8_W60_Ts40_Fp2.00_E50_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_5_Ch8_W60_Ts40_OGmodel32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN1D_6_Ch1_W60_Ts40_Fp1.50_E50_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_1_Ch8_W60_Ts40_OgModel/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_2_Ch8_W60_Ts32_C0_E30_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_3_Ch3_W60_Ts40_C0_E30_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_4_Ch3_W60_Ts32_C2_E30_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_5_Ch3_W60_Ts40_C1_E30_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/CNN2D_6_Ch1_W60_Ts16_C3_E30_TB32/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/ENS/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/ENS/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/ENS/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/ENS/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/ENS/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/ENS/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/ENS/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/ENS/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_1_Ch8_W60_Ts32_Bi0_L4_U11_E10_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_2_Ch8_W60_Ts16_Bi0_L4_U25_E05_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_3_Ch8_W60_Ts16_Bi0_L3_U11_E10_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_4_Ch8_W60_Ts16_Bi0_L3_U14_E05_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/keras_metadata.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/keras_metadata.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_5_Ch8_W60_Ts32_Bi1_L4_U20_E10_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/saved_model.pb
--------------------------------------------------------------------------------
/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/LSTM_6_Ch1_W60_Ts40_Bi1_L3_U12_E10_TB256/variables/variables.index
--------------------------------------------------------------------------------
/optimized_models/SVM_1_Ch8_W60_Ts001_Us0.05:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_1_Ch8_W60_Ts001_Us0.05
--------------------------------------------------------------------------------
/optimized_models/SVM_2_Ch8_W60_Ts001_Us0.10:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_2_Ch8_W60_Ts001_Us0.10
--------------------------------------------------------------------------------
/optimized_models/SVM_3_Ch8_W60_Ts002_Us0.05:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_3_Ch8_W60_Ts002_Us0.05
--------------------------------------------------------------------------------
/optimized_models/SVM_4_Ch8_W60_Ts001_Us1.00:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_4_Ch8_W60_Ts001_Us1.00
--------------------------------------------------------------------------------
/optimized_models/SVM_5_Ch8_W60_Ts001_Us0.50:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_5_Ch8_W60_Ts001_Us0.50
--------------------------------------------------------------------------------
/optimized_models/SVM_6_Ch8_W60_Ts060_Us1.00:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/SVM_6_Ch8_W60_Ts060_Us1.00
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_1_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE1
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_2_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_2_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE5
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_3_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_3_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE3
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_4_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_4_Ch8_W60_Ts016_D7_Lr0.10_G0.25_L10_SCALE5
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_5_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_5_Ch8_W60_Ts016_D7_Lr0.10_G0.00_L10_SCALE3
--------------------------------------------------------------------------------
/optimized_models/XGBOOST_6_Ch1_W60_Ts032_D7_Lr0.10_G0.00_L10_SCALE3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PridaLab/rippl-AI/9a02b7bbf431cad148a14f85802f264f858b3fee/optimized_models/XGBOOST_6_Ch1_W60_Ts032_D7_Lr0.10_G0.00_L10_SCALE3
--------------------------------------------------------------------------------
/rippl_AI.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.style as mplstyle
3 | mplstyle.use('fast')
4 | from matplotlib.backend_bases import MouseButton
5 | import matplotlib.pyplot as plt
6 | from matplotlib.widgets import Button
7 | import keras
8 | import os
9 | from aux_fcn import process_LFP,prediction_parser, get_predictions_index, middle_stamps,get_click_th, format_predictions,split_data, retraining_parser, save_model,get_performance
10 |
11 | # Detection functions
12 |
13 | def predict(LFP,sf,d_sf=1250,arch='CNN1D',model_number=1,channels=np.arange(8),new_model=None,n_channels=None,n_timesteps=None):
14 | '''
15 | predict(LFP,sf,d_sf=1250,arch='CNN1D',model_number=1,channels=np.arange(8),new_model=None,n_channels=None,n_timesteps=None)
16 |
17 | Returns the requested architecture and model number output probability
18 |
19 | Mandatory inputs:
20 | -----------------
21 | LFP: (np.array: n_samples x n_channels). LFP_recorded data. Although there
22 | are no restrictions in n_channels, some considerations should be taken into
23 | account (see channels). Data does not need to be normalized, because it will
24 | be internally be z-scored (see aux_fcn.process_LFP())
25 | sf: (int) Original sampling frequency (in Hz)
26 |
27 | Optional inputs:
28 | ----------------
29 | d_sf: (int) Desired subsampling frequency (in Hz)
30 | arch: Name of the AI architecture to use (string).
31 | It can be: CNN1D, CNN2D, LSTM, SVM or XGBOOST.
32 | model_number: Number of the model to use (integer). There are six different models
33 | for each architecture, sorted by performance, 1 being the best, and 5 the last.
34 | A sixth model is included if single-channel data needs to be used.
35 | channels: Channels to be used for detection (np.array or list: 1 x 8). This is the most
36 | senstive parameter, because models will be looking for specific spatial features
37 | over all channels. Counting starts in 0.
38 | The two main remarks are:
39 | - All models have been trained to look at features in the pyramidal layer (SP),
40 | so for them to work at their maximum potential, the selected channels would
41 | ideally be centered in the SP, with a postive deflection on the first channels
42 | (upper channels) and a negative deflection on the last channels (lower channels).
43 | - For all combinations of architectures and model_numbers, channels has to be
44 | of size 8. There is only one exception, for architecture = 2D-CNN with
45 | models = {3, 4, 5}, that needs to have 3 channels.
46 | - If you are using a high-density probe, then we recommend to use equi-distant
47 | channels from the beginningto the end of the SP. For example, for Neuropixels
48 | in mice, a good set of channels would be pyr_channel + [-8,-6,-4,-2,0,2,4,6].
49 | - In the case of linear probes or tetrodes, there are not enough density to cover
50 | the SP with 8 channels. For that, interpolation or recorded channels can be
51 | done without compromising performance. New artificial interpolated channels
52 | will be add to the LFP wherever there is a -1 in channels.
53 | For example, if pyr_channel=11 in your linear probe, so that 10 is in stratum
54 | oriens and 12 in stratum radiatum, then we could define channels=[10,-1,-1,11,-1,-1,-1,12],
55 | where 2nd and 3rd channelswill be an interpolation of SO and SP channels, and
56 | 5th to 7th an interpolation of SP and SR channels.For tetrodes, organising
57 | channels according to their spatial profile is very convenient to assure best
58 | performance. These interpolations are done using the function aux_fcn.interpolate_channels().
59 | new_model: Other re-trained model you want to use for detection. If you have used our re-train function
60 | to adapt the optimized models to your own data (see rippl_AI.retrain() for more details),
61 | you can input the new_model here to use it to predict your events.
62 | IMPORTANT! If you are using new_model, the data wont be processed, so make sure to
63 | have your data z-scored, subsampled at your subsampling freq and with the
64 | correct channels before calling predict, for example using the process_LFPfunction
65 | IMPORTANT! If you are using a new_model, you have to pass as arguments its number of
66 | channels and the timesteps for window
67 | n_channels: (int) the number of channels of the new model
68 | timesteps: (int) the number of timesteps per window of the new model
69 |
70 | Output:
71 | -------
72 | SWR_prob: model output for every sample of the LFP (np.array: n_samples x 1).
73 | It can be interpreted as the confidence or probability of a SWR event, so values
74 | close to 0 mean that the model is certain that there are not SWRs, and values close
75 | to 1 that the model is very sure that there is a SWR hapenning.
76 | LFP_norm: LFP data used as an input to the model (np.array: n_samples x len(channels)).
77 | It is undersampled, z-scored, and transformed to used the channels specified in channels.
78 |
79 | A Rubio, 2023 LCN
80 | '''
81 |
82 | #channels=opt['channels']
83 | if new_model==None:
84 | norm_LFP=process_LFP(LFP,sf,d_sf,channels)
85 | else: # Data is supossedly already normalized when using new model
86 | norm_LFP=LFP
87 | prob=prediction_parser(norm_LFP,arch,model_number,new_model,n_channels,n_timesteps)
88 |
89 | return(prob,norm_LFP)
90 |
91 |
92 | def predict_ens(ens_input,model_name='ENS'):
93 | '''
94 | predict_ens(ens_input,model_name='ENS')
95 |
96 | Generates the output of the ensemble model specified with the model name
97 |
98 | Inputs:
99 | -------
100 | ens_input: (n_samples, 5) input of the model, consisting of the inputs of
101 | the 5 other different architectures
102 | model_name: str, name of the ens model found in the folder 'optimized_models'
103 |
104 | Outputs:
105 | --------
106 | prob: (n_samples) output of the model, the calculated probability of
107 | an event in each sample
108 |
109 | '''
110 | model = keras.models.load_model(os.path.join('optimized_models',model_name))
111 | prob = model.predict(ens_input,verbose=1)
112 | return prob
113 |
114 |
115 |
116 | def get_intervals(y,threshold=None,LFP_norm=None,sf=1250,win_size=100,file_path=None,merge_win=0):
117 | '''
118 | get_intervals(y,LFP_norm=None,sf=1250,win_size=100,threshold=None,file_path=None)
119 |
120 | Get events initial and end times, in seconds
121 | Displays a GUI to help you select the best threshold.
122 |
123 | Inputs:
124 | -------
125 | y: (n,) one dimensional output signal of the model
126 | threshold: (float), threshold of predictions
127 | LFP_norm: (n,n_channels), normalized input signal of the model
128 | sf: (int), sampling frequency (Hz) of the LFP_norm/model output.
129 | Change if used is different than 1250
130 | win_size: (int), length of the displayed ripples in miliseconds
131 | file_path: (str), absolute path of the folder where the .txt with the predictions
132 | will be generated. Leave empty if you don't want to generate the file
133 | merge_win: (float), minimal length of the interval in miliseconds between predictions. If
134 | two detections are closer in time than this parameter, they will be merged together
135 |
136 |
137 | Output:
138 | -------
139 |
140 | predictions: (n_events,2), returns the time (seconds) of the begining and end of each event
141 | 4 possible use cases, depending on which parameter combination is used when calling the function.
142 | 1.- (y): a histogram of the output is displayed, you drag a vertical bar to select your th
143 | 2.- (y,th): no GUI is displayed, the predictions are gererated automatically
144 | 3.- (y,LFP_norm): some examples of detected events are displayed next to the histogram
145 | 4.- (y,LFP_norm,th): same case as 3, but the initial location of the bar is th
146 |
147 | '''
148 | global predictions_index, line
149 | # Merge samples
150 | merge_s=round(sf*merge_win/1000)
151 | # If LFP_norm is passed, plot detected ripples
152 | if type(LFP_norm)==np.ndarray:
153 |
154 | timesteps=int((win_size*sf/1000)//2)
155 | if threshold==None:
156 | valinit=0.5
157 | else:
158 | valinit=threshold
159 | # The predictions_index with the initial th=0.5 is generated
160 |
161 | fig, axes = plt.subplot_mosaic("AAAAABCDEF;AAAAAGHIJK;AAAAALMNÑO;AAAAAPQRST;AAAAAUVWXY",figsize=(15,6))
162 | fig.subplots_adjust(wspace=0, hspace=0)
163 | fig.suptitle(f"Threshold selection")
164 | axes['A'].set_title(f'Th: {valinit}')
165 | for key in axes.keys():
166 | if key=='A':
167 | axes['A'].hist(y)
168 | axes['A'].set_yscale('log')
169 | line=axes['A'].axvline(valinit,c='k')
170 | continue
171 | axes[key].set_yticklabels([])
172 | axes[key].set_xticklabels([])
173 | axes[key].set_xticks([])
174 | axes[key].set_yticks([])
175 | axcolor = (20/255,175/255,245/255) # light blue
176 | hovercolor=(214/255,255/255,255/255)
177 |
178 | # Plot button definition
179 | plotax = plt.axes([0.4, 0.53, 0.035, 0.04])
180 | button_plot = Button(plotax, 'Plot', color=axcolor, hovercolor=hovercolor)
181 | # Save button definition
182 | Saveax = plt.axes([0.375, 0.47, 0.095, 0.04])
183 |
184 | button_save = Button(Saveax, f'Save: {len(get_predictions_index(y,valinit,merge_samples=merge_s))} events', color=axcolor, hovercolor=hovercolor)
185 |
186 | def plot_ripples():
187 | global predictions_index
188 |
189 | th=line.get_xdata()[0]
190 |
191 | predictions_index=get_predictions_index(y,th,merge_samples=merge_s)
192 | n_pred=len(predictions_index)
193 | # Clearing the axes
194 | for key in axes.keys():
195 | if key=='A':
196 | continue
197 | else:
198 | axes[key].clear()
199 | axes[key].set_yticklabels([])
200 | axes[key].set_xticklabels([])
201 | axes[key].set_xticks([])
202 | axes[key].set_yticks([])
203 |
204 |
205 | if n_pred==0:
206 | print("No predictions with this threshold")
207 | return
208 | else:
209 | mids=middle_stamps(predictions_index)
210 | pos_mat = list(range(LFP_norm.shape[1]-1, -1, -1)) * np.ones((timesteps*2, LFP_norm.shape[1]))
211 | if len(mids)<25:
212 | for i,key in enumerate(axes.keys()):
213 | if key=='A':
214 | continue
215 | if i>len(mids): # End of the events
216 | break
217 | # De momento quito la normalización, se va a la mierda muy a menudo
218 | extracted_window=LFP_norm[mids[i-1]-timesteps:mids[i-1]+timesteps,:]
219 | x=extracted_window*1/np.max(extracted_window)+pos_mat
220 |
221 | lines=axes[key].plot(x,c='0.6',linewidth=0.5)
222 | # Ripple fragment different color
223 | ini_rip=int(predictions_index[i-1,0]) # Timestamps absolutos
224 | end_rip=int(predictions_index[i-1,1])
225 | small_pos_mat = list(range(LFP_norm.shape[1]-1, -1, -1)) * np.ones((end_rip- ini_rip, LFP_norm.shape[1]))
226 | ripple_window=LFP_norm[ini_rip:end_rip,:]
227 |
228 | x_ripple=ripple_window*1/np.max(extracted_window)+small_pos_mat
229 | samples_ripple=np.linspace(ini_rip,end_rip,end_rip-ini_rip,dtype=int)-(mids[i-1]-timesteps)
230 |
231 | rip_lines=axes[key].plot(samples_ripple,x_ripple,c='k',linewidth=0.5)
232 |
233 | else: # More than 25 events: Random selection of 25 events
234 | sample_index=np.random.permutation(len(mids))[:25]
235 | for i,key in enumerate(axes.keys()):
236 | if key=='A':
237 | continue
238 |
239 | extracted_window=LFP_norm[mids[sample_index[i-1]]-timesteps:mids[sample_index[i-1]]+timesteps,:]
240 |
241 | x=extracted_window*1/np.max(extracted_window)+pos_mat
242 |
243 | lines=axes[key].plot(x,c='0.6',linewidth=0.5)
244 | # Ripple fragment different color
245 | ini_rip=int(predictions_index[sample_index[i-1],0]) # Timestamps absolutos
246 | end_rip=int(predictions_index[sample_index[i-1],1])
247 | small_pos_mat = list(range(LFP_norm.shape[1]-1, -1, -1)) * np.ones((end_rip- ini_rip, LFP_norm.shape[1]))
248 | ripple_window=LFP_norm[ini_rip:end_rip,:]
249 |
250 | x_ripple=ripple_window*1/np.max(extracted_window)+small_pos_mat
251 | samples_ripple=np.linspace(ini_rip,end_rip,end_rip-ini_rip,dtype=int)-(mids[sample_index[i-1]]-timesteps)
252 |
253 | rip_lines=axes[key].plot(samples_ripple,x_ripple,c='k',linewidth=0.5)
254 | plt.draw()
255 |
256 | plot_ripples()
257 | # button generate events ripple
258 |
259 | def generate_pred(event):
260 | global predictions_index
261 | # Generar las predicciones con el th guardado
262 | th=line.get_xdata()[0]
263 | predictions_index=get_predictions_index(y,th,merge_samples=merge_s)
264 |
265 | if file_path:
266 | format_predictions(file_path,predictions_index,sf)
267 | plt.close()
268 | return
269 | button_save.on_clicked(generate_pred)
270 | # Plot random ripples
271 | ############################
272 | # Click events
273 | def on_click_press(event):
274 | global line
275 | if event.button is MouseButton.LEFT:
276 | clicked_ax=event.inaxes
277 | if clicked_ax==axes['A']:
278 | th=get_click_th(event)
279 | line.remove()
280 | line=axes['A'].axvline(th,c='k')
281 | clicked_ax.set_title(f'Th: {th:1.3f}')
282 | n_events=len(get_predictions_index(y,th,merge_samples=merge_s))
283 | button_save.label.set_text(f"Save: {n_events} events")
284 |
285 | plt.connect('button_press_event',on_click_press)
286 | plt.connect('motion_notify_event',on_click_press)
287 | def on_click_release(event):
288 | if event.button is MouseButton.LEFT:
289 | clicked_ax=event.inaxes
290 | if clicked_ax==axes['A']:
291 | plot_ripples()
292 | plt.connect('button_release_event',on_click_release)
293 |
294 | def plot_button_click(event):
295 | # Generar las predicciones otra vez
296 | plot_ripples()
297 |
298 | button_plot.on_clicked(plot_button_click)
299 | plt.show(block=True)
300 |
301 | # If no threhold is defined, choose your own with the GUI,without LFP_norm plotting
302 | elif threshold==None:
303 | axcolor = (20/255,175/255,245/255) # light blue
304 | hovercolor=(214/255,255/255,255/255)
305 | valinit=0.5
306 | fig, ax = plt.subplots(1, 1, figsize=(8, 6))
307 | ax.hist(y)
308 | ax.set_yscale('log')
309 | fig.suptitle(f"Threshold selection")
310 | ax.set_title(f'Th: {valinit}')
311 |
312 | line=ax.axvline(valinit,c='k')
313 | # Button definition
314 | resetax = plt.axes([0.7, 0.5, 0.12, 0.075])
315 | button = Button(resetax, f'Save\n{len(get_predictions_index(y,valinit,merge_samples=merge_s))} events', color=axcolor, hovercolor=hovercolor)
316 |
317 |
318 | # Button definition
319 |
320 | def generate_pred(event):
321 | global predictions_index
322 | th=line.get_xdata()[0]
323 | predictions_index=get_predictions_index(y,th,merge_samples=merge_s)
324 | if file_path: # Si la linea del archivo no esta vacia
325 | format_predictions(file_path,predictions_index,sf)
326 | plt.close()
327 |
328 | button.on_clicked(generate_pred)
329 |
330 |
331 | def on_click(event):
332 | global line
333 | if event.button is MouseButton.LEFT:
334 | clicked_ax=event.inaxes
335 | if clicked_ax==ax:
336 | th=get_click_th(event)
337 | line.remove()
338 | line=ax.axvline(th,c='k')
339 | ax.set_title(f'Th: {th:1.3f}')
340 |
341 | n_events=len(get_predictions_index(y,th,merge_samples=merge_s))
342 | button.label.set_text(f"Save\n{n_events} events")
343 | plt.draw()
344 | plt.connect('button_press_event',on_click)
345 |
346 | plt.connect('motion_notify_event', on_click)
347 | plt.show(block=True)
348 | # If threshold is defined, and no LFP_norm is passsed, the function simply generates the predictions
349 | else:
350 | predictions_index=get_predictions_index(y,threshold,merge_samples=merge_s)
351 | if file_path:
352 | format_predictions(file_path,predictions_index,sf)
353 | return (predictions_index/sf)
354 |
355 | # Prepares data for training, used in retraining and exploring notebooks
356 | def prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000,d_sf=1250,channels=np.arange(0,8)):
357 | '''
358 | prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000,d_sf=1250,channels=np.arange(0,8))
359 |
360 | Prepares data for training: subsamples, interpolates (if required), z-scores and concatenates
361 | the train/test data passed. Does the same for the validation data, but without concatenating
362 |
363 | Inputs:
364 | -------
365 | train_LFPs: (n_train_sessions) list with the raw LFP of n sessions that will be used to train
366 | train_GTs: (n_train_sessions) list with the GT events of n sessions, in the format [ini end] in seconds
367 | (A): quizá se podría quitar esto, lo de formatear tambien las de validacion
368 | val_LFPs: (n_val_sessions) list: with the raw LFP of the sessions that will be used in validation
369 | val_GTs: (n_val_sessions) list: with the GT events of n validation sessions
370 | sf: (int) original sampling frequency of the data in Hz
371 | sf: (int) desired downsample frequency of the data in Hz
372 | channels: (n_channels) np.array. Channels that will be used to generate data. Check interpolate_channels for more information
373 |
374 | Output:
375 | -------
376 | retrain_LFP: (n_samples x n_channels): sumbsampled, z-scored, interpolated and concatenated data from all the training sessions
377 | retrain_GT: (n_events x 2): concatenation of all the events in the training sessions
378 | norm_val_GT: (n_val_sessions) list: list with the normalized LFP of all the val sessions
379 | val_GTs: (n_val_sessions) list: Gt events of each val sessions
380 |
381 | A Rubio LCN 2023
382 |
383 | '''
384 | assert len(train_LFPs) == len(train_GTs), "The number of train LFPs doesn't match the number of train GTs"
385 | assert len(val_LFPs) == len(val_GTs), "The number of test LFPs doesn't match the number of test GTs"
386 |
387 | # All the training sessions data and GT will be concatenated in one data array and one GT array (2 x n events)
388 | retrain_LFP=[]
389 | for LFP,GT in zip(train_LFPs,train_GTs):
390 | # 1st session in the array
391 | print('Original training data shape: ',LFP.shape)
392 | if retrain_LFP==[]:
393 | retrain_LFP=process_LFP(LFP,sf,d_sf,channels)
394 | offset=len(retrain_LFP)/d_sf
395 | retrain_GT=GT
396 | # Append the rest of the sessions, taking into account the length (in seconds)
397 | # of the previous sessions, to cocatenate the events' times
398 | else:
399 | aux_LFP=process_LFP(LFP,sf,d_sf,channels)
400 | retrain_LFP=np.vstack([retrain_LFP,aux_LFP])
401 | retrain_GT=np.vstack([retrain_GT,GT+offset])
402 | offset+=len(aux_LFP)/d_sf
403 | # Each validation session LFP will be normalized, etc and stored in an array
404 | # the GT needs no further treatment
405 | norm_val_GT=[]
406 | for LFP in val_LFPs:
407 | print('Original validation data shape: ',LFP.shape)
408 | norm_val_GT.append(process_LFP(LFP,sf,d_sf,channels))
409 |
410 |
411 | return retrain_LFP, retrain_GT , norm_val_GT, val_GTs
412 |
413 | # Retrain the best model of each architecture, and save it in the path specified in save_path.
414 | # also plots the trai, test and validation performance
415 | def retrain_model(LFP_retrain,GT_retrain,LFP_val,GT_val,arch,parameters=None,save_path=None,d_sf=1250,merge_win=0):
416 | '''
417 | retrain_model(LFP_retrain,GT_retrain,LFP_val,GT_val,arch,parameters=None,save_path=None,d_sf=1250,merge_win=0)
418 |
419 | Retrains the best model of the specified architecture with the retrain data and the specified parameters. Performs validation if validation data is provided, and plots the train, test and validation performance.
420 |
421 | Mandatory inputs:
422 | -----------------
423 | LFP_retrain: (n_samples x n_channels). Concatenated LFP of all the trained sessions
424 | GT_retrain: (n_events x 2). List with the concatenated GT events times of n sessions,
425 | in the format [ini end] in seconds
426 | LFP_val: (n_val_sessions). List with the normalized LFP of the sessions that will
427 | be used in validation
428 | GT_val: (n_val_sessions). List with the GT events of the validation sessions
429 | arch: (string). Architecture of the model to be retrained
430 |
431 | Optional inputs:
432 | ----------------
433 | parameters: (dictionary) Parameters that will be use in each specific architecture retraining
434 | - For 'XGBOOST', not needed.
435 | - For 'SVM', one parameter is needed:
436 | - parameters['Undersampler proportion']: Any value between 0 and 1.
437 | This parameter eliminates samples where no ripple is present untill the
438 | desired proportion is achieved:
439 | Undersampler proportion= Positive samples/Negative samples
440 | - For 'LSTM', 'CNN1D' and 'CNN2D', two things are needed:
441 | - parameters['Epochs']. The number of times the training data set will
442 | be used to train the model.
443 | - parameters['Training batch']. The number of windows that will be processed
444 | before updating the weights
445 | save_path: (string). Path where the retrained model will be saved
446 | d_sf: (int) Desired subsampling frequency (in Hz)
447 | merge_win: (float). Minimal length of the interval in miliseconds between predictions. If
448 | two detections are closer in time than this parameter, they will be merged together
449 |
450 |
451 | Output:
452 | -------
453 | retrain_LFP: (n_samples x n_channels): sumbsampled, z-scored, interpolated and concatenated data from all the training sessions
454 | retrain_GT: (n_events x 2): concatenation of all the events in the training sessions
455 | norm_val_GT: (n_val_sessions) list: list with the normalized LFP of all the val sessions
456 | val_GTs: (n_val_sessions) list: Gt events of each val sessions
457 |
458 | A Rubio LCN 2023
459 |
460 | '''
461 | merge_s=round(d_sf*merge_win/1000)
462 | # Do the train/test split. Feel free to try other proportions
463 | LFP_test,events_test,LFP_train,events_train=split_data(LFP_retrain,GT_retrain,split=0.7,sf=d_sf)
464 |
465 | print(f'Number of validation sessions: {len(LFP_val)}') #TODO: for shwoing length and events
466 | print(f'Shape of train data: {LFP_train.shape}, Number of train events: {events_train.shape[0]}')
467 | print(f'Shape of test data: {LFP_test.shape}, Number of test events: {events_test.shape[0]}')
468 |
469 | # prediction parser returns the retrained model, the output predictions probabilities
470 | model,y_pred_train,y_pred_test=retraining_parser(arch,LFP_train,events_train,LFP_test,events_test,params=parameters,d_sf=d_sf)
471 |
472 | # Save model if save_path is not empty
473 | if save_path:
474 | save_model(model,arch,save_path)
475 |
476 | # Plot section #
477 | # for loop iterating over the validation data
478 | val_pred=[]
479 | # The correct n_channels and timesteps needs to be passed to predict for the fcn to work when using new_model
480 | if arch=='XGBOOST':
481 | n_channels=8
482 | timesteps=16
483 | elif arch=='SVM':
484 | n_channels=8
485 | timesteps=1
486 | elif arch=='LSTM':
487 | n_channels=8
488 | timesteps=32
489 | elif arch=='CNN2D':
490 | n_channels=8
491 | timesteps=40
492 | elif arch=='CNN1D':
493 | n_channels=8
494 | timesteps=16
495 |
496 | for LFP in LFP_val:
497 | val_pred.append(predict(LFP,sf=d_sf,arch=arch,new_model=model,n_channels=n_channels,n_timesteps=timesteps)[0])
498 | # Extract and plot the train and test performance
499 | th_arr=np.linspace(0.1,0.9,9)
500 | F1_train=np.empty(shape=len(th_arr))
501 | F1_test=np.empty(shape=len(th_arr))
502 | for i,th in enumerate(th_arr):
503 | pred_train_events=get_predictions_index(y_pred_train,th,merge_samples=merge_s)/d_sf
504 | pred_test_events=get_predictions_index(y_pred_test,th,merge_samples=merge_s)/d_sf
505 | _,_,F1_train[i],_,_,_=get_performance(pred_train_events,events_train,verbose=False)
506 | _,_,F1_test[i],_,_,_=get_performance(pred_test_events,events_test,verbose=False)
507 |
508 |
509 | fig,axs=plt.subplots(1,2,figsize=(12,5),sharey='all')
510 | axs[0].plot(th_arr,F1_train,'k.-')
511 | axs[0].plot(th_arr,F1_test,'b.-')
512 | axs[0].legend(['Train','Test'])
513 | axs[0].set_ylim([0 ,max(max(F1_train), max(F1_test)) + 0.1])
514 | axs[0].set_title('F1 test and train')
515 | axs[0].set_ylabel('F1')
516 | axs[0].set_xlabel('Threshold')
517 |
518 |
519 | # Validation plot in the second ax
520 | F1_val=np.zeros(shape=(len(LFP_val),len(th_arr)))
521 | for j,pred in enumerate(val_pred):
522 | for i,th in enumerate(th_arr):
523 | pred_val_events=get_predictions_index(pred,th,merge_samples=merge_s)/d_sf
524 | _,_,F1_val[j,i],_,_,_=get_performance(pred_val_events,GT_val[j],verbose=False)
525 |
526 | for i in range(len(LFP_val)):
527 | axs[1].plot(th_arr,F1_val[i])
528 | axs[1].plot(th_arr,np.mean(F1_val,axis=0),'k.-')
529 | axs[1].set_title('Validation F1')
530 | axs[1].set_ylabel('F1')
531 | axs[1].set_xlabel('Threshold')
532 |
533 |
534 |
535 | plt.show()
536 |
--------------------------------------------------------------------------------