├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── LICENSE ├── README.md ├── deepinsight ├── __init__.py ├── analyse.py ├── architecture.py ├── preprocess.py ├── train.py ├── util │ ├── __init__.py │ ├── custom_losses.py │ ├── data_generator.py │ ├── hdf5.py │ ├── opts.py │ ├── stats.py │ ├── tetrode.py │ └── wavelet_transform.py └── visualize.py ├── media ├── colab_walkthrough.gif ├── decoding_error.gif └── model_architecture.png ├── notebooks ├── deepinsight_calcium_example.ipynb ├── example_data │ └── calcium │ │ └── calcium_rois.jpg └── static │ ├── calcium_example.ipynb │ └── ephys_example.ipynb ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── run_test.py └── tests.ipynb /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: build 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.7 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.7 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pytest 27 | pip install -e git+https://github.com/CYHSM/DeepInsight.git#egg=DeepInsight 28 | pip install git+https://github.com/CYHSM/wavelets 29 | #- name: Lint with flake8 30 | # run: | 31 | # # stop the build if there are Python syntax errors or undefined names 32 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 33 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | pytest tests/run_test.py 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.p 3 | *.h5 4 | *.hdf5 5 | logs* 6 | .ipynb_checkpoints 7 | *.mp4 8 | /data 9 | .vscode 10 | *.mat 11 | *.html 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | notebooks/private/ 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Markus Frey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg)](https://github.com/CYHSM/DeepInsight/blob/master/LICENSE.md) 2 | ![py36 status](https://img.shields.io/badge/python3.6-supported-green.svg) 3 | ![Build Status](https://github.com/CYHSM/DeepInsight/workflows/build/badge.svg) 4 | 5 | # DeepInsight: A general framework for interpreting wide-band neural activity 6 | 7 | DeepInsight is a toolbox for the analysis and interpretation of wide-band neural activity and can be applied on unsorted neural data. This means the traditional step of spike-sorting can be omitted and the raw data can be used directly as input, providing a more objective way of measuring decoding performance. 8 | ![Model Architecture](media/model_architecture.png) 9 | 10 | ## Google Colaboratory 11 | 12 | We created a Colab notebook to showcase how to analyse your own two-photon calcium imaging data. We provide the raw as well as the preprocessed dataset as downloads if you just want to train the model. You can replace the code which loads the traces with your own data handling and directly train it to decode your behaviour or stimuli in the browser. 13 | 14 | [![Two-Photon Imaging](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11RXK7JIgVM8Zy9M7xEtt1k62i3JXbZLU) 15 | ![Colab Walkthrough](media/colab_walkthrough.gif) 16 | 17 | ## Example Usage 18 | ```python 19 | import deepinsight 20 | 21 | # Load your electrophysiological or calcium-imaging data 22 | (raw_data, 23 | raw_timestamps, 24 | output, 25 | output_timestamps, 26 | info) = deepinsight.util.tetrode.read_tetrode_data(fp_raw_file) 27 | 28 | # Transform raw data to frequency domain 29 | deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=info['sampling_rate'], 30 | channels=info['channels']) 31 | 32 | # Prepare outputs 33 | deepinsight.util.tetrode.preprocess_output(fp_deepinsight, raw_timestamps, output, output_timestamps, 34 | sampling_rate=info['sampling_rate']) 35 | 36 | # Train the model 37 | deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights) 38 | 39 | # Get loss and shuffled loss for influence plot 40 | losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=10) 41 | shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1, stepsize=10) 42 | 43 | # Plot influence across behaviours 44 | deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2) 45 | ``` 46 | 47 | See also the [jupyter notebook](notebooks/static/ephys_example.ipynb) for a full example for decoding behaviours from tetrode CA1 recordings. Note that the static notebook does not include interactive plots as shown in the above Colab notebook. The expected run time for a high sampling rate dataset (e.g. tetrode recordings) is highly dependend on the number of channels and duration of experiment. Preprocessing can take up to one day for a 128 channel - 1 hour experiment, while training the model takes between 6 and 12 hours. For calcium recordings the preprocessing time is shrunk down to minutes. 48 | 49 | Following Video shows the performance of the model trained on position (left), head direction (top right) and speed (bottom right): 50 | ![Model Performance](media/decoding_error.gif) 51 | 52 | ## Installation 53 | Install DeepInsight with the following command (Installation time ~ 2 minutes, depending on internet speed): 54 | ``` 55 | pip install git+https://github.com/CYHSM/DeepInsight.git 56 | ``` 57 | 58 | If you prefer to use DeepInsight from within your browser, we provide Colab-Notebooks to guide you through how to use DeepInsight with your own data. 59 | 60 | - How to use DeepInsight with two-photon calcium imaging data [![Two-Photon Imaging](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11RXK7JIgVM8Zy9M7xEtt1k62i3JXbZLU) 61 | 62 | - How to use DeepInsight with electrophysiology data [![Ephys Data](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1h3RYr3r0Zs2k6I53bTiYRq_6VQo38iMP) 63 | 64 | ## System Requirements 65 | 66 | ### Hardware requirements 67 | For preprocessing raw data with a high sampling rate it is recommended to at least use 4 parallel cores. For calcium recordings one core is enough. For training the model it is recommended to use a GPU with at least 6Gb of memory. 68 | 69 | ### Software requirements 70 | The following python dependencies are being automatically installed when installing DeepInsight (specified in requirements.txt): 71 | ``` 72 | tensorflow-gpu (2.1.0) 73 | numpy (1.18.1) 74 | pandas (1.0.1) 75 | joblib (0.14.1) 76 | seaborn (0.10.0) 77 | matplotlib (3.1.3) 78 | h5py (2.10.0) 79 | scipy (1.4.1) 80 | ipython (7.12.0) 81 | ``` 82 | Version in parentheses indicate the ones used for testing the framework. Its extensively tested on Linux 16.04 but should run on all OS (Windows, Mac, Linux) supporting a Python version >3.6 and pip. It is recommended to install the framework and dependencies in a virtual environment (e.g. conda). -------------------------------------------------------------------------------- /deepinsight/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util 2 | from . import preprocess 3 | from . import architecture 4 | from . import train 5 | from . import analyse 6 | from . import visualize 7 | -------------------------------------------------------------------------------- /deepinsight/analyse.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | from . import util 8 | import h5py 9 | import numpy as np 10 | import pandas as pd 11 | from scipy.stats import spearmanr 12 | from tensorflow.compat.v1.keras.backend import clear_session, placeholder, get_session 13 | import os 14 | 15 | import tensorflow as tf 16 | tf.compat.v1.disable_eager_execution() 17 | 18 | 19 | def get_model_loss(fp_hdf_out, stepsize=1, shuffles=None, axis=0, verbose=1, fp_test=None, timestep=-1): 20 | """ 21 | Loops across cross validated models and calculates loss and predictions for full experiment length 22 | 23 | Parameters 24 | ---------- 25 | fp_hdf_out : str 26 | File path to HDF5 file 27 | stepsize : int, optional 28 | Determines how many samples will be evaluated. 1 -> N samples evaluated, 29 | 2 -> N/2 samples evaluated, etc..., by default 1 30 | shuffles : dict, optional 31 | If wavelets should be shuffled, important for calculating influence scores, by default None 32 | 33 | Returns 34 | ------- 35 | losses : (N,1) array_like 36 | Loss between predicted and ground truth observation 37 | predictions : dict 38 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...) 39 | indices : (N,1) array_like 40 | Indices which were evaluated, important when taking stepsize unequal to 1 41 | """ 42 | dirname = os.path.dirname(fp_hdf_out) 43 | if fp_test is None: 44 | filename = os.path.basename(fp_hdf_out)[0:-3] 45 | else: 46 | filename = os.path.basename(fp_test)[0:-3] 47 | cv_results = [] 48 | (_, _, _, opts) = util.hdf5.load_model_with_opts( 49 | dirname + '/models/' + filename + '_model_{}.h5'.format(0)) 50 | loss_names = opts['loss_names'] 51 | time_shift = opts['model_timesteps'] 52 | if verbose > 0: 53 | progress_bar = tf.keras.utils.Progbar( 54 | opts['num_cvs'], width=30, verbose=1, interval=0.05, unit_name='run') 55 | for k in range(0, opts['num_cvs']): 56 | clear_session() 57 | # Find folders 58 | model_path = dirname + '/models/' + filename + '_model_{}.h5'.format(k) 59 | # Load model and generators 60 | (model, training_generator, testing_generator, opts) = util.hdf5.load_model_with_opts(model_path) 61 | if fp_test is not None: 62 | opts['fp_hdf_out'] = fp_hdf_out 63 | opts['handle_nan'] = False 64 | hdf5_file = h5py.File(fp_hdf_out, mode='r') 65 | wavelets = hdf5_file['inputs/wavelets'][()] 66 | hdf5_file.close() 67 | opts['training_indices'] = np.arange(0,wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size'])) 68 | opts['testing_indices'] = np.arange(0,wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size'])) 69 | (training_generator, testing_generator) = util.data_generator.create_train_and_test_generators(opts) 70 | #testing_generator.cv_indices = np.arange(0, testing_generator.wavelets.shape[0] - (opts['model_timesteps'] * opts['batch_size'])) 71 | # ----------------------------------------------------------------------------------------------- 72 | if shuffles is not None: 73 | testing_generator = shuffle_wavelets( 74 | training_generator, testing_generator, shuffles) 75 | losses, predictions, indices = calculate_losses_from_generator( 76 | testing_generator, model, verbose=verbose-1, stepsize=stepsize) 77 | # ----------------------------------------------------------------------------------------------- 78 | cv_results.append((losses, predictions, indices)) 79 | if verbose > 0: 80 | progress_bar.add(1) 81 | cv_results = np.array(cv_results, dtype='object') 82 | # Reshape cv_results 83 | losses = np.concatenate(cv_results[:, 0], axis=0) 84 | predictions = {k: [] for k in loss_names} 85 | for out in cv_results[:, 1]: 86 | for p, name in zip(out, loss_names): 87 | predictions[name].append(p) 88 | for key, item in predictions.items(): 89 | tmp_output = np.concatenate(predictions[key], axis=0)[:, timestep, :] 90 | predictions[key] = tmp_output 91 | indices = np.concatenate(cv_results[:, 2], axis=0) 92 | # We only take the last timestep for decoding, so decoder does not see any part of the future 93 | # We also need to shift the indices as we decode samples within the time window 94 | time_shifts = np.arange(0, testing_generator.model_timesteps + 1, testing_generator.average_output)[1::] - 1 95 | indices = indices + time_shifts[timestep] 96 | 97 | # Also save to HDF5 98 | hdf5_file = h5py.File(fp_hdf_out, mode='a') 99 | for key, item in predictions.items(): 100 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/predictions/{}_axis{}_stepsize{}".format(key, axis, stepsize), 101 | dataset_shape=item.shape, dataset_type=np.float32, dataset_value=item) 102 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/losses_axis{}_stepsize{}".format(axis, stepsize), 103 | dataset_shape=losses.shape, dataset_type=np.float32, dataset_value=losses) 104 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/indices_axis{}_stepsize{}".format(axis, stepsize), 105 | dataset_shape=indices.shape, dataset_type=np.int64, dataset_value=indices) 106 | 107 | # Add real to output of this function 108 | output_real = dict() 109 | for idx, (key, y_pred) in enumerate(predictions.items()): 110 | output_real[key] = np.array(hdf5_file['outputs/{}'.format(key)])[indices, ...] 111 | 112 | hdf5_file.close() 113 | 114 | # Report model performance 115 | if verbose > 0: 116 | df_stats = calculate_model_stats(losses, predictions, indices, output_real) 117 | print(df_stats) 118 | 119 | return losses, predictions, indices, output_real 120 | 121 | 122 | def get_shuffled_model_loss(fp_hdf_out, stepsize=1, axis=0, verbose=1): 123 | """ 124 | Shuffles the wavelets and recalculates error 125 | 126 | Parameters 127 | ---------- 128 | fp_hdf_out : str 129 | File path to HDF5 file 130 | stepsize : int, optional 131 | Determines how many samples will be evaluated. 1 -> N samples evaluated, 132 | 2 -> N/2 samples evaluated, etc..., by default 1 133 | axis : int, optional 134 | Which axis to shuffle 135 | 136 | Returns 137 | ------- 138 | shuffled_losses : (N,1) array_like 139 | Loss between predicted and ground truth observation for shuffled wavelets on specified axis 140 | """ 141 | if axis == 0: 142 | raise ValueError( 143 | 'Shuffling across time dimension (axis=0) not supported yet.') 144 | hdf5_file = h5py.File(fp_hdf_out, mode='r') 145 | tmp_wavelets_shape = hdf5_file['inputs/wavelets'].shape 146 | hdf5_file.close() 147 | shuffled_losses = [] 148 | if verbose > 0: 149 | progress_bar = tf.keras.utils.Progbar( 150 | tmp_wavelets_shape[axis], width=30, verbose=1, interval=0.05, unit_name='run') 151 | for s in range(0, tmp_wavelets_shape[axis]): 152 | if axis == 1: 153 | losses = get_model_loss(fp_hdf_out, stepsize=stepsize, shuffles={'f': s}, axis=axis, verbose=0)[0] 154 | elif axis == 2: 155 | losses = get_model_loss(fp_hdf_out, stepsize=stepsize, shuffles={'c': s}, axis=axis, verbose=0)[0] 156 | shuffled_losses.append(losses) 157 | if verbose > 0: 158 | progress_bar.add(1) 159 | shuffled_losses = np.array(shuffled_losses) 160 | # Also save to HDF5 161 | hdf5_file = h5py.File(fp_hdf_out, mode='a') 162 | util.hdf5.create_or_update(hdf5_file, dataset_name="analysis/influence/shuffled_losses_axis{}_stepsize{}".format(axis, stepsize), 163 | dataset_shape=shuffled_losses.shape, dataset_type=np.float32, dataset_value=shuffled_losses) 164 | hdf5_file.close() 165 | 166 | return shuffled_losses 167 | 168 | 169 | def calculate_losses_from_generator(tg, model, num_steps=None, stepsize=1, verbose=0): 170 | """ 171 | Keras evaluate_generator only returns a scalar loss (mean) while predict_generator only returns the predictions but not the real labels 172 | TODO Make it batch size independent 173 | 174 | Parameters 175 | ---------- 176 | tg : object 177 | Data generator 178 | model : object 179 | Keras model 180 | num_steps : int, optional 181 | How many steps should be evaluated, by default None (runs through full experiment) 182 | stepsize : int, optional 183 | Determines how many samples will be evaluated. 1 -> N samples evaluated, 184 | 2 -> N/2 samples evaluated, etc..., by default 1 185 | verbose : int, optional 186 | Verbosity level 187 | 188 | Returns 189 | ------- 190 | losses : (N,1) array_like 191 | Loss between predicted and ground truth observation 192 | predictions : dict 193 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...) 194 | indices : (N,1) array_like 195 | Indices which were evaluated, important when taking stepsize unequal to 1 196 | """ 197 | # X.) Parse inputs 198 | if num_steps is None: 199 | num_steps = len(tg) 200 | 201 | # 1.) Make a copy and adjust attributes 202 | tmp_dict = tg.__dict__.copy() 203 | if tg.batch_size != 1: 204 | tg.batch_size = 1 205 | tg.random_batches = False 206 | tg.shuffle = False 207 | tg.sample_size = tg.model_timesteps * tg.batch_size 208 | 209 | # 2.) Get output tensors 210 | sess = get_session() 211 | (_, test_out) = tg.__getitem__(0) 212 | real_tensor, calc_tensors = placeholder(), [] 213 | for output_index in range(0, len(test_out)): 214 | prediction_tensor = model.outputs[output_index] 215 | loss_tensor = model.loss_functions[output_index].fn( 216 | real_tensor, prediction_tensor) 217 | calc_tensors.append((prediction_tensor, loss_tensor)) 218 | 219 | # 3.) Predict 220 | losses, predictions, indices = [], [], [] 221 | for i in range(0, num_steps, stepsize): 222 | (in_tg, out_tg) = tg.__getitem__(i) 223 | indices.append(tg.cv_indices[i]) 224 | loss, prediction = [], [] 225 | for o in range(0, len(out_tg)): 226 | evaluated = sess.run(calc_tensors[o], feed_dict={ 227 | model.input: in_tg, real_tensor: out_tg[o]}) 228 | prediction.append(evaluated[0][0, ...]) 229 | loss.append(evaluated[1][0, ...]) # Get rid of batch dimensions 230 | predictions.append(prediction) 231 | losses.append(loss) 232 | if verbose > 0 and not i % 50: 233 | print('{} / {}'.format(i, num_steps), end='\r') 234 | if verbose > 0: 235 | print('Performed {} gradient steps'.format(num_steps // stepsize)) 236 | losses, predictions, indices = np.array( 237 | losses), swap_listaxes(predictions), np.array(indices) 238 | tg.__dict__.update(tmp_dict) 239 | 240 | return losses, predictions, indices 241 | 242 | 243 | def shuffle_wavelets(training_generator, testing_generator, shuffles): 244 | """ 245 | Shuffle procedure for model interpretation 246 | 247 | Parameters 248 | ---------- 249 | training_generator : object 250 | Data generator for training data 251 | testing_generator : object 252 | Data generator for testing data 253 | shuffles : dict 254 | Indicates which axis to shuffle and which index in selected dimension, e.g. {'f' : 5} shuffles frequency axis 5 255 | 256 | Returns 257 | ------- 258 | testing_generator : object 259 | Data generator for testing data with shuffled wavelets 260 | """ 261 | rolled_wavelets = training_generator.wavelets.copy() 262 | for key, item in shuffles.items(): 263 | if key == 'f': 264 | np.random.shuffle(rolled_wavelets[:, item, :]) # In place 265 | elif key == 'c': 266 | np.random.shuffle(rolled_wavelets[:, :, item]) # In place 267 | elif key == 't': 268 | np.random.shuffle(rolled_wavelets[item, :, :]) # In place 269 | testing_generator.wavelets = rolled_wavelets 270 | return testing_generator 271 | 272 | 273 | def swap_listaxes(list_in): 274 | list_out = [] 275 | for o in range(0, len(list_in[0])): 276 | list_out.append(np.array([out[o] for out in list_in])) 277 | return list_out 278 | 279 | 280 | def calculate_model_stats(losses, predictions, indices, real, additional_metrics=[spearmanr]): 281 | """ 282 | Calculates statistics on model predictions 283 | 284 | Parameters 285 | ---------- 286 | fp_hdf_out : str 287 | File path to HDF5 file 288 | losses : (N,1) array_like 289 | Loss between predicted and ground truth observation 290 | predictions : dict 291 | Dictionary with predictions for each behaviour, each item in dict has size (N, Z) with Z the dimensions of the sample (e.g. Z_position=2, Z_speed=1, ...) 292 | indices : (N,1) array_like 293 | Indices which were evaluated, important when taking stepsize unequal to 1 294 | additional_metrics : list, optional 295 | Additional metrics besides Pearson and Model loss to be evaluated, should take arguments (y_true, y_pred) and return scalar or first argument as metric 296 | 297 | Returns 298 | ------- 299 | df_scores 300 | Dataframe of evaluated scores 301 | """ 302 | output_scores = [] 303 | for idx, ((key, y_pred), (key2, y_true)) in enumerate(zip(predictions.items(), real.items())): 304 | pearson_mean, additional_mean = 0, np.zeros((len(additional_metrics))) 305 | for p in range(y_pred.shape[1]): 306 | pearson_mean += np.corrcoef(y_true[:, p], y_pred[:, p])[0, 1] 307 | for add_idx, am in enumerate(additional_metrics): 308 | am_eval = am(y_true[:, p], y_pred[:, p]) 309 | if len(am_eval) > 1: 310 | am_eval = am_eval[0] 311 | additional_mean[add_idx] += am_eval 312 | additional_mean /= y_pred.shape[1] 313 | pearson_mean /= y_pred.shape[1] 314 | loss_mean = np.mean(losses[:, idx]) 315 | output_scores.append((pearson_mean, loss_mean, *additional_mean)) 316 | additional_columns = [f.__name__.title() for f in additional_metrics] 317 | df_scores = pd.DataFrame(output_scores, index=predictions.keys(), columns=['Pearson', 'Model Loss', *additional_columns]) 318 | 319 | return df_scores 320 | -------------------------------------------------------------------------------- /deepinsight/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | from tensorflow.keras.layers import Conv2D, GaussianNoise, TimeDistributed, Input, Dense, Lambda, Flatten, Dropout 8 | from tensorflow.keras.models import Model 9 | import tensorflow.keras.backend as K 10 | 11 | 12 | def the_decoder(tg, show_summary=True): 13 | """ 14 | Model architecture used for decoding from wavelet transformed neural signals 15 | 16 | Parameters 17 | ---------- 18 | tg : object 19 | Data generator, holding all important options for creating and training the model 20 | show_summary : bool, optional 21 | Whether to show a summary of the model after creation, by default True 22 | 23 | Returns 24 | ------- 25 | model : object 26 | Keras model 27 | """ 28 | model_input = Input(shape=tg.input_shape) 29 | 30 | x = GaussianNoise(tg.gaussian_noise)(model_input) 31 | # timestep reductions 32 | for nct in range(0, tg.num_convs_tsr): 33 | x = TimeDistributed(Conv2D(filters=tg.filter_size, kernel_size=(tg.kernel_size, tg.kernel_size), strides=( 34 | 2, 1), padding=tg.conv_padding, activation=tg.act_conv, name='conv_tsr{}'.format(nct)))(x) 35 | x = TimeDistributed(Conv2D(filters=tg.filter_size, kernel_size=(tg.kernel_size, tg.kernel_size), strides=( 36 | 1, 2), padding=tg.conv_padding, activation=tg.act_conv, name='conv_fr{}'.format(nct)))(x) 37 | 38 | # batch x 128 x 60 x 11 39 | x = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1, 4)))(x) 40 | 41 | layer_counter = 0 42 | while (K.int_shape(x)[3] > tg.channel_lower_limit): 43 | x = TimeDistributed(Conv2D(filters=tg.filter_size * 2, kernel_size=(1, 2), strides=(1, 2), 44 | padding=tg.conv_padding, activation=tg.act_conv, name='conv_after_tsr{}'.format(layer_counter)))(x) 45 | layer_counter += 1 46 | 47 | # Flatten and fc 48 | x_flat = TimeDistributed(Flatten())(x) 49 | 50 | outputs = [] 51 | for (key, item), output in zip(tg.loss_functions.items(), tg.outputs): 52 | x = x_flat 53 | for d in range(0, tg.num_dense): 54 | x = Dense(tg.num_units_dense, activation=tg.act_fc, name='dense{}_combine{}'.format(d, key))(x) 55 | x = Dropout(tg.dropout_ratio)(x) 56 | out = Dense(output.shape[1], name='{}'.format(key), activation=tg.last_layer_activation_function)(x) 57 | outputs.append(out) 58 | 59 | model = Model(inputs=model_input, outputs=outputs) 60 | 61 | if show_summary: 62 | print(model.summary(line_length=100)) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /deepinsight/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import time 8 | from joblib import Parallel, delayed 9 | import numpy as np 10 | import h5py 11 | import tensorflow as tf # Progress bar only 12 | import deepinsight.util.wavelet_transform as wt 13 | from deepinsight.util import hdf5 14 | 15 | 16 | def preprocess_input(fp_hdf_out, raw_data, average_window=1000, channels=None, window_size=100000, 17 | gap_size=50000, sampling_rate=30000, scaling_factor=0.5, num_cores=4, **args): 18 | """ 19 | Transforms raw neural data to frequency space, via wavelet transform implemented currently with aaren-wavelets (https://github.com/aaren/wavelets) 20 | Saves wavelet transformed data to HDF5 file (N, P, M) - (Number of timepoints, Number of frequencies, Number of channels) 21 | 22 | Parameters 23 | ---------- 24 | fp_hdf_out : str 25 | File path to HDF5 file 26 | raw_data : (N, M) file or array_like 27 | Variable storing the raw_data (N data points, M channels), should allow indexing 28 | average_window : int, optional 29 | Average window to downsample wavelet transformed input, by default 1000 30 | channels : array_like, optional 31 | Which channels from raw_data to use, by default None 32 | window_size : int, optional 33 | Window size for calculating wavelet transformation, by default 100000 34 | gap_size : int, optional 35 | Gap size for calculating wavelet transformation, by default 50000 36 | sampling_rate : int, optional 37 | Sampling rate of raw_data, by default 30000 38 | scaling_factor : float, optional 39 | Determines amount of log-spaced frequencies P in output, by default 0.5 40 | num_cores : int, optional 41 | Number of paralell cores to use to calculate wavelet transformation, by default 4 42 | """ 43 | # Get number of chunks 44 | if channels is None: 45 | channels = np.arange(0, raw_data.shape[1]) 46 | num_points = raw_data.shape[0] 47 | if window_size > num_points: 48 | num_chunks = len(channels) 49 | output_size = raw_data.shape[0] 50 | mean_signal = np.mean(raw_data, axis=1) 51 | average_window = 1 52 | full_transform = True 53 | else: 54 | num_chunks = (num_points // gap_size) - 1 55 | output_size = ((num_chunks + 1) * gap_size) // average_window 56 | full_transform = False 57 | 58 | # Get estimate for number of frequencies 59 | (_, wavelet_frequencies) = wt.wavelet_transform(np.ones(window_size), sampling_rate, average_window, scaling_factor, **args) 60 | num_fourier_frequencies = len(wavelet_frequencies) 61 | # Prepare output file 62 | hdf5_file = h5py.File(fp_hdf_out, mode='a') 63 | if "inputs/wavelets" not in hdf5_file: 64 | hdf5_file.create_dataset("inputs/wavelets", [output_size, num_fourier_frequencies, len(channels)], np.float32) 65 | hdf5_file.create_dataset("inputs/fourier_frequencies", [num_fourier_frequencies], np.float16) 66 | # Makes saving 5 times faster as last index saving is fancy indexing and therefore slow 67 | hdf5_file.create_dataset("inputs/tmp_wavelets", [len(channels), output_size, num_fourier_frequencies], np.float32) 68 | 69 | # Prepare par pool 70 | par = Parallel(n_jobs=num_cores, verbose=0) 71 | 72 | # Start parallel wavelet transformation 73 | print('Starting wavelet transformation (n={}, chunks={}, frequencies={})'.format( 74 | num_points, num_chunks, num_fourier_frequencies)) 75 | progress_bar = tf.keras.utils.Progbar(num_chunks, width=30, verbose=1, interval=0.05, unit_name='chunk') 76 | for c in range(0, num_chunks): 77 | if full_transform: 78 | raw_chunk = raw_data[:, c] - mean_signal 79 | else: 80 | start = gap_size * c 81 | end = start + window_size 82 | raw_chunk = raw_data[start: end, channels] 83 | # Process raw chunk 84 | raw_chunk = preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False) 85 | 86 | # Calculate wavelet transform 87 | if full_transform: 88 | (wavelet_power, wavelet_frequencies) = wt.wavelet_transform(raw_chunk, 89 | sampling_rate=sampling_rate, scaling_factor=scaling_factor, average_window=average_window, **args) 90 | else: 91 | wavelet_transformed = np.zeros((raw_chunk.shape[0] // average_window, num_fourier_frequencies, raw_chunk.shape[1])) 92 | for ind, (wavelet_power, wavelet_frequencies) in enumerate(par(delayed(wt.wavelet_transform)(raw_chunk[:, i], sampling_rate, average_window, scaling_factor, **args) for i in range(0, raw_chunk.shape[1]))): 93 | wavelet_transformed[:, :, ind] = wavelet_power 94 | 95 | # Save in output file 96 | if full_transform: 97 | hdf5_file["inputs/tmp_wavelets"][c, :, :] = wavelet_power 98 | else: 99 | wavelet_index_end = end // average_window 100 | wavelet_index_start = start // average_window 101 | index_gap = gap_size // 2 // average_window 102 | if c == 0: 103 | this_index_start = 0 104 | this_index_end = wavelet_index_end - index_gap 105 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[0: -index_gap, :, :] 106 | elif c == num_chunks - 1: # Make sure the last one fits fully 107 | this_index_start = wavelet_index_start + index_gap 108 | this_index_end = wavelet_index_end 109 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[index_gap::, :, :] 110 | else: 111 | this_index_start = wavelet_index_start + index_gap 112 | this_index_end = wavelet_index_end - index_gap 113 | hdf5_file["inputs/wavelets"][this_index_start:this_index_end, :, :] = wavelet_transformed[index_gap: -index_gap, :, :] 114 | hdf5_file.flush() 115 | progress_bar.add(1) 116 | 117 | # 7.) Put frequencies in and close file 118 | if full_transform: 119 | wavelet_power = np.transpose(hdf5_file["inputs/tmp_wavelets"], axes=(1, 2, 0)) 120 | del hdf5_file["inputs/tmp_wavelets"] 121 | hdf5_file["inputs/wavelets"][:] = wavelet_power 122 | hdf5_file["inputs/fourier_frequencies"][:] = wavelet_frequencies 123 | hdf5_file.flush() 124 | hdf5_file.close() 125 | 126 | 127 | def preprocess_chunk(raw_chunk, subtract_mean=True, convert_to_milivolt=False): 128 | """ 129 | Preprocesses a chunk of data. 130 | 131 | Parameters 132 | ---------- 133 | raw_chunk : array_like 134 | Chunk of raw_data to preprocess 135 | subtract_mean : bool, optional 136 | Subtract mean over all other channels, by default True 137 | convert_to_milivolt : bool, optional 138 | Convert chunk to milivolt , by default False 139 | 140 | Returns 141 | ------- 142 | raw_chunk : array_like 143 | preprocessed_chunk 144 | """ 145 | # Subtract mean across all channels 146 | if subtract_mean: 147 | raw_chunk = raw_chunk.transpose() - np.mean(raw_chunk.transpose(), axis=0) 148 | raw_chunk = raw_chunk.transpose() 149 | # Convert to milivolt 150 | if convert_to_milivolt: 151 | raw_chunk = raw_chunk * (0.195 / 1000) 152 | return raw_chunk 153 | 154 | 155 | def preprocess_output(fp_hdf_out, raw_timestamps, output, output_timestamps, average_window=1000, dataset_name='aligned', dataset_type=np.float16): 156 | """ 157 | Base file for preprocessing outputs (handles M-D case as of March2020). 158 | For more complex cases use specialized functions (see for example preprocess_output in util.tetrode module) 159 | 160 | Parameters 161 | ---------- 162 | fp_hdf_out : str 163 | File path to HDF5 file 164 | raw_timestamps : (N,1) array_like 165 | Timestamps for each sample in continous 166 | output : array_like 167 | M dimensional output which will be aligned with continous 168 | output_timestamps : (N,1) array_like 169 | Timestamps for output 170 | average_window : int, optional 171 | Downsampling factor for raw data and output, by default 1000 172 | dataset_name : str, optional 173 | Field name for output stored in HDF5 file 174 | """ 175 | hdf5_file = h5py.File(fp_hdf_out, mode='a') 176 | 177 | # Get size of wavelets 178 | input_length = hdf5_file['inputs/wavelets'].shape[0] 179 | 180 | # Get positions of both LEDs 181 | raw_timestamps = raw_timestamps[()] # Slightly faster than np.array 182 | if output.ndim == 1: 183 | output = output[..., np.newaxis] 184 | 185 | output_aligned = np.array([np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0], 186 | average_window)], output_timestamps, output[:, i]) for i in range(output.shape[1])]).transpose() 187 | 188 | # Create and save datasets in HDF5 File 189 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/{}".format(dataset_name), 190 | dataset_shape=[input_length, output_aligned.shape[1]], dataset_type=dataset_type, dataset_value=output_aligned[0: input_length, ...]) 191 | hdf5_file.flush() 192 | hdf5_file.close() 193 | print('Successfully written Dataset="{}" to {}'.format(dataset_name, fp_hdf_out)) 194 | -------------------------------------------------------------------------------- /deepinsight/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import os 8 | import numpy as np 9 | import h5py 10 | 11 | from tensorflow.keras import optimizers 12 | from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau 13 | 14 | import tensorflow.keras.backend as K 15 | 16 | from . import architecture 17 | from . import util 18 | 19 | 20 | def train_model_on_generator(model, training_generator, testing_generator, loss_functions, loss_weights, steps_per_epoch=300, validation_steps=300, loss_metrics=[], 21 | epochs=10, tensorboard_logfolder='./', model_name='', verbose=1, reduce_lr=False, log_output=False, save_model_only=False, compile_only=False): 22 | """ 23 | Function for training a given model, with data provided by training and testing generators 24 | 25 | Parameters 26 | ---------- 27 | model : object 28 | Keras model 29 | training_generator : object 30 | Data generator for training data 31 | testing_generator : object 32 | Data generator for testing data 33 | loss_functions : dict 34 | Selected loss function for each behaviour 35 | loss_weights : dict 36 | Selected weights for each loss function 37 | steps_per_epoch : int, optional 38 | Number of steps for training the model, by default 300 39 | validation_steps : int, optional 40 | Number of steps for validating the model, by default 300 41 | epochs : int, optional 42 | Number of epochs to train model, by default 10 43 | tensorboard_logfolder : str, optional 44 | Where to store tensorboard logfiles, by default './' 45 | model_name : str, optional 46 | Name of selected model, used to return best model, by default '' 47 | verbose : int, optional 48 | Verbosity level, by default 1 49 | reduce_lr : bool, optional 50 | If True reduce learning rate on plateau, by default False 51 | log_output : bool, optional 52 | Log the output to tensorflow logfolder, by default False 53 | save_model_only : bool, optional 54 | Save best model after each epoch, by default False 55 | compile_only : bool, optional 56 | If true returns only compiled model, by default False 57 | 58 | Returns 59 | ------- 60 | model : object 61 | Keras model 62 | history : dict 63 | Dictionary containing training and validation performance 64 | """ 65 | # Compile model 66 | opt = optimizers.Adam(lr=training_generator.learning_rate, amsgrad=True) 67 | # Check if there are multiple outputs 68 | for key, item in loss_functions.items(): 69 | try: 70 | function_handle = getattr(util.custom_losses, item) 71 | except (AttributeError, TypeError) as e: 72 | function_handle = item 73 | loss_functions[key] = function_handle 74 | model.compile(loss=loss_functions, optimizer=opt, loss_weights=loss_weights, metrics=loss_metrics) 75 | if compile_only: # What a hack. Keras bug from Oct9 in saving/loading models. 76 | return model 77 | # Get model name for storing tmp files 78 | if model_name is '': 79 | model_name = training_generator.get_name() 80 | # Initiate callbacks 81 | callbacks = [] 82 | if reduce_lr: 83 | reduce_lr_cp = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, verbose=1) 84 | callbacks.append(reduce_lr_cp) 85 | if log_output: 86 | tensorboard_cp = TensorBoard(log_dir=tensorboard_logfolder) 87 | callbacks.append(tensorboard_cp) 88 | if save_model_only: 89 | file_name = model_name + '.hdf5' 90 | model_cp = ModelCheckpoint(filepath=file_name, save_best_only=True, save_weights_only=True) 91 | callbacks.append(model_cp) 92 | # Run model training 93 | try: 94 | history = model.fit(training_generator, steps_per_epoch=steps_per_epoch, epochs=epochs, shuffle=training_generator.shuffle, 95 | validation_steps=validation_steps, validation_data=testing_generator, verbose=verbose, callbacks=callbacks) 96 | except KeyboardInterrupt: 97 | print('-> Notebook interrupted') 98 | history = [] 99 | finally: 100 | if save_model_only: # Make sure interruption of jupyter notebook returns best model 101 | model.load_weights(file_name) 102 | print('-> Returning best Model') 103 | return (model, history) 104 | 105 | 106 | def train_model(model_path, path_in, tensorboard_logfolder, model_tmp_path, loss_functions, loss_weights, user_opts, num_cvs=5, verbose=0): 107 | """ 108 | Trains the model across the experiment using cross validation and saves the model files 109 | TODO Save models back to HDF5 to keep everything in one place 110 | 111 | Parameters 112 | ---------- 113 | model_path : str 114 | Path to where model should be stored 115 | path_in : str 116 | Path to HDF5 File 117 | tensorboard_logfolder : str 118 | Path to where tensorboard logs should be stored 119 | model_tmp_path : str 120 | Temporary file path used for returning best fit model 121 | loss_functions : dict 122 | For each output the corresponding loss function 123 | loss_weights : dict 124 | For each output the corresponding weight 125 | user_opts : dict 126 | Model parameters in case default opts should be changed 127 | num_cvs : int, optional 128 | Number of cross validation splits, by default 5 129 | """ 130 | # Get experiment length 131 | hdf5_file = h5py.File(path_in, mode='r') 132 | tmp_wavelets = hdf5_file['inputs/wavelets'] 133 | tmp_opts = util.opts.get_opts(path_in, train_test_times=(np.array([]), np.array([]))) 134 | # check for user options 135 | if user_opts is not None: 136 | for key, value in user_opts.items(): 137 | tmp_opts[key] = value 138 | exp_indices = np.arange(0, tmp_wavelets.shape[0] - (tmp_opts['model_timesteps'] * tmp_opts['batch_size'])) 139 | cv_splits = np.array_split(exp_indices, num_cvs) 140 | for cv_run, cvs in enumerate(cv_splits): 141 | K.clear_session() 142 | # For cv 143 | training_indices = np.setdiff1d(exp_indices, cvs) # All except the test indices 144 | testing_indices = cvs 145 | # opts -> generators -> model 146 | opts = util.opts.get_opts(path_in, train_test_times=(training_indices, testing_indices)) 147 | opts['loss_functions'] = loss_functions.copy() 148 | opts['loss_weights'] = loss_weights 149 | opts['loss_names'] = list(loss_functions.keys()) 150 | opts['num_cvs'] = num_cvs 151 | # check for user options 152 | if user_opts is not None: 153 | for key, value in user_opts.items(): 154 | opts[key] = value 155 | (training_generator, testing_generator) = util.data_generator.create_train_and_test_generators(opts) 156 | model = get_model_from_function(training_generator, show_summary=False) 157 | 158 | print('------------------------------------------------') 159 | print('-> Model and generators loaded') 160 | print('------------------------------------------------') 161 | if verbose > 0: 162 | print(model.summary()) 163 | 164 | (model, history) = train_model_on_generator(model, training_generator, testing_generator, loss_functions=loss_functions.copy(), loss_weights=loss_weights, reduce_lr=True, 165 | log_output=opts['log_output'], tensorboard_logfolder=tensorboard_logfolder, model_name=model_tmp_path, save_model_only=opts['save_model'], 166 | steps_per_epoch=opts['steps_per_epoch'], validation_steps=opts['validation_steps'], epochs=opts['epochs'], loss_metrics=opts['metrics']) 167 | # Save model and history 168 | if history: 169 | opts['history'] = history.history 170 | cv_model_path = model_path[0:-3] + '_' + str(cv_run) + '.h5' 171 | util.hdf5.save_model_with_opts(model, opts, cv_model_path) 172 | print('------------------------------------------------') 173 | print('-> Model_{} saved to {}'.format(cv_run, cv_model_path)) 174 | print('------------------------------------------------') 175 | hdf5_file.close() 176 | 177 | 178 | def run_from_path(path_in, loss_functions, loss_weights, user_opts=None, **args): 179 | """ 180 | Runs model training giving path to HDF5 file and loss dictionaries 181 | 182 | Parameters 183 | ---------- 184 | path_in : str 185 | Path to HDF5 186 | loss_functions : dict 187 | For each output the corresponding loss function 188 | loss_weights : dict 189 | For each output the corresponding weight 190 | """ 191 | dirname = os.path.dirname(path_in) 192 | filename = os.path.basename(path_in) 193 | # Define folders 194 | tensorboard_logfolder = dirname + '/logs/' + filename[0:-3] # Remove .h5 for logfolder 195 | model_tmp_path = dirname + '/models/tmp/tmp_model' 196 | model_path = dirname + '/models/' + filename[0:-3] + '_model.h5' 197 | # Create folders if needed 198 | for f in [os.path.dirname(model_tmp_path), os.path.dirname(model_path)]: 199 | if not os.path.exists(f): 200 | os.makedirs(f) 201 | print('------------------------------------------------') 202 | print('-> Running {} from {}'.format(filename, dirname)) 203 | print('- Logs : {} \n- Model temporary : {} \n- Model : {}'.format(tensorboard_logfolder, model_tmp_path, model_path)) 204 | print('------------------------------------------------') 205 | # Train model 206 | print('------------------------------------------------') 207 | print('Starting standard model') 208 | print('------------------------------------------------') 209 | train_model(model_path, path_in, tensorboard_logfolder, model_tmp_path, loss_functions, loss_weights, user_opts, **args) 210 | 211 | 212 | def get_model_from_function(training_generator, show_summary=True): 213 | model_function = getattr(architecture, training_generator.model_function) 214 | model = model_function(training_generator, show_summary=show_summary) 215 | 216 | return model 217 | -------------------------------------------------------------------------------- /deepinsight/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hdf5 2 | from . import tetrode 3 | from . import stats 4 | from . import custom_losses 5 | from . import wavelet_transform 6 | from . import opts 7 | from . import data_generator 8 | -------------------------------------------------------------------------------- /deepinsight/util/custom_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | from tensorflow.keras import backend as K 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | def euclidean_loss(y_true, y_pred): 13 | # We use tf.sqrt instead of K.sqrt as there is a bug in K.sqrt (as of March 14, 2018) 14 | res = tf.sqrt(K.sum(K.square(y_pred - y_true), axis=-1)) 15 | return res 16 | 17 | 18 | def cyclical_mae_rad(y_true, y_pred): 19 | return K.mean(K.minimum(K.abs(y_pred - y_true), K.minimum(K.abs(y_pred - y_true + 2*np.pi), K.abs(y_pred - y_true - 2*np.pi))), axis=-1) 20 | 21 | 22 | def mse(y_true, y_pred): 23 | return tf.keras.losses.MSE(y_true, y_pred) 24 | 25 | 26 | def mae(y_true, y_pred): 27 | return tf.keras.losses.MAE(y_true, y_pred) 28 | -------------------------------------------------------------------------------- /deepinsight/util/data_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import pickle 8 | import os 9 | import numpy as np 10 | from tensorflow.keras.utils import Sequence 11 | 12 | from . import hdf5 13 | 14 | 15 | def create_train_and_test_generators(opts): 16 | """ 17 | Creates training and test generators given opts dictionary 18 | 19 | Parameters 20 | ---------- 21 | opts : dict 22 | Dictionary holding options for data creation and model training 23 | 24 | Returns 25 | ------- 26 | training_generator : object 27 | Sequence class used for generating training data 28 | testing_generator : object 29 | Sequence class used for generating testing data 30 | """ 31 | # 1.) Create training generator 32 | training_generator = RawWaveletSequence(opts, training=True) 33 | # 2.) Create testing generator 34 | testing_generator = RawWaveletSequence(opts, training=False) 35 | # 3.) Assert that training and testing data are different 36 | 37 | return (training_generator, testing_generator) 38 | 39 | 40 | class RawWaveletSequence(Sequence): 41 | """ 42 | Data Generator class. Import functions are get_input_sample and get_output_sample. 43 | Each call to __getitem__ will yield a (input, output) pair 44 | 45 | Parameters 46 | ---------- 47 | Sequence : object 48 | Keras sequence 49 | 50 | Yields 51 | ------- 52 | input_sample : array_like 53 | Batched input for model training 54 | output_sample : array_like 55 | Batched output for model optimization 56 | """ 57 | 58 | def __init__(self, opts, training): 59 | # 1.) Set all options as attributes 60 | self.set_opts_as_attribute(opts) 61 | 62 | # 2.) Load data memmaped for mean/std estimation and fast plotting 63 | self.wavelets = hdf5.read_hdf_memmapped(self.fp_hdf_out, 'inputs/wavelets') 64 | 65 | # Get output(s) 66 | outputs = [] 67 | for key, value in opts['loss_functions'].items(): 68 | tmp_out = hdf5.read_hdf_memmapped(self.fp_hdf_out, 'outputs/' + key) 69 | outputs.append(tmp_out) 70 | self.outputs = outputs 71 | 72 | # 3.) Prepare for training 73 | self.training = training 74 | self.prepare_data_generator(training=training) 75 | 76 | def __len__(self): 77 | return len(self.cv_indices) 78 | 79 | def __getitem__(self, idx): 80 | # 1.) Define start and end index 81 | if self.shuffle: 82 | idx = np.random.choice(self.cv_indices) 83 | else: 84 | idx = self.cv_indices[idx] 85 | cut_range = np.arange(idx, idx + self.sample_size) 86 | 87 | # 2.) Above takes consecutive batches, implement random batching here 88 | if self.random_batches: 89 | indices = np.random.choice(self.cv_indices, size=self.batch_size) 90 | cut_range = [np.arange(start_index, start_index + self.model_timesteps) for start_index in indices] 91 | cut_range = np.array(cut_range) 92 | else: 93 | cut_range = np.reshape(cut_range, (self.batch_size, cut_range.shape[0] // self.batch_size)) 94 | 95 | # 3.) Get input sample 96 | input_sample = self.get_input_sample(cut_range) 97 | 98 | # 4.) Get output sample 99 | output_sample = self.get_output_sample(cut_range) 100 | 101 | return (input_sample, output_sample) 102 | 103 | def get_input_sample(self, cut_range): 104 | # 1.) Cut Ephys / fancy indexing for memmap is planned, if fixed use: cut_data = self.wavelets[cut_range, self.fourier_frequencies, self.channels] 105 | cut_data = self.wavelets[cut_range, :, :] 106 | cut_data = np.reshape(cut_data, (cut_data.shape[0] * cut_data.shape[1], cut_data.shape[2], cut_data.shape[3])) 107 | 108 | # 2.) Normalize input 109 | cut_data = (cut_data - self.est_mean) / self.est_std 110 | 111 | # 3.) Reshape for model input 112 | cut_data = np.reshape(cut_data, (self.batch_size, self.model_timesteps, cut_data.shape[1], cut_data.shape[2])) 113 | 114 | # 4.) Take care of optional settings 115 | cut_data = np.transpose(cut_data, axes=(0, 3, 1, 2)) 116 | cut_data = cut_data[..., np.newaxis] 117 | 118 | return cut_data 119 | 120 | def get_output_sample(self, cut_range): 121 | # 1.) Cut Ephys 122 | out_sample = [] 123 | for out in self.outputs: 124 | cut_data = out[cut_range, ...] 125 | cut_data = np.reshape(cut_data, (cut_data.shape[0] * cut_data.shape[1], cut_data.shape[2])) 126 | 127 | # 2.) Reshape for model output 128 | if len(cut_data.shape) is not self.batch_size: 129 | cut_data = np.reshape(cut_data, (self.batch_size, self.model_timesteps, cut_data.shape[1])) 130 | 131 | # 3.) Divide evenly and make sure last output is being decoded 132 | if self.average_output: 133 | cut_data = cut_data[:, np.arange(0, cut_data.shape[1] + 1, self.average_output)[1::] - 1] 134 | out_sample.append(cut_data) 135 | 136 | return out_sample 137 | 138 | def prepare_data_generator(self, training): 139 | # 1.) Define sample size and means 140 | self.sample_size = self.model_timesteps * self.batch_size 141 | 142 | if training: 143 | self.cv_indices = self.training_indices 144 | else: 145 | self.cv_indices = self.testing_indices 146 | 147 | # Make sure random choice takes from array not list 500x speedup 148 | self.cv_indices = np.array(self.cv_indices) 149 | 150 | # 9.) Calculate normalization for wavelets 151 | meanstd_path = os.path.dirname(self.fp_hdf_out) + '/models/tmp/' + os.path.basename(self.fp_hdf_out)[:-3] + '_meanstd_start{}_end{}_tstart{}_tend{}.p'.format( 152 | self.training_indices[0], self.training_indices[-1], self.testing_indices[0], self.testing_indices[-1]) 153 | 154 | if os.path.exists(meanstd_path): 155 | (self.est_mean, self.est_std) = pickle.load(open(meanstd_path, 'rb')) 156 | else: 157 | print('Calculating MAD normalization parameters') 158 | if len(self.training_indices) > 1e5: 159 | print('Downsampling wavelets for MAD calculation') 160 | self.est_mean = np.median(self.wavelets[self.training_indices[::100], :, :], axis=0) 161 | self.est_std = np.median(abs(self.wavelets[self.training_indices[::100], :, :] - self.est_mean), axis=0) 162 | else: 163 | self.est_mean = np.median(self.wavelets[self.training_indices, :, :], axis=0) 164 | self.est_std = np.median(abs(self.wavelets[self.training_indices, :, :] - self.est_mean), axis=0) 165 | pickle.dump((self.est_mean, self.est_std), open(meanstd_path, 'wb')) 166 | 167 | # Make sure indices contain no NaN values 168 | if self.handle_nan: 169 | self.cv_indices = self.check_for_nan() 170 | 171 | # 10.) Define output shape. Most robust way is to get a dummy input and take that shape as output shape 172 | (dummy_input, dummy_output) = self.__getitem__(0) 173 | # Corresponds to the output of this generator, aka input to model. Also remove batch shape, 174 | self.input_shape = dummy_input.shape[1:] 175 | 176 | def set_opts_as_attribute(self, opts): 177 | for k, v in opts.items(): 178 | setattr(self, k, v) 179 | 180 | def get_name(self): 181 | name = "" 182 | for attr in self.important_attributes: 183 | name += attr + ':{},'.format(getattr(self, attr)) 184 | return name[:-1] 185 | 186 | def check_for_nan(self): 187 | new_cv_indices, len_before = [], len(self.cv_indices) 188 | for idx, cv in enumerate(self.cv_indices): 189 | if not idx % 100000: 190 | print('{} / {}'.format(cv, self.cv_indices[-1]), end='\r') 191 | cut_range = np.arange(cv, cv + self.sample_size) 192 | cut_range = np.reshape(cut_range, (self.batch_size, cut_range.shape[0] // self.batch_size)) 193 | out_sample = self.get_output_sample(cut_range) 194 | nan_in_out = any([any(np.isnan(x.flatten())) for x in out_sample[0]]) 195 | if not nan_in_out: 196 | new_cv_indices.append(cv) 197 | print('Len before {}, after {} --- Diff {}'.format(len_before, len(new_cv_indices), len_before - len(new_cv_indices))) 198 | return np.array(new_cv_indices) 199 | -------------------------------------------------------------------------------- /deepinsight/util/hdf5.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import h5py 8 | import numpy as np 9 | from . import data_generator 10 | 11 | 12 | def create_or_update(hdf5_file, dataset_name, dataset_shape, dataset_type, dataset_value): 13 | """ 14 | Create or update dataset in HDF5 file 15 | 16 | Parameters 17 | ---------- 18 | hdf5_file : File 19 | File identifier 20 | dataset_name : str 21 | Name of new dataset 22 | dataset_shape : array_like 23 | Shape of new dataset 24 | dataset_type : type 25 | Type of dataset (np.float16, np.float32, 'S', etc...) 26 | dataset_value : array_like 27 | Data to store in HDF5 file 28 | """ 29 | if not dataset_name in hdf5_file: 30 | hdf5_file.create_dataset(dataset_name, dataset_shape, dataset_type) 31 | hdf5_file[dataset_name][:] = dataset_value 32 | else: 33 | if hdf5_file[dataset_name].shape != dataset_shape: 34 | del hdf5_file[dataset_name] 35 | hdf5_file.create_dataset(dataset_name, dataset_shape, dataset_type) 36 | hdf5_file[dataset_name][:] = dataset_value 37 | hdf5_file.flush() 38 | 39 | 40 | def save_model_with_opts(model, opts, file_name): 41 | """ 42 | Saves Keras model and training options to HDF5 file 43 | Uses Keras save_weights for creating the model HDF5 file and then inserts into that 44 | 45 | Parameters 46 | ---------- 47 | model : object 48 | Keras model 49 | opts : dict 50 | Dictionary used for training the model 51 | file_name : str 52 | Path to save to 53 | """ 54 | model.save_weights(file_name) 55 | hdf5_file = h5py.File(file_name, mode='a') 56 | hdf5_file['opts'] = str(opts) 57 | hdf5_file.flush() 58 | hdf5_file.close() 59 | 60 | 61 | def load_model_with_opts(file_name): 62 | """ 63 | Load Keras model and training options from HDF5 file 64 | TODO: Remove eval and find better way of storing dict in HDF5 (hickle, pytables, etc...) 65 | 66 | Parameters 67 | ---------- 68 | file_name : str 69 | Model path 70 | 71 | Returns 72 | ------- 73 | model : object 74 | Keras model 75 | training_generator : object 76 | Datagenerator used to create training samples on the fly 77 | testing_generator : object 78 | Datagenerator used to create testing samples on the fly 79 | opts : dict 80 | Dictionary used for training the model 81 | """ 82 | from .. import train 83 | # Get options from dictionary, stored as str in HDF5 (not recommended, TODO) 84 | hdf5_file = h5py.File(file_name, mode='r') 85 | opts = eval(hdf5_file['opts'][()]) 86 | opts['handle_nan'] = False 87 | hdf5_file.close() 88 | 89 | # Use options to create data generators and model weights 90 | (training_generator, testing_generator) = data_generator.create_train_and_test_generators(opts) 91 | 92 | model = train.get_model_from_function(training_generator, show_summary=False) 93 | model = train.train_model_on_generator(model, training_generator, testing_generator, 94 | loss_functions=opts['loss_functions'], loss_weights=opts['loss_weights'], compile_only=True) 95 | model.load_weights(file_name) 96 | 97 | return (model, training_generator, testing_generator, opts) 98 | 99 | 100 | def read_hdf_memmapped(fn_hdf, hdf_group): 101 | """ 102 | Reads the hdf file as a numpy memmapped file, makes slicing a bit faster 103 | (From https://gist.github.com/rossant/7b4704e8caeb8f173084) 104 | 105 | Parameters 106 | ---------- 107 | fn_hdf : str 108 | Path to preprocessed HDF5 109 | hdf_group : str 110 | Group to read from HDF5 111 | 112 | Returns 113 | ------- 114 | data : array_like 115 | Data as a memory mapped array 116 | """ 117 | # Define function for memmapping 118 | def _mmap_h5(path, h5path): 119 | with h5py.File(path, mode='r') as f: 120 | ds = f[h5path] 121 | # We get the dataset address in the HDF5 fiel. 122 | offset = ds.id.get_offset() 123 | # We ensure we have a non-compressed contiguous array. 124 | assert ds.chunks is None 125 | assert ds.compression is None 126 | assert offset > 0 127 | dtype = ds.dtype 128 | shape = ds.shape 129 | arr = np.memmap(path, mode='r', shape=shape, 130 | offset=offset, dtype=dtype) 131 | return arr 132 | # Load data 133 | data = _mmap_h5(fn_hdf, hdf_group) 134 | 135 | return data 136 | -------------------------------------------------------------------------------- /deepinsight/util/opts.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | 8 | 9 | def get_opts(fp_hdf_out, train_test_times): 10 | """ 11 | Returns the options dictionary which contains all parameters needed to 12 | create DataGenerator and train the model 13 | TODO Find better method of parameter storing (config files, store in HDF5, etc...) 14 | 15 | Parameters 16 | ---------- 17 | fp_hdf_out : str 18 | File path to HDF5 file 19 | train_test_times : array_like 20 | Indices for training and testing generator 21 | 22 | Returns 23 | ------- 24 | opts : dict 25 | Dictionary containing all model and training parameters 26 | """ 27 | opts = dict() 28 | # -------- DATA ------------------------ 29 | opts['fp_hdf_out'] = fp_hdf_out # Filepath for hdf5 file storing wavelets and outputs 30 | opts['sampling_rate'] = 30 # Sampling rate of the wavelets 31 | opts['training_indices'] = train_test_times[0].tolist() # Indices into wavelets used for training the model, adjusted during CV 32 | opts['testing_indices'] = train_test_times[1].tolist() # Indices into wavelets used for testing the model, adjusted during CV 33 | 34 | # -------- MODEL PARAMETERS -------------- 35 | opts['model_function'] = 'the_decoder' # Model architecture used 36 | opts['model_timesteps'] = 64 # How many timesteps are used in the input layer, e.g. a sampling rate of 30 will yield 2.13s windows. Has to be divisible X times by 2. X='num_convs_tsr' 37 | opts['num_convs_tsr'] = 4 # Number of downsampling steps within the model, e.g. with model_timesteps=64, it will downsample 64->32->16->8->4 and output 4 timesteps 38 | opts['average_output'] = 2**opts['num_convs_tsr'] # Whats the ratio between input and output shape 39 | opts['channel_lower_limit'] = 2 40 | 41 | opts['optimizer'] = 'adam' # Learning algorithm 42 | opts['learning_rate'] = 0.0007 # Learning rate 43 | opts['kernel_size'] = 3 # Kernel size for all convolutional layers 44 | opts['conv_padding'] = 'same' # Which padding should be used for the convolutional layers 45 | opts['act_conv'] = 'elu' # Activation function for convolutional layers 46 | opts['act_fc'] = 'elu' # Activation function for fully connected layers 47 | opts['dropout_ratio'] = 0 # Dropout ratio for fully connected layers 48 | opts['filter_size'] = 64 # Number of filters in convolutional layers 49 | opts['num_units_dense'] = 1024 # Number of units in fully connected layer 50 | opts['num_dense'] = 2 # Number of fully connected layers 51 | opts['gaussian_noise'] = 1 # How much gaussian noise is added (unit = standard deviation) 52 | 53 | # -------- TRAINING---------------------- 54 | opts['batch_size'] = 8 # Batch size used for training the model 55 | opts['steps_per_epoch'] = 250 # Number of steps per training epoch 56 | opts['validation_steps'] = 250 # Number of steps per validation epoch 57 | opts['epochs'] = 20 # Number of epochs 58 | opts['shuffle'] = True # If input should be shuffled 59 | opts['random_batches'] = True # If random batches in time are used 60 | opts['metrics'] = [] 61 | opts['last_layer_activation_function'] = 'linear' 62 | opts['handle_nan'] = False 63 | 64 | # -------- MISC--------------- ------------ 65 | opts['tensorboard_logfolder'] = './' # Logfolder for tensorboard 66 | opts['model_folder'] = './' # Folder for saving the model 67 | opts['log_output'] = False # If output should be logged 68 | opts['save_model'] = False # If model should be saved 69 | 70 | return opts 71 | -------------------------------------------------------------------------------- /deepinsight/util/stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import numpy as np 8 | 9 | 10 | def calculate_speed_from_position(positions, interval, smoothing=False): 11 | """ 12 | Calculate speed from X,Y coordinates 13 | 14 | Parameters 15 | ---------- 16 | positions : (N, 2) array_like 17 | N samples of observations, containing X and Y coordinates 18 | interval : int 19 | Duration between observations (in s, equal to 1 / sr) 20 | smoothing : bool or int, optional 21 | If speeds should be smoothed, by default False/0 22 | 23 | Returns 24 | ------- 25 | speed : (N, 1) array_like 26 | Instantenous speed of the animal 27 | """ 28 | X, Y = positions[:, 0], positions[:, 1] 29 | # Smooth diffs instead of speeds directly 30 | Xdiff = np.diff(X) 31 | Ydiff = np.diff(Y) 32 | if smoothing: 33 | Xdiff = smooth_signal(Xdiff, smoothing) 34 | Ydiff = smooth_signal(Ydiff, smoothing) 35 | speed = np.sqrt(Xdiff**2 + Ydiff**2) / interval 36 | speed = np.append(speed, speed[-1]) 37 | 38 | return speed 39 | 40 | 41 | def calculate_heading_direction_from_position(positions, smoothing=False, return_as_deg=False): 42 | """ 43 | Calculates heading direction based on X and Y coordinates. With one measurement we can only calculate heading direction 44 | 45 | Parameters 46 | ---------- 47 | positions : (N, 2) array_like 48 | N samples of observations, containing X and Y coordinates 49 | smoothing : bool or int, optional 50 | If speeds should be smoothed, by default False/0 51 | return_as_deg : bool 52 | Return heading in radians or degree 53 | 54 | Returns 55 | ------- 56 | heading_direction : (N, 1) array_like 57 | Heading direction of the animal 58 | """ 59 | X, Y = positions[:, 0], positions[:, 1] 60 | # Smooth diffs instead of speeds directly 61 | Xdiff = np.diff(X) 62 | Ydiff = np.diff(Y) 63 | if smoothing: 64 | Xdiff = smooth_signal(Xdiff, smoothing) 65 | Ydiff = smooth_signal(Ydiff, smoothing) 66 | # Calculate heading direction 67 | heading_direction = np.arctan2(Ydiff, Xdiff) 68 | heading_direction = np.append(heading_direction, heading_direction[-1]) 69 | if return_as_deg: 70 | heading_direction = heading_direction * (180 / np.pi) 71 | 72 | return heading_direction 73 | 74 | 75 | def calculate_head_direction_from_leds(positions, return_as_deg=False): 76 | """ 77 | Calculates head direction based on X and Y coordinates with two LEDs. 78 | 79 | Parameters 80 | ---------- 81 | positions : (N, 2) array_like 82 | N samples of observations, containing X and Y coordinates 83 | return_as_deg : bool 84 | Return heading in radians or degree 85 | 86 | Returns 87 | ------- 88 | head_direction : (N, 1) array_like 89 | Head direction of the animal 90 | """ 91 | X_led1, Y_led1, X_led2, Y_led2 = positions[:, 0], positions[:, 1], positions[:, 2], positions[:, 3] 92 | # Calculate head direction 93 | head_direction = np.arctan2(X_led1 - X_led2, Y_led1 - Y_led2) 94 | # Put in right perspective in relation to the environment 95 | offset = +np.pi/2 96 | head_direction = (head_direction + offset + np.pi) % (2*np.pi) - np.pi 97 | head_direction *= -1 98 | if return_as_deg: 99 | head_direction = head_direction * (180 / np.pi) 100 | 101 | return head_direction 102 | 103 | 104 | def smooth_signal(signal, N): 105 | """ 106 | Simple smoothing by convolving a filter with 1/N. 107 | 108 | Parameters 109 | ---------- 110 | signal : array_like 111 | Signal to be smoothed 112 | N : int 113 | smoothing_factor 114 | 115 | Returns 116 | ------- 117 | signal : array_like 118 | Smoothed signal 119 | """ 120 | # Preprocess edges 121 | signal = np.concatenate([signal[0:N], signal, signal[-N:]]) 122 | # Convolve 123 | signal = np.convolve(signal, np.ones((N,))/N, mode='same') 124 | # Postprocess edges 125 | signal = signal[N:-N] 126 | 127 | return signal 128 | -------------------------------------------------------------------------------- /deepinsight/util/tetrode.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import numpy as np 8 | import pandas as pd 9 | import h5py 10 | 11 | from . import hdf5 12 | from . import stats 13 | 14 | 15 | def read_open_ephys(fp_raw_file): 16 | """ 17 | Reads ST open ephys files 18 | 19 | Parameters 20 | ---------- 21 | fp_raw_file : str 22 | File path to open ephys file 23 | 24 | Returns 25 | ------- 26 | continouos : (N,M) array_like 27 | Continous ephys with N timepoints and M channels 28 | timestamps : (N,1) array_like 29 | Timestamps for each sample in continous 30 | positions : (N,5) array_like 31 | Position of animal with two LEDs and timestamps 32 | info : object 33 | Additional information about experiments 34 | """ 35 | fid_ephys = h5py.File(fp_raw_file, mode='r') 36 | 37 | # Load timestamps and continuous data, python 3 keys() returns view 38 | recording_key = list(fid_ephys['acquisition']['timeseries'].keys())[0] 39 | processor_key = list(fid_ephys['acquisition']['timeseries'][recording_key]['continuous'].keys())[0] 40 | 41 | # Load raw ephys and timestamps 42 | # not converted to microvolts, need to multiply by 0.195. We don't multiply here as we cant load full array into memory 43 | continuous = fid_ephys['acquisition']['timeseries'][recording_key]['continuous'][processor_key]['data'] 44 | timestamps = fid_ephys['acquisition']['timeseries'][recording_key]['continuous'][processor_key]['timestamps'] 45 | 46 | # We can also read position directly from the raw file 47 | positions = fid_ephys['acquisition']['timeseries'][recording_key]['tracking']['ProcessedPos'] 48 | 49 | # Read general settings 50 | info = fid_ephys['general']['data_collection']['Settings'] 51 | 52 | return (continuous, timestamps, positions, info) 53 | 54 | 55 | def read_tetrode_data(fp_raw_file): 56 | """ 57 | Read ST data from openEphys recording system 58 | 59 | Parameters 60 | ---------- 61 | fp_raw_file : str 62 | File path to open ephys file 63 | 64 | Returns 65 | ------- 66 | raw_data : (N,M) array_like 67 | Continous ephys with N timepoints and M channels 68 | raw_timestamps : (N,1) array_like 69 | Timestamps for each sample in continous 70 | output : (N,4) array_like 71 | Position of animal with two LEDs 72 | output_timestamps : (N,1) array_like 73 | Timestamps for positions 74 | info : object 75 | Additional information about experiments 76 | """ 77 | (raw_data, raw_timestamps, positions, info) = read_open_ephys(fp_raw_file) 78 | output_timestamps = positions[:, 0] 79 | output = positions[:, 1:5] 80 | bad_channels = info['General']['badChan'] 81 | bad_channels = [int(n) for n in bad_channels[()].decode('UTF-8').split(',')] 82 | good_channels = np.delete(np.arange(0, 128), bad_channels) 83 | info = {'channels': good_channels, 'bad_channels': bad_channels, 'sampling_rate': 30000} 84 | 85 | return (raw_data, raw_timestamps, output, output_timestamps, info) 86 | 87 | 88 | def preprocess_output(fp_hdf_out, raw_timestamps, output, output_timestamps, average_window=1000, sampling_rate=30000): 89 | """ 90 | Write behaviours to decode into HDF5 file 91 | 92 | Parameters 93 | ---------- 94 | fp_hdf_out : str 95 | File path to HDF5 file 96 | raw_timestamps : (N,1) array_like 97 | Timestamps for each sample in continous 98 | output : (N,4) array_like 99 | Position of animal with two LEDs 100 | output_timestamps : (N,1) array_like 101 | Timestamps for positions 102 | average_window : int, optional 103 | Downsampling factor for raw data and positions, by default 1000 104 | sampling_rate : int, optional 105 | Sampling rate of raw ephys, by default 30000 106 | """ 107 | hdf5_file = h5py.File(fp_hdf_out, mode='a') 108 | 109 | # Get size of wavelets 110 | input_length = hdf5_file['inputs/wavelets'].shape[0] 111 | 112 | # Get positions of both LEDs 113 | raw_timestamps = raw_timestamps[()] # Slightly faster than np.array 114 | output_x_led1 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0], 115 | average_window)], output_timestamps, output[:, 0]) 116 | output_y_led1 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0], 117 | average_window)], output_timestamps, output[:, 1]) 118 | output_x_led2 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0], 119 | average_window)], output_timestamps, output[:, 2]) 120 | output_y_led2 = np.interp(raw_timestamps[np.arange(0, raw_timestamps.shape[0], 121 | average_window)], output_timestamps, output[:, 3]) 122 | raw_positions = np.array([output_x_led1, output_y_led1, output_x_led2, output_y_led2]).transpose() 123 | 124 | # Clean raw_positions and get centre 125 | positions_smooth = pd.DataFrame(raw_positions.copy()).interpolate( 126 | limit_direction='both').rolling(5, min_periods=1).mean().values 127 | position = np.array([(positions_smooth[:, 0] + positions_smooth[:, 2]) / 2, 128 | (positions_smooth[:, 1] + positions_smooth[:, 3]) / 2]).transpose() 129 | 130 | # Also get head direction and speed from positions 131 | speed = stats.calculate_speed_from_position(position, interval=1/(sampling_rate//average_window), smoothing=3) 132 | head_direction = stats.calculate_head_direction_from_leds(positions_smooth, return_as_deg=False) 133 | 134 | # Create and save datasets in HDF5 File 135 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/raw_position", 136 | dataset_shape=[input_length, 4], dataset_type=np.float16, dataset_value=raw_positions[0: input_length, :]) 137 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/position", 138 | dataset_shape=[input_length, 2], dataset_type=np.float16, dataset_value=position[0: input_length, :]) 139 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/head_direction", dataset_shape=[ 140 | input_length, 1], dataset_type=np.float16, dataset_value=head_direction[0: input_length, np.newaxis]) 141 | hdf5.create_or_update(hdf5_file, dataset_name="outputs/speed", 142 | dataset_shape=[input_length, 1], dataset_type=np.float16, dataset_value=speed[0: input_length, np.newaxis]) 143 | hdf5_file.flush() 144 | hdf5_file.close() 145 | -------------------------------------------------------------------------------- /deepinsight/util/wavelet_transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | from wavelets import WaveletAnalysis 8 | import numpy as np 9 | 10 | 11 | def wavelet_transform(signal, sampling_rate, average_window=1000, scaling_factor=0.25, wave_highpass=2, wave_lowpass=30000): 12 | """ 13 | Calculates the wavelet transform for each point in signal, then averages 14 | each window and returns together fourier frequencies 15 | 16 | Parameters 17 | ---------- 18 | signal : (N,1) array_like 19 | Signal to be transformed 20 | sampling_rate : int 21 | Sampling rate of signal 22 | average_window : int, optional 23 | Average window to downsample wavelet transformed input, by default 1000 24 | scaling_factor : float, optional 25 | Determines amount of log-spaced frequencies M in output, by default 0.25 26 | wave_highpass : int, optional 27 | Cut of frequencies below, by default 2 28 | wave_lowpass : int, optional 29 | Cut of frequencies above, by default 30000 30 | 31 | Returns 32 | ------- 33 | wavelet_power : (N, M) array_like 34 | Wavelet transformed signal 35 | wavelet_frequencies : (M, 1) array_like 36 | Corresponding frequencies to wavelet_power 37 | """ 38 | (wavelet_power, wavelet_frequencies, wavelet_obj) = simple_wavelet_transform(signal, sampling_rate, 39 | scaling_factor=scaling_factor, wave_highpass=wave_highpass, wave_lowpass=wave_lowpass) 40 | 41 | # Average over window 42 | if average_window is not 1: 43 | wavelet_power = np.reshape( 44 | wavelet_power, (wavelet_power.shape[0], wavelet_power.shape[1] // average_window, average_window)) 45 | wavelet_power = np.mean(wavelet_power, axis=2).transpose() 46 | else: 47 | wavelet_power = wavelet_power.transpose() 48 | 49 | return wavelet_power, wavelet_frequencies 50 | 51 | 52 | def simple_wavelet_transform(signal, sampling_rate, scaling_factor=0.25, wave_lowpass=None, wave_highpass=None): 53 | """ 54 | Simple wavelet transformation of signal 55 | 56 | Parameters 57 | ---------- 58 | signal : (N,1) array_like 59 | Signal to be transformed 60 | sampling_rate : int 61 | Sampling rate of signal 62 | scaling_factor : float, optional 63 | Determines amount of log-space frequencies M in output, by default 0.25 64 | wave_highpass : int, optional 65 | Cut of frequencies below, by default 2 66 | wave_lowpass : int, optional 67 | Cut of frequencies above, by default 30000 68 | 69 | Returns 70 | ------- 71 | wavelet_power : (N, M) array_like 72 | Wavelet transformed signal 73 | wavelet_frequencies : (M, 1) array_like 74 | Corresponding frequencies to wavelet_power 75 | wavelet_obj : object 76 | WaveletTransform Object 77 | """ 78 | wavelet_obj = WaveletAnalysis(signal, dt=1 / sampling_rate, dj=scaling_factor) 79 | wavelet_power = wavelet_obj.wavelet_power 80 | wavelet_frequencies = wavelet_obj.fourier_frequencies 81 | 82 | if wave_lowpass or wave_highpass: 83 | wavelet_power = wavelet_power[(wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass), :] 84 | wavelet_frequencies = wavelet_frequencies[(wavelet_frequencies < wave_lowpass) & (wavelet_frequencies > wave_highpass)] 85 | 86 | return (wavelet_power, wavelet_frequencies, wavelet_obj) 87 | -------------------------------------------------------------------------------- /deepinsight/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | import numpy as np 8 | import h5py 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | sns.set_style('white') 13 | 14 | 15 | def plot_residuals(fp_hdf_out, output_names, losses=None, shuffled_losses=None, aggregator=np.mean, frequency_spacing=1, offset=0): 16 | """ 17 | Plots influence plots for each output 18 | 19 | Parameters 20 | ---------- 21 | fp_hdf_out : str 22 | File path to HDF5 file 23 | aggregator : function handle, optional 24 | Which aggregator to use for plotting the lineplots, by default np.mean 25 | frequency_spacing : int, optional 26 | Spacing on x axis between frequencies, by default 1 27 | """ 28 | # Read data from HDF5 file 29 | hdf5_file = h5py.File(fp_hdf_out, mode='r') 30 | if losses is None: 31 | losses = hdf5_file["analysis/losses"][()] 32 | if shuffled_losses is None: 33 | shuffled_losses = hdf5_file["analysis/influence/shuffled_losses"][()] 34 | frequencies = hdf5_file["inputs/fourier_frequencies"][()].astype(np.float32) 35 | hdf5_file.close() 36 | 37 | # Calculate residuals, make sure there is no division by zero by adding small constant. TODO Should be relative to loss and only if needed 38 | residuals = (shuffled_losses - losses) / (losses + offset) 39 | 40 | # Plot 41 | fig, axes = plt.subplots(len(output_names), 1, figsize=(16, 8)) 42 | if len(output_names) > 1: 43 | axes = axes.flatten() 44 | else: 45 | axes = [axes] 46 | for all_residuals, ax, on in zip(residuals.transpose(), axes, output_names): 47 | residuals_mean = np.mean(all_residuals, axis=0) 48 | all_residuals = all_residuals / np.sum(residuals_mean) 49 | df_to_plot = pd.DataFrame(all_residuals).melt() 50 | sns.lineplot(x="variable", y="value", data=df_to_plot, ax=ax, estimator=aggregator, ci=68, marker='o', 51 | color='k').set(xlabel='Frequencies (Hz)', ylabel='Frequency Influence (%)') 52 | ax.set_xticks(np.arange(0, len(frequencies), frequency_spacing)) 53 | ax.set_xticklabels(np.round(frequencies[0::frequency_spacing], 2), fontsize=8, rotation=45) 54 | ax.set_title(on) 55 | for ax in axes: 56 | ax.invert_xaxis() 57 | sns.despine() 58 | fig.tight_layout() 59 | fig.show() 60 | -------------------------------------------------------------------------------- /media/colab_walkthrough.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/colab_walkthrough.gif -------------------------------------------------------------------------------- /media/decoding_error.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/decoding_error.gif -------------------------------------------------------------------------------- /media/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/media/model_architecture.png -------------------------------------------------------------------------------- /notebooks/deepinsight_calcium_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "pn9vUyWfZTX6", 8 | "slideshow": { 9 | "slide_type": "slide" 10 | } 11 | }, 12 | "source": [ 13 | "---\n", 14 | "# **Introduction to DeepInsight - Decoding position from two-photon calcium recordings**\n", 15 | "---\n", 16 | "\n", 17 | "This notebook stands as an example of how to use DeepInsight on calcium data and can be used as a guide on how to adapt it to your own datasets. All methods are stored in the deepinsight library and can be called directly or in their respective submodules. A typical workflow might look like the following: \n", 18 | "\n", 19 | "- Load your dataset into a format which can be directly indexed (numpy array or pointer to a file on disk)\n", 20 | "- Preprocess the raw data (wavelet transformation)\n", 21 | "- Preprocess your outputs (the variable you want to decode)\n", 22 | "- Define appropriate loss functions for your output and train the model \n", 23 | "- Predict performance across all cross validated models\n", 24 | "- Visualize influence of different input frequencies on model output\n", 25 | "\n", 26 | "We use the calcium dataset here as it has lower sampling rate and is therefore faster to preprocess and train, which makes it suitable to also run the preprocessing in a Colab notebook.\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "colab_type": "text", 33 | "id": "9iwZvplEoO70", 34 | "slideshow": { 35 | "slide_type": "subslide" 36 | } 37 | }, 38 | "source": [ 39 | "---\n", 40 | "## **Install and import DeepInsight**\n", 41 | "---\n", 42 | "Make sure you are using a **GPU runtime** if you want to train your own models. Go to Runtime -> Change Runtime type to change from CPU to GPU.\n", 43 | "You can check the GPU which is used in Colab by running !nvidia-smi in a new cell " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "colab": { 51 | "base_uri": "https://localhost:8080/", 52 | "height": 51 53 | }, 54 | "colab_type": "code", 55 | "id": "Uguw1SjlZLRX", 56 | "outputId": "fa22ce71-f3ff-4ff7-ac55-491d7011f2e0", 57 | "slideshow": { 58 | "slide_type": "subslide" 59 | } 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "# Import DeepInsight\n", 64 | "import sys\n", 65 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n", 66 | "import deepinsight\n", 67 | "\n", 68 | "# Other imports\n", 69 | "import os\n", 70 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 71 | "import h5py\n", 72 | "import numpy as np\n", 73 | "import pandas as pd\n", 74 | "from scipy.io import loadmat\n", 75 | "import plotly.graph_objs as go\n", 76 | "from skimage import io\n", 77 | "\n", 78 | "# Initialize plotly figures\n", 79 | "from plotly.offline import init_notebook_mode \n", 80 | "init_notebook_mode(connected = True)\n", 81 | "\n", 82 | "# Make sure the output width is adjusted for better export as HTML\n", 83 | "from IPython.core.display import display, HTML\n", 84 | "display(HTML(\"\"))\n", 85 | "display(HTML(\"\"))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "colab_type": "text", 92 | "id": "INZc18eRZTYL", 93 | "slideshow": { 94 | "slide_type": "slide" 95 | } 96 | }, 97 | "source": [ 98 | "---\n", 99 | "## **Load and preprocess your data**\n", 100 | "---\n", 101 | "For this example we provide two-photon calcium imaging data from a mouse in a virtual environment. Calcium traces together with the variable of interest is stored in one .mat file. You can load it from whatever datasource you want, just make sure that the dimensions match. \n", 102 | "\n", 103 | "The input to the model in the form of (Timepoints x Number of Cells) is stored in `raw_data`\n", 104 | "\n", 105 | "The output to be decoded is in the form of (Timepoints x 1) and is stored in `output` together with timestamps for the output in `raw_timestamps`\n", 106 | "\n", 107 | "---\n", 108 | "\n", 109 | "**Run the next cells if you want to load the example data and preprocess it. You can also skip to 'Preprocess Data' to just load the preprocessed file and directly train the model.** " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "colab": { 117 | "base_uri": "https://localhost:8080/", 118 | "height": 289 119 | }, 120 | "colab_type": "code", 121 | "id": "Q1GNMR41iqBf", 122 | "outputId": "42008674-03b7-4837-dbe8-372a89a36e03", 123 | "slideshow": { 124 | "slide_type": "subslide" 125 | } 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "base_path = './example_data/calcium/'\n", 130 | "fp_raw_file = base_path + 'traces_M1336.mat'\n", 131 | "if not os.path.exists(base_path):\n", 132 | " os.makedirs(base_path)\n", 133 | "if not os.path.exists(fp_raw_file): # Careful as next command is a colab command where parameters had to be hard coded. Keep in mind if changing fp_raw_file\n", 134 | " !wget https://ndownloader.figshare.com/files/24024683 -O ./example_data/calcium/traces_M1336.mat" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": { 141 | "colab": { 142 | "base_uri": "https://localhost:8080/", 143 | "height": 34 144 | }, 145 | "colab_type": "code", 146 | "id": "uCT9WQ56wTia", 147 | "outputId": "3a530bab-7aee-4675-b062-ba656c7f2c8b", 148 | "slideshow": { 149 | "slide_type": "subslide" 150 | } 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "# Set base variables\n", 155 | "sampling_rate = 30 # Might also be stored in above mat file for easier access\n", 156 | "channels = np.arange(0, 100) # For this recording channels corresponds to cells. We only use the first 100 cells to speed up preprocessing (Change this if you run it on your own dataset)\n", 157 | "\n", 158 | "# Also define Paths to access downloaded files\n", 159 | "base_path = './example_data/calcium/'\n", 160 | "fp_raw_file = base_path + 'traces_M1336.mat' # This is an example dataset containing calcium traces and linear position in a virtual track\n", 161 | "fp_deepinsight = base_path + 'processed_M1336.h5' # This will be the processed HDF5 file\n", 162 | "\n", 163 | "# Load data from mat file\n", 164 | "calcium_data = loadmat(fp_raw_file)['dataSave']\n", 165 | "raw_data = np.squeeze(calcium_data['df_f'][0][0])\n", 166 | "raw_timestamps = np.arange(0, raw_data.shape[0]) / sampling_rate\n", 167 | "output = np.squeeze(calcium_data['pos_dat'][0][0])\n", 168 | "\n", 169 | "print('Data loaded. Calcium traces: {}, Decoding target {}'.format(raw_data.shape, output.shape))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "colab_type": "text", 176 | "id": "fNbPYTuf79de", 177 | "slideshow": { 178 | "slide_type": "slide" 179 | } 180 | }, 181 | "source": [ 182 | "---\n", 183 | "### Plot example calcium traces\n", 184 | "---\n", 185 | "To give a visual impression of the input to our model we can now plot calcium traces for a bunch of different cells. " 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": { 192 | "colab": { 193 | "base_uri": "https://localhost:8080/", 194 | "height": 542 195 | }, 196 | "colab_type": "code", 197 | "id": "s7LG4uj8ywFw", 198 | "outputId": "05677e37-cfd3-4ea9-ff8a-f6c1107f5330", 199 | "slideshow": { 200 | "slide_type": "subslide" 201 | } 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "end_point, y_offset, num_cells = 10000, 400, 6\n", 206 | "fig = go.Figure()\n", 207 | "for i in range(0, num_cells):\n", 208 | " fig.add_trace(go.Scatter(x=np.arange(0, end_point) / sampling_rate, y=raw_data[0:end_point, i] + (i * y_offset), line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Cell {}'.format(i+1)))\n", 209 | "# aesthetics\n", 210 | "fig.update_yaxes(visible=False)\n", 211 | "fig.update_layout(showlegend=False,plot_bgcolor=\"white\",width=1800, height=650,margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))\n", 212 | "fig.show()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "colab_type": "text", 219 | "id": "9ubJSXsC_yh9", 220 | "slideshow": { 221 | "slide_type": "slide" 222 | } 223 | }, 224 | "source": [ 225 | "---\n", 226 | "### Preprocess data \n", 227 | "---" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "colab": { 235 | "base_uri": "https://localhost:8080/", 236 | "height": 156 237 | }, 238 | "colab_type": "code", 239 | "id": "O61lbp1TZTYM", 240 | "outputId": "6aad0856-0424-4954-e985-1c3f8f672ff7", 241 | "slideshow": { 242 | "slide_type": "subslide" 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "if not os.path.exists(fp_deepinsight):\n", 248 | " if os.path.exists(fp_raw_file): # Only do this if user downloaded raw files otherwise download preprocessed hdf5 file\n", 249 | " # Process output for use as decoding target\n", 250 | " # As the mouse is running on a virtual linear track we have a circular variable. We can solve this by either:\n", 251 | " # (1) Using a circular loss function or \n", 252 | " # (2) Using the sin and cos of the variable\n", 253 | " # For this dataset we choose method (2), see the loss calculation for head directionality on CA1 recordings for an example of (1)\n", 254 | " output = (output - np.nanmin(output)) / (np.nanmax(output) - np.nanmin(output))\n", 255 | " output = (output * 2*np.pi) - np.pi # Scaled to -pi / pi\n", 256 | " output = np.squeeze(np.column_stack([np.sin(output), np.cos(output)]))\n", 257 | " output = pd.DataFrame(output).ffill().bfill().values # Get rid of NaNs\n", 258 | " output_timestamps = raw_timestamps # In this recording timestamps are the same for output and raw_data, meaning they are already aligned to each other\n", 259 | "\n", 260 | " # Transform raw data to frequency domain\n", 261 | " # We use a small cutoff (1/500) for the low frequencies to keep the dimensions low & the model training fast\n", 262 | " deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=sampling_rate, average_window=1, wave_highpass=1/500, wave_lowpass=sampling_rate, channels=channels) \n", 263 | " # # Prepare outputs\n", 264 | " deepinsight.preprocess.preprocess_output(fp_deepinsight, raw_timestamps, output, output_timestamps, average_window=1, dataset_name='sin_cos')\n", 265 | " else:\n", 266 | " if not os.path.exists(base_path):\n", 267 | " os.makedirs(base_path)\n", 268 | " if not os.path.exists(fp_deepinsight):\n", 269 | " !wget https://ndownloader.figshare.com/files/23658674 -O ./example_data/calcium/processed_M1336.h5" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": { 275 | "colab_type": "text", 276 | "id": "PEjBzYQYo9d9", 277 | "slideshow": { 278 | "slide_type": "slide" 279 | } 280 | }, 281 | "source": [ 282 | "---\n", 283 | "### Plot preprocessed data\n", 284 | "---\n", 285 | "We plot examples to double check the wavelet preprocessing. Each plot shows the wavelet processed calcium traces for one cell" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": { 292 | "colab": {}, 293 | "colab_type": "code", 294 | "id": "xeS4h0_5k54q", 295 | "slideshow": { 296 | "slide_type": "skip" 297 | } 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n", 302 | "wavelets = hdf5_file['inputs/wavelets']\n", 303 | "frequencies = np.round(hdf5_file['inputs/fourier_frequencies'], 3)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": { 310 | "colab": { 311 | "base_uri": "https://localhost:8080/", 312 | "height": 542 313 | }, 314 | "colab_type": "code", 315 | "id": "3Isfo5s5CIhO", 316 | "outputId": "122a960f-05fa-4831-f442-c65daf4a92f3", 317 | "slideshow": { 318 | "slide_type": "subslide" 319 | } 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "num_cells, gap = 20, 30\n", 324 | "fig = go.Figure()\n", 325 | "for i in range(0, num_cells):\n", 326 | " this_z = wavelets[0:wavelets.shape[0]//2:gap,:,i].transpose()\n", 327 | " fig.add_heatmap(x=np.arange(0, this_z.shape[0]) / (sampling_rate / gap), z=this_z,colorscale='Viridis',visible=False,showscale=False)\n", 328 | "fig.data[0].visible = True\n", 329 | "# aesthetics\n", 330 | "steps = []\n", 331 | "for i in range(len(fig.data)):\n", 332 | " step = dict(method=\"update\",label=\"Cell {}\".format(i+1),args=[{\"visible\": [False] * len(fig.data)}])\n", 333 | " step[\"args\"][0][\"visible\"][i] = True # Toggle i'th trace to \"visible\"\n", 334 | " steps.append(step)\n", 335 | "sliders = [dict(active=10,currentvalue={\"prefix\": \"Cell: \", \"visible\" : False},pad={\"t\": 70},steps=steps)]\n", 336 | "\n", 337 | "fig.update_layout(width=1800, height=650,sliders=sliders, yaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'), yaxis_title='Frequency (Hz)',\n", 338 | " showlegend=False, plot_bgcolor=\"white\",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', font=dict(family='Open Sans', size=16, color='black'))\n", 339 | "fig" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "colab": {}, 347 | "colab_type": "code", 348 | "id": "yIXVttGdEmzd", 349 | "slideshow": { 350 | "slide_type": "skip" 351 | } 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "hdf5_file.close()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": { 361 | "colab_type": "text", 362 | "id": "R2hXQTtuZTYX", 363 | "slideshow": { 364 | "slide_type": "slide" 365 | } 366 | }, 367 | "source": [ 368 | "---\n", 369 | "## **Train the model**\n", 370 | "---\n", 371 | "The following command uses 5 cross validations to train the models and stores weights in HDF5 files" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "colab": { 379 | "base_uri": "https://localhost:8080/", 380 | "height": 1000 381 | }, 382 | "colab_type": "code", 383 | "id": "z3wSllkHZTYY", 384 | "outputId": "c6715e26-35b4-40bf-9623-1628a02600ae", 385 | "slideshow": { 386 | "slide_type": "subslide" 387 | } 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "# Define loss functions and train model, if more then one behaviour/stimuli needs to be decoded, define loss functions and weights for each of them here\n", 392 | "loss_functions = {'sin_cos' : 'mse'}\n", 393 | "loss_weights = {'sin_cos' : 1} \n", 394 | "user_opts = {'epochs' : 10, 'sample_per_epoch' : 250} # Speed up for Colab, normally set to {'epochs' : 20, 'sample_per_epoch' : 250\n", 395 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights, user_opts=user_opts)" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": { 401 | "colab_type": "text", 402 | "id": "p6ybugPYcuRe", 403 | "slideshow": { 404 | "slide_type": "slide" 405 | } 406 | }, 407 | "source": [ 408 | "---\n", 409 | "## **Evaluate model performance**\n", 410 | "---\n", 411 | "Here we calculate the losses over the whole duration of the experiment. Step size indicates how many timesteps are skipped between samples. Note that each sample contains 64 timesteps, so setting step size to 64 will result in non-overlapping samples" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": { 418 | "colab": { 419 | "base_uri": "https://localhost:8080/", 420 | "height": 156 421 | }, 422 | "colab_type": "code", 423 | "id": "ZBrHYchnVVck", 424 | "outputId": "d310f7e3-e760-4382-87cc-4965aea6dbcc", 425 | "slideshow": { 426 | "slide_type": "subslide" 427 | } 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "step_size = 100\n", 432 | "\n", 433 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n", 434 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=step_size)\n", 435 | "\n", 436 | "# Get real output from HDF5 file\n", 437 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n", 438 | "output_real = hdf5_file['outputs/sin_cos'][indices,:]" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": { 444 | "colab_type": "text", 445 | "id": "2WUiAPxPdw6c", 446 | "slideshow": { 447 | "slide_type": "slide" 448 | } 449 | }, 450 | "source": [ 451 | "---\n", 452 | "### Visualize model performance\n", 453 | "---\n", 454 | "We plot the real output vs. the predicted output for the above trained models. The real output is linearized as in the virtual reality environment the start follows after the mouse reaches the end, therefore we can use a circular variable. Also note that the example plot below is only trained on a subset of channels (see channels variable, default=100) and a limited number of epochs (see epochs, default=5), to make training in the Colab notebook faster. The performance on the fully evaluated dataset is higher. " 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": { 461 | "colab": { 462 | "base_uri": "https://localhost:8080/", 463 | "height": 542 464 | }, 465 | "colab_type": "code", 466 | "id": "TrimxQpIX20O", 467 | "outputId": "c831ca7a-aa76-47b0-ed45-80fb788b8c40", 468 | "slideshow": { 469 | "slide_type": "subslide" 470 | } 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "fig = go.Figure()\n", 475 | "\n", 476 | "fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_real[:,0], line=dict(color='rgba(0, 0, 0, 0.85)', width=2), name='Real'))\n", 477 | "fig.add_trace(go.Scatter(x=np.arange(0, output_real.shape[0]) / (sampling_rate / step_size), y=output_predictions['sin_cos'][:,0], line=dict(color='rgb(67, 116, 144)', width=3), name='Predicted'))\n", 478 | "\n", 479 | "# aesthetics\n", 480 | "#fig.update_yaxes(visible=False)\n", 481 | "fig.update_layout(width=1800, height=650, plot_bgcolor=\"rgb(245, 245, 245)\",margin=dict(t=20,l=20,b=20,r=20),xaxis_title='Time (s)', yaxis_title='Decoding target (sin)', font=dict(family='Open Sans', size=16, color='black'))\n", 482 | "fig" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "colab": {}, 490 | "colab_type": "code", 491 | "id": "jPkRVpHPX25i", 492 | "slideshow": { 493 | "slide_type": "skip" 494 | } 495 | }, 496 | "outputs": [], 497 | "source": [ 498 | "hdf5_file.close()" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": { 504 | "colab_type": "text", 505 | "id": "inqoPEb4eCmu", 506 | "slideshow": { 507 | "slide_type": "skip" 508 | } 509 | }, 510 | "source": [ 511 | "---\n", 512 | "### Get shuffled model performance\n", 513 | "---\n", 514 | "We use the shuffled loss to evaluate feature importance" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": { 521 | "colab": { 522 | "base_uri": "https://localhost:8080/", 523 | "height": 122 524 | }, 525 | "colab_type": "code", 526 | "id": "LM9SKb4FZTYc", 527 | "outputId": "066e9263-c598-428f-f811-3e9169ce9c6e", 528 | "scrolled": true, 529 | "slideshow": { 530 | "slide_type": "skip" 531 | } 532 | }, 533 | "outputs": [], 534 | "source": [ 535 | "shuffled_losses_ax1 = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1, stepsize=step_size)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": { 542 | "colab": {}, 543 | "colab_type": "code", 544 | "id": "dCVfQGTxpm05", 545 | "slideshow": { 546 | "slide_type": "skip" 547 | } 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "# Calculate residuals, make sure there is no division by zero by adding small constant.\n", 552 | "residuals = (shuffled_losses_ax1 - losses) / (losses + 0.1)\n", 553 | "residuals_mean = np.mean(residuals, axis=1)[:,0]\n", 554 | "residuals_standarderror = np.std(residuals, axis=1)[:,0] / np.sqrt(residuals.shape[0])" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": { 560 | "colab_type": "text", 561 | "id": "1bSLxr8RJvnN", 562 | "slideshow": { 563 | "slide_type": "slide" 564 | } 565 | }, 566 | "source": [ 567 | "---\n", 568 | "### Show feature importance for frequency axis\n", 569 | "---\n", 570 | "This plot shows the relative influence of each frequency band on the decoding of the position in the virtual environment. We plot the mean across samples + the standard error for each frequency band. " 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": { 577 | "colab": { 578 | "base_uri": "https://localhost:8080/", 579 | "height": 542 580 | }, 581 | "colab_type": "code", 582 | "id": "BCeTJQykprE7", 583 | "outputId": "cd340d3f-82d8-49b1-d77d-ec2bc2b93683", 584 | "slideshow": { 585 | "slide_type": "subslide" 586 | } 587 | }, 588 | "outputs": [], 589 | "source": [ 590 | "end_point, y_offset, num_cells = 1000, 400, 6\n", 591 | "fig = go.Figure()\n", 592 | "\n", 593 | "fig.add_trace(go.Scatter(x=np.arange(0, residuals_mean.shape[0]), y=residuals_mean, line=dict(color='rgba(0, 0, 0, 0.85)', width=3), name='Real',\n", 594 | " error_y=dict(type='data', array=residuals_standarderror, visible=True, color='rgb(67, 116, 144)', thickness=3)))\n", 595 | "\n", 596 | "# aesthetics\n", 597 | "#fig.update_yaxes(visible=False)\n", 598 | "fig.update_layout(width=1800, height=650, plot_bgcolor=\"rgb(245, 245, 245)\",margin=dict(t=20,l=20,b=20,r=20), xaxis = dict(tickvals=np.arange(0, len(frequencies)), ticktext = ['{:.3f}'.format(i) for i in frequencies], autorange='reversed'),\n", 599 | " xaxis_title='Frequency (Hz)', yaxis_title='Relative influence', font=dict(family='Open Sans', size=16, color='black',\n", 600 | "))\n", 601 | "fig" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": { 607 | "colab_type": "text", 608 | "id": "xW1fimfsJ2Hl", 609 | "slideshow": { 610 | "slide_type": "slide" 611 | } 612 | }, 613 | "source": [ 614 | "---\n", 615 | "### Show feature importance for cell axis\n", 616 | "---\n", 617 | "For this we shuffle across the cell dimension to see the influence each cell has on the decoding of position and then plot it back to the calcium ROIs. In the plot below the size of the dot is indicating the relative influence of this ROI (cell) on the decoding performance. Red dots indicate a high influence of this cell on the decoding of position and blue dots indicate a negative influence of this cell.\n" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "colab": { 625 | "base_uri": "https://localhost:8080/", 626 | "height": 122 627 | }, 628 | "colab_type": "code", 629 | "id": "34Cnoa_fpq9z", 630 | "outputId": "3989f4dc-20d7-4261-fcc1-fee2084058f2", 631 | "slideshow": { 632 | "slide_type": "skip" 633 | } 634 | }, 635 | "outputs": [], 636 | "source": [ 637 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=2, stepsize=step_size)" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "metadata": { 644 | "colab": {}, 645 | "colab_type": "code", 646 | "id": "bh0ZvGpEKE4r", 647 | "slideshow": { 648 | "slide_type": "skip" 649 | } 650 | }, 651 | "outputs": [], 652 | "source": [ 653 | "# Calculate residuals, make sure there is no division by zero by adding small constant.\n", 654 | "residuals = (shuffled_losses - losses) / (losses + 0.1)\n", 655 | "residuals_mean = np.mean(residuals, axis=1)[:,0]" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": { 662 | "colab": { 663 | "base_uri": "https://localhost:8080/", 664 | "height": 785 665 | }, 666 | "colab_type": "code", 667 | "id": "AYtjoLGaF9Ml", 668 | "outputId": "1dab7de7-43dd-46f3-f6ae-b06600045e42", 669 | "slideshow": { 670 | "slide_type": "skip" 671 | } 672 | }, 673 | "outputs": [], 674 | "source": [ 675 | "# Get some files for plotting the importance of each cell back to brain anatomy\n", 676 | "if not os.path.exists('./example_data/calcium/centroid_YX.mat'):\n", 677 | " !wget https://www.dropbox.com/s/z8ynet2nkt9pe1u/centroid_YX.mat -O ./example_data/calcium/centroid_YX.mat\n", 678 | "if not os.path.exists('./example_data/calcium/calcium_rois.jpg'): \n", 679 | " !wget https://www.dropbox.com/s/czak7rphajslcr0/test_rois_F5.jpg -O ./example_data/calcium/calcium_rois.jpg\n", 680 | "roi_data = loadmat('./example_data/calcium/centroid_YX.mat')['xy_coords']" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": null, 686 | "metadata": { 687 | "colab": { 688 | "base_uri": "https://localhost:8080/", 689 | "height": 542 690 | }, 691 | "colab_type": "code", 692 | "id": "nMYWGuPb6AWK", 693 | "outputId": "a1e353e7-ad0b-4a9e-ffda-23bb74a952fb", 694 | "slideshow": { 695 | "slide_type": "slide" 696 | } 697 | }, 698 | "outputs": [], 699 | "source": [ 700 | "fig = go.Figure()\n", 701 | "point_size_adjustment = 1250\n", 702 | "all_pos = residuals_mean > 0\n", 703 | "all_pos_channels = channels[all_pos]\n", 704 | "all_neg_channels = channels[~all_pos]\n", 705 | "fig.add_trace(go.Image(z=io.imread('./example_data/calcium/calcium_rois.jpg')))\n", 706 | "fig.add_trace(go.Scatter(x=roi_data[:,1], y=roi_data[:,0], marker_symbol='circle', mode='markers', marker=dict(color='white', opacity=0.5, line=dict(color='white',width=0)), name='Cell centers'))\n", 707 | "fig.add_trace(go.Scatter(x=roi_data[all_pos_channels,1], y=roi_data[all_pos_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='red', size=residuals_mean[all_pos]*point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Pos. influence'))\n", 708 | "fig.add_trace(go.Scatter(x=roi_data[all_neg_channels,1], y=roi_data[all_neg_channels,0], marker_symbol='circle', mode='markers', marker=dict(color='blue', size=residuals_mean[~all_pos]*-point_size_adjustment, opacity=0.5, line=dict(color='black',width=3)), name='Neg. influence'))\n", 709 | "\n", 710 | "fig.update_layout(width=1800, height=650, showlegend=False, plot_bgcolor=\"white\",margin=dict(t=10,l=0,b=10,r=0), xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False))\n", 711 | "fig" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": null, 717 | "metadata": { 718 | "colab": {}, 719 | "colab_type": "code", 720 | "id": "xR73Csfr-U83", 721 | "slideshow": { 722 | "slide_type": "skip" 723 | } 724 | }, 725 | "outputs": [], 726 | "source": [] 727 | } 728 | ], 729 | "metadata": { 730 | "accelerator": "GPU", 731 | "celltoolbar": "Slideshow", 732 | "colab": { 733 | "collapsed_sections": [], 734 | "name": "deepinsight_calcium_example.ipynb", 735 | "provenance": [], 736 | "toc_visible": true 737 | }, 738 | "kernelspec": { 739 | "display_name": "Python 3.7.6 64-bit", 740 | "language": "python", 741 | "name": "python37664bit5fa017aec819437bacf63081b14c694c" 742 | }, 743 | "language_info": { 744 | "codemirror_mode": { 745 | "name": "ipython", 746 | "version": 3 747 | }, 748 | "file_extension": ".py", 749 | "mimetype": "text/x-python", 750 | "name": "python", 751 | "nbconvert_exporter": "python", 752 | "pygments_lexer": "ipython3", 753 | "version": "3.7.10" 754 | } 755 | }, 756 | "nbformat": 4, 757 | "nbformat_minor": 1 758 | } 759 | -------------------------------------------------------------------------------- /notebooks/example_data/calcium/calcium_rois.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/notebooks/example_data/calcium/calcium_rois.jpg -------------------------------------------------------------------------------- /notebooks/static/ephys_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "---\n", 8 | "---\n", 9 | "\n", 10 | "# Introduction to DeepInsight - Decoding position, speed and head direction from tetrode CA1 recordings\n", 11 | "\n", 12 | "This notebook stands as an example of how to use DeepInsight v0.5 on tetrode data and can be used as a guide on how to adapt it to your own datasets. All methods are stored in the deepinsight library and can be called directly or in their respective submodules. A typical workflow might look like the following: \n", 13 | "- Load your dataset into a format which can be directly indexed (numpy array or pointer to a file on disk)\n", 14 | "- Preprocess the raw data (wavelet transformation)\n", 15 | "- Preprocess your outputs (the variable you want to decode)\n", 16 | "- Define appropriate loss functions for your output and train the model \n", 17 | "- Predict performance across all cross validated models\n", 18 | "- Visualize influence of different input frequencies on model output\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Import DeepInsight\n", 28 | "import sys\n", 29 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n", 30 | "import deepinsight\n", 31 | "# Choose GPU\n", 32 | "import os\n", 33 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "---\n", 41 | "---\n", 42 | "Here you can define the paths to your raw data files, and create file names for the preprocessed HDF5 datasets.\n", 43 | "\n", 44 | "The data we use here is usually relatively large in its raw format. Running it through the next lines takes roughly 24 hours for a 40 minute recording.\n", 45 | "\n", 46 | "We provide a preprocess file to play with the code. See next cell" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# Define base paths\n", 56 | "base_path = './example_data/'\n", 57 | "fp_raw_file = base_path + 'experiment_1.nwb' # This is your raw file\n", 58 | "fp_deepinsight = base_path + 'processed_R2478.h5' # This will be the processed HDF5 file\n", 59 | "\n", 60 | "if os.path.exists(fp_raw_file):\n", 61 | " # Load data \n", 62 | " (raw_data,\n", 63 | " raw_timestamps,\n", 64 | " output,\n", 65 | " output_timestamps,\n", 66 | " info) = deepinsight.util.tetrode.read_tetrode_data(fp_raw_file)\n", 67 | " # Transform raw data to frequency domain\n", 68 | " deepinsight.preprocess.preprocess_input(fp_deepinsight, raw_data, sampling_rate=info['sampling_rate'],\n", 69 | " channels=info['channels'])\n", 70 | " # Prepare outputs\n", 71 | " deepinsight.util.tetrode.preprocess_output(fp_deepinsight, raw_timestamps, output,\n", 72 | " output_timestamps, sampling_rate=info['sampling_rate'])" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "---\n", 80 | "---\n", 81 | "The above steps create a HDF5 file with all important data for training the model.\n", 82 | "\n", 83 | "You can download the preprocessed dataset by running the following command" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "!wget https://ndownloader.figshare.com/files/20150468 -O ./example_data/processed_R2478.h5" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "---\n", 100 | "---\n", 101 | "Now we can train the model. \n", 102 | "\n", 103 | "The following command uses 5 cross validations to train the models and stores weights in HDF5 files" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# Define loss functions and train model\n", 113 | "loss_functions = {'position' : 'euclidean_loss', \n", 114 | " 'head_direction' : 'cyclical_mae_rad', \n", 115 | " 'speed' : 'mae'}\n", 116 | "loss_weights = {'position' : 1, \n", 117 | " 'head_direction' : 25, \n", 118 | " 'speed' : 2}\n", 119 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "scrolled": true 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n", 131 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight,\n", 132 | " stepsize=10)\n", 133 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1,\n", 134 | " stepsize=10)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "---\n", 142 | "---\n", 143 | "Above line calculates the loss and shuffled loss across the full experiment and writes it back to the HDF5 file.\n", 144 | "\n", 145 | "Below command visualizes the influence across different frequency bands for all samples\n", 146 | "\n", 147 | "Note that Figure 3 in the manuscript shows influence across animals, while this plot shows the influence for one animal across the experiment" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# Plot influence across behaviours\n", 157 | "deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2,\n", 158 | " output_names=['Position', 'Head Direction', 'Speed'])" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "---\n", 166 | "---" 167 | ] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "Python 3", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.7.10" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 2 191 | } 192 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu 2 | numpy 3 | pandas 4 | joblib 5 | seaborn 6 | matplotlib 7 | h5py 8 | scipy 9 | ipython -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepInsight Toolbox 3 | © Markus Frey 4 | https://github.com/CYHSM/DeepInsight 5 | Licensed under MIT License 6 | """ 7 | from setuptools import setup, find_packages 8 | 9 | long_description = open('README.md').read() 10 | with open('requirements.txt') as f: 11 | requirements = f.read().splitlines() 12 | 13 | setup( 14 | name='deepinsight', 15 | version='0.5', 16 | install_requires=requirements, 17 | author='Markus Frey', 18 | author_email='markus.frey1@gmail.com', 19 | description="A general framework for interpreting wide-band neural activity", 20 | long_description=long_description, 21 | url='https://github.com/CYHSM/DeepInsight/', 22 | license='MIT', 23 | classifiers=[ 24 | 'Development Status :: 3 - Alpha', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Intended Audience :: Developers', 27 | 'Natural Language :: English', 28 | 'Operating System :: OS Independent', 29 | 'Programming Language :: Python', 30 | 'Programming Language :: Python :: 3', 31 | ], 32 | packages=find_packages(), 33 | package_data={ 34 | "": ["*.p", "*.h5", "*.csv", "*.gif", "*.png", "*.txt"], 35 | }, 36 | ) 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CYHSM/DeepInsight/e5a66be5dc3c671c37bd30ddf8f1f8ebae78ed2c/tests/__init__.py -------------------------------------------------------------------------------- /tests/run_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import h5py 4 | import deepinsight 5 | 6 | import numpy as np 7 | import unittest 8 | unittest.TestLoader.sortTestMethodsUsing = None 9 | 10 | 11 | class TestDeepInsight(unittest.TestCase): 12 | """Simple Testing Class""" 13 | 14 | def tearDown(self): 15 | time.sleep(0.1) 16 | 17 | def setUp(self): 18 | unittest.TestCase.setUp(self) 19 | np.random.seed(0) 20 | self.fp_deepinsight_folder = os.getcwd() + '/tests/test_files/' 21 | self.fp_deepinsight = self.fp_deepinsight_folder + 'test.h5' 22 | if os.path.exists(self.fp_deepinsight): 23 | os.remove(self.fp_deepinsight) 24 | else: 25 | os.makedirs(self.fp_deepinsight_folder) 26 | self.input_length = int(3e5) 27 | self.input_channels = 5 28 | self.sampling_rate = 30000 29 | self.input_output_ratio = 100 30 | self.average_window = 10 31 | 32 | self.rand_input = np.sin(np.random.rand( 33 | int(self.input_length), self.input_channels)) 34 | self.rand_input_timesteps = np.arange(0, self.input_length) 35 | self.rand_output = np.random.rand( 36 | self.input_length // self.input_output_ratio) 37 | self.rand_timesteps = np.arange( 38 | 0, self.input_length, self.input_output_ratio) 39 | 40 | def test_fullrun(self): 41 | """ 42 | Tests wavelet transformation of random signal 43 | """ 44 | # Transform raw data to frequency domain 45 | deepinsight.preprocess.preprocess_input( 46 | self.fp_deepinsight, self.rand_input, sampling_rate=self.sampling_rate, average_window=self.average_window) 47 | hdf5_file = h5py.File(self.fp_deepinsight, mode='r') 48 | # Get wavelets from hdf5 file 49 | input_wavelets = hdf5_file['inputs/wavelets'] 50 | # Check statistics of wavelets 51 | np.testing.assert_almost_equal(np.mean(input_wavelets), 0.048329710) 52 | np.testing.assert_almost_equal(np.std(input_wavelets), 0.04667989) 53 | np.testing.assert_almost_equal(np.median(input_wavelets), 0.03440293) 54 | np.testing.assert_almost_equal(np.max(input_wavelets), 0.60365933) 55 | np.testing.assert_almost_equal(np.min(input_wavelets), 3.78198024e-08) 56 | hdf5_file.close() 57 | 58 | # Prepare outputs 59 | deepinsight.preprocess.preprocess_output( 60 | self.fp_deepinsight, self.rand_input_timesteps, self.rand_output, self.rand_timesteps, average_window=self.average_window) 61 | 62 | # Define loss functions and train model 63 | loss_functions = {'aligned': 'mse'} 64 | loss_weights = {'aligned': 1} 65 | user_opts = {'epochs': 2, 'steps_per_epoch': 10, 66 | 'validation_steps': 10, 'log_output': False, 'save_model': True} 67 | 68 | deepinsight.train.run_from_path( 69 | self.fp_deepinsight, loss_functions, loss_weights, user_opts) 70 | 71 | # Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file 72 | losses, output_predictions, indices, output_real = deepinsight.analyse.get_model_loss( 73 | self.fp_deepinsight, stepsize=10) 74 | 75 | shuffled_losses = deepinsight.analyse.get_shuffled_model_loss( 76 | self.fp_deepinsight, axis=1, stepsize=10) 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main(warnings='ignore') 81 | -------------------------------------------------------------------------------- /tests/tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "seed_value = 0\n", 10 | "\n", 11 | "# Import DeepInsight\n", 12 | "import sys\n", 13 | "sys.path.insert(0, \"/home/marx/Documents/Github/DeepInsight\")\n", 14 | "import deepinsight\n", 15 | "# Choose GPU\n", 16 | "import os\n", 17 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 18 | "os.environ['PYTHONHASHSEED']=str(0)\n", 19 | "import tensorflow as tf\n", 20 | "tf.random.set_seed(seed_value)\n", 21 | "# Also numpy random generator\n", 22 | "import numpy as np\n", 23 | "np.random.seed(seed_value)\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import h5py\n", 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "#%run run_test.py" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "fp_deepinsight = './test_files/test.h5'\n", 54 | "if os.path.exists(fp_deepinsight):\n", 55 | " os.remove(fp_deepinsight)\n", 56 | "input_length = int(3e5)\n", 57 | "input_channels = 5\n", 58 | "sampling_rate = 30000\n", 59 | "input_output_ratio = 100\n", 60 | "\n", 61 | "np.random.seed(0)\n", 62 | "rand_input = np.sin(np.random.rand(int(input_length), input_channels))\n", 63 | "rand_input_timesteps = np.arange(0, input_length)\n", 64 | "rand_output = np.random.rand(input_length // input_output_ratio)\n", 65 | "rand_timesteps = np.arange(0, input_length, input_output_ratio)\n", 66 | "\n", 67 | "print(rand_input[0,0])\n", 68 | "print(rand_input_timesteps[0:10])\n", 69 | "print(rand_output[0])\n", 70 | "print(rand_timesteps[0:10])" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Transform raw data to frequency domain\n", 80 | "deepinsight.preprocess.preprocess_input(fp_deepinsight, rand_input, sampling_rate=sampling_rate, average_window=10)\n", 81 | "\n", 82 | "# Test cases\n", 83 | "hdf5_file = h5py.File(fp_deepinsight, mode='r')\n", 84 | "# Get size of wavelets\n", 85 | "input_wavelets = hdf5_file['inputs/wavelets']\n", 86 | "# Check statistics of wavelets\n", 87 | "np.testing.assert_almost_equal(np.mean(input_wavelets), 0.048329726)\n", 88 | "np.testing.assert_almost_equal(np.std(input_wavelets), 0.032383125)\n", 89 | "np.testing.assert_almost_equal(np.median(input_wavelets), 0.04608967)\n", 90 | "np.testing.assert_almost_equal(np.max(input_wavelets), 0.40853173)\n", 91 | "np.testing.assert_almost_equal(np.min(input_wavelets), 1.6544704e-05)\n", 92 | "hdf5_file.close()\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "print('Mean {:.10}, Std {:.10}, Median {:.10}, Max {:.10}, Min {:.10}'.format(np.mean(input_wavelets), np.std(input_wavelets), np.median(input_wavelets), np.max(input_wavelets), np.min(input_wavelets)))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# Prepare outputs\n", 118 | "deepinsight.preprocess.preprocess_output(fp_deepinsight, rand_input_timesteps, rand_output,\n", 119 | " rand_timesteps)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Define loss functions and train model\n", 129 | "loss_functions = {'output_aligned' : 'mse'}\n", 130 | "loss_weights = {'output_aligned' : 1}\n", 131 | "user_opts = {'epochs' : 2, 'steps_per_epoch' : 10, 'validation_steps' : 10, 'log_output' : False, 'save_model' : False}\n", 132 | "\n", 133 | "deepinsight.train.run_from_path(fp_deepinsight, loss_functions, loss_weights, user_opts)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "# Get loss and shuffled loss for influence plot, both is also stored back to HDF5 file\n", 143 | "losses, output_predictions, indices = deepinsight.analyse.get_model_loss(fp_deepinsight, stepsize=10)\n", 144 | "\n", 145 | "# Test cases\n", 146 | "np.testing.assert_almost_equal(losses[-1], 1.0168755e-05)\n", 147 | "np.testing.assert_almost_equal(losses[0], 0.53577816)\n", 148 | "np.testing.assert_almost_equal(np.mean(losses), 0.09069238)\n", 149 | "np.testing.assert_almost_equal(np.std(losses), 0.13594063)\n", 150 | "np.testing.assert_almost_equal(np.median(losses), 0.045781307)\n", 151 | "np.testing.assert_almost_equal(np.max(losses), 0.53577816)\n", 152 | "np.testing.assert_almost_equal(np.min(losses), 1.0168755e-05)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "shuffled_losses = deepinsight.analyse.get_shuffled_model_loss(fp_deepinsight, axis=1,stepsize=10)\n", 169 | "\n", 170 | "# Test cases\n", 171 | "np.testing.assert_almost_equal(np.mean(shuffled_losses), 0.09304095)\n", 172 | "np.testing.assert_almost_equal(np.std(shuffled_losses), 0.13982493)\n", 173 | "np.testing.assert_almost_equal(np.median(shuffled_losses), 0.04165206)\n", 174 | "np.testing.assert_almost_equal(np.max(shuffled_losses), 0.7405345)\n", 175 | "np.testing.assert_almost_equal(np.min(shuffled_losses), 2.0834877e-07)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "deepinsight.visualize.plot_residuals(fp_deepinsight, frequency_spacing=2,\n", 185 | " output_names=['output_aligned'])" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [] 194 | } 195 | ], 196 | "metadata": { 197 | "kernelspec": { 198 | "display_name": "Python 3.7.6 64-bit", 199 | "language": "python", 200 | "name": "python37664bit5fa017aec819437bacf63081b14c694c" 201 | }, 202 | "language_info": { 203 | "codemirror_mode": { 204 | "name": "ipython", 205 | "version": 3 206 | }, 207 | "file_extension": ".py", 208 | "mimetype": "text/x-python", 209 | "name": "python", 210 | "nbconvert_exporter": "python", 211 | "pygments_lexer": "ipython3", 212 | "version": "3.7.10" 213 | } 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 2 217 | } 218 | --------------------------------------------------------------------------------